In [37]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pandas as pd
from torch.utils.data import random_split

# Creating a Custom Dataset 

In [38]:
## Define a class for the dataset

class MoleculeDataset(Dataset):
    def __init__(self, dataframe, features, target):
        self.dataframe = dataframe
        self.features = features
        self.target = target

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        row = self.dataframe.iloc[index]
        features = torch.tensor(row[self.features].values, dtype=torch.float32)
        target = torch.tensor(row[self.target], dtype=torch.float32)
        return features, target


In [39]:
## Create a dataset
df = pd.read_csv('./set1.csv')
features = ['HDonors_norm', 'HAcceptors_norm', 'MolWt_norm', 'LogP_norm']
target = 'Active'
dataset = MoleculeDataset(df, features, target)

In [40]:
# Split into testing and training sets
generator1 = torch.Generator().manual_seed(42)
test_data, training_data = random_split(dataset, [.3, .7], generator=generator1)

In [46]:
batch_size=64
# Create data loaders
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in train_dataloader:
    print("X, Features:", X)
    print("y, Target:", y)
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

X, Features: tensor([[ 1.2623e+00,  1.3939e+00, -4.4449e-01, -2.0503e+00],
        [ 5.6572e-01,  1.6393e-01, -4.1462e-01, -1.9575e-01],
        [-8.2753e-01, -2.4607e-01,  2.2288e-01, -4.9018e-01],
        [ 5.6572e-01,  1.6393e-01, -3.9164e-01, -7.0297e-01],
        [-8.2753e-01, -2.4607e-01,  2.8261e-01,  1.0865e+00],
        [ 5.6572e-01,  5.7392e-01, -3.0709e-03, -5.5378e-01],
        [ 5.6572e-01,  5.7392e-01,  2.2218e-01, -1.6329e-01],
        [-1.3091e-01, -6.5606e-01, -6.6163e-01, -1.9785e-01],
        [ 5.6572e-01,  9.8392e-01,  1.2124e+00,  5.0344e-01],
        [-8.2753e-01,  1.6393e-01, -8.4161e-01, -9.9788e-01],
        [-1.3091e-01,  9.8392e-01,  1.7290e+00,  1.2450e+00],
        [-1.3091e-01,  5.7392e-01,  1.7696e-01,  1.0148e-01],
        [-1.3091e-01, -6.5606e-01, -5.3441e-01, -4.2121e-01],
        [-8.2753e-01, -6.5606e-01, -7.4415e-01, -2.2083e-01],
        [ 1.2623e+00,  5.7392e-01,  3.4147e-01, -3.5509e-01],
        [-1.3091e-01,  1.6393e-01,  2.8247e-01,  1.1552e+