In [1]:
import matplotlib.pyplot as plt

In [2]:
import torch

In [3]:
import numpy as np
import pandas as pd
from rdkit import Chem

import openpyxl
import os
from pathlib import Path

In [4]:
TOP = Path.cwd().as_posix().replace('notebooks','')
raw_dir = Path(TOP) / 'data'/'raw'
interim_dir = Path(TOP) / 'data'/'interim'
external_dir = Path(TOP) / 'data'/'external'
figures_dir = Path(TOP) / 'reports'/'figures/'
processed_dir = Path(TOP) / 'data'/'processed'

In [5]:
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, Dropout
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GINConv
from torch_geometric.nn import global_mean_pool, global_add_pool
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.transforms import Compose

In [8]:
from torch_geometric.utils import subgraph

class RandomNodeDrop:
    def __init__(self, p=0.2):
        self.p = p

    def __call__(self, data):
        # Randomly mask nodes
        num_nodes = data.num_nodes
        mask = torch.rand(num_nodes) > self.p  # Keep nodes with probability (1-p)
        edge_index, _ = subgraph(mask, data.edge_index, relabel_nodes=True)
        
        # Update data object
        data.edge_index = edge_index
        if data.x is not None:  # Update node features if present
            data.x = data.x[mask]
        return data

In [9]:
class RandomEdgeDelete:
    def __init__(self, p=0.2):
        self.p = p

    def __call__(self, data):
        num_edges = data.edge_index.size(1)
        mask = torch.rand(num_edges) > self.p

        #print(f"Original edge_index shape: {data.edge_index.shape}")
        #print(f"Original edge_attr shape: {data.edge_attr.shape if data.edge_attr is not None else 'None'}")
        #print(f"Mask shape: {mask.shape}")

        # Apply mask to edge_index
        data.edge_index = data.edge_index[:, mask]

        # Apply mask to edge_attr (if present)
        if data.edge_attr is not None:
            data.edge_attr = data.edge_attr[mask]

        #print(f"New edge_index shape: {data.edge_index.shape}")
        #print(f"New edge_attr shape: {data.edge_attr.shape if data.edge_attr is not None else 'None'}")

        return data


In [6]:
# Define a simple GCN model
class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x


In [7]:
def mol_to_graph(mol):
        """
        Converts an RDKit mol object into a graph representation with atom and bond features.
        
        :param mol: RDKit mol object representing a molecule.
        :return: PyTorch Geometric Data object containing graph representation of the molecule.
        """
        atom_features = []
        bond_features = []
        edge_index = []
        
        # Atom feature extraction
        for atom in mol.GetAtoms():
            atom_feature = [
                atom.GetAtomicNum(),                         # Atomic number
                atom.GetFormalCharge(),                      # Formal charge
                atom.GetHybridization().real,                # Hybridization
                atom.GetIsAromatic(),                        # Aromaticity
                atom.GetImplicitValence(),                   # Implicit valence
                atom.GetDegree(),                            # Number of bonds to other atoms
                atom.IsInRing(),                             # Is in ring
                atom.GetChiralTag(),                         # Chirality
                atom.GetExplicitValence(),                   # Explicit valence
                atom.GetNumRadicalElectrons(),               # Number of radical electrons
                atom.GetTotalNumHs(),                        # Total number of hydrogens
                atom.GetIsotope(),                           # Isotope (if any)
                atom.GetMass()                               # Atomic mass
            ]
            atom_features.append(atom_feature)
        
        # Bond feature extraction
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_index.extend([[i, j], [j, i]])  # Add bidirectional edges
            
            bond_feature = [
                bond.GetBondTypeAsDouble(),                  # Bond type (single, double, triple, aromatic)
                bond.GetIsConjugated(),                      # Is conjugated
                bond.IsInRing(),                             # Is in ring
                bond.GetStereo(),                            # Bond stereochemistry (cis/trans)
                bond.GetBondDir()                            # Bond direction (up/down in 3D)
            ]
            bond_features.extend([bond_feature, bond_feature])  # Add twice for bidirectional edges
        
        # Convert to tensors
        x = torch.tensor(atom_features, dtype=torch.float)
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(bond_features, dtype=torch.float)
        
        return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

In [8]:
# Graph contrastive learning loss
def contrastive_loss(z1, z2, temperature=0.5):
    # Normalize embeddings
    z1 = F.normalize(z1, dim=-1)
    z2 = F.normalize(z2, dim=-1)

    # Compute cosine similarity
    similarity_matrix = torch.mm(z1, z2.t())

    # Exponential scaling
    sim_exp = torch.exp(similarity_matrix / temperature)

    # Contrastive loss
    positive_samples = torch.diag(sim_exp)
    all_samples = sim_exp.sum(dim=1)
    loss = -torch.log(positive_samples / all_samples).mean()
    return loss





In [9]:
class GraphData(Dataset):
    def __init__(self, df):
        """
        GraphData class inheriting from the Dataset class in PyTorch.

        Parameters
        ----------
        df : pandas.DataFrame
            The dataframe containing the SMILES strings.
        """
        self.df = df
        self.valid_indices = self._get_valid_indices()

    def _get_valid_indices(self):
        valid_indices = []
        for idx in range(len(self.df)):
            row = self.df.iloc[idx]
            smiles = row['smiles']
            try:
                mol = Chem.MolFromSmiles(smiles)
                if mol is None:
                    raise ValueError(f"Invalid SMILES: {smiles}")
                valid_indices.append(idx)
            except Exception as e:
                print(f"Error processing SMILES {smiles}: {e}")
                continue  # Skip invalid SMILES
        return valid_indices

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        """
        Returns the graph representation of the molecule.

        Parameters
        ----------
        idx : int
            Index of the valid sample.

        Returns
        -------
        graph : torch_geometric.data.Data
            The graph representation of the molecule.
        """
        # Use valid index to get the actual row from the dataframe
        valid_idx = self.valid_indices[idx]
        row = self.df.iloc[valid_idx]
        smiles = row['smiles']
        
        # Process the valid SMILES string
        mol = Chem.MolFromSmiles(smiles)
        graph = mol_to_graph(mol)
        
        return graph

In [10]:
df = pd.read_csv(external_dir/'dsstox_smiles_dec24.csv', index_col = [0])

In [11]:
df.head()

Unnamed: 0,dtxsid,smiles
1,DTXSID70419842,CC1=NC2=C(C(C)=C1)C(=O)N=C(N2)SCC(=O)NC1=CC=CC=C1
2,DTXSID30419843,COC(=O)CC1=CC(=O)OC2=C1C=CC(O)=C2
3,DTXSID90419844,CC1=C(C(=O)C2=C(O)C=C(O)C=C2O1)C1=CC=CC=C1
4,DTXSID50419845,CCC1=CC(=O)OC2=C1C=CC(O)=C2
5,DTXSID10419846,COC1=CC=C(C=C1)C1=CC(=O)OC2=C1C=CC(O)=C2C


In [12]:
from sklearn.model_selection import train_test_split

In [13]:
train, test = train_test_split(df, test_size = 0.2)

In [14]:
train.info()

<class 'pandas.core.frame.DataFrame'>
Index: 434397 entries, 34870 to 66164
Data columns (total 2 columns):
 #   Column  Non-Null Count   Dtype 
---  ------  --------------   ----- 
 0   dtxsid  434397 non-null  object
 1   smiles  434397 non-null  object
dtypes: object(2)
memory usage: 9.9+ MB


In [15]:
train_data = GraphData(train)

[17:17:56] Explicit valence for atom # 17 N, 5, is greater than permitted


Error processing SMILES CC1(C)CC(CC(C)(C)N1O)O[P-]([O-])(N1CC1)#[N]1CC1: Invalid SMILES: CC1(C)CC(CC(C)(C)N1O)O[P-]([O-])(N1CC1)#[N]1CC1


[17:17:56] Explicit valence for atom # 5 C, 6, is greater than permitted
[17:17:56] Explicit valence for atom # 11 N, 5, is greater than permitted


Error processing SMILES CC(=O)OC[C]1234[BH]567[BH]89%10[BH]%11%12%13[BH]585[BH]%118%11[BH]%12%12%14[BH]9%139[BH]16%10[BH]2%129[C]38%14(COC(C)=O)[BH]475%11: Invalid SMILES: CC(=O)OC[C]1234[BH]567[BH]89%10[BH]%11%12%13[BH]585[BH]%118%11[BH]%12%12%14[BH]9%139[BH]16%10[BH]2%129[C]38%14(COC(C)=O)[BH]475%11
Error processing SMILES O=C(C(C(=O)C1=CC=CC=C1)=[N]1=CC=CC=C1)C1=CC=CC=C1: Invalid SMILES: O=C(C(C(=O)C1=CC=CC=C1)=[N]1=CC=CC=C1)C1=CC=CC=C1


[17:18:00] Explicit valence for atom # 1 N, 5, is greater than permitted


Error processing SMILES C[N](C)(C)=NC(O)=NC1=CC=CC=C1: Invalid SMILES: C[N](C)(C)=NC(O)=NC1=CC=CC=C1


[17:18:01] Explicit valence for atom # 4 Cl, 3, is greater than permitted


Error processing SMILES C[Se](C)(C)#[Cl]: Invalid SMILES: C[Se](C)(C)#[Cl]


[17:18:02] Explicit valence for atom # 6 N, 5, is greater than permitted


Error processing SMILES CN(C)C1=CC=[N](C=C1)=C1C(=O)N(C2=CC=CC=C2)C2=CC=CC=C2C1=O: Invalid SMILES: CN(C)C1=CC=[N](C=C1)=C1C(=O)N(C2=CC=CC=C2)C2=CC=CC=C2C1=O


[17:18:05] Explicit valence for atom # 4 Cl, 5, is greater than permitted


Error processing SMILES CC(CO[Cl-](COCC(F)([N+]([O-])=O)[N+]([O-])=O)(OCC(C)([N+]([O-])=O)[N+]([O-])=O)C(Cl)(Cl)SSC(Cl)(Cl)[Cl-](COCC(F)([N+]([O-])=O)[N+]([O-])=O)(OCC(C)([N+]([O-])=O)[N+]([O-])=O)OCC(C)([N+]([O-])=O)[N+]([O-])=O)([N+]([O-])=O)[N+]([O-])=O: Invalid SMILES: CC(CO[Cl-](COCC(F)([N+]([O-])=O)[N+]([O-])=O)(OCC(C)([N+]([O-])=O)[N+]([O-])=O)C(Cl)(Cl)SSC(Cl)(Cl)[Cl-](COCC(F)([N+]([O-])=O)[N+]([O-])=O)(OCC(C)([N+]([O-])=O)[N+]([O-])=O)OCC(C)([N+]([O-])=O)[N+]([O-])=O)([N+]([O-])=O)[N+]([O-])=O


[17:18:05] Explicit valence for atom # 1 N, 5, is greater than permitted


Error processing SMILES C[N](C)(C)=NC(=O)C1=CC=CC=C1C(=O)N=[N](C)(C)C: Invalid SMILES: C[N](C)(C)=NC(=O)C1=CC=CC=C1C(=O)N=[N](C)(C)C


[17:18:06] Explicit valence for atom # 17 N, 5, is greater than permitted


Error processing SMILES [S-][P-](OC1=C(Cl)C=C(Cl)C(Cl)=C1)(N1CCCC1)#[N]1CCCC1: Invalid SMILES: [S-][P-](OC1=C(Cl)C=C(Cl)C(Cl)=C1)(N1CCCC1)#[N]1CCCC1


[17:18:08] Explicit valence for atom # 1 B, 4, is greater than permitted
[17:18:08] Explicit valence for atom # 4 C, 6, is greater than permitted


Error processing SMILES F[B]1(F)N2C=CC=C2C=C2C=CC=[N+]12: Invalid SMILES: F[B]1(F)N2C=CC=C2C=C2C=CC=[N+]12
Error processing SMILES OC(=O)C[C]1234[BH]567[BH]89%10[BH]%11%12%13[BH]585[BH]%118%11[BH]%12%12%14[BH]9%139[BH]16%10[BH]2%129[C]38%14(C1=CC=CC=C1)[BH]475%11: Invalid SMILES: OC(=O)C[C]1234[BH]567[BH]89%10[BH]%11%12%13[BH]585[BH]%118%11[BH]%12%12%14[BH]9%139[BH]16%10[BH]2%129[C]38%14(C1=CC=CC=C1)[BH]475%11


[17:18:09] Explicit valence for atom # 16 N, 5, is greater than permitted
[17:18:09] Explicit valence for atom # 0 O, 3, is greater than permitted


Error processing SMILES [O-][P-](OC1=CC2=CC=CC=C2C=C1)(N1CC1)#[N]1CC1: Invalid SMILES: [O-][P-](OC1=CC2=CC=CC=C2C=C1)(N1CC1)#[N]1CC1
Error processing SMILES [O-]([Si](C1=CC=CC=C1)(C1=CC=CC=C1)C1=CC=CC=C1)[Si+4]123([O-][Si](C4=CC=CC=C4)(C4=CC=CC=C4)C4=CC=CC=C4)[N-]4C5=NC6=[N]1C(=NC1=C7C=CC=CC7=C(N=C7C8=C(C=CC=C8)C(N=C4C4=C5C=CC=C4)=[N]27)[N-]31)C1=C6C=CC=C1: Invalid SMILES: [O-]([Si](C1=CC=CC=C1)(C1=CC=CC=C1)C1=CC=CC=C1)[Si+4]123([O-][Si](C4=CC=CC=C4)(C4=CC=CC=C4)C4=CC=CC=C4)[N-]4C5=NC6=[N]1C(=NC1=C7C=CC=CC7=C(N=C7C8=C(C=CC=C8)C(N=C4C4=C5C=CC=C4)=[N]27)[N-]31)C1=C6C=CC=C1


[17:18:11] Explicit valence for atom # 7 N, 5, is greater than permitted


Error processing SMILES CC(C)(C)C(=O)N=[N](C)(C)C: Invalid SMILES: CC(C)(C)C(=O)N=[N](C)(C)C


[17:18:14] Explicit valence for atom # 1 Cl, 7, is greater than permitted


Error processing SMILES O=[Cl](=O)(=O)C1=CC=CC=C1: Invalid SMILES: O=[Cl](=O)(=O)C1=CC=CC=C1


[17:18:15] Explicit valence for atom # 9 Cl, 3, is greater than permitted


Error processing SMILES ClC1CCCCC1[Se](Cl)(#[Cl])C1CCCCC1Cl: Invalid SMILES: ClC1CCCCC1[Se](Cl)(#[Cl])C1CCCCC1Cl


[17:18:15] Explicit valence for atom # 3 B, 5, is greater than permitted


Error processing SMILES CN1C(=[BH3])N(C)C2=CC=CC=C12: Invalid SMILES: CN1C(=[BH3])N(C)C2=CC=CC=C12


[17:18:19] Explicit valence for atom # 0 Cl, 3, is greater than permitted
[17:18:19] Explicit valence for atom # 6 Cl, 7, is greater than permitted


Error processing SMILES [Cl-]1C23[Cl-]C12[Cl-]3: Invalid SMILES: [Cl-]1C23[Cl-]C12[Cl-]3
Error processing SMILES CCC[Si](Cl)(Cl)[Cl](F)(F)(F)(F)(F)C1=CC=CC=C1: Invalid SMILES: CCC[Si](Cl)(Cl)[Cl](F)(F)(F)(F)(F)C1=CC=CC=C1


[17:18:19] Explicit valence for atom # 10 B, 6, is greater than permitted


Error processing SMILES O=NN(CC1=CC=CC=C1)[B]1234[BH]567[BH]89%10[BH]%11%12%13[BH]585[BH]%118%11[BH]%12%12%14[BH]9%139[BH]16%10[CH]2%129[CH]38%14[BH]475%11: Invalid SMILES: O=NN(CC1=CC=CC=C1)[B]1234[BH]567[BH]89%10[BH]%11%12%13[BH]585[BH]%118%11[BH]%12%12%14[BH]9%139[BH]16%10[CH]2%129[CH]38%14[BH]475%11


[17:18:22] Explicit valence for atom # 4 Br, 3, is greater than permitted


Error processing SMILES FC(F)(F)[Br]=O: Invalid SMILES: FC(F)(F)[Br]=O


[17:18:23] Explicit valence for atom # 1 N, 5, is greater than permitted
[17:18:23] Explicit valence for atom # 7 C, 6, is greater than permitted


Error processing SMILES C[N](C)(C)=NC(=O)C1=CC(=CC=C1)C(=O)N=[N](C)(C)C: Invalid SMILES: C[N](C)(C)=NC(=O)C1=CC(=CC=C1)C(=O)N=[N](C)(C)C
Error processing SMILES C(C1=CC=CC=C1)[C]1234[BH]567[BH]89%10[BH]%11%12%13[BH]585[BH]%118%11[BH]%12%12%14[BH]9%139[BH]16%10[BH]2%129[C]38%14(CC1=CC=CC=C1)[BH]475%11: Invalid SMILES: C(C1=CC=CC=C1)[C]1234[BH]567[BH]89%10[BH]%11%12%13[BH]585[BH]%118%11[BH]%12%12%14[BH]9%139[BH]16%10[BH]2%129[C]38%14(CC1=CC=CC=C1)[BH]475%11


[17:18:29] Explicit valence for atom # 2 N, 5, is greater than permitted


Error processing SMILES CC[N](CC)(CC)(CC(F)(F)C(F)F)OS(=O)(=O)C(F)(F)F: Invalid SMILES: CC[N](CC)(CC)(CC(F)(F)C(F)F)OS(=O)(=O)C(F)(F)F


[17:18:30] Explicit valence for atom # 15 N, 5, is greater than permitted


Error processing SMILES FC(F)(F)C1=NC(=O)C(C(=N1)C(F)(F)F)=[N]1=CC=CC=C1: Invalid SMILES: FC(F)(F)C1=NC(=O)C(C(=N1)C(F)(F)F)=[N]1=CC=CC=C1


[17:18:31] Explicit valence for atom # 0 B, 5, is greater than permitted


Error processing SMILES [H][B]([H])([H])=C=O: Invalid SMILES: [H][B]([H])([H])=C=O


[17:18:31] Explicit valence for atom # 9 Cl, 3, is greater than permitted


Error processing SMILES CCOC(CC[Si](Cl)(Cl)[Cl-]OC)(OCC)OCC: Invalid SMILES: CCOC(CC[Si](Cl)(Cl)[Cl-]OC)(OCC)OCC


[17:18:33] Explicit valence for atom # 2 O, 3, is greater than permitted


Error processing SMILES CC[O]1CCO[B]1(F)F: Invalid SMILES: CC[O]1CCO[B]1(F)F


[17:18:35] Explicit valence for atom # 8 Br, 3, is greater than permitted
[17:18:35] Explicit valence for atom # 6 Cl, 3, is greater than permitted


Error processing SMILES [H]C(=C([H])C1=CC=C(C=C1)[Br]=O)C(=O)C1=CC=CC=C1: Invalid SMILES: [H]C(=C([H])C1=CC=C(C=C1)[Br]=O)C(=O)C1=CC=CC=C1
Error processing SMILES OC(=O)COO[Cl](Cl)Cl: Invalid SMILES: OC(=O)COO[Cl](Cl)Cl


[17:18:37] Explicit valence for atom # 11 Cl, 3, is greater than permitted


Error processing SMILES CCOC1CCCCC1[Se](Cl)(#[Cl])C1CCCCC1OCC: Invalid SMILES: CCOC1CCCCC1[Se](Cl)(#[Cl])C1CCCCC1OCC


[17:18:37] Explicit valence for atom # 0 C, 6, is greater than permitted


Error processing SMILES [H][C-]1(I)(CC(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)F)OCCN2CCO[Si+]1OCC2: Invalid SMILES: [H][C-]1(I)(CC(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)F)OCCN2CCO[Si+]1OCC2


[17:18:38] Explicit valence for atom # 6 C, 6, is greater than permitted


Error processing SMILES C1=CC=C(C=C1)[C]1234[BH]567[BH]89%10[BH]%11%12%13[BH]585[BH]%118%11[BH]%12%12%14[BH]9%139[BH]16%10[BH]2%129[C]38%14(C1=CC=CC=C1)[BH]475%11: Invalid SMILES: C1=CC=C(C=C1)[C]1234[BH]567[BH]89%10[BH]%11%12%13[BH]585[BH]%118%11[BH]%12%12%14[BH]9%139[BH]16%10[BH]2%129[C]38%14(C1=CC=CC=C1)[BH]475%11


[17:18:39] Explicit valence for atom # 2 B, 6, is greater than permitted


Error processing SMILES C(=N[B]1234[BH]567[BH]89%10[BH]55%11[BH]88%12[BH]%13%14%15[BH]11([BH]269[CH]%108%131)[BH]3%141[BH]475[CH]%11%12%151)C1=CC=CC=C1: Invalid SMILES: C(=N[B]1234[BH]567[BH]89%10[BH]55%11[BH]88%12[BH]%13%14%15[BH]11([BH]269[CH]%108%131)[BH]3%141[BH]475[CH]%11%12%151)C1=CC=CC=C1


[17:18:40] Explicit valence for atom # 3 Br, 3, is greater than permitted


Error processing SMILES C1CC[Br-]OC1: Invalid SMILES: C1CC[Br-]OC1


[17:18:40] Explicit valence for atom # 1 N, 5, is greater than permitted


Error processing SMILES C[N](C)(C)=NC(=O)C1=CC=C(C=C1)C(=O)N=[N](C)(C)C: Invalid SMILES: C[N](C)(C)=NC(=O)C1=CC=C(C=C1)C(=O)N=[N](C)(C)C


[17:18:42] Explicit valence for atom # 8 C, 6, is greater than permitted
[17:18:42] Explicit valence for atom # 8 C, 6, is greater than permitted


Error processing SMILES C[Si](C)(O)O[Si](C)(C)[C]1234[BH]567[BH]89%10[BH]55%11[BH]88%12[BH]99%13[BH]16%10[BH]291[BH]323[BH]475[BH]%1182[C]%12%1313[Si](C)(C)O[Si](C)(C)O[Si](C)(O)C1=CC=CC=C1: Invalid SMILES: C[Si](C)(O)O[Si](C)(C)[C]1234[BH]567[BH]89%10[BH]55%11[BH]88%12[BH]99%13[BH]16%10[BH]291[BH]323[BH]475[BH]%1182[C]%12%1313[Si](C)(C)O[Si](C)(C)O[Si](C)(O)C1=CC=CC=C1
Error processing SMILES O=C(C1=CC=CC=C1)[C]1234[BH]567[BH]89%10[BH]55%11[BH]88%12[BH]99%13[BH]16%10[BH]291[BH]323[BH]58([BH]47%112)P%12%1313: Invalid SMILES: O=C(C1=CC=CC=C1)[C]1234[BH]567[BH]89%10[BH]55%11[BH]88%12[BH]99%13[BH]16%10[BH]291[BH]323[BH]58([BH]47%112)P%12%1313


[17:18:42] Explicit valence for atom # 7 Cl, 7, is greater than permitted


Error processing SMILES ClC1=CC=C(C=C1)[Cl](=O)(=O)=O: Invalid SMILES: ClC1=CC=C(C=C1)[Cl](=O)(=O)=O


[17:18:43] Explicit valence for atom # 4 N, 5, is greater than permitted


Error processing SMILES CC(O)C[N](C)(C)=NC(=O)CCCCC(=O)N=[N](C)(C)CC(C)O: Invalid SMILES: CC(O)C[N](C)(C)=NC(=O)CCCCC(=O)N=[N](C)(C)CC(C)O


[17:18:44] Explicit valence for atom # 2 N, 4, is greater than permitted


Error processing SMILES [N-]=[N+]=[N-][B+3]([N-]=[N+]=[N-])([C-]1=CC=CC=C1)[N]1=CC=CC=C1: Invalid SMILES: [N-]=[N+]=[N-][B+3]([N-]=[N+]=[N-])([C-]1=CC=CC=C1)[N]1=CC=CC=C1


[17:18:46] Explicit valence for atom # 7 Cl, 5, is greater than permitted


Error processing SMILES [O-][N+](=O)C(F)(CO[Cl-](COCC(F)(F)F)(OCC(F)([N+]([O-])=O)[N+]([O-])=O)C(Cl)(Cl)SSC(Cl)(Cl)[Cl-](COCC(F)(F)F)(OCC(F)([N+]([O-])=O)[N+]([O-])=O)OCC(F)([N+]([O-])=O)[N+]([O-])=O)[N+]([O-])=O: Invalid SMILES: [O-][N+](=O)C(F)(CO[Cl-](COCC(F)(F)F)(OCC(F)([N+]([O-])=O)[N+]([O-])=O)C(Cl)(Cl)SSC(Cl)(Cl)[Cl-](COCC(F)(F)F)(OCC(F)([N+]([O-])=O)[N+]([O-])=O)OCC(F)([N+]([O-])=O)[N+]([O-])=O)[N+]([O-])=O


[17:18:51] Explicit valence for atom # 3 B, 4, is greater than permitted


Error processing SMILES CN(C)[B](F)(F)F: Invalid SMILES: CN(C)[B](F)(F)F


[17:18:52] Explicit valence for atom # 4 B, 4, is greater than permitted


Error processing SMILES CC(C)C[B]1([H][B]([H]1)(CC(C)C)CC(C)C)CC(C)C: Invalid SMILES: CC(C)C[B]1([H][B]([H]1)(CC(C)C)CC(C)C)CC(C)C


[17:18:52] Explicit valence for atom # 20 N, 5, is greater than permitted


Error processing SMILES [O-][P-](OC1=CC=C(OCC2=CC=CC=C2)C=C1)(N1CC1)#[N]1CC1: Invalid SMILES: [O-][P-](OC1=CC=C(OCC2=CC=CC=C2)C=C1)(N1CC1)#[N]1CC1


[17:18:53] Explicit valence for atom # 12 C, 6, is greater than permitted


Error processing SMILES C[Si](C)(O)O[Si](C)(C)O[Si](C)(C)[C]1234[BH]567[BH]89%10[BH]55%11[BH]88%12[BH]99%13[BH]16%10[BH]291[BH]323[BH]58([BH]47%112)[C]%12%1313[Si](C)(C)O: Invalid SMILES: C[Si](C)(O)O[Si](C)(C)O[Si](C)(C)[C]1234[BH]567[BH]89%10[BH]55%11[BH]88%12[BH]99%13[BH]16%10[BH]291[BH]323[BH]58([BH]47%112)[C]%12%1313[Si](C)(C)O


In [16]:
train_data

GraphData(434351)

In [17]:
test.shape

(108600, 2)

In [18]:
test_data = GraphData(test)

[17:18:59] Explicit valence for atom # 22 N, 5, is greater than permitted


Error processing SMILES COC(=O)C1=CC=C(OC(=O)NC2=CC=C(O[P-]([O-])(N3CC3)#[N]3CC3)C=C2)C=C1: Invalid SMILES: COC(=O)C1=CC=C(OC(=O)NC2=CC=C(O[P-]([O-])(N3CC3)#[N]3CC3)C=C2)C=C1


[17:19:00] Explicit valence for atom # 1 C, 6, is greater than permitted


Error processing SMILES C[C]1234[BH]567[BH]89%10[BH]55%11[BH]88%12[BH]99%13[BH]16%10[BH]291[BH]323[BH]475[BH]%1182[C]%12%1313C1=CC=CC(F)=C1: Invalid SMILES: C[C]1234[BH]567[BH]89%10[BH]55%11[BH]88%12[BH]99%13[BH]16%10[BH]291[BH]323[BH]475[BH]%1182[C]%12%1313C1=CC=CC(F)=C1


[17:19:03] Explicit valence for atom # 1 Cl, 3, is greater than permitted


Error processing SMILES C[Cl-]OOC(=O)CCC1=CC=CC=C1: Invalid SMILES: C[Cl-]OOC(=O)CCC1=CC=CC=C1


[17:19:03] Explicit valence for atom # 11 N, 5, is greater than permitted


Error processing SMILES CCOC(=O)C(C(=O)OCC)=[N]1=CC=CC=C1: Invalid SMILES: CCOC(=O)C(C(=O)OCC)=[N]1=CC=CC=C1


[17:19:04] Explicit valence for atom # 7 C, 6, is greater than permitted


Error processing SMILES O=C1N(CCCC[C]2345[BH]678[BH]9%10%11[BH]%12%13%14[BH]696[BH]%129%12[BH]%13%13%15[BH]%10%14%10[BH]27%11[BH]3%13%10[CH]49%15[BH]586%12)C(=O)C2=CC=CC=C12: Invalid SMILES: O=C1N(CCCC[C]2345[BH]678[BH]9%10%11[BH]%12%13%14[BH]696[BH]%129%12[BH]%13%13%15[BH]%10%14%10[BH]27%11[BH]3%13%10[CH]49%15[BH]586%12)C(=O)C2=CC=CC=C12


[17:19:07] Explicit valence for atom # 7 C, 6, is greater than permitted


Error processing SMILES CC(C)P(C(C)C)[C]1234[BH]567[BH]89%10[BH]%11%12%13[BH]585[BH]%118%11[BH]%12%12%14[BH]9%139[BH]16%10[BH]2%129[C]38%14(C1=CC=CC=C1)[BH]475%11: Invalid SMILES: CC(C)P(C(C)C)[C]1234[BH]567[BH]89%10[BH]%11%12%13[BH]585[BH]%118%11[BH]%12%12%14[BH]9%139[BH]16%10[BH]2%129[C]38%14(C1=CC=CC=C1)[BH]475%11


[17:19:08] Explicit valence for atom # 3 C, 6, is greater than permitted


Error processing SMILES [N-]=[N+]=N[C]1234[BH]567[BH]89%10[BH]%11%12%13[BH]585[BH]%118%11[BH]%12%12%14[BH]9%139[BH]16%10[BH]2%129[C]38%14(C1=CC=CC=C1)[BH]475%11: Invalid SMILES: [N-]=[N+]=N[C]1234[BH]567[BH]89%10[BH]%11%12%13[BH]585[BH]%118%11[BH]%12%12%14[BH]9%139[BH]16%10[BH]2%129[C]38%14(C1=CC=CC=C1)[BH]475%11


[17:19:09] Explicit valence for atom # 1 N, 4, is greater than permitted


Error processing SMILES C[N](C)(CCCN1C2=CC=CC=C2SC2=C1C=C(Cl)C=C2)[B]1234[BH]567[BH]89%10[BH]%11%12%13[BH]585[BH]%118%11[BH]%12%12%14[BH]9%139[BH]16%10[BH]2%129[BH]38%14[BH]475%11: Invalid SMILES: C[N](C)(CCCN1C2=CC=CC=C2SC2=C1C=C(Cl)C=C2)[B]1234[BH]567[BH]89%10[BH]%11%12%13[BH]585[BH]%118%11[BH]%12%12%14[BH]9%139[BH]16%10[BH]2%129[BH]38%14[BH]475%11


[17:19:10] Explicit valence for atom # 7 Cl, 7, is greater than permitted
[17:19:10] Explicit valence for atom # 7 Cl, 3, is greater than permitted


Error processing SMILES FC1=CC=C(C=C1)[Cl](=O)(=O)=O: Invalid SMILES: FC1=CC=C(C=C1)[Cl](=O)(=O)=O
Error processing SMILES [O-][N+](=O)C(F)(CO[Cl](C(OCC(F)(F)F)OCC(F)(F)F)C(Cl)(Cl)SSC(Cl)(Cl)[Cl](OCC(F)([N+]([O-])=O)[N+]([O-])=O)C(OCC(F)(F)F)OCC(F)(F)F)[N+]([O-])=O: Invalid SMILES: [O-][N+](=O)C(F)(CO[Cl](C(OCC(F)(F)F)OCC(F)(F)F)C(Cl)(Cl)SSC(Cl)(Cl)[Cl](OCC(F)([N+]([O-])=O)[N+]([O-])=O)C(OCC(F)(F)F)OCC(F)(F)F)[N+]([O-])=O


In [19]:
test_data

GraphData(108590)

In [20]:
train_data[0]

Data(x=[10, 13], edge_index=[2, 20], edge_attr=[20, 5])

In [45]:
class RandomSubgraph:
    def __init__(self, num_nodes: int):
        self.num_nodes = num_nodes

    def __call__(self, data):
        import random
        from torch_geometric.utils import subgraph

        # Select random nodes from the graph
        num_nodes = data.num_nodes
        if self.num_nodes >= num_nodes:
            return data  # Return the original graph if too few nodes to sample

        random_nodes = random.sample(range(num_nodes), self.num_nodes)

        # Generate subgraph
        edge_index, edge_mask = subgraph(
            random_nodes, edge_index=data.edge_index, relabel_nodes=True, num_nodes=num_nodes
        )
        # Create a new `Data` object to avoid modifying the input `data` in place
        subgraph_data = Data()

        # Update node features
        if data.x is not None:
            subgraph_data.x = data.x[random_nodes]

        # Update edge index
        subgraph_data.edge_index = edge_index

        # Update edge attributes if present
        if data.edge_attr is not None:
            subgraph_data.edge_attr = data.edge_attr[edge_mask]

        # Copy additional attributes (e.g., labels, targets)
        for key in data.keys():  # Corrected here
            if key not in ['x', 'edge_index', 'edge_attr']:
                subgraph_data[key] = data[key]

        return subgraph_data

In [21]:
from torch_geometric.transforms import RandomLinkSplit, RandomNodeSplit

In [25]:
transform = Compose([RandomNodeSplit(split="train_rest", num_train_per_class=5),RandomLinkSplit(is_undirected=True, add_negative_train_samples=False)])

In [62]:
'''for i, graph in enumerate(train_data):
    print(f"Graph {i} before transforms:", graph)
    try:
        transformed_graph = transform(graph)
        print(f"Graph {i} after transforms:", transformed_graph)
    except Exception as e:
        print(f"Error with Graph {i}:", e)
        break

SyntaxError: incomplete input (3158717337.py, line 1)

In [69]:
for graph in train_data:
    g1 = RandomNodeSplit(split="train_rest", num_train_per_class=5)(graph)
    g2 = RandomLinkSplit(is_undirected=True, add_negative_train_samples=False)(graph)
    print("g1 type:", type(g1))
    print("g2 type:", type(g2))
    break

g1 type: <class 'torch_geometric.data.data.Data'>
g2 type: <class 'tuple'>


In [73]:
for graph in train_data:
    # Apply RandomNodeSplit to get one graph
   # g1 = RandomNodeSplit(split="train_rest", num_train_per_class=5)(graph)

    # Apply RandomLinkSplit to get a tuple of graphs
    split_graphs = RandomLinkSplit(is_undirected=True, add_negative_train_samples=False)(graph)
    
    # Access the positive and negative graphs from the tuple
    g2_pos = split_graphs[0]
    g2_neg = split_graphs[1]

    print("g2_pos type:", type(g2_pos))  # Positive links graph
    print("g2_neg type:", type(g2_neg))  # Negative links graph
    
    break

g2_pos type: <class 'torch_geometric.data.data.Data'>
g2_neg type: <class 'torch_geometric.data.data.Data'>


In [33]:
def collate_fn(batch):
    # Filter out None values from the batch
    batch = [item for item in batch if item is not None]
    
    if len(batch) == 0:
        return None  # If no valid data left, return None (or you could return an empty batch)
    
    # Otherwise, proceed with the batch processing (you might need to adapt for your graph data)
    return batch


In [22]:
train_loader = DataLoader(train_data, batch_size=100, shuffle=True, drop_last=True)
test_loader = DataLoader(test_data, batch_size=100, shuffle=False, drop_last=True)




In [None]:
for data in train_loader:
    print(data)
    break

In [None]:
# Define transforms
node_split_transform = RandomNodeSplit(split="train_rest", num_train_per_class=5)
link_split_transform = RandomLinkSplit(is_undirected=True, add_negative_train_samples=False)

# Model and optimizer
model = GCN(input_dim=train_data[0].num_features, hidden_dim=32, output_dim=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Contrastive learning loop
epochs = 5

for epoch in range(epochs):
    total_loss = 0

    for batch in train_loader:
        # Apply RandomNodeSplit for the first augmented graph
        
        g1 = node_split_transform(batch)

        # Apply RandomLinkSplit to get positive and negative graphs
        g2_pos, g2_neg, _ = link_split_transform(batch)

        # Forward pass with positive graph
        z1 = model(g1.x, g1.edge_index)  # Forward pass on the first graph (g1)
        z2 = model(g2_pos.x, g2_pos.edge_index)  # Forward pass on the second graph (g2_pos)

        # Compute contrastive loss with the positive graphs
        loss = contrastive_loss(z1, z2)

        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")



Epoch 1, Loss: 6.0885




Epoch 2, Loss: 6.0454




Epoch 3, Loss: 6.0359




Epoch 4, Loss: 6.0294




In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset
import optuna

# Define the model
class SimpleNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Training function
def train_model(model, criterion, optimizer, train_loader, device):
    model.train()
    total_loss = 0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

# Validation function
def validate_model(model, criterion, val_loader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
    return total_loss / len(val_loader)

# Define the Optuna objective function
def objective(trial):
    # Hyperparameters to optimize
    hidden_size = trial.suggest_int("hidden_size", 16, 128)
    learning_rate = trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True)
    batch_size = trial.suggest_int("batch_size", 16, 64)

    # Data preparation
    input_size = 10  # Example input features
    output_size = 1  # Example output size (e.g., regression task)
    dataset_size = 1000
    x = torch.randn(dataset_size, input_size)
    y = torch.randn(dataset_size, output_size)
    dataset = TensorDataset(x, y)
    train_size = int(0.8 * dataset_size)
    val_size = dataset_size - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Model, loss, optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SimpleNet(input_size, hidden_size, output_size).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training and validation
    for epoch in range(10):  # Small number of epochs for testing
        train_model(model, criterion, optimizer, train_loader, device)
    val_loss = validate_model(model, criterion, val_loader, device)

    return val_loss

# Create an Optuna study and optimize
if __name__ == "__main__":
    study = optuna.create_study(direction="minimize")
    study.optimize(objective, n_trials=50)

    # Print the best parameters
    print("Best hyperparameters:", study.best_params)