In [1]:
import torch
from rdkit.Chem import AllChem, Draw
from rdkit import Chem
from rdkit.Chem.rdchem import BondType, HybridizationType
# from torch_scatter import scatter


# general tools
import numpy as np
# RDkit
from rdkit import Chem
from rdkit.Chem.rdmolops import GetAdjacencyMatrix

# Pytorch and Pytorch Geometric
import torch
from torch_geometric.data import Data
from torch.utils.data import DataLoader

from tqdm import tqdm

In [2]:
""" Load the Data """

SMILE_PATH = './sample_smiles.pt'
LATENT_PATH = './sample_z.pt'

# smile is a list of str (e.g. 'Cn1ncc2c3ncncc3n(CC3CC4OC3Cc3oncc34)c21')
# len of smile : 100000
smile = torch.load(SMILE_PATH)

# [:250] in latent vector is corresponding to 2D part
# shape = [100, 1000, 500]
latent = torch.load(LATENT_PATH)
emb2d = latent.reshape(-1,500)[:,:250]

## Feature Encoding
From this on, the credits of the code all goes to: https://www.blopig.com/blog/2022/02/how-to-turn-a-smiles-string-into-a-molecular-graph-for-pytorch-geometric/
This is NOT my work.

In [4]:
def one_hot_encoding(x, permitted_list):
    """
    Maps input elements x which are not in the permitted list to the last element
    of the permitted list.
    """

    if x not in permitted_list:
        x = permitted_list[-1]

    binary_encoding = [int(boolean_value) for boolean_value in list(map(lambda s: x == s, permitted_list))]

    return binary_encoding

def get_atom_features(atom, 
                      use_chirality = True, 
                      hydrogens_implicit = True):
    """
    Takes an RDKit atom object as input and gives a 1d-numpy array of atom features as output.
    """

    # define list of permitted atoms
    
    permitted_list_of_atoms =  ['C','N','O','S','F','Si','P','Cl','Br','Mg','Na','Ca','Fe','As','Al','I', 'B','V','K','Tl','Yb','Sb','Sn','Ag','Pd','Co','Se','Ti','Zn', 'Li','Ge','Cu','Au','Ni','Cd','In','Mn','Zr','Cr','Pt','Hg','Pb','Unknown']
    
    if hydrogens_implicit == False:
        permitted_list_of_atoms = ['H'] + permitted_list_of_atoms
    
    # compute atom features
    
    atom_type_enc = one_hot_encoding(str(atom.GetSymbol()), permitted_list_of_atoms)
    
    n_heavy_neighbors_enc = one_hot_encoding(int(atom.GetDegree()), [0, 1, 2, 3, 4, "MoreThanFour"])
    
    formal_charge_enc = one_hot_encoding(int(atom.GetFormalCharge()), [-3, -2, -1, 0, 1, 2, 3, "Extreme"])
    
    hybridisation_type_enc = one_hot_encoding(str(atom.GetHybridization()), ["S", "SP", "SP2", "SP3", "SP3D", "SP3D2", "OTHER"])
    
    is_in_a_ring_enc = [int(atom.IsInRing())]
    
    is_aromatic_enc = [int(atom.GetIsAromatic())]
    
    atomic_mass_scaled = [float((atom.GetMass() - 10.812)/116.092)]
    
    vdw_radius_scaled = [float((Chem.GetPeriodicTable().GetRvdw(atom.GetAtomicNum()) - 1.5)/0.6)]
    
    covalent_radius_scaled = [float((Chem.GetPeriodicTable().GetRcovalent(atom.GetAtomicNum()) - 0.64)/0.76)]

    atom_feature_vector = atom_type_enc + n_heavy_neighbors_enc + formal_charge_enc + hybridisation_type_enc + is_in_a_ring_enc + is_aromatic_enc + atomic_mass_scaled + vdw_radius_scaled + covalent_radius_scaled
                                    
    if use_chirality == True:
        chirality_type_enc = one_hot_encoding(str(atom.GetChiralTag()), ["CHI_UNSPECIFIED", "CHI_TETRAHEDRAL_CW", "CHI_TETRAHEDRAL_CCW", "CHI_OTHER"])
        atom_feature_vector += chirality_type_enc
    
    if hydrogens_implicit == True:
        n_hydrogens_enc = one_hot_encoding(int(atom.GetTotalNumHs()), [0, 1, 2, 3, 4, "MoreThanFour"])
        atom_feature_vector += n_hydrogens_enc

    return np.array(atom_feature_vector)

def get_bond_features(bond, 
                      use_stereochemistry = True):
    """
    Takes an RDKit bond object as input and gives a 1d-numpy array of bond features as output.
    """

    permitted_list_of_bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]

    bond_type_enc = one_hot_encoding(bond.GetBondType(), permitted_list_of_bond_types)
    
    bond_is_conj_enc = [int(bond.GetIsConjugated())]
    
    bond_is_in_ring_enc = [int(bond.IsInRing())]
    
    bond_feature_vector = bond_type_enc + bond_is_conj_enc + bond_is_in_ring_enc
    
    if use_stereochemistry == True:
        stereo_type_enc = one_hot_encoding(str(bond.GetStereo()), ["STEREOZ", "STEREOE", "STEREOANY", "STEREONONE"])
        bond_feature_vector += stereo_type_enc

    return np.array(bond_feature_vector)

In [5]:
def create_pytorch_geometric_graph_data_list_from_smiles_and_labels(x_smiles, y):
    """
    Inputs:
    
    x_smiles = [smiles_1, smiles_2, ....] ... a list of SMILES strings
    y = [y_1, y_2, ...] ... a list of numerial labels for the SMILES strings (such as associated pKi values)
    
    Outputs:
    
    data_list = [G_1, G_2, ...] ... a list of torch_geometric.data.Data objects which represent labeled molecular graphs that can readily be used for machine learning
    
    """
    
    data_list = []
    
    for (smiles, y_val) in zip(x_smiles, y):
        
        # convert SMILES to RDKit mol object
        mol = Chem.MolFromSmiles(smiles)

        # get feature dimensions
        n_nodes = mol.GetNumAtoms()
        n_edges = 2*mol.GetNumBonds()
        unrelated_smiles = "O=O"
        unrelated_mol = Chem.MolFromSmiles(unrelated_smiles)
        n_node_features = len(get_atom_features(unrelated_mol.GetAtomWithIdx(0)))
        n_edge_features = len(get_bond_features(unrelated_mol.GetBondBetweenAtoms(0,1)))

        # construct node feature matrix X of shape (n_nodes, n_node_features)
        X = np.zeros((n_nodes, n_node_features))

        for atom in mol.GetAtoms():
            X[atom.GetIdx(), :] = get_atom_features(atom)
            
        X = torch.tensor(X, dtype = torch.float)
        
        # construct edge index array E of shape (2, n_edges)
        (rows, cols) = np.nonzero(GetAdjacencyMatrix(mol))
        torch_rows = torch.from_numpy(rows.astype(np.int64)).to(torch.long)
        torch_cols = torch.from_numpy(cols.astype(np.int64)).to(torch.long)
        E = torch.stack([torch_rows, torch_cols], dim = 0)
        
        # construct edge feature array EF of shape (n_edges, n_edge_features)
        EF = np.zeros((n_edges, n_edge_features))
        
        for (k, (i,j)) in enumerate(zip(rows, cols)):
            
            EF[k] = get_bond_features(mol.GetBondBetweenAtoms(int(i),int(j)))
        
        EF = torch.tensor(EF, dtype = torch.float)
        
        # construct label tensor
#         y_tensor = torch.tensor(np.array([y_val]), dtype = torch.float)
        y_tensor = y_val
        
        # construct Pytorch Geometric data object and append to data list
        data_list.append(Data(x = X, edge_index = E, edge_attr = EF, y = y_tensor))

    return data_list

## GNN Model

In [13]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

class GCNRegression(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, output_channels):
        super(GCNRegression, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, output_channels)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        x = F.log_softmax(x, dim=1)
        
        return x

## Experiment

In [31]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [6]:
# create list of molecular graph objects from list of SMILES x_smiles and list of labels y
# x, edge_index, edge_attr, y
data_list = create_pytorch_geometric_graph_data_list_from_smiles_and_labels(smile, emb2d)

print(data_list[0])
print(data_list[1].x.shape)

Data(x=[25, 79], edge_index=[2, 60], edge_attr=[60, 10], y=[250])


In [19]:
from torch_geometric.data import Data, DataLoader
def custom_collate(batch):  
    # Assuming all data objects have the same attributes and properties
    keys = batch[0].keys
    batched_data = Data.from_data_list(batch)
    
    return batched_data

# create dataloader for training
dataloader = DataLoader(dataset = data_list, batch_size = 10000, collate_fn=custom_collate)



In [14]:
""" Create Model """

num_features = 79

# Define model
gnn_model = GCNRegression(num_features, hidden_channels=16, output_channels=250)

optimizer = torch.optim.Adam(gnn_model.parameters(), lr=0.01, weight_decay=5e-4)
loss_function = torch.nn.MSELoss()

In [15]:
# test model: 1 instance
batch = torch.tensor(np.zeros(data_list[1].x.shape[0]), dtype=torch.long)
o = gnn_model(data_list[1].x, data_list[1].edge_index, batch)
print(o.shape)

torch.Size([1, 250])


In [26]:
gnn_model.to(device)

# loop over 10 training epochs
for epoch in range(10):

    # set model to training mode
    gnn_model.train()

    # initialize tqdm for progress visualization
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f'Epoch {epoch + 1}/10', unit='batch')
    # loop over minibatches for training
    for k, batch in pbar:
        
        # compute current value of loss function via forward pass
        b = batch.batch
        output = gnn_model(batch.x, batch.edge_index, b)
        loss_function_value = loss_function(output, torch.tensor(batch.y.view(-1, 250), dtype=torch.float32))

        # set past gradient to zero
        optimizer.zero_grad()

        # compute current gradient via backward pass
        loss_function_value.backward()

        # update model weights using gradient and optimisation method
        optimizer.step()

        # update progress bar description
        pbar.set_postfix(loss=loss_function_value.item())
        
        print(f'Training loss at epoch {epoch+1}: {loss_function_value}.')

  loss_function_value = loss_function(output, torch.tensor(batch.y.view(-1, 250), dtype=torch.float32))
Epoch 1/10:  10%|█         | 1/10 [00:02<00:19,  2.12s/batch, loss=30.7]

Training loss at epoch 1: 30.721059799194336.


Epoch 1/10:  20%|██        | 2/10 [00:04<00:16,  2.10s/batch, loss=30.7]

Training loss at epoch 1: 30.709997177124023.


Epoch 1/10:  30%|███       | 3/10 [00:06<00:14,  2.07s/batch, loss=30.7]

Training loss at epoch 1: 30.712749481201172.


Epoch 1/10:  40%|████      | 4/10 [00:08<00:12,  2.15s/batch, loss=30.7]

Training loss at epoch 1: 30.713510513305664.


Epoch 1/10:  50%|█████     | 5/10 [00:11<00:11,  2.34s/batch, loss=30.7]

Training loss at epoch 1: 30.711008071899414.


Epoch 1/10:  60%|██████    | 6/10 [00:13<00:09,  2.34s/batch, loss=30.7]

Training loss at epoch 1: 30.709936141967773.


Epoch 1/10:  70%|███████   | 7/10 [00:15<00:06,  2.27s/batch, loss=30.7]

Training loss at epoch 1: 30.709157943725586.


Epoch 1/10:  80%|████████  | 8/10 [00:17<00:04,  2.23s/batch, loss=30.7]

Training loss at epoch 1: 30.71107292175293.


Epoch 1/10:  90%|█████████ | 9/10 [00:19<00:02,  2.21s/batch, loss=30.7]

Training loss at epoch 1: 30.712221145629883.


Epoch 1/10: 100%|██████████| 10/10 [00:22<00:00,  2.21s/batch, loss=30.7]


Training loss at epoch 1: 30.706851959228516.


Epoch 2/10:  10%|█         | 1/10 [00:02<00:19,  2.15s/batch, loss=30.7]

Training loss at epoch 2: 30.71082305908203.


Epoch 2/10:  20%|██        | 2/10 [00:04<00:16,  2.09s/batch, loss=30.7]

Training loss at epoch 2: 30.705904006958008.


Epoch 2/10:  30%|███       | 3/10 [00:06<00:14,  2.08s/batch, loss=30.7]

Training loss at epoch 2: 30.709890365600586.


Epoch 2/10:  40%|████      | 4/10 [00:08<00:12,  2.11s/batch, loss=30.7]

Training loss at epoch 2: 30.710664749145508.


Epoch 2/10:  50%|█████     | 5/10 [00:10<00:10,  2.12s/batch, loss=30.7]

Training loss at epoch 2: 30.70797348022461.


Epoch 2/10:  60%|██████    | 6/10 [00:12<00:08,  2.12s/batch, loss=30.7]

Training loss at epoch 2: 30.706768035888672.


Epoch 2/10:  70%|███████   | 7/10 [00:14<00:06,  2.13s/batch, loss=30.7]

Training loss at epoch 2: 30.7060546875.


Epoch 2/10:  80%|████████  | 8/10 [00:16<00:04,  2.14s/batch, loss=30.7]

Training loss at epoch 2: 30.7082462310791.


Epoch 2/10:  90%|█████████ | 9/10 [00:19<00:02,  2.14s/batch, loss=30.7]

Training loss at epoch 2: 30.70983123779297.


Epoch 2/10: 100%|██████████| 10/10 [00:21<00:00,  2.14s/batch, loss=30.7]


Training loss at epoch 2: 30.704946517944336.


Epoch 3/10:  10%|█         | 1/10 [00:03<00:29,  3.31s/batch, loss=30.7]

Training loss at epoch 3: 30.709426879882812.


Epoch 3/10:  20%|██        | 2/10 [00:05<00:21,  2.75s/batch, loss=30.7]

Training loss at epoch 3: 30.70493507385254.


Epoch 3/10:  30%|███       | 3/10 [00:08<00:20,  2.94s/batch, loss=30.7]

Training loss at epoch 3: 30.70913314819336.


Epoch 3/10:  40%|████      | 4/10 [00:10<00:15,  2.58s/batch, loss=30.7]

Training loss at epoch 3: 30.709917068481445.


Epoch 3/10:  50%|█████     | 5/10 [00:12<00:12,  2.41s/batch, loss=30.7]

Training loss at epoch 3: 30.707286834716797.


Epoch 3/10:  60%|██████    | 6/10 [00:15<00:09,  2.31s/batch, loss=30.7]

Training loss at epoch 3: 30.706239700317383.


Epoch 3/10:  70%|███████   | 7/10 [00:17<00:06,  2.23s/batch, loss=30.7]

Training loss at epoch 3: 30.705705642700195.


Epoch 3/10:  80%|████████  | 8/10 [00:19<00:04,  2.20s/batch, loss=30.7]

Training loss at epoch 3: 30.708003997802734.


Epoch 3/10:  90%|█████████ | 9/10 [00:21<00:02,  2.17s/batch, loss=30.7]

Training loss at epoch 3: 30.709632873535156.


Epoch 3/10: 100%|██████████| 10/10 [00:23<00:00,  2.35s/batch, loss=30.7]


Training loss at epoch 3: 30.70476722717285.


Epoch 4/10:  10%|█         | 1/10 [00:02<00:18,  2.03s/batch, loss=30.7]

Training loss at epoch 4: 30.709274291992188.


Epoch 4/10:  20%|██        | 2/10 [00:04<00:16,  2.04s/batch, loss=30.7]

Training loss at epoch 4: 30.70480728149414.


Epoch 4/10:  30%|███       | 3/10 [00:06<00:14,  2.02s/batch, loss=30.7]

Training loss at epoch 4: 30.709020614624023.


Epoch 4/10:  40%|████      | 4/10 [00:08<00:12,  2.03s/batch, loss=30.7]

Training loss at epoch 4: 30.70981788635254.


Epoch 4/10:  50%|█████     | 5/10 [00:10<00:10,  2.03s/batch, loss=30.7]

Training loss at epoch 4: 30.707197189331055.


Epoch 4/10:  60%|██████    | 6/10 [00:12<00:08,  2.02s/batch, loss=30.7]

Training loss at epoch 4: 30.706157684326172.


Epoch 4/10:  70%|███████   | 7/10 [00:14<00:06,  2.02s/batch, loss=30.7]

Training loss at epoch 4: 30.70562171936035.


Epoch 4/10:  80%|████████  | 8/10 [00:16<00:04,  2.03s/batch, loss=30.7]

Training loss at epoch 4: 30.70792007446289.


Epoch 4/10:  90%|█████████ | 9/10 [00:18<00:02,  2.26s/batch, loss=30.7]

Training loss at epoch 4: 30.709543228149414.


Epoch 4/10: 100%|██████████| 10/10 [00:21<00:00,  2.10s/batch, loss=30.7]


Training loss at epoch 4: 30.70467185974121.


Epoch 5/10:  10%|█         | 1/10 [00:02<00:18,  2.05s/batch, loss=30.7]

Training loss at epoch 5: 30.709177017211914.


Epoch 5/10:  20%|██        | 2/10 [00:04<00:19,  2.42s/batch, loss=30.7]

Training loss at epoch 5: 30.704713821411133.


Epoch 5/10:  30%|███       | 3/10 [00:06<00:15,  2.24s/batch, loss=30.7]

Training loss at epoch 5: 30.70893096923828.


Epoch 5/10:  40%|████      | 4/10 [00:09<00:13,  2.30s/batch, loss=30.7]

Training loss at epoch 5: 30.709733963012695.


Epoch 5/10:  50%|█████     | 5/10 [00:11<00:10,  2.19s/batch, loss=30.7]

Training loss at epoch 5: 30.707109451293945.


Epoch 5/10:  60%|██████    | 6/10 [00:13<00:08,  2.14s/batch, loss=30.7]

Training loss at epoch 5: 30.706073760986328.


Epoch 5/10:  70%|███████   | 7/10 [00:15<00:06,  2.13s/batch, loss=30.7]

Training loss at epoch 5: 30.705541610717773.


Epoch 5/10:  80%|████████  | 8/10 [00:17<00:04,  2.13s/batch, loss=30.7]

Training loss at epoch 5: 30.707839965820312.


Epoch 5/10:  90%|█████████ | 9/10 [00:19<00:02,  2.13s/batch, loss=30.7]

Training loss at epoch 5: 30.70947265625.


Epoch 5/10: 100%|██████████| 10/10 [00:21<00:00,  2.17s/batch, loss=30.7]


Training loss at epoch 5: 30.70461082458496.


Epoch 6/10:  10%|█         | 1/10 [00:02<00:18,  2.11s/batch, loss=30.7]

Training loss at epoch 6: 30.70911407470703.


Epoch 6/10:  20%|██        | 2/10 [00:04<00:16,  2.11s/batch, loss=30.7]

Training loss at epoch 6: 30.70465850830078.


Epoch 6/10:  30%|███       | 3/10 [00:06<00:14,  2.11s/batch, loss=30.7]

Training loss at epoch 6: 30.708879470825195.


Epoch 6/10:  40%|████      | 4/10 [00:08<00:12,  2.11s/batch, loss=30.7]

Training loss at epoch 6: 30.709686279296875.


Epoch 6/10:  50%|█████     | 5/10 [00:10<00:10,  2.08s/batch, loss=30.7]

Training loss at epoch 6: 30.70707130432129.


Epoch 6/10:  60%|██████    | 6/10 [00:12<00:08,  2.06s/batch, loss=30.7]

Training loss at epoch 6: 30.706035614013672.


Epoch 6/10:  70%|███████   | 7/10 [00:14<00:06,  2.04s/batch, loss=30.7]

Training loss at epoch 6: 30.705507278442383.


Epoch 6/10:  80%|████████  | 8/10 [00:17<00:04,  2.20s/batch, loss=30.7]

Training loss at epoch 6: 30.707805633544922.


Epoch 6/10:  90%|█████████ | 9/10 [00:19<00:02,  2.14s/batch, loss=30.7]

Training loss at epoch 6: 30.709440231323242.


Epoch 6/10: 100%|██████████| 10/10 [00:21<00:00,  2.10s/batch, loss=30.7]


Training loss at epoch 6: 30.704578399658203.


Epoch 7/10:  10%|█         | 1/10 [00:02<00:18,  2.00s/batch, loss=30.7]

Training loss at epoch 7: 30.70908546447754.


Epoch 7/10:  20%|██        | 2/10 [00:04<00:16,  2.02s/batch, loss=30.7]

Training loss at epoch 7: 30.704633712768555.


Epoch 7/10:  30%|███       | 3/10 [00:06<00:14,  2.01s/batch, loss=30.7]

Training loss at epoch 7: 30.70885467529297.


Epoch 7/10:  40%|████      | 4/10 [00:08<00:12,  2.02s/batch, loss=30.7]

Training loss at epoch 7: 30.709657669067383.


Epoch 7/10:  50%|█████     | 5/10 [00:10<00:10,  2.02s/batch, loss=30.7]

Training loss at epoch 7: 30.707042694091797.


Epoch 7/10:  60%|██████    | 6/10 [00:12<00:08,  2.02s/batch, loss=30.7]

Training loss at epoch 7: 30.706016540527344.


Epoch 7/10:  70%|███████   | 7/10 [00:14<00:06,  2.02s/batch, loss=30.7]

Training loss at epoch 7: 30.705488204956055.


Epoch 7/10:  80%|████████  | 8/10 [00:16<00:04,  2.03s/batch, loss=30.7]

Training loss at epoch 7: 30.707796096801758.


Epoch 7/10:  90%|█████████ | 9/10 [00:18<00:02,  2.03s/batch, loss=30.7]

Training loss at epoch 7: 30.709423065185547.


Epoch 7/10: 100%|██████████| 10/10 [00:20<00:00,  2.03s/batch, loss=30.7]


Training loss at epoch 7: 30.70456314086914.


Epoch 8/10:  10%|█         | 1/10 [00:02<00:18,  2.01s/batch, loss=30.7]

Training loss at epoch 8: 30.70906639099121.


Epoch 8/10:  20%|██        | 2/10 [00:04<00:16,  2.02s/batch, loss=30.7]

Training loss at epoch 8: 30.704614639282227.


Epoch 8/10:  30%|███       | 3/10 [00:06<00:14,  2.00s/batch, loss=30.7]

Training loss at epoch 8: 30.70883560180664.


Epoch 8/10:  40%|████      | 4/10 [00:08<00:12,  2.02s/batch, loss=30.7]

Training loss at epoch 8: 30.709644317626953.


Epoch 8/10:  50%|█████     | 5/10 [00:10<00:10,  2.01s/batch, loss=30.7]

Training loss at epoch 8: 30.707027435302734.


Epoch 8/10:  60%|██████    | 6/10 [00:12<00:08,  2.17s/batch, loss=30.7]

Training loss at epoch 8: 30.70599365234375.


Epoch 8/10:  70%|███████   | 7/10 [00:14<00:06,  2.11s/batch, loss=30.7]

Training loss at epoch 8: 30.70546531677246.


Epoch 8/10:  80%|████████  | 8/10 [00:16<00:04,  2.08s/batch, loss=30.7]

Training loss at epoch 8: 30.707773208618164.


Epoch 8/10:  90%|█████████ | 9/10 [00:18<00:02,  2.05s/batch, loss=30.7]

Training loss at epoch 8: 30.70940399169922.


Epoch 8/10: 100%|██████████| 10/10 [00:20<00:00,  2.05s/batch, loss=30.7]


Training loss at epoch 8: 30.704544067382812.


Epoch 9/10:  10%|█         | 1/10 [00:02<00:18,  2.00s/batch, loss=30.7]

Training loss at epoch 9: 30.709049224853516.


Epoch 9/10:  20%|██        | 2/10 [00:04<00:16,  2.01s/batch, loss=30.7]

Training loss at epoch 9: 30.7045955657959.


Epoch 9/10:  30%|███       | 3/10 [00:06<00:14,  2.00s/batch, loss=30.7]

Training loss at epoch 9: 30.708816528320312.


Epoch 9/10:  40%|████      | 4/10 [00:08<00:12,  2.01s/batch, loss=30.7]

Training loss at epoch 9: 30.709623336791992.


Epoch 9/10:  50%|█████     | 5/10 [00:10<00:10,  2.01s/batch, loss=30.7]

Training loss at epoch 9: 30.707008361816406.


Epoch 9/10:  60%|██████    | 6/10 [00:12<00:08,  2.00s/batch, loss=30.7]

Training loss at epoch 9: 30.705974578857422.


Epoch 9/10:  70%|███████   | 7/10 [00:14<00:06,  2.01s/batch, loss=30.7]

Training loss at epoch 9: 30.7054500579834.


Epoch 9/10:  80%|████████  | 8/10 [00:16<00:04,  2.03s/batch, loss=30.7]

Training loss at epoch 9: 30.707754135131836.


Epoch 9/10:  90%|█████████ | 9/10 [00:18<00:02,  2.02s/batch, loss=30.7]

Training loss at epoch 9: 30.709392547607422.


Epoch 9/10: 100%|██████████| 10/10 [00:20<00:00,  2.03s/batch, loss=30.7]


Training loss at epoch 9: 30.704530715942383.


Epoch 10/10:  10%|█         | 1/10 [00:02<00:19,  2.13s/batch, loss=30.7]

Training loss at epoch 10: 30.709033966064453.


Epoch 10/10:  20%|██        | 2/10 [00:04<00:19,  2.43s/batch, loss=30.7]

Training loss at epoch 10: 30.704586029052734.


Epoch 10/10:  30%|███       | 3/10 [00:06<00:15,  2.28s/batch, loss=30.7]

Training loss at epoch 10: 30.708803176879883.


Epoch 10/10:  40%|████      | 4/10 [00:08<00:13,  2.22s/batch, loss=30.7]

Training loss at epoch 10: 30.709611892700195.


Epoch 10/10:  50%|█████     | 5/10 [00:11<00:10,  2.18s/batch, loss=30.7]

Training loss at epoch 10: 30.706995010375977.


Epoch 10/10:  60%|██████    | 6/10 [00:13<00:08,  2.16s/batch, loss=30.7]

Training loss at epoch 10: 30.705970764160156.


Epoch 10/10:  70%|███████   | 7/10 [00:15<00:06,  2.14s/batch, loss=30.7]

Training loss at epoch 10: 30.705440521240234.


Epoch 10/10:  80%|████████  | 8/10 [00:17<00:04,  2.14s/batch, loss=30.7]

Training loss at epoch 10: 30.707746505737305.


Epoch 10/10:  90%|█████████ | 9/10 [00:19<00:02,  2.13s/batch, loss=30.7]

Training loss at epoch 10: 30.709379196166992.


Epoch 10/10: 100%|██████████| 10/10 [00:21<00:00,  2.17s/batch, loss=30.7]

Training loss at epoch 10: 30.70452117919922.





In [27]:
# save model
TIME = '05_06__17_53'
torch.save(gnn_model.state_dict(), f'./gnn_model_weights_{TIME}.pth')

In [32]:
# Inference all the instances
gnn_model.eval()
emb_g2d = []

for item in data_list:
    b = torch.tensor(np.zeros(item.x.shape[0]), dtype=torch.long)
    emb_g2d.append(gnn_model(item.x, item.edge_index, b))

print(len(emb_g2d))

100000


In [34]:
torch.save(emb_g2d, f"./emb_g2d_{TIME}.pt")