In [None]:
# augmenting with MoleculeNet

# https://proceedings.neurips.cc//paper/2020/file/99cad265a1768cc2dd013f0e740300ae-Paper.pdf

In [1]:
import pandas as pd
train_df =  pd.read_csv('data/train.csv')
test_df = pd.read_csv('data/test.csv')

train_df.head()

Unnamed: 0,Smiles,label
0,CC(C)Oc1cc(Oc2ccc(S(C)(=O)=O)cc2)cc(-c2ncc(Cl)...,2.0
1,CNS(=O)(=O)c1ccc(Oc2cc(OC(C)C)cc(-c3nccc(=O)[n...,2.6
2,CC(C)Oc1cc(Oc2cnc(C(=O)N(C)C)c(F)c2)cc(-c2nccc...,1.5
3,CC(C)Oc1cc(Oc2cnc(C(=O)N(C)C)cn2)cc(-c2nccc(=O...,1.3
4,CCC(CC)Oc1cc(Oc2cnc(C(=O)N(C)C)nc2)cc(-c2nccc(...,1.9


In [2]:
train_df

Unnamed: 0,Smiles,label
0,CC(C)Oc1cc(Oc2ccc(S(C)(=O)=O)cc2)cc(-c2ncc(Cl)...,2.00
1,CNS(=O)(=O)c1ccc(Oc2cc(OC(C)C)cc(-c3nccc(=O)[n...,2.60
2,CC(C)Oc1cc(Oc2cnc(C(=O)N(C)C)c(F)c2)cc(-c2nccc...,1.50
3,CC(C)Oc1cc(Oc2cnc(C(=O)N(C)C)cn2)cc(-c2nccc(=O...,1.30
4,CCC(CC)Oc1cc(Oc2cnc(C(=O)N(C)C)nc2)cc(-c2nccc(...,1.90
...,...,...
3824,CCC(C)C(N)C(=O)NC(C(=O)NCC(=O)NC(Cc1ccccc1)C(=...,-0.99
3825,CCC(C)C(N)C(=O)NC(C)C(=O)NC(C)C(=O)NC(C(=O)O)C...,-2.82
3826,NC(Cc1ccccc1)C(=O)NC(Cc1ccccc1)C(=O)NCC(=O)NC(...,0.17
3827,CC(C)CC(NC(=O)C(NC(=O)C(CC(C)C)NC(=O)C(N)C(C)C...,-1.23


In [3]:
import rdkit
from tqdm.notebook import trange, tqdm
from ogb.utils import smiles2graph

tqdm.pandas('Converting SMILES to molecular graph...')
train_df['graph'] = train_df.progress_apply(lambda row: smiles2graph(row['Smiles']),axis=1)

HBox(children=(FloatProgress(value=0.0, max=3829.0), HTML(value='')))




In [4]:
test_df['graph']=test_df.progress_apply(lambda row: smiles2graph(row['Smiles']),axis=1)

HBox(children=(FloatProgress(value=0.0, max=141.0), HTML(value='')))




In [5]:
row = train_df.sample(1)
row

Unnamed: 0,Smiles,label,graph
471,CCCSc1nc(NC2CC2c2ccccc2)c2nnn(C3CC(C(N)=O)C(O)...,2.05,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."


In [6]:
extra_df = pd.read_csv('data/extra_Lipophilicity.csv', usecols=['exp','smiles']).rename(columns={'smiles':'Smiles','exp': 'label'})

In [7]:
extra_df['graph']=extra_df.progress_apply(lambda row: smiles2graph(row['Smiles']),axis=1)
extra_df

HBox(children=(FloatProgress(value=0.0, max=4200.0), HTML(value='')))




Unnamed: 0,label,Smiles,graph
0,3.54,Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
1,-1.18,COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
2,3.69,COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl,"{'edge_index': [[0, 1, 1, 2, 2, 3, 2, 4, 4, 5,..."
3,3.37,OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...,"{'edge_index': [[0, 1, 1, 2, 2, 3, 2, 4, 4, 5,..."
4,3.10,Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
...,...,...,...
4195,3.85,OCCc1ccc(NC(=O)c2cc3cc(Cl)ccc3[nH]2)cc1,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
4196,3.21,CCN(C1CCN(CCC(c2ccc(F)cc2)c3ccc(F)cc3)CC1)C(=O...,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
4197,2.10,COc1cccc2[nH]ncc12,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
4198,2.65,Clc1ccc2ncccc2c1C(=O)NCC3CCCCC3,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."


In [8]:
extra_df=extra_df[train_df.columns] #reordering to match.
extra_df

Unnamed: 0,Smiles,label,graph
0,Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14,3.54,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
1,COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...,-1.18,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
2,COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl,3.69,"{'edge_index': [[0, 1, 1, 2, 2, 3, 2, 4, 4, 5,..."
3,OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...,3.37,"{'edge_index': [[0, 1, 1, 2, 2, 3, 2, 4, 4, 5,..."
4,Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...,3.10,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
...,...,...,...
4195,OCCc1ccc(NC(=O)c2cc3cc(Cl)ccc3[nH]2)cc1,3.85,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
4196,CCN(C1CCN(CCC(c2ccc(F)cc2)c3ccc(F)cc3)CC1)C(=O...,3.21,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
4197,COc1cccc2[nH]ncc12,2.10,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
4198,Clc1ccc2ncccc2c1C(=O)NCC3CCCCC3,2.65,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."


In [9]:
sum(extra_df['Smiles'].isin(train_df['Smiles']).astype(int))
#number of duplicate SMILES.

285

In [10]:
augmented_train_df = pd.concat([train_df,extra_df])
augmented_train_df

Unnamed: 0,Smiles,label,graph
0,CC(C)Oc1cc(Oc2ccc(S(C)(=O)=O)cc2)cc(-c2ncc(Cl)...,2.00,"{'edge_index': [[0, 1, 1, 2, 1, 3, 3, 4, 4, 5,..."
1,CNS(=O)(=O)c1ccc(Oc2cc(OC(C)C)cc(-c3nccc(=O)[n...,2.60,"{'edge_index': [[0, 1, 1, 2, 2, 3, 2, 4, 2, 5,..."
2,CC(C)Oc1cc(Oc2cnc(C(=O)N(C)C)c(F)c2)cc(-c2nccc...,1.50,"{'edge_index': [[0, 1, 1, 2, 1, 3, 3, 4, 4, 5,..."
3,CC(C)Oc1cc(Oc2cnc(C(=O)N(C)C)cn2)cc(-c2nccc(=O...,1.30,"{'edge_index': [[0, 1, 1, 2, 1, 3, 3, 4, 4, 5,..."
4,CCC(CC)Oc1cc(Oc2cnc(C(=O)N(C)C)nc2)cc(-c2nccc(...,1.90,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 2, 5,..."
...,...,...,...
4195,OCCc1ccc(NC(=O)c2cc3cc(Cl)ccc3[nH]2)cc1,3.85,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
4196,CCN(C1CCN(CCC(c2ccc(F)cc2)c3ccc(F)cc3)CC1)C(=O...,3.21,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
4197,COc1cccc2[nH]ncc12,2.10,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
4198,Clc1ccc2ncccc2c1C(=O)NCC3CCCCC3,2.65,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."


In [11]:
#checking if logD values match, then drop those duplicates
augmented_train_df[augmented_train_df.duplicated(['Smiles','label'],keep=False)].sort_values('Smiles')

Unnamed: 0,Smiles,label,graph
2623,CC(=N)NCc1ccccc1,-0.28,"{'edge_index': [[0, 1, 1, 2, 1, 3, 3, 4, 4, 5,..."
3107,CC(=N)NCc1ccccc1,-0.28,"{'edge_index': [[0, 1, 1, 2, 1, 3, 3, 4, 4, 5,..."
2502,CC(=O)Nc1ccc(CNc2[nH]nc3cccc(Oc4ccc(F)cc4)c23)cc1,4.28,"{'edge_index': [[0, 1, 1, 2, 1, 3, 3, 4, 4, 5,..."
2525,CC(=O)Nc1ccc(CNc2[nH]nc3cccc(Oc4ccc(F)cc4)c23)cc1,4.28,"{'edge_index': [[0, 1, 1, 2, 1, 3, 3, 4, 4, 5,..."
1551,CC(=O)Nc1ccc(CNc2[nH]nc3ccnc(Oc4ccccc4)c23)cc1,3.03,"{'edge_index': [[0, 1, 1, 2, 1, 3, 3, 4, 4, 5,..."
...,...,...,...
497,c1ccc2ncccc2c1,2.09,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
721,c1ccc2ncncc2c1,1.00,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
3013,c1ccc2ncncc2c1,1.00,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
3753,c1cnc2cccnc2c1,0.89,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."


In [12]:
augmented_train_df = augmented_train_df.drop_duplicates(['Smiles','label'])
augmented_train_df

Unnamed: 0,Smiles,label,graph
0,CC(C)Oc1cc(Oc2ccc(S(C)(=O)=O)cc2)cc(-c2ncc(Cl)...,2.00,"{'edge_index': [[0, 1, 1, 2, 1, 3, 3, 4, 4, 5,..."
1,CNS(=O)(=O)c1ccc(Oc2cc(OC(C)C)cc(-c3nccc(=O)[n...,2.60,"{'edge_index': [[0, 1, 1, 2, 2, 3, 2, 4, 2, 5,..."
2,CC(C)Oc1cc(Oc2cnc(C(=O)N(C)C)c(F)c2)cc(-c2nccc...,1.50,"{'edge_index': [[0, 1, 1, 2, 1, 3, 3, 4, 4, 5,..."
3,CC(C)Oc1cc(Oc2cnc(C(=O)N(C)C)cn2)cc(-c2nccc(=O...,1.30,"{'edge_index': [[0, 1, 1, 2, 1, 3, 3, 4, 4, 5,..."
4,CCC(CC)Oc1cc(Oc2cnc(C(=O)N(C)C)nc2)cc(-c2nccc(...,1.90,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 2, 5,..."
...,...,...,...
4194,CC1CC(N(C(=O)C)c2ccccc2)c3ccccc3N1S(=O)(=O)c4c...,3.68,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
4195,OCCc1ccc(NC(=O)c2cc3cc(Cl)ccc3[nH]2)cc1,3.85,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
4196,CCN(C1CCN(CCC(c2ccc(F)cc2)c3ccc(F)cc3)CC1)C(=O...,3.21,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."
4198,Clc1ccc2ncccc2c1C(=O)NCC3CCCCC3,2.65,"{'edge_index': [[0, 1, 1, 2, 2, 3, 3, 4, 4, 5,..."


In [1]:
import os
import os.path as osp
import re

import torch
from torch_geometric.data import (InMemoryDataset, Data, download_url,
                                  extract_gz)

try:
    from rdkit import Chem
except ImportError:
    Chem = None

x_map = {
    'atomic_num':
    list(range(0, 119)),
    'chirality': [
        'CHI_UNSPECIFIED',
        'CHI_TETRAHEDRAL_CW',
        'CHI_TETRAHEDRAL_CCW',
        'CHI_OTHER',
    ],
    'degree':
    list(range(0, 11)),
    'formal_charge':
    list(range(-5, 7)),
    'num_hs':
    list(range(0, 9)),
    'num_radical_electrons':
    list(range(0, 5)),
    'hybridization': [
        'UNSPECIFIED',
        'S',
        'SP',
        'SP2',
        'SP3',
        'SP3D',
        'SP3D2',
        'OTHER',
    ],
    'is_aromatic': [False, True],
    'is_in_ring': [False, True],
}

e_map = {
    'bond_type': [
        'misc',
        'SINGLE',
        'DOUBLE',
        'TRIPLE',
        'AROMATIC',
    ],
    'stereo': [
        'STEREONONE',
        'STEREOZ',
        'STEREOE',
        'STEREOCIS',
        'STEREOTRANS',
        'STEREOANY',
    ],
    'is_conjugated': [False, True],
}

#http://moleculenet.ai/datasets-1
class MoleculeNet(InMemoryDataset):
    r"""The `MoleculeNet <http://moleculenet.ai/datasets-1>`_ benchmark
    collection  from the `"MoleculeNet: A Benchmark for Molecular Machine
    Learning" <https://arxiv.org/abs/1703.00564>`_ paper, containing datasets
    from physical chemistry, biophysics and physiology.
    All datasets come with the additional node and edge features introduced by
    the `Open Graph Benchmark <https://ogb.stanford.edu/docs/graphprop/>`_.

    Args:
        root (string): Root directory where the dataset should be saved.
        name (string): The name of the dataset (:obj:`"ESOL"`,
            :obj:`"FreeSolv"`, :obj:`"Lipo"`, :obj:`"PCBA"`, :obj:`"MUV"`,
            :obj:`"HIV"`, :obj:`"BACE"`, :obj:`"BBPB"`, :obj:`"Tox21"`,
            :obj:`"ToxCast"`, :obj:`"SIDER"`, :obj:`"ClinTox"`).
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj:`None`)
    """

    # Format: name: [display_name, url_name, csv_name, smiles_idx, y_idx]
    names = {
        'train': ['Train', 'train.csv', 'train', 0, 1],
        'alt': ['Lipophilicity', 'Lipophilicity.csv', 'Lipophilicity', 2, 1],
        'test':['Test','60170fffa2720fa4d0b9067a_holdout_set.csv','60170fffa2720fa4d0b9067a_holdout_set',0,1]
    }

    def __init__(self, root, name, transform=None, pre_transform=None,
                 pre_filter=None):

        if Chem is None:
            raise ImportError('`MoleculeNet` requires `rdkit`.')

        self.name = name.lower()
        assert self.name in self.names.keys()
        super(MoleculeNet, self).__init__(root, transform, pre_transform,
                                          pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_dir(self):
        return osp.join(self.root, self.name, 'raw')

    @property
    def processed_dir(self):
        return osp.join(self.root, self.name, 'processed')

    @property
    def raw_file_names(self):
        return f'{self.names[self.name][2]}.csv'

    @property
    def processed_file_names(self):
        return 'data.pt'

    def process(self):
        with open(self.raw_paths[0], 'r') as f:
            dataset = f.read().split('\n')[1:-1]
            dataset = [x for x in dataset if len(x) > 0]  # Filter empty lines.

        data_list = []
        for line in dataset:
            line = re.sub(r'\".*\"', '', line)  # Replace ".*" strings.
            line = line.split(',')

            smiles = line[self.names[self.name][3]]
            ys = line[self.names[self.name][4]]
            ys = ys if isinstance(ys, list) else [ys]

            ys = [float(y) if len(y) > 0 else float('NaN') for y in ys]
            y = torch.tensor(ys, dtype=torch.float).view(1, -1)

            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                continue

            xs = []
            for atom in mol.GetAtoms():
                x = []
                x.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
                x.append(x_map['chirality'].index(str(atom.GetChiralTag())))
                x.append(x_map['degree'].index(atom.GetTotalDegree()))
                x.append(x_map['formal_charge'].index(atom.GetFormalCharge()))
                x.append(x_map['num_hs'].index(atom.GetTotalNumHs()))
                x.append(x_map['num_radical_electrons'].index(
                    atom.GetNumRadicalElectrons()))
                x.append(x_map['hybridization'].index(
                    str(atom.GetHybridization())))
                x.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))
                x.append(x_map['is_in_ring'].index(atom.IsInRing()))
                xs.append(x)

            x = torch.tensor(xs, dtype=torch.long).view(-1, 9)

            edge_indices, edge_attrs = [], []
            for bond in mol.GetBonds():
                i = bond.GetBeginAtomIdx()
                j = bond.GetEndAtomIdx()

                e = []
                e.append(e_map['bond_type'].index(str(bond.GetBondType())))
                e.append(e_map['stereo'].index(str(bond.GetStereo())))
                e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))

                edge_indices += [[i, j], [j, i]]
                edge_attrs += [e, e]

            edge_index = torch.tensor(edge_indices)
            edge_index = edge_index.t().to(torch.long).view(2, -1)
            edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)

            # Sort indices.
            if edge_index.numel() > 0:
                perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
                edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]

            data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y,
                        smiles=smiles)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            data_list.append(data)

        torch.save(self.collate(data_list), self.processed_paths[0])

    def __repr__(self):
        return '{}({})'.format(self.names[self.name][0], len(self))


In [2]:
# # Install required packages.
# !pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
# !pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
# !pip install -q torch-geometric

In [3]:
dataset=MoleculeNet(root='./',name='lipo')
dataset.process()

In [4]:
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[14]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
print(f'Contains self-loops: {data.contains_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


Dataset: Train(3828):
Number of graphs: 3828
Number of features: 9
Number of classes: 1

Data(edge_attr=[70, 3], edge_index=[2, 70], smiles="Cc1n[nH]c2c(C)cc(C(=O)N3CCC4(CC3)Cc3cn(C(C)C)nc3C(=O)N4)cc12", x=[31, 9], y=[1, 1])
Number of nodes: 31
Number of edges: 70
Average node degree: 2.26
Contains isolated nodes: False
Contains self-loops: False
Is undirected: True


In [5]:
torch.manual_seed(12345)
dataset = dataset.shuffle()

train_dataset = dataset[:500]
test_dataset = dataset[500:550]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

Number of training graphs: 500
Number of test graphs: 50


In [6]:
from torch_geometric.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 64
Batch(batch=[1871], edge_attr=[4042, 3], edge_index=[2, 4042], smiles=[64], x=[1871, 9], y=[64, 1])

Step 2:
Number of graphs in the current batch: 64
Batch(batch=[1779], edge_attr=[3894, 3], edge_index=[2, 3894], smiles=[64], x=[1779, 9], y=[64, 1])

Step 3:
Number of graphs in the current batch: 64
Batch(batch=[1660], edge_attr=[3590, 3], edge_index=[2, 3590], smiles=[64], x=[1660, 9], y=[64, 1])

Step 4:
Number of graphs in the current batch: 64
Batch(batch=[1732], edge_attr=[3752, 3], edge_index=[2, 3752], smiles=[64], x=[1732, 9], y=[64, 1])

Step 5:
Number of graphs in the current batch: 64
Batch(batch=[1654], edge_attr=[3574, 3], edge_index=[2, 3574], smiles=[64], x=[1654, 9], y=[64, 1])

Step 6:
Number of graphs in the current batch: 64
Batch(batch=[1624], edge_attr=[3496, 3], edge_index=[2, 3496], smiles=[64], x=[1624, 9], y=[64, 1])

Step 7:
Number of graphs in the current batch: 64
Batch(batch=[1740], edge_attr=[3802, 3], edg

In [7]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

class AtomEncoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(AtomEncoder, self).__init__()

        self.embeddings = torch.nn.ModuleList()

        for i in range(9):
            self.embeddings.append(torch.nn.Embedding(100, hidden_channels))

    def reset_parameters(self):
        for embedding in self.embeddings:
            embedding.reset_parameters()

    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(1)

        out = 0
        for i in range(x.size(1)):
            out += self.embeddings[i](x[:, i])
        return out

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.emb = AtomEncoder(dataset.num_node_features)
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.emb(x)
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

model = GCN(hidden_channels=64)
print(model)

GCN(
  (emb): AtomEncoder(
    (embeddings): ModuleList(
      (0): Embedding(100, 9)
      (1): Embedding(100, 9)
      (2): Embedding(100, 9)
      (3): Embedding(100, 9)
      (4): Embedding(100, 9)
      (5): Embedding(100, 9)
      (6): Embedding(100, 9)
      (7): Embedding(100, 9)
      (8): Embedding(100, 9)
    )
  )
  (conv1): GCNConv(9, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (lin): Linear(in_features=64, out_features=1, bias=True)
)


In [8]:

model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.L1Loss()

def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
         out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

         print(loss)

def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)  
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.


for epoch in range(1, 201):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    # print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

.6302, grad_fn=<L1LossBackward>)
tensor(0.5455, grad_fn=<L1LossBackward>)
tensor(0.6531, grad_fn=<L1LossBackward>)
tensor(0.5422, grad_fn=<L1LossBackward>)
tensor(0.5945, grad_fn=<L1LossBackward>)
tensor(0.6934, grad_fn=<L1LossBackward>)
tensor(0.6040, grad_fn=<L1LossBackward>)
tensor(0.6509, grad_fn=<L1LossBackward>)
tensor(0.6151, grad_fn=<L1LossBackward>)
tensor(0.7476, grad_fn=<L1LossBackward>)
tensor(0.6899, grad_fn=<L1LossBackward>)
tensor(0.5574, grad_fn=<L1LossBackward>)
tensor(0.5906, grad_fn=<L1LossBackward>)
tensor(0.5593, grad_fn=<L1LossBackward>)
tensor(0.5387, grad_fn=<L1LossBackward>)
tensor(0.6250, grad_fn=<L1LossBackward>)
tensor(0.6511, grad_fn=<L1LossBackward>)
tensor(0.6491, grad_fn=<L1LossBackward>)
tensor(0.4846, grad_fn=<L1LossBackward>)
tensor(0.5077, grad_fn=<L1LossBackward>)
tensor(0.7213, grad_fn=<L1LossBackward>)
tensor(0.5448, grad_fn=<L1LossBackward>)
tensor(0.5753, grad_fn=<L1LossBackward>)
tensor(0.6201, grad_fn=<L1LossBackward>)
tensor(0.5545, grad_fn=<