<a href="https://colab.research.google.com/github/yufanlili211/master_thesis/blob/main/GNN_atom_smiles_level.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

peptide sequence to smiles

In [None]:
import sys
!{sys.executable} -m pip install rdkit # install rdkit
!pip install --no-cache-dir torch-scatter -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
!pip install --no-cache-dir torch-sparse -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
!pip install --no-cache-dir torch-geometric -f https://data.pyg.org/whl/torch-2.3.0+cu121.html

import pandas as pd
import torch
from rdkit import Chem
from torch_geometric.data import Data

# -----------------------------
# 1) Load residue dictionary
# -----------------------------
# Expected columns in Excel:
# - one column for amino-acid code (e.g., A, C, D)
# - one column for capped residue smiles (Ac-Res-NMe)

def load_residue_dictionary_from_excel(
    excel_path,
    code_col='ID',
    smiles_col='SMILES (Ac-Res-NMe)',
):
    df = pd.read_excel(excel_path)
    if code_col not in df.columns or smiles_col not in df.columns:
        raise ValueError(
            f"Missing required columns. Found: {list(df.columns)}; "
            f"need code_col='{code_col}', smiles_col='{smiles_col}'."
        )
    mapping = {
        str(row[code_col]).strip(): str(row[smiles_col]).strip()
        for _, row in df.iterrows()
        if pd.notna(row[code_col]) and pd.notna(row[smiles_col])
    }
    return mapping


# -----------------------------
# 2) Feature vocabularies
# -----------------------------
# All categorical values are converted to integer ids for nn.Embedding.

ATOM_TYPE_VOCAB = [
    'H', 'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'Br', 'I', 'Se', 'UNK'
]
CHIRALITY_VOCAB = [
    str(Chem.rdchem.ChiralType.CHI_UNSPECIFIED),
    str(Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW),
    str(Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW),
    str(Chem.rdchem.ChiralType.CHI_OTHER),
    'UNK',
]
DEGREE_VOCAB = [0, 1, 2, 3, 4, 5, 6, 'UNK']
FORMAL_CHARGE_VOCAB = [-3, -2, -1, 0, 1, 2, 3, 'UNK']
NUM_H_VOCAB = [0, 1, 2, 3, 4, 'UNK']
RADICAL_E_VOCAB = [0, 1, 2, 'UNK']
HYBRIDIZATION_VOCAB = [
    str(Chem.rdchem.HybridizationType.SP),
    str(Chem.rdchem.HybridizationType.SP2),
    str(Chem.rdchem.HybridizationType.SP3),
    str(Chem.rdchem.HybridizationType.SP3D),
    str(Chem.rdchem.HybridizationType.SP3D2),
    'UNK',
]
BOOL_VOCAB = [False, True]

BOND_TYPE_VOCAB = [
    str(Chem.rdchem.BondType.SINGLE),
    str(Chem.rdchem.BondType.DOUBLE),
    str(Chem.rdchem.BondType.TRIPLE),
    str(Chem.rdchem.BondType.AROMATIC),
    'UNK',
]
BOND_STEREO_VOCAB = [
    str(Chem.rdchem.BondStereo.STEREONONE),
    str(Chem.rdchem.BondStereo.STEREOANY),
    str(Chem.rdchem.BondStereo.STEREOZ),
    str(Chem.rdchem.BondStereo.STEREOE),
    str(Chem.rdchem.BondStereo.STEREOCIS),
    str(Chem.rdchem.BondStereo.STEREOTRANS),
    'UNK',
]


def _to_idx(value, vocab):
    return vocab.index(value) if value in vocab else vocab.index('UNK')


def atom_features(atom):
    atom_symbol = atom.GetSymbol()
    if atom_symbol not in ATOM_TYPE_VOCAB:
        atom_symbol = 'UNK'

    chirality = str(atom.GetChiralTag())
    if chirality not in CHIRALITY_VOCAB:
        chirality = 'UNK'

    degree = atom.GetTotalDegree()
    if degree not in DEGREE_VOCAB:
        degree = 'UNK'

    formal_charge = atom.GetFormalCharge()
    if formal_charge not in FORMAL_CHARGE_VOCAB:
        formal_charge = 'UNK'

    num_h = atom.GetTotalNumHs(includeNeighbors=True)
    if num_h not in NUM_H_VOCAB:
        num_h = 'UNK'

    radical_e = atom.GetNumRadicalElectrons()
    if radical_e not in RADICAL_E_VOCAB:
        radical_e = 'UNK'

    hybrid = str(atom.GetHybridization())
    if hybrid not in HYBRIDIZATION_VOCAB:
        hybrid = 'UNK'

    is_aromatic = atom.GetIsAromatic()
    in_ring = atom.IsInRing()

    return [
        _to_idx(atom_symbol, ATOM_TYPE_VOCAB),
        _to_idx(chirality, CHIRALITY_VOCAB),
        DEGREE_VOCAB.index(degree),
        FORMAL_CHARGE_VOCAB.index(formal_charge),
        NUM_H_VOCAB.index(num_h),
        RADICAL_E_VOCAB.index(radical_e),
        HYBRIDIZATION_VOCAB.index(hybrid),
        BOOL_VOCAB.index(is_aromatic),
        BOOL_VOCAB.index(in_ring),
    ]


def bond_features(bond):
    btype = str(bond.GetBondType())
    if btype not in BOND_TYPE_VOCAB:
        btype = 'UNK'

    stereo = str(bond.GetStereo())
    if stereo not in BOND_STEREO_VOCAB:
        stereo = 'UNK'

    return [
        _to_idx(btype, BOND_TYPE_VOCAB),
        _to_idx(stereo, BOND_STEREO_VOCAB),
        BOOL_VOCAB.index(bond.GetIsConjugated()),
    ]



Collecting rdkit
  Downloading rdkit-2025.9.5-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.8 kB)
Downloading rdkit-2025.9.5-cp312-cp312-manylinux_2_28_x86_64.whl (36.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.7/36.7 MB[0m [31m50.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-2025.9.5
Looking in links: https://data.pyg.org/whl/torch-2.3.0+cu121.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.3.0%2Bcu121/torch_scatter-2.1.2%2Bpt23cu121-cp312-cp312-linux_x86_64.whl (10.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m89.7 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2+pt23cu121
Looking in links: https://data.pyg.org/whl/torch-2.3.0+cu121.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.3.0%

  import torch_geometric.typing
  import torch_geometric.typing


smiles to atom-level graph features

In [None]:
from rdkit.Chem import rdchem


def _find_backbone_atoms_and_caps(mol):
    """
    Detect Ac-Res-NMe pattern and return key indices.

    Pattern target: CH3-C(=O)-N-*-C(=O)-N-CH3
    Keep: backbone N (left) and backbone carbonyl C (right)
    Remove: Ac atoms (CH3-C=O) and NMe atoms (N-CH3)
    """
    patt = Chem.MolFromSmarts('[CH3:1]-[C:2](=[O:3])-[N:4]-[*:5]-[C:6](=[O:7])-[N:8]-[CH3:9]')
    matches = mol.GetSubstructMatches(patt)
    if not matches:
        raise ValueError('Cannot detect Ac-Res-NMe backbone pattern in residue SMILES.')

    # Use first match; if you have ambiguous templates, refine SMARTS.
    m = matches[0]
    ac_methyl, ac_carb, ac_oxy = m[0], m[1], m[2]
    backbone_n = m[3]
    c_term_carb = m[5]
    nme_n, nme_methyl = m[7], m[8]

    remove_set = {ac_methyl, ac_carb, ac_oxy, nme_n, nme_methyl}
    return backbone_n, c_term_carb, remove_set


def _remove_atoms_and_map(mol, remove_set):
    """Remove a set of atom indices and return new molecule + old->new index mapping."""
    rw = Chem.RWMol(mol)
    for idx in sorted(remove_set, reverse=True):
        rw.RemoveAtom(idx)

    old_to_new = {}
    shift = 0
    remove_sorted = sorted(remove_set)
    j = 0
    n = mol.GetNumAtoms()
    for old_idx in range(n):
        while j < len(remove_sorted) and remove_sorted[j] < old_idx:
            j += 1
        if old_idx in remove_set:
            continue
        removed_before = sum(1 for r in remove_sorted if r < old_idx)
        old_to_new[old_idx] = old_idx - removed_before

    mol2 = rw.GetMol()
    Chem.SanitizeMol(mol2)
    return mol2, old_to_new


def _decap_residue(smiles):
    """
    Convert Ac-Res-NMe to decapped residue fragment with open valences at backbone N and C(=O).
    """
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f'Invalid residue SMILES: {smiles}')

    n_idx, c_idx, remove_set = _find_backbone_atoms_and_caps(mol)
    mol2, old_to_new = _remove_atoms_and_map(mol, remove_set)

    if n_idx not in old_to_new or c_idx not in old_to_new:
        raise ValueError('Backbone N or C was removed unexpectedly during decapping.')

    n_new = old_to_new[n_idx]
    c_new = old_to_new[c_idx]
    return mol2, n_new, c_new


def _combine_two_residues(mol_a, c_idx_a, mol_b, n_idx_b):
    """
    Combine two residue fragments and form peptide bond C(=O)-N.
    """
    combo = Chem.CombineMols(mol_a, mol_b)
    rw = Chem.RWMol(combo)

    offset = mol_a.GetNumAtoms()
    n_idx_b_shifted = n_idx_b + offset

    if rw.GetBondBetweenAtoms(c_idx_a, n_idx_b_shifted) is None:
        rw.AddBond(c_idx_a, n_idx_b_shifted, rdchem.BondType.SINGLE)

    new_mol = rw.GetMol()
    Chem.SanitizeMol(new_mol)
    return new_mol


def molecule_to_pyg_data(mol):
    """Convert RDKit mol to torch_geometric Data with atom/bond categorical indices."""
    x = [atom_features(atom) for atom in mol.GetAtoms()]
    x = torch.tensor(x, dtype=torch.long)

    edge_indices = []
    edge_attrs = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bf = bond_features(bond)

        edge_indices.append([i, j])
        edge_indices.append([j, i])
        edge_attrs.append(bf)
        edge_attrs.append(bf)

    if len(edge_indices) == 0:
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, 3), dtype=torch.long)
    else:
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attrs, dtype=torch.long)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)


def build_peptide_graph(sequence, res_dict):
    """
    Build an atom-level peptide graph from 1-letter sequence.

    Parameters
    ----------
    sequence : str
        e.g., 'ACD'
    res_dict : dict
        maps AA code -> capped residue smiles (Ac-Res-NMe)

    Returns
    -------
    torch_geometric.data.Data
        with x, edge_index, edge_attr
    """
    if not sequence:
        raise ValueError('Sequence is empty.')

    residue_items = []
    for aa in sequence:
        if aa not in res_dict:
            raise KeyError(f"Residue '{aa}' not found in dictionary.")
        mol_i, n_i, c_i = _decap_residue(res_dict[aa])
        residue_items.append((mol_i, n_i, c_i))

    current_mol, _, current_c = residue_items[0]

    for i in range(1, len(residue_items)):
        next_mol, next_n, next_c = residue_items[i]

        prev_atoms = current_mol.GetNumAtoms()
        current_mol = _combine_two_residues(current_mol, current_c, next_mol, next_n)
        current_c = next_c + prev_atoms

    Chem.SanitizeMol(current_mol)
    return molecule_to_pyg_data(current_mol)



build atom-level graph dataset (train/val/test)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import pandas as pd
import torch
from sklearn.model_selection import train_test_split

# -----------------------------
# Paths
# -----------------------------
RESIDUAL_DICT_XLSX = "/content/drive/MyDrive/master_thesis/sampled_data_5000/residual_dictionary.xlsx"
PEPTIDE_EXCEL_PATH = "/content/drive/MyDrive/master_thesis/sampled_data_5000/canya_data_sampled_5000_smiles.xlsx"
SAVE_DIR = "/content/drive/MyDrive/master_thesis/sampled_data_5000/GNN/atom-level-smiles"

# Update only if your AA code column is different.
AA_CODE_COL = "ID"
AA_SMILES_COL = "SMILES (Ac-Res-NMe)"

# You can set these explicitly if needed; otherwise auto-detect is used.
SEQUENCE_COL = "aa_seq"
LABEL_COL = "seed_bh"


def _pick_column(df, explicit_name, candidates, kind):
    if explicit_name is not None:
        if explicit_name not in df.columns:
            raise ValueError(f"{kind} column '{explicit_name}' not found. Available: {list(df.columns)}")
        return explicit_name

    for c in candidates:
        if c in df.columns:
            return c

    raise ValueError(
        f"Cannot auto-detect {kind} column. Available columns: {list(df.columns)}. "
        f"Please set {kind.upper()}_COL explicitly."
    )


def load_peptide_rows(peptide_excel_path, sequence_col=None, label_col=None):
    df = pd.read_excel(peptide_excel_path)

    seq_col = _pick_column(
        df,
        sequence_col,
        ["sequence", "Sequence", "peptide", "Peptide", "seq", "SEQ"],
        "sequence",
    )
    y_col = _pick_column(
        df,
        label_col,
        ["label", "Label", "aggregation", "Aggregation", "y", "Y"],
        "label",
    )

    rows = []
    for _, row in df.iterrows():
        if pd.isna(row[seq_col]) or pd.isna(row[y_col]):
            continue
        seq = str(row[seq_col]).strip()
        if not seq:
            continue
        y = int(row[y_col])
        rows.append((seq, y))

    return rows, seq_col, y_col


def build_atom_level_dataset(
    peptide_excel_path,
    res_dict,
    sequence_col=None,
    label_col=None,
    strict=True,
):
    rows, used_seq_col, used_label_col = load_peptide_rows(
        peptide_excel_path,
        sequence_col=sequence_col,
        label_col=label_col,
    )

    data_list = []
    dropped = 0

    for seq, y in rows:
        try:
            data = build_peptide_graph(seq, res_dict)
            data.y = torch.tensor([y], dtype=torch.float)
            data.sequence = seq
            data_list.append(data)
        except Exception as e:
            if strict:
                raise RuntimeError(f"Failed on sequence '{seq}': {e}")
            dropped += 1

    if len(data_list) == 0:
        raise ValueError("No valid graph samples were generated.")

    print(f"Loaded {len(rows)} rows from peptide excel")
    print(f"Using sequence column: {used_seq_col}")
    print(f"Using label column: {used_label_col}")
    print(f"Built {len(data_list)} graphs (dropped={dropped})")

    return data_list


def stratified_split_data_list(data_list, random_state=42):
    labels = [int(d.y.item()) for d in data_list]
    indices = list(range(len(data_list)))

    train_idx, temp_idx = train_test_split(
        indices,
        test_size=0.3,
        stratify=labels,
        random_state=random_state,
    )
    val_idx, test_idx = train_test_split(
        temp_idx,
        test_size=0.5,
        stratify=[labels[i] for i in temp_idx],
        random_state=random_state,
    )

    train_data = [data_list[i] for i in train_idx]
    val_data = [data_list[i] for i in val_idx]
    test_data = [data_list[i] for i in test_idx]

    return train_data, val_data, test_data


def save_splits(train_data, val_data, test_data, save_dir):
    os.makedirs(save_dir, exist_ok=True)

    torch.save({"data_list": train_data}, os.path.join(save_dir, "atom_level_train.pt"))
    torch.save({"data_list": val_data}, os.path.join(save_dir, "atom_level_val.pt"))
    torch.save({"data_list": test_data}, os.path.join(save_dir, "atom_level_test.pt"))

    print(f"Saved splits to: {save_dir}")
    print(f"train/val/test = {len(train_data)}/{len(val_data)}/{len(test_data)}")
    print(f"Sample x shape: {train_data[0].x.shape}")
    print(f"Sample edge_attr shape: {train_data[0].edge_attr.shape}")


def main_build_atom_level_dataset():
    # 1) Build residue dictionary: AA code -> capped residue SMILES
    res_dict = load_residue_dictionary_from_excel(
        RESIDUAL_DICT_XLSX,
        code_col=AA_CODE_COL,
        smiles_col=AA_SMILES_COL,
    )

    # 2) Build atom-level graph list from peptide sequences
    data_list = build_atom_level_dataset(
        PEPTIDE_EXCEL_PATH,
        res_dict,
        sequence_col=SEQUENCE_COL,
        label_col=LABEL_COL,
        strict=True,
    )

    # 3) 70/15/15 stratified split with fixed random seed
    train_data, val_data, test_data = stratified_split_data_list(data_list, random_state=42)

    # 4) Save .pt files
    save_splits(train_data, val_data, test_data, SAVE_DIR)


# Run when you are ready.
main_build_atom_level_dataset()



Loaded 5000 rows from peptide excel
Using sequence column: aa_seq
Using label column: seed_bh
Built 5000 graphs (dropped=0)
Saved splits to: /content/drive/MyDrive/master_thesis/sampled_data_5000/GNN/atom-level-smiles
train/val/test = 3500/750/750
Sample x shape: torch.Size([154, 9])
Sample edge_attr shape: torch.Size([312, 3])


model

model_smiles.py code (copied for notebook use, source file unchanged)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv, global_add_pool, global_mean_pool


def _make_gine_mlp(hidden_dim: int) -> nn.Module:
    return nn.Sequential(
        nn.Linear(hidden_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, hidden_dim),
    )


class AtomEncoder(nn.Module):
    """
    Encode per-atom categorical features (Table S2) using one embedding per field,
    then sum them into a single hidden representation.

    Expected atom feature order in data.x:
      0 Atom Type
      1 Chirality
      2 Total Degree
      3 Formal Charge
      4 Total Number of Hs
      5 Number of Radical Electrons
      6 Hybridization
      7 Aromatic (bool)
      8 Part of Ring (bool)
    """

    def __init__(self, hidden_dim: int, feature_dims=None):
        super().__init__()
        if feature_dims is None:
            # Must match preprocessing vocab sizes.
            feature_dims = [14, 5, 8, 8, 6, 4, 6, 2, 2]

        self.feature_dims = feature_dims
        self.embeddings = nn.ModuleList([
            nn.Embedding(dim, hidden_dim) for dim in feature_dims
        ])

        self.reset_parameters()

    def reset_parameters(self):
        for emb in self.embeddings:
            nn.init.xavier_uniform_(emb.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() != 2:
            raise ValueError(f"Expected x with shape [num_nodes, num_atom_features], got {x.shape}")
        if x.size(1) != len(self.embeddings):
            raise ValueError(
                f"Expected {len(self.embeddings)} atom feature columns, got {x.size(1)}"
            )

        x = x.long()
        out = 0
        for i, emb in enumerate(self.embeddings):
            out = out + emb(x[:, i])
        return out


class EdgeEncoder(nn.Module):
    """
    Encode per-bond categorical features (Table S3) using one embedding per field,
    then sum them into a hidden representation used by GINEConv.

    Expected edge feature order in data.edge_attr:
      0 Bond Type
      1 Bond Stereo
      2 Conjugation (bool)
    """

    def __init__(self, hidden_dim: int, feature_dims=None):
        super().__init__()
        if feature_dims is None:
            # Must match preprocessing vocab sizes.
            feature_dims = [5, 7, 2]

        self.feature_dims = feature_dims
        self.embeddings = nn.ModuleList([
            nn.Embedding(dim, hidden_dim) for dim in feature_dims
        ])

        self.reset_parameters()

    def reset_parameters(self):
        for emb in self.embeddings:
            nn.init.xavier_uniform_(emb.weight)

    def forward(self, edge_attr: torch.Tensor) -> torch.Tensor:
        if edge_attr.dim() != 2:
            raise ValueError(
                f"Expected edge_attr with shape [num_edges, num_edge_features], got {edge_attr.shape}"
            )
        if edge_attr.size(1) != len(self.embeddings):
            raise ValueError(
                f"Expected {len(self.embeddings)} edge feature columns, got {edge_attr.size(1)}"
            )

        edge_attr = edge_attr.long()
        out = 0
        for i, emb in enumerate(self.embeddings):
            out = out + emb(edge_attr[:, i])
        return out


class GINEVirtualNodeClassifierAtom(nn.Module):
    """
    Atom-level GINEConv + Virtual Node classifier for graph-level binary prediction.

    Input:
      data.x:        [num_nodes, 9]   categorical atom features
      data.edge_attr:[num_edges, 3]   categorical bond features
    """

    def __init__(
        self,
        hidden_dim: int = 128,
        num_layers: int = 4,
        dropout: float = 0.2,
        pooling: str = "mean",
        atom_feature_dims=None,
        edge_feature_dims=None,
    ):
        super().__init__()
        if pooling not in ["mean", "add"]:
            raise ValueError(f"pooling must be 'mean' or 'add', got {pooling}")

        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.pooling = pooling

        self.atom_encoder = AtomEncoder(hidden_dim=hidden_dim, feature_dims=atom_feature_dims)
        self.edge_encoder = EdgeEncoder(hidden_dim=hidden_dim, feature_dims=edge_feature_dims)

        self.convs = nn.ModuleList(
            [GINEConv(_make_gine_mlp(hidden_dim)) for _ in range(num_layers)]
        )
        self.bns = nn.ModuleList([nn.BatchNorm1d(hidden_dim) for _ in range(num_layers)])

        self.vn_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        x = self.atom_encoder(x)
        e = self.edge_encoder(edge_attr)

        num_graphs = int(batch.max().item()) + 1 if batch.numel() > 0 else 0
        virtualnode_emb = x.new_zeros((num_graphs, self.hidden_dim))

        for i in range(self.num_layers):
            x = x + virtualnode_emb[batch]
            x = self.convs[i](x, edge_index, e)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

            if i != self.num_layers - 1:
                vn_update = global_add_pool(x, batch)
                virtualnode_emb = virtualnode_emb + self.vn_mlp(vn_update)

        graph_emb = global_add_pool(x, batch) if self.pooling == "add" else global_mean_pool(x, batch)
        logits = self.classifier(graph_emb).view(-1)
        return logits


# Backward-compatible class name
class GINEVirtualNodeClassifier(GINEVirtualNodeClassifierAtom):
    pass



train_atom_gnn.py code (saved file + notebook copy)

In [None]:
import importlib.util
import os
import sys

import numpy as np
import torch
import torch.nn as nn
from torch_geometric.loader import DataLoader


# Atom-level dataset built from notebook pipeline
TRAIN_PT = "/content/drive/MyDrive/master_thesis/sampled_data_5000/GNN/atom-level-smiles/atom_level_train.pt"
VAL_PT = "/content/drive/MyDrive/master_thesis/sampled_data_5000/GNN/atom-level-smiles/atom_level_val.pt"
TEST_PT = "/content/drive/MyDrive/master_thesis/sampled_data_5000/GNN/atom-level-smiles/atom_level_test.pt"

LOG_DIR = "/content/drive/MyDrive/master_thesis/sampled_data_5000/GNN/atom-level-smiles/logs_tensorboard"
BEST_DIR = "/content/drive/MyDrive/master_thesis/sampled_data_5000/GNN/atom-level-smiles/BEST_MODEL"


def _resolve_data_py_dir():
    """
    Resolve .../GNN/data_py_file in both script and notebook execution modes.
    """
    if "__file__" in globals():
        this_dir = os.path.dirname(os.path.abspath(__file__))
        return os.path.dirname(this_dir)

    candidates = [
        "/content/drive/MyDrive/master_thesis/sampled_data_5000/GNN/data_py_file",
        os.path.join(
            os.getcwd(),
            "My Drive/master_thesis/sampled_data_5000/GNN/data_py_file",
        ),
        os.path.join(os.getcwd(), "data_py_file"),
    ]
    for path in candidates:
        if os.path.isdir(path):
            return path

    raise FileNotFoundError(
        "Cannot locate '.../GNN/data_py_file'. "
        "Set working directory to project root or run this file as a script."
    )



def _load_model_class():
    """
    Load model class from sibling folder:
      .../data_py_file/atom-level-smiles/model_smiles.py
    """
    data_py_dir = _resolve_data_py_dir()
    model_path = os.path.join(data_py_dir, "atom-level-smiles", "model_smiles.py")

    if not os.path.exists(model_path):
        raise FileNotFoundError(f"model_smiles.py not found at: {model_path}")

    spec = importlib.util.spec_from_file_location("model_smiles", model_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)

    if hasattr(module, "GINEVirtualNodeClassifierAtom"):
        return module.GINEVirtualNodeClassifierAtom
    if hasattr(module, "GINEVirtualNodeClassifier"):
        return module.GINEVirtualNodeClassifier

    raise AttributeError("No compatible classifier class found in model_smiles.py")


def _import_train_utils():
    data_py_dir = _resolve_data_py_dir()
    if data_py_dir not in sys.path:
        sys.path.insert(0, data_py_dir)

    from train_utils import (
        compute_pos_weight,
        evaluate,
        final_test_report,
        find_best_threshold_f1,
        fit_with_validation,
        set_seed,
    )

    return {
        "compute_pos_weight": compute_pos_weight,
        "evaluate": evaluate,
        "final_test_report": final_test_report,
        "find_best_threshold_f1": find_best_threshold_f1,
        "fit_with_validation": fit_with_validation,
        "set_seed": set_seed,
    }


def build_loaders_with_seed_pyg(train_dataset, val_dataset, test_dataset, batch_size, seed):
    gen = torch.Generator()
    gen.manual_seed(seed)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=gen)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader


def mean_std(vals):
    vals = np.array(vals, dtype=float)
    return vals.mean(), vals.std(ddof=1) if len(vals) > 1 else 0.0


def main():
    os.makedirs(BEST_DIR, exist_ok=True)

    train_list = torch.load(TRAIN_PT, map_location="cpu", weights_only=False)["data_list"]
    val_list = torch.load(VAL_PT, map_location="cpu", weights_only=False)["data_list"]
    test_list = torch.load(TEST_PT, map_location="cpu", weights_only=False)["data_list"]

    if torch.cuda.is_available():
        print("GPU available")
        device = torch.device("cuda")
    else:
        print("Using CPU")
        device = torch.device("cpu")

    ModelCls = _load_model_class()
    utils = _import_train_utils()

    seeds = [0, 1, 2]
    results = []

    for seed in seeds:
        print(f"\n=== Run seed={seed} ===")
        utils["set_seed"](seed)

        train_loader, val_loader, test_loader = build_loaders_with_seed_pyg(
            train_list, val_list, test_list, batch_size=32, seed=seed
        )

        model = ModelCls(
            hidden_dim=64,
            num_layers=3,
            dropout=0.2,
            pooling="mean",
        )

        pos_weight = utils["compute_pos_weight"](train_loader, device)
        loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

        exp_prefix = "atom_01"
        run_prefix = f"{exp_prefix}-seed{seed:02d}"
        best_model_path = os.path.join(BEST_DIR, f"{run_prefix}_best_model.pt")

        utils["fit_with_validation"](
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            device=device,
            optimizer=optimizer,
            loss_fn=loss_fn,
            epochs=200,
            patience=10,
            model_path=best_model_path,
            log_dir=LOG_DIR,
            run_prefix=run_prefix,
        )

        model.load_state_dict(torch.load(best_model_path, map_location=device))
        model = model.to(device)

        val_metrics = utils["evaluate"](model, val_loader, device, loss_fn, threshold=0.5)
        best_t, best_f1, _, _ = utils["find_best_threshold_f1"](
            val_metrics["probs"], val_metrics["labels"]
        )
        test_metrics = utils["evaluate"](model, test_loader, device, loss_fn, threshold=best_t)

        utils["final_test_report"](
            model=model,
            val_loader=val_loader,
            test_loader=test_loader,
            device=device,
            loss_fn=loss_fn,
            model_path=best_model_path,
        )

        results.append(
            {
                "seed": seed,
                "val_auc": val_metrics["roc_auc"],
                "val_pr_auc": val_metrics["pr_auc"],
                "val_f1": best_f1,
                "val_acc": val_metrics["accuracy"],
                "test_auc": test_metrics["roc_auc"],
                "test_pr_auc": test_metrics["pr_auc"],
                "test_f1": test_metrics["f1"],
                "test_acc": test_metrics["accuracy"],
            }
        )

    print("\n=== Summary over seeds ===")
    for split in ["val", "test"]:
        for key in ["auc", "pr_auc", "f1", "acc"]:
            m, s = mean_std([r[f"{split}_{key}"] for r in results])
            print(f"{split}/{key}: {m:.4f} ± {s:.4f}")


if __name__ == "__main__":
    main()




Using CPU


NameError: name '__file__' is not defined

In [None]:
# 方式1：推荐，直接跑脚本文件
!python "/content/drive/MyDrive/master_thesis/sampled_data_5000/GNN/data_py_file/atom_level_smiles/train_atom_gnn.py"


  import torch_geometric.typing
  import torch_geometric.typing
Using CPU
2026-02-19 14:53:55.399807: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-19 14:53:55.925781: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-19 14:53:56.227986: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1771512836.285137    5984 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1771512836.297215    5984 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1771512836.332722    5984 computation_placer.cc:177] co