In [1]:
import matplotlib.pyplot as plt

In [2]:
import torch

In [3]:
import numpy as np
import pandas as pd
from rdkit import Chem

import openpyxl
import os
from pathlib import Path

In [4]:
TOP = Path.cwd().as_posix().replace('notebooks','')
raw_dir = Path(TOP) / 'data'/'raw'
interim_dir = Path(TOP) / 'data'/'interim'
external_dir = Path(TOP) / 'data'/'external'
figures_dir = Path(TOP) / 'reports'/'figures/'
processed_dir = Path(TOP) / 'data'/'processed'

In [5]:
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, Dropout
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GINConv
from torch_geometric.nn import global_mean_pool, global_add_pool
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.transforms import Compose

In [8]:
from torch_geometric.utils import subgraph

class RandomNodeDrop:
    def __init__(self, p=0.2):
        self.p = p

    def __call__(self, data):
        # Randomly mask nodes
        num_nodes = data.num_nodes
        mask = torch.rand(num_nodes) > self.p  # Keep nodes with probability (1-p)
        edge_index, _ = subgraph(mask, data.edge_index, relabel_nodes=True)
        
        # Update data object
        data.edge_index = edge_index
        if data.x is not None:  # Update node features if present
            data.x = data.x[mask]
        return data

In [9]:
class RandomEdgeDelete:
    def __init__(self, p=0.2):
        self.p = p

    def __call__(self, data):
        num_edges = data.edge_index.size(1)
        mask = torch.rand(num_edges) > self.p

        #print(f"Original edge_index shape: {data.edge_index.shape}")
        #print(f"Original edge_attr shape: {data.edge_attr.shape if data.edge_attr is not None else 'None'}")
        #print(f"Mask shape: {mask.shape}")

        # Apply mask to edge_index
        data.edge_index = data.edge_index[:, mask]

        # Apply mask to edge_attr (if present)
        if data.edge_attr is not None:
            data.edge_attr = data.edge_attr[mask]

        #print(f"New edge_index shape: {data.edge_index.shape}")
        #print(f"New edge_attr shape: {data.edge_attr.shape if data.edge_attr is not None else 'None'}")

        return data


In [26]:
# Define a simple GCN model
class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x


In [7]:
def mol_to_graph(mol):
        """
        Converts an RDKit mol object into a graph representation with atom and bond features.
        
        :param mol: RDKit mol object representing a molecule.
        :return: PyTorch Geometric Data object containing graph representation of the molecule.
        """
        atom_features = []
        bond_features = []
        edge_index = []
        
        # Atom feature extraction
        for atom in mol.GetAtoms():
            atom_feature = [
                atom.GetAtomicNum(),                         # Atomic number
                atom.GetFormalCharge(),                      # Formal charge
                atom.GetHybridization().real,                # Hybridization
                atom.GetIsAromatic(),                        # Aromaticity
                atom.GetImplicitValence(),                   # Implicit valence
                atom.GetDegree(),                            # Number of bonds to other atoms
                atom.IsInRing(),                             # Is in ring
                atom.GetChiralTag(),                         # Chirality
                atom.GetExplicitValence(),                   # Explicit valence
                atom.GetNumRadicalElectrons(),               # Number of radical electrons
                atom.GetTotalNumHs(),                        # Total number of hydrogens
                atom.GetIsotope(),                           # Isotope (if any)
                atom.GetMass()                               # Atomic mass
            ]
            atom_features.append(atom_feature)
        
        # Bond feature extraction
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_index.extend([[i, j], [j, i]])  # Add bidirectional edges
            
            bond_feature = [
                bond.GetBondTypeAsDouble(),                  # Bond type (single, double, triple, aromatic)
                bond.GetIsConjugated(),                      # Is conjugated
                bond.IsInRing(),                             # Is in ring
                bond.GetStereo(),                            # Bond stereochemistry (cis/trans)
                bond.GetBondDir()                            # Bond direction (up/down in 3D)
            ]
            bond_features.extend([bond_feature, bond_feature])  # Add twice for bidirectional edges
        
        # Convert to tensors
        x = torch.tensor(atom_features, dtype=torch.float)
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(bond_features, dtype=torch.float)
        
        return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

In [8]:
# Graph contrastive learning loss
def contrastive_loss(z1, z2, temperature=0.5, hard_negatives=None):
    # Normalize embeddings
    z1 = F.normalize(z1, dim=-1)
    z2 = F.normalize(z2, dim=-1)

    # Compute cosine similarity
    similarity_matrix = torch.mm(z1, z2.t())

    # Stabilize exp by subtracting max similarity
    max_similarity = similarity_matrix.max(dim=1, keepdim=True)[0]
    sim_exp = torch.exp((similarity_matrix - max_similarity) / temperature)

    # Contrastive loss
    if hard_negatives is not None:
        # Optionally focus on hard negatives
        hard_mask = hard_negatives.bool()
        negative_samples = sim_exp.masked_select(hard_mask).sum(dim=1)
    else:
        all_samples = sim_exp.sum(dim=1)
        positive_samples = torch.diag(sim_exp)
        loss = -torch.log(positive_samples / all_samples).mean()

    return loss




In [9]:
class GraphData(Dataset):
    def __init__(self, df):
        """
        GraphData class inheriting from the Dataset class in PyTorch.

        Parameters
        ----------
        df : pandas.DataFrame
            The dataframe containing the SMILES strings.
        """
        self.df = df
        self.valid_indices = self._get_valid_indices()

    def _get_valid_indices(self):
        valid_indices = []
        for idx in range(len(self.df)):
            row = self.df.iloc[idx]
            smiles = row['smiles']
            try:
                mol = Chem.MolFromSmiles(smiles)
                if mol is None:
                    raise ValueError(f"Invalid SMILES: {smiles}")
                valid_indices.append(idx)
            except Exception as e:
                print(f"Error processing SMILES {smiles}: {e}")
                continue  # Skip invalid SMILES
        return valid_indices

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        """
        Returns the graph representation of the molecule.

        Parameters
        ----------
        idx : int
            Index of the valid sample.

        Returns
        -------
        graph : torch_geometric.data.Data
            The graph representation of the molecule.
        """
        # Use valid index to get the actual row from the dataframe
        valid_idx = self.valid_indices[idx]
        row = self.df.iloc[valid_idx]
        smiles = row['smiles']
        
        # Process the valid SMILES string
        mol = Chem.MolFromSmiles(smiles)
        graph = mol_to_graph(mol)
        
        return graph

In [10]:
df = pd.read_csv(external_dir/'dsstox_smiles_dec24.csv', index_col = [0])

In [10]:
df1 = pd.read_csv(external_dir/'toxcast_cleaned.csv', index_col = [0])

In [11]:
df1.head()

Unnamed: 0,INPUT,FOUND_BY,DTXSID,PREFERRED_NAME,CASRN,SMILES,QSAR_READY_SMILES,smiles
0,DTXSID3042423,DSSTox_Substance_Id,DTXSID3042423,Sucrose octaacetate,126-14-7,[H][C@]1(O[C@]2(COC(C)=O)O[C@H](COC(C)=O)[C@@H...,CC(=O)OCC1OC(COC(C)=O)(OC2OC(COC(C)=O)C(OC(C)=...,CC(=O)OCC1OC(COC(C)=O)(OC2OC(COC(C)=O)C(OC(C)=...
1,DTXSID3042477,DSSTox_Substance_Id,DTXSID3042477,Tolnaftate,2398-96-1,CN(C(=S)OC1=CC2=CC=CC=C2C=C1)C1=CC=CC(C)=C1,CN(C(=S)OC1=CC2=CC=CC=C2C=C1)C1=CC=CC(C)=C1,CN(C(=S)OC1=CC2=CC=CC=C2C=C1)C1=CC=CC(C)=C1
3,DTXSID3042508,DSSTox_Substance_Id,DTXSID3042508,Uric acid,69-93-2,O=C1NC2=C(N1)C(=O)NC(=O)N2,[H]N1C2C(=NC1=O)N([H])C(=O)N([H])C2=O,[H]N1C2C(=NC1=O)N([H])C(=O)N([H])C2=O
4,DTXSID3042631,DSSTox_Substance_Id,DTXSID3042631,"(1R,3S)-3-(3,4-Dichlorophenyl)-N-methyl-2,3-di...",96850-13-4,Cl.CN[C@@H]1C[C@H](C2=CC=CC=C12)C1=CC=C(Cl)C(C...,CNC1CC(C2=CC=CC=C12)C1=CC(Cl)=C(Cl)C=C1,CNC1CC(C2=CC=CC=C12)C1=CC(Cl)=C(Cl)C=C1
5,DTXSID3042633,DSSTox_Substance_Id,DTXSID3042633,Clomipramine hydrochloride,17321-77-6,Cl.CN(C)CCCN1C2=C(CCC3=C1C=C(Cl)C=C3)C=CC=C2,CN(C)CCCN1C2=CC=CC=C2CCC2=CC=C(Cl)C=C12,CN(C)CCCN1C2=CC=CC=C2CCC2=CC=C(Cl)C=C12


In [12]:
from sklearn.model_selection import train_test_split

In [13]:
train, test = train_test_split(df1, test_size = 0.2)

In [14]:
train.info()

<class 'pandas.core.frame.DataFrame'>
Index: 6847 entries, 4088 to 9476
Data columns (total 8 columns):
 #   Column             Non-Null Count  Dtype 
---  ------             --------------  ----- 
 0   INPUT              6847 non-null   object
 1   FOUND_BY           6847 non-null   object
 2   DTXSID             6847 non-null   object
 3   PREFERRED_NAME     6847 non-null   object
 4   CASRN              6847 non-null   object
 5   SMILES             6847 non-null   object
 6   QSAR_READY_SMILES  6847 non-null   object
 7   smiles             6847 non-null   object
dtypes: object(8)
memory usage: 481.4+ KB


In [15]:
train_data = GraphData(train)

[22:07:40] Explicit valence for atom # 5 N, 5, is greater than permitted


Error processing SMILES OCC=C1C[NH]2(CC=C)CCC34C2CC1C1=CN2C5C(=CN(C31)C1=C4C=CC=C1)C1CC3C5(CC[NH]3(CC=C)CC1=CCO)C1=C2C=CC=C1: Invalid SMILES: OCC=C1C[NH]2(CC=C)CCC34C2CC1C1=CN2C5C(=CN(C31)C1=C4C=CC=C1)C1CC3C5(CC[NH]3(CC=C)CC1=CCO)C1=C2C=CC=C1


In [16]:
train_data

GraphData(6846)

In [17]:
test.shape

(1712, 8)

In [18]:
test_data = GraphData(test)

In [19]:
test_data

GraphData(1712)

In [20]:
train_data[0]

Data(x=[14, 13], edge_index=[2, 32], edge_attr=[32, 5])

In [45]:
class RandomSubgraph:
    def __init__(self, num_nodes: int):
        self.num_nodes = num_nodes

    def __call__(self, data):
        import random
        from torch_geometric.utils import subgraph

        # Select random nodes from the graph
        num_nodes = data.num_nodes
        if self.num_nodes >= num_nodes:
            return data  # Return the original graph if too few nodes to sample

        random_nodes = random.sample(range(num_nodes), self.num_nodes)

        # Generate subgraph
        edge_index, edge_mask = subgraph(
            random_nodes, edge_index=data.edge_index, relabel_nodes=True, num_nodes=num_nodes
        )
        # Create a new `Data` object to avoid modifying the input `data` in place
        subgraph_data = Data()

        # Update node features
        if data.x is not None:
            subgraph_data.x = data.x[random_nodes]

        # Update edge index
        subgraph_data.edge_index = edge_index

        # Update edge attributes if present
        if data.edge_attr is not None:
            subgraph_data.edge_attr = data.edge_attr[edge_mask]

        # Copy additional attributes (e.g., labels, targets)
        for key in data.keys():  # Corrected here
            if key not in ['x', 'edge_index', 'edge_attr']:
                subgraph_data[key] = data[key]

        return subgraph_data

In [21]:
from torch_geometric.transforms import RandomLinkSplit, RandomNodeSplit

In [25]:
transform = Compose([RandomNodeSplit(split="train_rest", num_train_per_class=5),RandomLinkSplit(is_undirected=True, add_negative_train_samples=False)])

In [62]:
'''for i, graph in enumerate(train_data):
    print(f"Graph {i} before transforms:", graph)
    try:
        transformed_graph = transform(graph)
        print(f"Graph {i} after transforms:", transformed_graph)
    except Exception as e:
        print(f"Error with Graph {i}:", e)
        break

SyntaxError: incomplete input (3158717337.py, line 1)

In [69]:
for graph in train_data:
    g1 = RandomNodeSplit(split="train_rest", num_train_per_class=5)(graph)
    g2 = RandomLinkSplit(is_undirected=True, add_negative_train_samples=False)(graph)
    print("g1 type:", type(g1))
    print("g2 type:", type(g2))
    break

g1 type: <class 'torch_geometric.data.data.Data'>
g2 type: <class 'tuple'>


In [73]:
for graph in train_data:
    # Apply RandomNodeSplit to get one graph
   # g1 = RandomNodeSplit(split="train_rest", num_train_per_class=5)(graph)

    # Apply RandomLinkSplit to get a tuple of graphs
    split_graphs = RandomLinkSplit(is_undirected=True, add_negative_train_samples=False)(graph)
    
    # Access the positive and negative graphs from the tuple
    g2_pos = split_graphs[0]
    g2_neg = split_graphs[1]

    print("g2_pos type:", type(g2_pos))  # Positive links graph
    print("g2_neg type:", type(g2_neg))  # Negative links graph
    
    break

g2_pos type: <class 'torch_geometric.data.data.Data'>
g2_neg type: <class 'torch_geometric.data.data.Data'>


In [33]:
def collate_fn(batch):
    # Filter out None values from the batch
    batch = [item for item in batch if item is not None]
    
    if len(batch) == 0:
        return None  # If no valid data left, return None (or you could return an empty batch)
    
    # Otherwise, proceed with the batch processing (you might need to adapt for your graph data)
    return batch


In [22]:
train_loader = DataLoader(train_data, batch_size=300, shuffle=True, drop_last=True)
test_loader = DataLoader(test_data, batch_size=300, shuffle=False, drop_last=True)




In [36]:
for data in train_loader:
    print(data)
    break

DataBatch(x=[5769, 13], edge_index=[2, 12084], edge_attr=[12084, 5], batch=[5769], ptr=[301])


In [23]:
def validate_model(model, val_loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            g1 = node_split_transform(batch)
            g2_pos, g2_neg, _ = link_split_transform(batch)
            z1 = model(g1.x, g1.edge_index)
            z2 = model(g2_pos.x, g2_pos.edge_index)
            loss = contrastive_loss(z1, z2)
            total_loss += loss.item()
    return total_loss / len(val_loader)

In [28]:
# Define transforms
epochs = 100
model = GCN(input_dim=train_data[0].num_features, hidden_dim=32, output_dim=16)
node_split_transform = RandomNodeSplit(split="train_rest", num_train_per_class=5)
link_split_transform = RandomLinkSplit(is_undirected=True, add_negative_train_samples=False)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
# Model and optimizer



# Contrastive learning loop


for epoch in range(epochs):
    total_loss = 0

    for batch in train_loader:
        # Apply RandomNodeSplit for the first augmented graph
        
        g1 = node_split_transform(batch)

        # Apply RandomLinkSplit to get positive and negative graphs
        g2_pos, g2_neg, g2_val = link_split_transform(batch)

        # Forward pass with positive graph
        z1 = model(g1.x, g1.edge_index)  # Forward pass on the first graph (g1)
        z2 = model(g2_pos.x, g2_pos.edge_index)  # Forward pass on the second graph (g2_pos)

        # Compute contrastive loss with the positive graphs
        loss = contrastive_loss(z1, z2)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Validation step
    val_loss = validate_model(model, test_loader)

    print(f"Epoch {epoch+1}, Train Loss: {total_loss / len(train_loader):.4f}, Val Loss: {val_loss:.4f}")
    
    scheduler.step()
# Backpropagation
        

    

Epoch 1, Train Loss: 8.5105, Val Loss: 8.4104
Epoch 2, Train Loss: 8.0451, Val Loss: 7.6298
Epoch 3, Train Loss: 7.6068, Val Loss: 7.5699
Epoch 4, Train Loss: 7.5448, Val Loss: 7.5250
Epoch 5, Train Loss: 7.5012, Val Loss: 7.4835
Epoch 6, Train Loss: 7.4615, Val Loss: 7.4409
Epoch 7, Train Loss: 7.4201, Val Loss: 7.4046
Epoch 8, Train Loss: 7.3823, Val Loss: 7.3703
Epoch 9, Train Loss: 7.3494, Val Loss: 7.3451
Epoch 10, Train Loss: 7.3274, Val Loss: 7.3299
Epoch 11, Train Loss: 7.3093, Val Loss: 7.3064
Epoch 12, Train Loss: 7.2936, Val Loss: 7.2866
Epoch 13, Train Loss: 7.2832, Val Loss: 7.2788
Epoch 14, Train Loss: 7.2794, Val Loss: 7.2844
Epoch 15, Train Loss: 7.2673, Val Loss: 7.2618
Epoch 16, Train Loss: 7.2580, Val Loss: 7.2576
Epoch 17, Train Loss: 7.2508, Val Loss: 7.2528
Epoch 18, Train Loss: 7.2436, Val Loss: 7.2471
Epoch 19, Train Loss: 7.2396, Val Loss: 7.2381
Epoch 20, Train Loss: 7.2337, Val Loss: 7.2344
Epoch 21, Train Loss: 7.2319, Val Loss: 7.2272
Epoch 22, Train Loss: 

RuntimeError: index 5291 is out of bounds for dimension 0 with size 5261

In [29]:
torch.save(model.state_dict(), 'model_weights_txcst.pth') 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset
import optuna

# Define the model
class SimpleNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Training function
def train_model(model, criterion, optimizer, train_loader, device):
    model.train()
    total_loss = 0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

# Validation function
def validate_model(model, criterion, val_loader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
    return total_loss / len(val_loader)

# Define the Optuna objective function
def objective(trial):
    # Hyperparameters to optimize
    hidden_size = trial.suggest_int("hidden_size", 16, 128)
    learning_rate = trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True)
    batch_size = trial.suggest_int("batch_size", 16, 64)

    # Data preparation
    input_size = 10  # Example input features
    output_size = 1  # Example output size (e.g., regression task)
    dataset_size = 1000
    x = torch.randn(dataset_size, input_size)
    y = torch.randn(dataset_size, output_size)
    dataset = TensorDataset(x, y)
    train_size = int(0.8 * dataset_size)
    val_size = dataset_size - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Model, loss, optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SimpleNet(input_size, hidden_size, output_size).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training and validation
    for epoch in range(10):  # Small number of epochs for testing
        train_model(model, criterion, optimizer, train_loader, device)
    val_loss = validate_model(model, criterion, val_loader, device)

    return val_loss

# Create an Optuna study and optimize
if __name__ == "__main__":
    study = optuna.create_study(direction="minimize")
    study.optimize(objective, n_trials=50)

    # Print the best parameters
    print("Best hyperparameters:", study.best_params)