In [1]:
from rdkit import Chem
from dgllife.utils import mol_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer

def MolToGraph(smi):
    mol = Chem.MolFromSmiles(smi)
    node_featurizer = CanonicalAtomFeaturizer() 
    edge_featurizer = CanonicalBondFeaturizer()
    graph = mol_to_bigraph(mol, node_featurizer = node_featurizer, edge_featurizer = edge_featurizer)
    return graph

In [2]:
import pandas as pd
df = pd.read_csv('odor_dataset.csv')
df.head()

Unnamed: 0,SMILES,Odor,Class,Split
0,CC(O)CN,fishy,79,train
1,CCC(=O)C(=O)O,fatty,6,train
2,O=C(O)CCc1ccccc1,rose,15,train
3,OCc1ccc(O)cc1,medicinal,88,train
4,O=Cc1ccc(O)cc1,phenolic,33,train


In [4]:
label_dict = {}
for odor, label in zip(df['Odor'], df['Class']):
    label_dict[label] = odor
    
label_dict

{79: 'fishy',
 6: 'fatty',
 15: 'rose',
 88: 'medicinal',
 33: 'phenolic',
 14: 'nutty',
 29: 'pungent',
 7: 'fresh',
 46: 'pear',
 86: 'sour',
 76: 'cherry',
 32: 'burnt',
 0: 'fruity',
 2: 'sweet',
 31: 'cheesy',
 65: 'clean',
 9: 'spicy',
 1: 'green',
 35: 'powdery',
 58: 'sharp',
 16: 'earthy',
 18: 'roasted',
 41: 'buttery',
 4: 'herbal',
 24: 'mint',
 28: 'odorless',
 94: 'bitter',
 23: 'caramellic',
 10: 'sulfurous',
 47: 'savory',
 70: 'rummy',
 8: 'waxy',
 90: 'chocolate',
 73: 'cooked',
 77: 'cooling',
 55: 'vanilla',
 27: 'musty',
 74: 'anisic',
 12: 'tropical',
 5: 'woody',
 11: 'oily',
 75: 'ripe',
 52: 'garlic',
 57: 'alcoholic',
 42: 'leafy',
 3: 'floral',
 25: 'winey',
 97: 'plum',
 38: 'berry',
 81: 'apricot',
 44: 'camphoreous',
 50: 'animal',
 63: 'musk',
 60: 'tobacco',
 36: 'dry',
 87: 'smoky',
 66: 'warm',
 68: 'coconut',
 43: 'metallic',
 48: 'banana',
 99: 'solvent',
 39: 'fermented',
 61: 'amber',
 45: 'melon',
 62: 'mushroom',
 22: 'vegetable',
 91: 'lactonic'

In [29]:

class DGLDataset(object):
    def __init__(self, file_path='odor_dataset.csv', split='train'):
        df = pd.read_csv(file_path)
        self.X = [smi for smi, spl in zip(df['SMILES'], df['Split']) if spl == split]
        self.Y = [smi for smi, spl in zip(df['Class'], df['Split']) if spl == split]

    def __getitem__(self, item):
        smi, label = self.X[item], self.Y[item]
        graph = MolToGraph(smi)
        return graph, label

    def __len__(self):
        return len(self.X)

In [37]:
import torch.nn as nn
from torch.utils.data import DataLoader

def collate_molgraphs(data):
    graphs, labels = map(list, zip(*data))
    batch_graph = dgl.batch(graphs)
    return batch_graph, torch.LongTensor(labels)

batch_size = 10
train_set = DGLDataset(split='train')
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, collate_fn=collate_molgraphs, shuffle=True)

In [38]:
for data in train_loader:
    print (data)
    break

(Graph(num_nodes=109, num_edges=208,
      ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
      edata_schemes={'e': Scheme(shape=(12,), dtype=torch.float32)}), tensor([50,  3,  3,  6,  1, 12, 31, 59, 47,  6]))


# Tasks
### 1. Make a function called get_dataloaders() that you can get test dataloader and train dataloader in one line
### 2. Instead of (graph, label), make the dataloader to give (smiles, graph, label)
### 3. Make a dataset.py file that you can import get_dataloaders() function