In [1]:
import torch

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
import numpy as np
import pandas as pd
from rdkit import Chem

import openpyxl
import os
from pathlib import Path


In [3]:
TOP = Path.cwd().as_posix().replace('notebooks','')
raw_dir = Path(TOP) / 'data'/'raw'
interim_dir = Path(TOP) / 'data'/'interim'
external_dir = Path(TOP) / 'data'/'external'
figures_dir = Path(TOP) / 'reports'/'figures/'
processed_dir = Path(TOP) / 'data'/'processed'


In [4]:
from torch_geometric.data import Data, Dataset, DataLoader

In [5]:
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as transforms

from torch_scatter import scatter
from torch_geometric.data import Data, Dataset, DataLoader

import rdkit
from rdkit import Chem
from rdkit.Chem.rdchem import HybridizationType
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem import AllChem
import math
import random
from copy import deepcopy
import networkx as nx

In [6]:
ATOM_LIST = list(range(1,119))
CHIRALITY_LIST = [
    Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
    Chem.rdchem.ChiralType.CHI_OTHER
]
BOND_LIST = [
    BT.SINGLE, 
    BT.DOUBLE, 
    BT.TRIPLE, 
    BT.AROMATIC
]
BONDDIR_LIST = [
    Chem.rdchem.BondDir.NONE,
    Chem.rdchem.BondDir.ENDUPRIGHT,
    Chem.rdchem.BondDir.ENDDOWNRIGHT
]


In [7]:
def remove_subgraph(Graph, center, percent=0.2):
    assert percent <= 1
    G = Graph.copy()
    num = int(np.floor(len(G.nodes)*percent))
    removed = []
    temp = [center]
    
    while len(removed) < num:
        neighbors = []
        if len(temp) < 1:
            break

        for n in temp:
            neighbors.extend([i for i in G.neighbors(n) if i not in temp])      
        for n in temp:
            if len(removed) < num:
                G.remove_node(n)
                removed.append(n)
            else:
                break

        temp = list(set(neighbors))
    return G, removed

In [8]:
class GraphData(Dataset):
    def __init__(self, df):
        """
        GraphData class inheriting from the Dataset class in PyTorch.

        Parameters
        ----------
        df : pandas.DataFrame
            The dataframe containing the SMILES strings.
        """
        self.df = df
        self.valid_indices = self._get_valid_indices()

    def _get_valid_indices(self):
        valid_indices = []
        for idx in range(len(self.df)):
            row = self.df.iloc[idx]
            smiles = row['smiles']
            try:
                mol = Chem.MolFromSmiles(smiles)
                if mol is None:
                    raise ValueError(f"Invalid SMILES: {smiles}")
                valid_indices.append(idx)
            except Exception as e:
                print(f"Error processing SMILES {smiles}: {e}")
                continue  # Skip invalid SMILES
        return valid_indices

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        """
        Returns the graph representation of the molecule.

        Parameters
        ----------
        idx : int
            Index of the valid sample.

        Returns
        -------
        graph : torch_geometric.data.Data
            The graph representation of the molecule.
        """
        # Use valid index to get the actual row from the dataframe
        valid_idx = self.valid_indices[idx]
        row = self.df.iloc[valid_idx]
        smiles = row['smiles']
        
        # Process the valid SMILES string
        mol = Chem.MolFromSmiles(smiles)
        mol = Chem.AddHs(mol)
        
        N = mol.GetNumAtoms()
        M = mol.GetNumBonds()
        atoms = mol.GetAtoms()
        bonds = mol.GetBonds()

        #########################
        # Get the molecule info #
        #########################
        type_idx = []
        chirality_idx = []
        atomic_number = []
        for atom in mol.GetAtoms():
            type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
            chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
            atomic_number.append(atom.GetAtomicNum())

        x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
        x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
        x = torch.cat([x1, x2], dim=-1)

        row, col, edge_feat = [], [], []
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row += [start, end]
            col += [end, start]
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long)

        ####################
        # Subgraph Masking #
        ####################

        # Construct the original molecular graph from edges (bonds)
        edges = []
        for bond in bonds:
            edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
        molGraph = nx.Graph(edges)
        
        # Get the graph for i and j after removing subgraphs
        start_i, start_j = random.sample(list(range(N)), 2)
        percent_i, percent_j = random.uniform(0, 0.2), random.uniform(0, 0.2)
        G_i, removed_i = remove_subgraph(molGraph, start_i, percent=percent_i)
        G_j, removed_j = remove_subgraph(molGraph, start_j, percent=percent_j)

        atom_remain_indices_i = [i for i in range(N) if i not in removed_i]
        atom_remain_indices_j = [i for i in range(N) if i not in removed_j]
        
        # Only consider bond still exist after removing subgraph
        row_i, col_i, row_j, col_j = [], [], [], []
        edge_feat_i, edge_feat_j = [], []
        G_i_edges = list(G_i.edges)
        G_j_edges = list(G_j.edges)

        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            feature = [
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ]
            if (start, end) in G_i_edges or (end, start) in G_i_edges:
                row_i += [start, end]
                col_i += [end, start]
                edge_feat_i.append(feature)
                edge_feat_i.append(feature)
            if (start, end) in G_j_edges or (end, start) in G_j_edges:
                row_j += [start, end]
                col_j += [end, start]
                edge_feat_j.append(feature)
                edge_feat_j.append(feature)
        
        edge_index_i = torch.tensor([row_i, col_i], dtype=torch.long)
        edge_attr_i = torch.tensor(np.array(edge_feat_i), dtype=torch.long)
        edge_index_j = torch.tensor([row_j, col_j], dtype=torch.long)
        edge_attr_j = torch.tensor(np.array(edge_feat_j), dtype=torch.long)

        ############################
        # Random Atom/Edge Masking #
        ############################

        num_mask_nodes_i = max([0, math.floor(0.25*N)-len(removed_i)])
        num_mask_edges_i = max([0, edge_attr_i.size(0)//2 - math.ceil(0.75*M)])
        num_mask_nodes_j = max([0, math.floor(0.25*N)-len(removed_j)])
        num_mask_edges_j = max([0, edge_attr_j.size(0)//2 - math.ceil(0.75*M)])
        mask_nodes_i = random.sample(atom_remain_indices_i, num_mask_nodes_i)
        mask_nodes_j = random.sample(atom_remain_indices_j, num_mask_nodes_j)
        mask_edges_i_single = random.sample(list(range(edge_attr_i.size(0)//2)), num_mask_edges_i)
        mask_edges_j_single = random.sample(list(range(edge_attr_j.size(0)//2)), num_mask_edges_j)
        mask_edges_i = [2*i for i in mask_edges_i_single] + [2*i+1 for i in mask_edges_i_single]
        mask_edges_j = [2*i for i in mask_edges_j_single] + [2*i+1 for i in mask_edges_j_single]

        x_i = deepcopy(x)
        for atom_idx in range(N):
            if (atom_idx in mask_nodes_i) or (atom_idx in removed_i):
                x_i[atom_idx,:] = torch.tensor([len(ATOM_LIST), 0])
        edge_index_final_i = torch.zeros((2, edge_attr_i.size(0) - 2*num_mask_edges_i), dtype=torch.long)
        edge_attr_final_i = torch.zeros((edge_attr_i.size(0) - 2*num_mask_edges_i, 2), dtype=torch.long)
        count = 0
        for bond_idx in range(edge_attr_i.size(0)):
            if bond_idx not in mask_edges_i:
                edge_index_final_i[:,count] = edge_index_i[:,bond_idx]
                edge_attr_final_i[count,:] = edge_attr_i[bond_idx,:]
                count += 1
        data_i = Data(x=x_i, edge_index=edge_index_final_i, edge_attr=edge_attr_final_i)

        x_j = deepcopy(x)
        for atom_idx in range(N):
            if (atom_idx in mask_nodes_j) or (atom_idx in removed_j):
                x_j[atom_idx,:] = torch.tensor([len(ATOM_LIST), 0])
        edge_index_final_j = torch.zeros((2, edge_attr_j.size(0) - 2*num_mask_edges_j), dtype=torch.long)
        edge_attr_final_j = torch.zeros((edge_attr_j.size(0) - 2*num_mask_edges_j, 2), dtype=torch.long)
        count = 0
        for bond_idx in range(edge_attr_j.size(0)):
            if bond_idx not in mask_edges_j:
                edge_index_final_j[:,count] = edge_index_j[:,bond_idx]
                edge_attr_final_j[count,:] = edge_attr_j[bond_idx,:]
                count += 1
        data_j = Data(x=x_j, edge_index=edge_index_final_j, edge_attr=edge_attr_final_j)
        
        return data_i, data_j
        

In [9]:
class GraphData_unaug(Dataset):
    def __init__(self, df):
        """
        GraphData class inheriting from the Dataset class in PyTorch.

        Parameters
        ----------
        df : pandas.DataFrame
            The dataframe containing the SMILES strings.
        """
        self.df = df
        self.valid_indices = self._get_valid_indices()

    def _get_valid_indices(self):
        valid_indices = []
        for idx in range(len(self.df)):
            row = self.df.iloc[idx]
            smiles = row['smiles']
            try:
                mol = Chem.MolFromSmiles(smiles)
                if mol is None:
                    raise ValueError(f"Invalid SMILES: {smiles}")
                valid_indices.append(idx)
            except Exception as e:
                print(f"Error processing SMILES {smiles}: {e}")
                continue  # Skip invalid SMILES
        return valid_indices

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        """
        Returns the graph representation of the molecule.

        Parameters
        ----------
        idx : int
            Index of the valid sample.

        Returns
        -------
        graph : torch_geometric.data.Data
            The graph representation of the molecule.
        """
        # Use valid index to get the actual row from the dataframe
        valid_idx = self.valid_indices[idx]
        row = self.df.iloc[valid_idx]
        smiles = row['smiles']
        
        # Process the valid SMILES string
        mol = Chem.MolFromSmiles(smiles)
        mol = Chem.AddHs(mol)
        
        N = mol.GetNumAtoms()
        M = mol.GetNumBonds()

        type_idx = []
        chirality_idx = []
        atomic_number = []
        # aromatic = []
        # sp, sp2, sp3, sp3d = [], [], [], []
        # num_hs = []
        for atom in mol.GetAtoms():
            type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
            chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
            atomic_number.append(atom.GetAtomicNum())
            # aromatic.append(1 if atom.GetIsAromatic() else 0)
            # hybridization = atom.GetHybridization()
            # sp.append(1 if hybridization == HybridizationType.SP else 0)
            # sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
            # sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
            # sp3d.append(1 if hybridization == HybridizationType.SP3D else 0)

        # z = torch.tensor(atomic_number, dtype=torch.long)
        x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
        x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
        x = torch.cat([x1, x2], dim=-1)
        # x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, sp3d, num_hs],
        #                     dtype=torch.float).t().contiguous()
        # x = torch.cat([x1.to(torch.float), x2], dim=-1)

        row, col, edge_feat = [], [], []
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row += [start, end]
            col += [end, start]
            # edge_type += 2 * [MOL_BONDS[bond.GetBondType()]]
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long)

        # random mask a subgraph of the molecule
        num_mask_nodes = max([1, math.floor(0.25*N)])
        num_mask_edges = max([0, math.floor(0.25*M)])
        mask_nodes_i = random.sample(list(range(N)), num_mask_nodes)
        mask_nodes_j = random.sample(list(range(N)), num_mask_nodes)
        mask_edges_i_single = random.sample(list(range(M)), num_mask_edges)
        mask_edges_j_single = random.sample(list(range(M)), num_mask_edges)
        mask_edges_i = [2*i for i in mask_edges_i_single] + [2*i+1 for i in mask_edges_i_single]
        mask_edges_j = [2*i for i in mask_edges_j_single] + [2*i+1 for i in mask_edges_j_single]

        x_i = deepcopy(x)
        for atom_idx in mask_nodes_i:
            x_i[atom_idx,:] = torch.tensor([len(ATOM_LIST), 0])
        edge_index_i = torch.zeros((2, 2*(M-num_mask_edges)), dtype=torch.long)
        edge_attr_i = torch.zeros((2*(M-num_mask_edges), 2), dtype=torch.long)
        count = 0
        for bond_idx in range(2*M):
            if bond_idx not in mask_edges_i:
                edge_index_i[:,count] = edge_index[:,bond_idx]
                edge_attr_i[count,:] = edge_attr[bond_idx,:]
                count += 1
        data_i = Data(x=x_i, edge_index=edge_index_i, edge_attr=edge_attr_i)

        x_j = deepcopy(x)
        for atom_idx in mask_nodes_j:
            x_j[atom_idx,:] = torch.tensor([len(ATOM_LIST), 0])
        edge_index_j = torch.zeros((2, 2*(M-num_mask_edges)), dtype=torch.long)
        edge_attr_j = torch.zeros((2*(M-num_mask_edges), 2), dtype=torch.long)
        count = 0
        for bond_idx in range(2*M):
            if bond_idx not in mask_edges_j:
                edge_index_j[:,count] = edge_index[:,bond_idx]
                edge_attr_j[count,:] = edge_attr[bond_idx,:]
                count += 1
        data_j = Data(x=x_j, edge_index=edge_index_j, edge_attr=edge_attr_j)
        
        return data_i, data_j

In [10]:
from torch.optim.lr_scheduler import CosineAnnealingLR

In [11]:
import torch
from torch import nn
import torch.nn.functional as F

from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool

In [12]:
class NTXentLoss(torch.nn.Module):

    def __init__(self, batch_size, temperature, use_cosine_similarity):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.softmax = torch.nn.Softmax(dim=-1)
        self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
        self.similarity_function = self._get_similarity_function(use_cosine_similarity)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")

    def _get_similarity_function(self, use_cosine_similarity):
        if use_cosine_similarity:
            self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
            return self._cosine_simililarity
        else:
            return self._dot_simililarity

    def _get_correlated_mask(self):
        diag = np.eye(2 * self.batch_size)
        l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
        l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
        mask = torch.from_numpy((diag + l1 + l2))
        mask = (1 - mask).type(torch.bool)
        return mask

    @staticmethod
    def _dot_simililarity(x, y):
        v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
        # x shape: (N, 1, C)
        # y shape: (1, C, 2N)
        # v shape: (N, 2N)
        return v

    def _cosine_simililarity(self, x, y):
        # x shape: (N, 1, C)
        # y shape: (1, 2N, C)
        # v shape: (N, 2N)
        v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def forward(self, zis, zjs):
        representations = torch.cat([zjs, zis], dim=0)

        similarity_matrix = self.similarity_function(representations, representations)

        # filter out the scores from the positive samples
        l_pos = torch.diag(similarity_matrix, self.batch_size)
        r_pos = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)

        negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)

        logits = torch.cat((positives, negatives), dim=1)
        logits /= self.temperature

        labels = torch.zeros(2 * self.batch_size).long()
        loss = self.criterion(logits, labels)

        return loss / (2 * self.batch_size)

In [13]:
criterion = NTXentLoss(batch_size = 512, temperature = 0.1, use_cosine_similarity=True )

In [14]:
num_atom_type = 119 # including the extra mask tokens
num_chirality_tag = 3

num_bond_type = 5 # including aromatic and self-loop edge
num_bond_direction = 3 

In [15]:
class GINEConv(MessagePassing):
    def __init__(self, emb_dim):
        super(GINEConv, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, 2*emb_dim), 
            nn.ReLU(), 
            nn.Linear(2*emb_dim, emb_dim)
        )
        self.edge_embedding1 = nn.Embedding(num_bond_type, emb_dim)
        self.edge_embedding2 = nn.Embedding(num_bond_direction, emb_dim)
        nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
        nn.init.xavier_uniform_(self.edge_embedding2.weight.data)

    def forward(self, x, edge_index, edge_attr):
        # add self loops in the edge space
        edge_index = add_self_loops(edge_index, num_nodes=x.size(0))[0]

        # add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), 2)
        self_loop_attr[:,0] = 4 #bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)

        edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])

        return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)

    def message(self, x_j, edge_attr):
        return x_j + edge_attr

    def update(self, aggr_out):
        return self.mlp(aggr_out)


class GINet(nn.Module):
    """
    Args:
        num_layer (int): the number of GNN layers
        emb_dim (int): dimensionality of embeddings
        max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
        drop_ratio (float): dropout rate
        gnn_type: gin, gcn, graphsage, gat
    Output:
        node representations
    """
    def __init__(self, num_layer=5, emb_dim=300, feat_dim=512, drop_ratio=0, pool='mean'):
        super(GINet, self).__init__()
        self.num_layer = num_layer
        self.emb_dim = emb_dim
        self.feat_dim = feat_dim
        self.drop_ratio = drop_ratio

        self.x_embedding1 = nn.Embedding(num_atom_type, emb_dim)
        self.x_embedding2 = nn.Embedding(num_chirality_tag, emb_dim)
        nn.init.xavier_uniform_(self.x_embedding1.weight.data)
        nn.init.xavier_uniform_(self.x_embedding2.weight.data)

        # List of MLPs
        self.gnns = nn.ModuleList()
        for layer in range(num_layer):
            self.gnns.append(GINEConv(emb_dim))

        # List of batchnorms
        self.batch_norms = nn.ModuleList()
        for layer in range(num_layer):
            self.batch_norms.append(nn.BatchNorm1d(emb_dim))
        
        if pool == 'mean':
            self.pool = global_mean_pool
        elif pool == 'max':
            self.pool = global_max_pool
        elif pool == 'add':
            self.pool = global_add_pool
        
        self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim)

        self.out_lin = nn.Sequential(
            nn.Linear(self.feat_dim, self.feat_dim), 
            nn.ReLU(inplace=True),
            nn.Linear(self.feat_dim, self.feat_dim//2)
        )

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        edge_attr = data.edge_attr

        h = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])

        for layer in range(self.num_layer):
            h = self.gnns[layer](h, edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            if layer == self.num_layer - 1:
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)

        h = self.pool(h, data.batch)
        h = self.feat_lin(h)
        out = self.out_lin(h)
        
        return h, out

In [16]:
df1 = pd.read_csv(external_dir/'toxcast_cleaned.csv', index_col = [0])

In [17]:
from sklearn.model_selection import train_test_split

In [18]:
train, test = train_test_split(df1, test_size = 0.2)

In [19]:
train_data = GraphData(train)

[10:28:20] Explicit valence for atom # 5 N, 5, is greater than permitted


Error processing SMILES OCC=C1C[NH]2(CC=C)CCC34C2CC1C1=CN2C5C(=CN(C31)C1=C4C=CC=C1)C1CC3C5(CC[NH]3(CC=C)CC1=CCO)C1=C2C=CC=C1: Invalid SMILES: OCC=C1C[NH]2(CC=C)CCC34C2CC1C1=CN2C5C(=CN(C31)C1=C4C=CC=C1)C1CC3C5(CC[NH]3(CC=C)CC1=CCO)C1=C2C=CC=C1


In [20]:
train_loader = DataLoader(train_data, batch_size=512, shuffle=True, drop_last=True)



In [21]:
test_data = GraphData(test)
test_loader = DataLoader(test_data, batch_size=512, shuffle=False, drop_last=True)

In [22]:
rax = (pd.read_csv(interim_dir/'icf_processed_220125.csv')
       .rename(columns = {'QSAR Ready SMILES':'smiles'})
 .query('smiles.notnull()')
)

In [23]:
rax.info()

rax_data = GraphData_unaug(rax)
rax_loader = DataLoader(rax_data, batch_size=1, shuffle=False, drop_last=True)

<class 'pandas.core.frame.DataFrame'>
Index: 755 entries, 0 to 796
Data columns (total 12 columns):
 #   Column                        Non-Null Count  Dtype 
---  ------                        --------------  ----- 
 0   Unnamed: 0                    755 non-null    int64 
 1   Substance Name in Assessment  755 non-null    object
 2   substance_role                755 non-null    object
 3   Approach                      755 non-null    object
 4   dtxsid                        755 non-null    object
 5   Preferred Name                750 non-null    object
 6   CASRN                         753 non-null    object
 7   SMILES                        755 non-null    object
 8   smiles                        755 non-null    object
 9   analogue_evidence_stream      678 non-null    object
 10  use_case                      755 non-null    object
 11  source                        697 non-null    object
dtypes: int64(1), object(11)
memory usage: 76.7+ KB




In [24]:
model = GINet(feat_dim=512)

In [25]:
print(f"Feature Dimension: {model.feat_dim}")

Feature Dimension: 512


In [26]:
optimizer = torch.optim.Adam(model.parameters(), lr= 0.0005, weight_decay=1e-5   )

In [27]:
from torch.optim.lr_scheduler import SequentialLR, CosineAnnealingLR, LambdaLR

# Warm-up for the first 10 epochs
warmup_scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: epoch / 10.0 if epoch < 10 else 1.0)

# Cosine decay for the remaining epochs
cosine_scheduler = CosineAnnealingLR(optimizer, T_max=50, eta_min=0)

# Combine the warm-up and cosine schedulers
scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[10])

In [28]:
def validate_model(model, val_loader):
    
    with torch.no_grad():
        model.eval()
        total_loss = 0
        for bn, (xis, xjs) in enumerate(val_loader):
            # Forward pass
            ris, zis = model(xis)
            rjs, zjs = model(xjs)

            # Normalize projections
            zis = F.normalize(zis, dim=1)
            zjs = F.normalize(zjs, dim=1)
            # Compute loss
            loss = criterion(zis, zjs)
            total_loss += loss.item()
        return total_loss / len(val_loader)

In [None]:
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.0, path='best_model.pt', verbose=False):
        """
        Args:
            patience (int): How many epochs to wait after last improvement.
            min_delta (float): Minimum change in monitored metric to qualify as an improvement.
            path (str): Path to save the best model.
            verbose (bool): Whether to print updates about early stopping.
        """
        self.patience = patience
        self.min_delta = min_delta
        self.path = path
        self.verbose = verbose
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, val_loss, model):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
            if self.verbose:
                print(f"Validation loss improved to {val_loss:.4f}. Model saved!")
        else:
            self.counter += 1
            if self.verbose:
                print(f"No improvement in validation loss for {self.counter}/{self.patience} epochs.")
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, model):
        """Saves the current best model."""
        torch.save(model.state_dict(), self.path)

In [None]:
# Initialize early stopping
early_stopping = EarlyStopping(patience=10, path="best_model.pt", verbose=True)

for epoch in range(epochs):
    # Training step
    model.train()
    train_loss = 0.0
    for (xis, xjs) in train_loader:
        optimizer.zero_grad()
        xis, xjs = xis.to(device), xjs.to(device)
        loss = model._step(xis, xjs)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    # Validation step
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for (xis, xjs) in valid_loader:
            xis, xjs = xis.to(device), xjs.to(device)
            loss = model._step(xis, xjs)
            val_loss += loss.item()
    val_loss /= len(valid_loader)

    print(f"Epoch {epoch + 1}: Train Loss = {train_loss:.4f}, Validation Loss = {val_loss:.4f}")

    # Check early stopping
    early_stopping(val_loss, model)

    if early_stopping.early_stop:
        print("Early stopping triggered. Training terminated.")
        break

# Load the best model after early stopping
model.load_state_dict(torch.load("best_model.pt"))
print("Loaded the best model based on validation loss.")

In [29]:
epochs = 50
embeddings_list = []
for epoch in range(epochs):
    model.train()  # Ensure model is in training mode

    for bn, (xis, xjs) in enumerate(train_loader):
        # Forward pass
        ris, zis = model(xis)
        rjs, zjs = model(xjs)
        
        # Normalize projections
        zis = F.normalize(zis, dim=1)
        zjs = F.normalize(zjs, dim=1)
        embeddings_list.append((F.normalize(ris, dim=1), F.normalize(rjs, dim=1))) 
        # Compute loss
        loss = criterion(zis, zjs)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    val_loss = validate_model(model, test_loader)

    # Step the learning rate scheduler at the end of the epoch
    scheduler.step()

    # Logging (optional, for monitoring purposes)
    current_lr = scheduler.get_last_lr()[0]  # Get current learning rate
    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}, LR: {current_lr:.6f}, Val Loss: {val_loss:.4f}")

Epoch [1/50], Loss: 6.0916, LR: 0.000050, Val Loss: 6.7507
Epoch [2/50], Loss: 4.7915, LR: 0.000100, Val Loss: 5.7389
Epoch [3/50], Loss: 3.9567, LR: 0.000150, Val Loss: 4.3158
Epoch [4/50], Loss: 3.4778, LR: 0.000200, Val Loss: 4.0385
Epoch [5/50], Loss: 3.3598, LR: 0.000250, Val Loss: 3.7401
Epoch [6/50], Loss: 3.0365, LR: 0.000300, Val Loss: 3.8215
Epoch [7/50], Loss: 2.8599, LR: 0.000350, Val Loss: 3.7542
Epoch [8/50], Loss: 2.8086, LR: 0.000400, Val Loss: 3.2139
Epoch [9/50], Loss: 2.6557, LR: 0.000450, Val Loss: 3.4889




Epoch [10/50], Loss: 2.6599, LR: 0.000500, Val Loss: 3.0659
Epoch [11/50], Loss: 2.6467, LR: 0.000500, Val Loss: 2.9623
Epoch [12/50], Loss: 2.4279, LR: 0.000498, Val Loss: 3.1351
Epoch [13/50], Loss: 2.4089, LR: 0.000496, Val Loss: 2.6694
Epoch [14/50], Loss: 2.3276, LR: 0.000492, Val Loss: 2.6286
Epoch [15/50], Loss: 2.3061, LR: 0.000488, Val Loss: 2.5881
Epoch [16/50], Loss: 2.2423, LR: 0.000482, Val Loss: 2.5448
Epoch [17/50], Loss: 2.1230, LR: 0.000476, Val Loss: 2.5555
Epoch [18/50], Loss: 2.1612, LR: 0.000469, Val Loss: 2.3313
Epoch [19/50], Loss: 2.1016, LR: 0.000461, Val Loss: 2.3231
Epoch [20/50], Loss: 2.1994, LR: 0.000452, Val Loss: 2.3563
Epoch [21/50], Loss: 2.1208, LR: 0.000443, Val Loss: 2.3305
Epoch [22/50], Loss: 2.1769, LR: 0.000432, Val Loss: 2.1813
Epoch [23/50], Loss: 2.0791, LR: 0.000421, Val Loss: 2.2195
Epoch [24/50], Loss: 1.9244, LR: 0.000409, Val Loss: 2.2032
Epoch [25/50], Loss: 1.8743, LR: 0.000397, Val Loss: 2.1274
Epoch [26/50], Loss: 1.9570, LR: 0.00038

In [30]:
embeddings_list[0][1].shape

torch.Size([512, 512])

In [31]:
torch.save(model.state_dict(), 'model_weights_txcst.pth') 

In [32]:
def extract_embeddings(model, test_loader):
    model.eval()
    all_embeddings = []

    with torch.no_grad():
        for batch_idx, (xis, _) in enumerate(test_loader):
            
            
            # Forward pass
            ris, _ = model(xis)  # Extract embeddings
            
            # Normalize
            
            ris = F.normalize(ris, dim=1)

            
            all_embeddings.append(ris.cpu().numpy())  

    # Stack all embeddings
    all_embeddings = np.vstack(all_embeddings)
    print(f"Final embeddings shape: {all_embeddings.shape}")  
    
    return all_embeddings



In [33]:
rax_embeddings = extract_embeddings(model, rax_loader)

Final embeddings shape: (755, 512)


In [34]:
len(rax_embeddings)

755

In [35]:
len(rax_loader)

755

In [36]:
print(len(rax_loader.dataset))  #

755


In [39]:
print(rax_loader.dataset)

GraphData_unaug(755)


In [37]:
print(rax_loader.batch_size) 

1


In [42]:
rax.head(1)

Unnamed: 0.1,Unnamed: 0,Substance Name in Assessment,substance_role,Approach,dtxsid,Preferred Name,CASRN,SMILES,smiles,analogue_evidence_stream,use_case,source
0,0,Chlorobenzene,Target,Category,DTXSID4020298,Chlorobenzene,108-90-7,ClC1=CC=CC=C1,ClC1=CC=CC=C1,Structural_ChemMine-tools_MCS-Tanimoto | Physc...,technical_guidance,OECD IATA case study


In [49]:
rax_embeddings_df = (pd.DataFrame(rax_embeddings, index = rax.dtxsid)
 .rename(columns =lambda x: f'gcn{x}')

        )


In [50]:
rax_embeddings_df.to_csv(interim_dir/'rax_txcst_embeddings.csv')