# MS-to-Structure Deep Learning Pipeline 

Steps: Step-by-Step Pipeline Overview
1. Setup & Environment (Cells 1-3)
Installs required packages (PyTorch, RDKit, XGBoost, FAISS, etc.)
Configures GPU optimization for RTX 3080 Ti
Sets random seeds for reproducibility
Defines special tokens (PAD, SOS, EOS, MASK)

2. Data Loading & Configuration (Cell 4)
Loads MassSpecGym dataset (231K samples total)
Splits into training (208K) and external test (23K) sets
Configures hyperparameters (model dimensions, epochs, batch sizes)
Inspects dataset structure (m/z values, intensities, SMILES, adducts)

3. Data Preprocessing (Cells 5-7)
SMILES Processing : Canonicalizes and augments SMILES with stereoisomers
Spectrum Binning : Converts raw m/z peaks into 1000-bin vectors
Graph Creation : Builds molecular graphs from spectra
Feature Engineering : Extracts ion modes, precursor m/z, adduct types
Data Cleaning : Handles shape issues and missing values

4. XGBoost Baseline (Cell 8)
Extracts numerical features from spectra (mean, std, max intensity, peak count)
Encodes SMILES strings as classification targets
Trains gradient boosting model on spectral features
Provides feature importance analysis

5. RAG System (Cells 10-15)
Molecular Indexing : Creates semantic embeddings of molecular descriptions
Fingerprint Database : Builds Morgan fingerprint index for structure similarity

Search Capabilities :
Semantic search using sentence transformers
Structure-based search using Tanimoto similarity
Hybrid search combining both approaches
FAISS Integration : Enables fast similarity search

6. Deep Learning Architecture (Cells 16-20)
SELFIES Tokenization : Converts SMILES to robust molecular tokens

Hybrid Model :
Transformer Encoder : Processes binned spectra with attention
GNN Encoder : Analyzes molecular graph structure
Fusion Layer : Combines transformer and GNN representations
Transformer Decoder : Generates SMILES sequences

Training Pipeline :
Self-supervised pretraining with masked language modeling
Supervised fine-tuning with cross-entropy loss
Beam search for inference

7. Training & Validation (Cells 21-22)
Cross-Validation : 5-fold CV for robust evaluation
Hyperparameter Optimization : Uses Optuna for learning rate tuning
Memory Management : Handles OOM errors with batch size reduction
Model Checkpointing : Saves best models for each fold

8. Evaluation & Analysis (Cells 23-24)
Multiple Metrics : Tanimoto similarity, Dice coefficient, MCS overlap
Molecular Properties : MW and LogP difference analysis
Visualization : Attention weights, molecular structures, performance plots
Error Analysis : Identifies failure modes and improvement areas

9. Integration System (Cell 21)
Ensemble Prediction : Combines XGBoost, Deep Learning, and RAG
Weighted Scoring : Balances different prediction approaches
Confidence Estimation : Provides uncertainty quantification


In [1]:
#cell1
# Install required packages for MS-to-Structure pipeline
#! pip install torch torch_geometric rdkit-pypi selfies datasets optuna nltk python-Levenshtein tqdm scikit-learn matplotlib xgboost faiss-cpu sentence-transformers

In [2]:
#cell 2
# Import libraries and set up logging for Jupyter compatibility
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.data import Data, Batch
from datasets import load_dataset
from rdkit import Chem
from rdkit.Chem import Draw, Descriptors, rdFMCS, EnumerateStereoisomers
from rdkit.Chem.EnumerateStereoisomers import StereoEnumerationOptions
from rdkit import DataStructs
from rdkit.Chem import rdFingerprintGenerator
from rdkit import RDLogger
from tqdm.notebook import tqdm
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import selfies as sf
import optuna
from nltk.translate.bleu_score import sentence_bleu
from Levenshtein import distance
import logging
import traceback
import math
import xgboost as xgb
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import LabelEncoder
import faiss
from sentence_transformers import SentenceTransformer

# Setup logging for Jupyter (prints to stdout)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(levelname)s %(message)s'
)


In [3]:
#cell 3
# Set random seed for reproducibility and define global variables
np.random.seed(42)
torch.manual_seed(42)

PAD_TOKEN = "<PAD>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"
MASK_TOKEN = "[MASK]"

# GPU optimization for RTX 3080 Ti
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.benchmark = True
    torch.cuda.empty_cache()
    torch.cuda.set_per_process_memory_fraction(0.95)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name()}')


GPU: NVIDIA GeForce RTX 3080 Ti


In [4]:
#cell 4
# Production Configuration
class Config:
    DATASET_PATH = '/home/onepaw/dataset'  # Change to your dataset path
    TRAIN_SPLIT = 0.9
    RANDOM_SEED = 42
    N_BINS = 1000
    MAX_MZ = 1000
    NOISE_LEVEL = 0.05
    MAX_ISOMERS = 8
    D_MODEL = 512
    NHEAD = 8
    NUM_LAYERS = 6
    BATCH_SIZE = 64
    SSL_EPOCHS = 3
    SUPERVISED_EPOCHS = 30
    LEARNING_RATE = 1e-4
    PATIENCE = 5
    N_FOLDS = 5
    # Token definitions
    PAD_TOKEN = '<PAD>'
    SOS_TOKEN = '< SOS >'
    EOS_TOKEN = '<EOS>'
    MASK_TOKEN = '[MASK]'

config = Config()

# Load dataset with configurable path
try:
    dataset = load_dataset(config.DATASET_PATH, split='train')
    df = pd.DataFrame(dataset)
    print(f'Loaded dataset with {len(df)} samples')
except Exception as e:
    print(f'Error loading dataset: {e}')
    print('Please update config.DATASET_PATH')
    raise

# Split dataset based on configuration
split_idx = int(config.TRAIN_SPLIT * len(df))
df_massspecgym = df.iloc[:split_idx].copy()
df_external = df.iloc[split_idx:].copy()
print("MassSpecGym size:", len(df_massspecgym), "External test size:", len(df_external))

# Inspect dataset
print("Dataset Columns:", df_massspecgym.columns.tolist())
print("\nFirst few rows of MassSpecGym dataset:")
print(df_massspecgym[['identifier', 'mzs', 'intensities', 'smiles', 'adduct', 'precursor_mz']].head())
print("\nUnique adduct values:", df_massspecgym['adduct'].unique())


Loaded dataset with 231104 samples
MassSpecGym size: 207993 External test size: 23111
Dataset Columns: ['identifier', 'mzs', 'intensities', 'smiles', 'inchikey', 'formula', 'precursor_formula', 'parent_mass', 'precursor_mz', 'adduct', 'instrument_type', 'collision_energy', 'fold', 'simulation_challenge']

First few rows of MassSpecGym dataset:
             identifier                                                mzs  \
0  MassSpecGymID0000001  91.0542,125.0233,154.0499,155.0577,185.0961,20...   
1  MassSpecGymID0000002  91.0542,125.0233,155.0577,185.0961,229.0859,24...   
2  MassSpecGymID0000003  69.0343,91.0542,125.0233,127.039,153.0699,154....   
3  MassSpecGymID0000004  69.0343,91.0542,110.06,111.0441,112.0393,120.0...   
4  MassSpecGymID0000005  91.0542,125.0233,185.0961,229.0859,246.1125,28...   

                                         intensities  \
0  0.24524524524524524,1.0,0.08008008008008008,0....   
1  0.0990990990990991,0.28128128128128127,0.04004...   
2  0.034034034034

In [5]:
#cell 5
# Canonicalize SMILES, augment, and bin spectra
def canonicalize_smiles(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        if mol:
            return Chem.MolToSmiles(mol, canonical=True)
        return None
    except Exception as e:
        logging.error(f"canonicalize_smiles failed for {smiles}: {e}\n{traceback.format_exc()}")
        return None

def augment_smiles(smiles, max_isomers=None):
    max_isomers = max_isomers or config.MAX_ISOMERS
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            opts = EnumerateStereoisomers.StereoEnumerationOptions()
            opts.maxIsomers = max_isomers
            stereoisomers = EnumerateStereoisomers.EnumerateStereoisomers(mol, options=opts)
            return [
                Chem.MolToSmiles(m, canonical=True, doRandom=True) for m in stereoisomers
                ]
        return [smiles]
    except Exception as e:
        logging.error(f"augment_smiles failed for {smiles}: {e}\n{traceback.format_exc()}")
        return [smiles]

def bin_spectrum_to_graph(mzs, intensities, ion_mode, precursor_mz, adduct, n_bins=1000, max_mz=1000, noise_level=0.05):
    try:
        spectrum = np.zeros(n_bins)
        for mz, intensity in zip(mzs, intensities):
            try:
                mz = float(mz)
                intensity = float(intensity)
                if mz < max_mz:
                    bin_idx = int((mz / max_mz) * n_bins)
                    spectrum[bin_idx] += intensity
            except (ValueError, TypeError) as e:
                logging.warning(f"bin_spectrum_to_graph: Skipping value error: {e}")
                continue
        if spectrum.max() > 0:
            spectrum = spectrum / spectrum.max()
        spectrum += np.random.normal(0, noise_level, spectrum.shape).clip(0, 1)
        x = torch.tensor(spectrum, dtype=torch.float).unsqueeze(-1)
        edge_index = []
        for i in range(n_bins-1):
            edge_index.append([i, i+1])
            edge_index.append([i+1, i])
        edge_index = torch.tensor(edge_index, dtype=torch.long).t()
        ion_mode = torch.tensor([ion_mode], dtype=torch.float)
        precursor_mz = torch.tensor([precursor_mz], dtype=torch.float)
        adduct_idx = adduct_to_idx.get(adduct, 0)
        return spectrum, Data(x=x, edge_index=edge_index, ion_mode=ion_mode, precursor_mz=precursor_mz, adduct_idx=adduct_idx)
    except Exception as e:
        logging.error(f"bin_spectrum_to_graph failed: {e}\n{traceback.format_exc()}")
        return np.zeros(n_bins), Data(x=torch.zeros(n_bins, 1), edge_index=torch.zeros(2, 0, dtype=torch.long), ion_mode=torch.zeros(1), precursor_mz=torch.zeros(1), adduct_idx=0)
# Canonicalize SMILES, augment, and bin spectra
def canonicalize_smiles(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        if mol:
            return Chem.MolToSmiles(mol, canonical=True)
        return None
    except Exception as e:
        logging.error(f"canonicalize_smiles failed for {smiles}: {e}\n{traceback.format_exc()}")
        return None

def augment_smiles(smiles, max_isomers=None):
    max_isomers = max_isomers or config.MAX_ISOMERS
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            opts = EnumerateStereoisomers.StereoEnumerationOptions()
            opts.maxIsomers = max_isomers
            stereoisomers = EnumerateStereoisomers.EnumerateStereoisomers(mol, options=opts)
            return [
                Chem.MolToSmiles(m, canonical=True, doRandom=True) for m in stereoisomers
                ]
        return [smiles]
    except Exception as e:
        logging.error(f"augment_smiles failed for {smiles}: {e}\n{traceback.format_exc()}")
        return [smiles]

def bin_spectrum_to_graph(mzs, intensities, ion_mode, precursor_mz, adduct, n_bins=1000, max_mz=1000, noise_level=0.05):
    try:
        spectrum = np.zeros(n_bins)
        for mz, intensity in zip(mzs, intensities):
            try:
                mz = float(mz)
                intensity = float(intensity)
                if mz < max_mz:
                    bin_idx = int((mz / max_mz) * n_bins)
                    spectrum[bin_idx] += intensity
            except (ValueError, TypeError) as e:
                logging.warning(f"bin_spectrum_to_graph: Skipping value error: {e}")
                continue
        if spectrum.max() > 0:
            spectrum = spectrum / spectrum.max()
        spectrum += np.random.normal(0, noise_level, spectrum.shape).clip(0, 1)
        x = torch.tensor(spectrum, dtype=torch.float).unsqueeze(-1)
        edge_index = []
        for i in range(n_bins-1):
            edge_index.append([i, i+1])
            edge_index.append([i+1, i])
        edge_index = torch.tensor(edge_index, dtype=torch.long).t()
        ion_mode = torch.tensor([ion_mode], dtype=torch.float)
        precursor_mz = torch.tensor([precursor_mz], dtype=torch.float)
        adduct_idx = adduct_to_idx.get(adduct, 0)
        return spectrum, Data(x=x, edge_index=edge_index, ion_mode=ion_mode, precursor_mz=precursor_mz, adduct_idx=adduct_idx)
    except Exception as e:
        logging.error(f"bin_spectrum_to_graph failed: {e}\n{traceback.format_exc()}")
        return np.zeros(n_bins), Data(x=torch.zeros(n_bins, 1), edge_index=torch.zeros(2, 0, dtype=torch.long), ion_mode=torch.zeros(1), precursor_mz=torch.zeros(1), adduct_idx=0)


In [6]:
#cell 6
# Apply canonicalization, augmentation, and binning to the dataframe
# Preprocess ion mode, precursor m/z, and adducts
df_massspecgym['smiles'] = df_massspecgym['smiles'].apply(canonicalize_smiles)
df_external['smiles'] = df_external['smiles'].apply(canonicalize_smiles)
df_massspecgym = df_massspecgym.dropna(subset=['smiles'])
df_external = df_external.dropna(subset=['smiles'])
df_massspecgym['smiles_list'] = df_massspecgym['smiles'].apply(augment_smiles)
df_massspecgym = df_massspecgym.explode('smiles_list').dropna(subset=['smiles_list']).rename(columns={'smiles_list': 'smiles'})

df_massspecgym['ion_mode'] = df_massspecgym['adduct'].apply(lambda x: 0 if '+' in str(x) else 1 if '-' in str(x) else 0).fillna(0)
df_massspecgym['precursor_bin'] = pd.qcut(df_massspecgym['precursor_mz'], q=100, labels=False, duplicates='drop')
df_external['ion_mode'] = df_external['adduct'].apply(lambda x: 0 if '+' in str(x) else 1 if '-' in str(x) else 0).fillna(0)
df_external['precursor_bin'] = pd.qcut(df_external['precursor_mz'], q=100, labels=False, duplicates='drop')
adduct_types = df_massspecgym['adduct'].unique()
adduct_to_idx = {adduct: i for i, adduct in enumerate(adduct_types)}
df_massspecgym['adduct_idx'] = df_massspecgym['adduct'].map(adduct_to_idx)
df_external['adduct_idx'] = df_external['adduct'].map(adduct_to_idx)

def safe_bin_spectrum_to_graph(mzs, intensities, ion_mode, precursor_mz, adduct):
    try:
        # Clean mzs and intensities to remove non-numeric values
        mzs_clean = [float(x) for x in mzs if isinstance(x, (int, float)) or (isinstance(x, str) and x.replace('.','',1).replace('-','',1).isdigit())]
        intensities_clean = [float(x) for x in intensities if isinstance(x, (int, float)) or (isinstance(x, str) and x.replace('.','',1).replace('-','',1).isdigit())]
        return bin_spectrum_to_graph(mzs_clean, intensities_clean, ion_mode, precursor_mz, adduct)
    except Exception as e:
        logging.warning(f"Skipping value error in safe_bin_spectrum_to_graph: {e}")
        return np.zeros(100), None

df_massspecgym[['binned', 'graph_data']] = df_massspecgym.apply(
    lambda row: pd.Series(safe_bin_spectrum_to_graph(row['mzs'], row['intensities'], row['ion_mode'], row['precursor_mz'], row['adduct'])),
    axis=1
    )
df_external[['binned', 'graph_data']] = df_external.apply(
    lambda row: pd.Series(safe_bin_spectrum_to_graph(row['mzs'], row['intensities'], row['ion_mode'], row['precursor_mz'], row['adduct'])),
    axis=1
)


In [7]:
#cell 7

# Fix SMILES column shape issue
# Ensure SMILES column is 1D and contains only strings
def flatten_smiles_column(col):
    # Debug print
    print(f"Initial column shape: {col.values.shape if hasattr(col.values, 'shape') else 'no shape'}")
    print(f"Sample value: {col.iloc[0]}")
    
    # Force to Series if DataFrame
    if isinstance(col, pd.DataFrame):
        col = col.iloc[:, 0]
    
    # Force 2D array to 1D
    if hasattr(col.values, 'shape') and len(col.values.shape) > 1:
        col = pd.Series(col.values.ravel())
    
    # Flatten any remaining sequences in cells
    col = col.apply(lambda x: str(x[0]) if isinstance(x, (list, tuple, np.ndarray)) else str(x))
    
    # Debug print
    print(f"Final column shape: {col.values.shape if hasattr(col.values, 'shape') else 'no shape'}")
    print(f"Final sample value: {col.iloc[0]}")
    
    return col

# Reset the column to ensure clean state
df_massspecgym = df_massspecgym.copy()
df_external = df_external.copy()

print("Processing df_massspecgym['smiles']...")
df_massspecgym['smiles'] = flatten_smiles_column(df_massspecgym['smiles'])

print("\nProcessing df_external['smiles']...")
df_external['smiles'] = flatten_smiles_column(df_external['smiles'])

# Verify the columns are 1D before LabelEncoder
print("\nFinal verification:")
print(f"df_massspecgym['smiles'] shape: {df_massspecgym['smiles'].values.shape}")
print(f"df_external['smiles'] shape: {df_external['smiles'].values.shape}")



Processing df_massspecgym['smiles']...
Initial column shape: (562533, 2)
Sample value: smiles     COc1cc([C@H](Cc2ccccc2)NC(C)=O)oc(=O)c1
smiles    COc1cc(oc(c1)=O)[C@@H](NC(C)=O)Cc1ccccc1
Name: 0, dtype: object
Final column shape: (562533,)
Final sample value: COc1cc([C@H](Cc2ccccc2)NC(C)=O)oc(=O)c1

Processing df_external['smiles']...
Initial column shape: (23111,)
Sample value: CCC(C)C(Nc1ccc2c(cc1=O)C(NC(C)=O)CCc1cc(OC)c(OC)c(OC)c1-2)C(=O)Nc1nc(C(=O)OC)c(-c2ccccc2)s1
Final column shape: (23111,)
Final sample value: CCC(C)C(Nc1ccc2c(cc1=O)C(NC(C)=O)CCc1cc(OC)c(OC)c(OC)c1-2)C(=O)Nc1nc(C(=O)OC)c(-c2ccccc2)s1

Final verification:
df_massspecgym['smiles'] shape: (562533, 2)
df_external['smiles'] shape: (23111,)


In [8]:
# Cell 8
# Extract features and prepare XGBoost data

# ----------------------------
# 1) Flatten SMILES column
# ----------------------------
smiles_flat = df_massspecgym['smiles'].iloc[:, 0].astype(str)
print("Flattened SMILES shape:", smiles_flat.shape)
print("Example:", smiles_flat.iloc[0])

# ----------------------------
# 2) Encode labels
# ----------------------------
le = LabelEncoder()
y_all = le.fit_transform(smiles_flat)
print("y_all shape:", y_all.shape)
print("Number of unique labels:", len(np.unique(y_all)))

# ----------------------------
# 3) Feature extraction
# ----------------------------
def extract_features(df):
    feats = []
    for _, row in df.iterrows():
        spectrum = np.array(row['binned'], dtype=np.float32)
        mzs = row['mzs']
        if isinstance(mzs, str):
            mzs_list = [m for m in mzs.split(',') if m]
        else:
            mzs_list = mzs if isinstance(mzs, (list, tuple, np.ndarray)) else []
        mz_len = len(mzs_list)

        feat = [
            float(np.mean(spectrum)),
            float(np.std(spectrum)),
            float(np.max(spectrum)),
            float(np.sum(spectrum > 0.1)),
            float(row['precursor_mz']),
            float(row['ion_mode']),
            float(row['adduct_idx']),
            float(mz_len)
        ]
        feats.append(feat)
    return np.array(feats, dtype=np.float32)

X_all = extract_features(df_massspecgym)
print("X_all shape:", X_all.shape)

# ----------------------------
# 4) Subsample if needed
# ----------------------------
subset_size = min(50000, X_all.shape[0])
X_all = X_all[:subset_size]
smiles_subset = smiles_flat[:subset_size]

# Re-encode labels after subsetting
le_subset = LabelEncoder()
y_all = le_subset.fit_transform(smiles_subset)

# Filter out classes with fewer than 2 samples
from collections import Counter
class_counts = Counter(y_all)
valid_classes = [cls for cls, count in class_counts.items() if count >= 2]
mask = np.isin(y_all, valid_classes)

X_all = X_all[mask]
y_all = y_all[mask]
smiles_subset = smiles_subset[mask]

# NEW: Re-encode labels after filtering to ensure consecutive integers
le_final = LabelEncoder()
y_all = le_final.fit_transform(smiles_subset)  # Re-encode filtered SMILES

print("Filtered X_all shape:", X_all.shape)
print("Filtered number of unique labels:", len(np.unique(y_all)))

# ----------------------------
# 5) Train/test split (stratified)
# ----------------------------
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
    X_all, y_all, test_size=0.2, random_state=42, stratify=y_all
)

print("Train classes:", np.unique(y_train).size)
print("Test classes:", np.unique(y_test).size)

# ----------------------------
# 6) XGBoost model setup
# ----------------------------
n_classes = len(np.unique(y_train))
xgb_model = xgb.XGBClassifier(
    n_estimators=200,
    max_depth=6,
    learning_rate=0.1,
    random_state=42,
    n_jobs=1,
    num_class=n_classes  # Explicitly set number of classes
)

# ----------------------------
# 7) Train model
# ----------------------------
print("Training XGBoost...")
xgb_model.fit(X_train, y_train)

# ----------------------------
# 8) Predict & evaluate
# ----------------------------
y_pred = xgb_model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"XGBoost Accuracy: {acc:.4f}")

# ----------------------------
# 9) Feature importances
# ----------------------------
feature_names = [
    'mean_intensity', 'std_intensity', 'max_intensity', 'peak_count',
    'precursor_mz', 'ion_mode', 'adduct_idx', 'spectrum_length'
]

for name, val in zip(feature_names, xgb_model.feature_importances_):
    print(f"{name}: {val:.4f}")

Flattened SMILES shape: (562533,)
Example: COc1cc([C@H](Cc2ccccc2)NC(C)=O)oc(=O)c1
y_all shape: (562533,)
Number of unique labels: 25944
X_all shape: (562533, 8)
Filtered X_all shape: (49994, 8)
Filtered number of unique labels: 1012
Train classes: 1012
Test classes: 992
Training XGBoost...
XGBoost Accuracy: 0.6966
mean_intensity: 0.0061
std_intensity: 0.0103
max_intensity: 0.0056
peak_count: 0.0056
precursor_mz: 0.5358
ion_mode: 0.0000
adduct_idx: 0.4180
spectrum_length: 0.0185


X_all shape: (562533, 8)
Filtered X_all shape: (49994, 8)
Filtered number of unique labels: 1012
Train classes: 1012
Test classes: 992
Training XGBoost...
XGBoost Accuracy: 0.6966
mean_intensity: 0.0061
std_intensity: 0.0103
max_intensity: 0.0056
peak_count: 0.0056
precursor_mz: 0.5358
ion_mode: 0.0000
adduct_idx: 0.4180
spectrum_length: 0.0185


In [9]:
#cell 9
# Display results
print('\nSample predictions:')
for i in range(min(5, len(y_test))):
    true_smiles = le.inverse_transform([y_test[i]])[0]
    pred_smiles = le.inverse_transform([y_pred[i]])[0]
    print(f'True: {true_smiles}')
    print(f'Pred: {pred_smiles}')
    print(f'Match: {true_smiles == pred_smiles}\n')


Sample predictions:
True: C/C=C1/[C@H](O[C@@H]2O[C@H](CO)[C@@H](O)[C@H](O)[C@H]2O)OC=C(C(=O)OC)[C@H]1CC(=O)O[C@@H]1O[C@H](CO)[C@@H](O)[C@H](O)[C@H]1O
Pred: C/C=C1/[C@H](O[C@@H]2O[C@H](CO)[C@@H](O)[C@H](O)[C@H]2O)OC=C(C(=O)OC)[C@H]1CC(=O)O[C@@H]1O[C@H](CO)[C@@H](O)[C@H](O)[C@H]1O
Match: True

True: C/C=C1/C[C@@]2(CO)O[C@@]2(C)C(=O)OCC2=CCN3CC[C@@H](OC1=O)[C@@H]23
Pred: C/C=C1/C[C@@]2(CO)O[C@@]2(C)C(=O)OCC2=CCN3CC[C@@H](OC1=O)[C@@H]23
Match: True

True: C=C1CCCC2C1(C)CCC(C)C2(C)CC1=C(O)C(NCCC(=O)O)=CC(=O)C1=O
Pred: C=C1C(O)CCC2(C)C1CCC1=C3CCC(C(C)CCC(C)C(C)C)C3(C)CCC12
Match: False

True: C=C(C)[C@H]1COc2cc3oc(=O)ccc3cc2O1
Pred: C=C(C)[C@H]1COc2cc3oc(=O)ccc3cc2O1
Match: True

True: C=C(C(=O)O)C1CCC(C)C2CCC(=O)OC2(C)C1
Pred: C=C1/C=C/C(=O)N(C)CC(=O)N[C@@H]([C@H](O)C(N)=O)C(=O)OC([C@H](C)CCCCCC)[C@H](C)C(=O)N1
Match: False



In [10]:
#cell 10
# Enhanced RAG System for Molecular Data
class MolecularRAG:
    def __init__(self, df):
        self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
        self.df = df.copy()
        self.morgan_gen = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048)
        self.build_molecular_descriptions()
        self.build_index()
        self.build_fingerprint_index()

    def get_molecular_properties(self, smiles):
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol:
                mw = Descriptors.MolWt(mol)
                logp = Descriptors.MolLogP(mol)
                hbd = Descriptors.NumHDonors(mol)
                hba = Descriptors.NumHAcceptors(mol)
                rings = Descriptors.RingCount(mol)
                aromatic = Descriptors.NumAromaticRings(mol)
                return {'mw': mw, 'logp': logp, 'hbd': hbd, 'hba': hba, 'rings': rings, 'aromatic': aromatic}
        except:
            pass
        return {'mw': 0, 'logp': 0, 'hbd': 0, 'hba': 0, 'rings': 0, 'aromatic': 0}

    def build_molecular_descriptions(self):
        descriptions = []
        for _, row in self.df.iterrows():
            props = self.get_molecular_properties(row['smiles'])
            desc = f"Molecule with SMILES {row['smiles']}. "
            desc += f"Molecular weight: {props['mw']:.1f} Da. "
            desc += f"LogP: {props['logp']:.2f}. "
            desc += f"H-bond donors: {props['hbd']}, acceptors: {props['hba']}. "
            desc += f"Contains {props['rings']} rings, {props['aromatic']} aromatic. "
            desc += f"Adduct: {row['adduct']}, precursor m/z: {row['precursor_mz']:.2f}. "
            desc += f"Ion mode: {'positive' if row['ion_mode'] == 0 else 'negative'}."
            descriptions.append(desc)
        self.descriptions = descriptions

    def build_index(self):
        print('Building semantic index...')
        self.embeddings = self.encoder.encode(self.descriptions, show_progress_bar=True)
        self.semantic_index = faiss.IndexFlatIP(self.embeddings.shape[1])
        faiss.normalize_L2(self.embeddings)
        self.semantic_index.add(self.embeddings.astype('float32'))

    def build_fingerprint_index(self):
        print('Building fingerprint index...')
        fingerprints = []
        for smiles in self.df['smiles']:
            mol = Chem.MolFromSmiles(smiles)
            if mol:
                fp = self.morgan_gen.GetFingerprint(mol)
                fp_array = np.zeros(2048)
                DataStructs.ConvertToNumpyArray(fp, fp_array)
                fingerprints.append(fp_array)
            else:
                fingerprints.append(np.zeros(2048))
        self.fingerprints = np.array(fingerprints)
        self.fp_index = faiss.IndexFlatIP(2048)
        self.fp_index.add(self.fingerprints.astype('float32'))

    def semantic_search(self, query, k=5):
        query_emb = self.encoder.encode([query])
        faiss.normalize_L2(query_emb)
        scores, indices = self.semantic_index.search(query_emb.astype('float32'), k)
        results = []
        for i, idx in enumerate(indices[0]):
            row = self.df.iloc[idx]
            results.append({
                'smiles': row['smiles'],
                'score': scores[0][i],
                'adduct': row['adduct'],
                'precursor_mz': row['precursor_mz'],
                'description': self.descriptions[idx]
            })
        return results

    def structure_search(self, query_smiles, k=5):
        mol = Chem.MolFromSmiles(query_smiles)
        if not mol:
            return []
        query_fp = self.morgan_gen.GetFingerprint(mol)
        query_array = np.zeros(2048)
        DataStructs.ConvertToNumpyArray(query_fp, query_array)
        scores, indices = self.fp_index.search(query_array.reshape(1, -1).astype('float32'), k)
        results = []
        for i, idx in enumerate(indices[0]):
            row = self.df.iloc[idx]
            results.append({
                'smiles': row['smiles'],
                'tanimoto': scores[0][i],
                'adduct': row['adduct'],
                'precursor_mz': row['precursor_mz']
            })
        return results

    def hybrid_search(self, query, query_smiles=None, k=5, alpha=0.7):
        semantic_results = self.semantic_search(query, k*2)
        if query_smiles:
            structure_results = self.structure_search(query_smiles, k*2)
            # Combine scores
            combined = {}
            for r in semantic_results:
                combined[r['smiles']] = {'semantic': r['score'], 'structure': 0, 'data': r}
            for r in structure_results:
                if r['smiles'] in combined:
                    combined[r['smiles']]['structure'] = r['tanimoto']
                else:
                    combined[r['smiles']] = {'semantic': 0, 'structure': r['tanimoto'], 'data': r}
            # Hybrid scoring
            for smiles in combined:
                combined[smiles]['hybrid_score'] = alpha * combined[smiles]['semantic'] + (1-alpha) * combined[smiles]['structure']
            sorted_results = sorted(combined.items(), key=lambda x: x[1]['hybrid_score'], reverse=True)
            return [{'smiles': smiles, 'hybrid_score': data['hybrid_score'], 'semantic_score': data['semantic'], 'structure_score': data['structure']} for smiles, data in sorted_results[:k]]
        return semantic_results[:k]

print('Initializing enhanced RAG system...')
rag_system = MolecularRAG(df_massspecgym)
print('RAG system ready!')


2025-11-19 03:52:47,501 INFO Use pytorch device_name: cuda:0
2025-11-19 03:52:47,501 INFO Load pretrained SentenceTransformer: all-MiniLM-L6-v2


Initializing enhanced RAG system...
Building semantic index...


Batches:   0%|          | 0/17580 [00:00<?, ?it/s]

Building fingerprint index...
RAG system ready!


[04:00:54] SMILES Parse Error: syntax error while parsing: smiles
[04:00:54] SMILES Parse Error: check for mistakes around position 2:
[04:00:54] smiles
[04:00:54] ~^
[04:00:54] SMILES Parse Error: Failed parsing SMILES 'smiles' for input: 'smiles'
[04:00:54] SMILES Parse Error: syntax error while parsing: smiles
[04:00:54] SMILES Parse Error: check for mistakes around position 2:
[04:00:54] smiles
[04:00:54] ~^
[04:00:54] SMILES Parse Error: Failed parsing SMILES 'smiles' for input: 'smiles'


In [11]:
#cell 11
# Semantic Search Examples
queries = [
    'aromatic compound with hydroxyl group',
    'small molecule with high logP',
    'compound with multiple rings and nitrogen'
]

for query in queries:
    print(f'\nQuery: {query}')
    results = rag_system.semantic_search(query, k=3)
    for i, result in enumerate(results):
        print(f'{i+1}. SMILES: {result["smiles"]} (Score: {result["score"]:.4f})')
        print(f'   Adduct: {result["adduct"]}, m/z: {result["precursor_mz"]:.2f}')


Query: aromatic compound with hydroxyl group


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

1. SMILES: smiles    c1ccc2c(c1)cnc1ccccc12
smiles    c1ccc2c(c1)cnc1ccccc12
Name: 99678, dtype: object (Score: 0.3758)
   Adduct: [M+H]+, m/z: 180.08
2. SMILES: smiles    c1ccc2c(c1)cnc1ccccc12
smiles    c1ccc2c(c1)cnc1ccccc12
Name: 99679, dtype: object (Score: 0.3756)
   Adduct: [M+H]+, m/z: 180.08
3. SMILES: smiles    CN(C)c1cccc(Br)c1
smiles    CN(C)c1cccc(Br)c1
Name: 52971, dtype: object (Score: 0.3739)
   Adduct: [M+H]+, m/z: 200.01

Query: small molecule with high logP


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

1. SMILES: smiles    C[C@@H]1CC[C@@]23CC[C@]4(C)[C@@](C=CC5[C@@]6(C...
smiles    C[C@@H]1CC[C@@]23CC[C@]4(C)[C@@](C=CC5[C@@]6(C...
Name: 112392, dtype: object (Score: 0.4392)
   Adduct: [M+H]+, m/z: 455.35
2. SMILES: smiles    C[C@@H]1CC[C@@]23CC[C@]4(C)[C@@](C=CC5[C@@]6(C...
smiles    C[C@@H]1CC[C@@]23CC[C@]4(C)[C@@](C=CC5[C@@]6(C...
Name: 112392, dtype: object (Score: 0.4392)
   Adduct: [M+H]+, m/z: 455.35
3. SMILES: smiles    C[C@@H]1CC[C@@]23CC[C@]4(C)[C@@](C=CC5[C@@]6(C...
smiles    C[C@@H]1CC[C@@]23CC[C@]4(C)[C@@](C=CC5[C@@]6(C...
Name: 112392, dtype: object (Score: 0.4392)
   Adduct: [M+H]+, m/z: 455.35

Query: compound with multiple rings and nitrogen


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

1. SMILES: smiles    Nc1nc(N)nc(N)n1
smiles    Nc1nc(N)nc(N)n1
Name: 16817, dtype: object (Score: 0.4792)
   Adduct: [M+H]+, m/z: 127.07
2. SMILES: smiles    Nc1nc(N)nc(N)n1
smiles    Nc1nc(N)nc(N)n1
Name: 16815, dtype: object (Score: 0.4789)
   Adduct: [M+H]+, m/z: 127.07
3. SMILES: smiles    Nc1nc(N)nc(N)n1
smiles    Nc1nc(N)nc(N)n1
Name: 16812, dtype: object (Score: 0.4781)
   Adduct: [M+H]+, m/z: 127.07


In [12]:
#cell 12
# Structure-based Search
query_smiles = 'c1ccccc1O'  # phenol
print(f'Structure search for: {query_smiles}')
results = rag_system.structure_search(query_smiles, k=5)
for i, result in enumerate(results):
    print(f'{i+1}. SMILES: {result["smiles"]} (Tanimoto: {result["tanimoto"]:.4f})')
    print(f'   Adduct: {result["adduct"]}, m/z: {result["precursor_mz"]:.2f}')

Structure search for: c1ccccc1O
1. SMILES: smiles    COc1cc([C@H](Cc2ccccc2)NC(C)=O)oc(=O)c1
smiles    COc1cc([C@H](Cc2ccccc2)NC(C)=O)oc(=O)c1
Name: 1, dtype: object (Tanimoto: 0.0000)
   Adduct: [M+H]+, m/z: 288.12
2. SMILES: smiles    COc1cc([C@H](Cc2ccccc2)NC(C)=O)oc(=O)c1
smiles    COc1cc([C@H](Cc2ccccc2)NC(C)=O)oc(=O)c1
Name: 0, dtype: object (Tanimoto: 0.0000)
   Adduct: [M+H]+, m/z: 288.12
3. SMILES: smiles    CCC(C)C(Nc1ccc2c(cc1=O)C(NC(C)=O)CCc1cc(OC)c(O...
smiles    CCC(C)C(Nc1ccc2c(cc1=O)C(NC(C)=O)CCc1cc(OC)c(O...
Name: 207992, dtype: object (Tanimoto: -340282346638528859811704183484516925440.0000)
   Adduct: [M+Na]+, m/z: 737.26
4. SMILES: smiles    CCC(C)C(Nc1ccc2c(cc1=O)C(NC(C)=O)CCc1cc(OC)c(O...
smiles    CCC(C)C(Nc1ccc2c(cc1=O)C(NC(C)=O)CCc1cc(OC)c(O...
Name: 207992, dtype: object (Tanimoto: -340282346638528859811704183484516925440.0000)
   Adduct: [M+Na]+, m/z: 737.26
5. SMILES: smiles    CCC(C)C(Nc1ccc2c(cc1=O)C(NC(C)=O)CCc1cc(OC)c(O...
smiles    CCC(C)C(Nc1ccc2c(cc1=

In [13]:
#cell 13
# Hybrid Search
query = 'compound with multiple rings and nitrogen'
def hybrid_search(self, text_query, structure_query, k=5):
    # 1️ Get semantic and structure search results
    semantic_results = self.semantic_search(text_query, k=k)
    structure_results = self.structure_search(structure_query, k=k)

    # 2️ Merge safely
    combined = {}
    for r in semantic_results:
        smiles_key = r['smiles']
        if isinstance(smiles_key, pd.Series):
            smiles_key = smiles_key.iloc[0]
        if pd.isna(smiles_key):
            continue
        smiles_key = str(smiles_key)
        combined[smiles_key] = {'semantic': r['score'], 'structure': 0, 'data': r}

    for r in structure_results:
        smiles_key = r['smiles']
        if isinstance(smiles_key, pd.Series):
            smiles_key = smiles_key.iloc[0]
        if pd.isna(smiles_key):
            continue
        smiles_key = str(smiles_key)
        if smiles_key in combined:
            combined[smiles_key]['structure'] = r['score']
        else:
            combined[smiles_key] = {'semantic': 0, 'structure': r['score'], 'data': r}

    # 3️ Compute hybrid score
    results = []
    for k_smiles, v in combined.items():
        hybrid_score = v['semantic'] + v['structure']  # adjust weighting if needed
        results.append({
            'smiles': k_smiles,
            'semantic_score': v['semantic'],
            'structure_score': v['structure'],
            'hybrid_score': hybrid_score,
            'data': v['data']
        })

    # 4️ Sort and return top-k
    results = sorted(results, key=lambda x: x['hybrid_score'], reverse=True)
    return results[:k]


In [14]:

# Cell 22 
# RAG System Analysis
import time
from rdkit import Chem

print('RAG System Statistics:')
print(f'Total molecules indexed: {len(rag_system.df)}')
print(f'Embedding dimension: {rag_system.embeddings.shape[1]}')
print(f'Fingerprint dimension: {rag_system.fingerprints.shape[1]}')

# Sample molecular properties distribution
mw_values = []
smiles_col = [col for col in rag_system.df.columns if 'smiles' in col.lower()][0]
for smiles in rag_system.df[smiles_col].head(100):
    props = rag_system.get_molecular_properties(smiles)
    if props['mw'] > 0:  # Only include valid molecular weights
        mw_values.append(props['mw'])
if mw_values:
    print(f'Sample MW range: {min(mw_values):.1f} - {max(mw_values):.1f} Da')
else:
    print("No valid molecular weights found in the first 100 SMILES.")

# Test query performance
start = time.time()
_ = rag_system.semantic_search('test query', k=10)
semantic_time = time.time() - start

start = time.time()
_ = rag_system.structure_search('CCO', k=10)
structure_time = time.time() - start

print(f'Semantic search time: {semantic_time:.4f}s')
print(f'Structure search time: {structure_time:.4f}s')

RAG System Statistics:
Total molecules indexed: 562533
Embedding dimension: 384
Fingerprint dimension: 2048
No valid molecular weights found in the first 100 SMILES.


[04:00:54] SMILES Parse Error: syntax error while parsing: smiles
[04:00:54] SMILES Parse Error: check for mistakes around position 2:
[04:00:54] smiles
[04:00:54] ~^
[04:00:54] SMILES Parse Error: Failed parsing SMILES 'smiles' for input: 'smiles'
[04:00:54] SMILES Parse Error: syntax error while parsing: smiles
[04:00:54] SMILES Parse Error: check for mistakes around position 2:
[04:00:54] smiles
[04:00:54] ~^
[04:00:54] SMILES Parse Error: Failed parsing SMILES 'smiles' for input: 'smiles'


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Semantic search time: 0.0439s
Structure search time: 0.0013s


In [15]:
#cell 15
# Hybrid Search Examples
queries = [
    'aromatic compound with hydroxyl group',
    'small molecule with high logP',
    'compound with multiple rings and nitrogen'
]
# Cell to verify df_massspecgym
print('df_massspecgym' in globals())
if 'df_massspecgym' in globals():
    print("df_massspecgym shape:", df_massspecgym.shape)
    print("Columns:", df_massspecgym.columns.tolist())

True
df_massspecgym shape: (562533, 20)
Columns: ['identifier', 'mzs', 'intensities', 'smiles', 'inchikey', 'formula', 'precursor_formula', 'parent_mass', 'precursor_mz', 'adduct', 'instrument_type', 'collision_energy', 'fold', 'simulation_challenge', 'smiles', 'ion_mode', 'precursor_bin', 'adduct_idx', 'binned', 'graph_data']


In [16]:
#cell 16
# SELFIES Vocabulary Construction
# Enhanced SMILES Extraction and SELFIES Vocabulary Construction

import logging
from rdkit import Chem
from rdkit import RDLogger
import selfies as sf
import pandas as pd

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.*')

# Verify dataset
if 'df_massspecgym' not in globals():
    raise NameError("df_massspecgym is not defined. Please load the dataset in Cell 3.")

# Step 1: Inspect dataset
logging.info(f"Dataset size: {df_massspecgym.shape}")
logging.info(f"Columns in dataset: {df_massspecgym.columns.tolist()}")
logging.info(f"First 5 rows:\n{df_massspecgym.head(5)}")

# Step 2: Auto-detect SMILES column
smiles_col_candidates = [col for col in df_massspecgym.columns if 'smiles' in col.lower()]
if not smiles_col_candidates:
    raise ValueError("No column containing 'smiles' found in MassSpecGym dataset.")

# Handle duplicate 'smiles' columns
if len(smiles_col_candidates) > 1:
    logging.warning(f"Multiple SMILES columns found: {smiles_col_candidates}")
    # Get indices of all 'smiles' columns
    smiles_col_indices = [i for i, col in enumerate(df_massspecgym.columns) if col == 'smiles']
    logging.info(f"Indices of 'smiles' columns: {smiles_col_indices}")
    
    # Inspect contents of each 'smiles' column
    for i, idx in enumerate(smiles_col_indices):
        col_data = df_massspecgym.iloc[:, idx].head(5)
        logging.info(f"Sample data from 'smiles' column at index {idx}:\n{col_data}")
    
    smiles_col_index = smiles_col_indices[0]
    smiles_col = 'smiles'
else:
    smiles_col_index = df_massspecgym.columns.get_loc(smiles_col_candidates[0])
    smiles_col = smiles_col_candidates[0]

logging.info(f"Using column '{smiles_col}' at index {smiles_col_index} as SMILES source.")

# Step 3: Extract clean SMILES strings safely
def is_valid_smiles(smiles):
    if not isinstance(smiles, str) or not smiles.strip():
        return False
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        return mol is not None
    except ValueError as e:
        logging.debug(f"Invalid SMILES '{smiles}': {e}")
        return False

smiles_series = df_massspecgym.iloc[:, smiles_col_index]
logging.info(f"Type of selected column: {type(smiles_series)}")
if not isinstance(smiles_series, pd.Series):
    raise TypeError(f"Selected column at index {smiles_col_index} is {type(smiles_series)}, expected pandas.Series")

all_smiles = smiles_series.dropna().tolist()
valid_smiles = [s for s in all_smiles if is_valid_smiles(s)]
logging.info(f"Extracted {len(valid_smiles)} valid SMILES strings.")
if not valid_smiles:
    raise ValueError("No valid SMILES could be extracted. Check your dataset!")

# Step 4: Convert to SELFIES
all_selfies = []
failed_conversions = []
for s in valid_smiles:
    try:
        selfies_str = sf.encoder(s)
        all_selfies.append(selfies_str)
    except Exception as e:
        failed_conversions.append((s, str(e)))
        continue

if failed_conversions:
    logging.warning(f"Failed to convert {len(failed_conversions)} SMILES to SELFIES. First few errors: {failed_conversions[:5]}")
if not all_selfies:
    raise ValueError("No valid SELFIES could be generated. Check your SMILES extraction!")
logging.info(f"Generated {len(all_selfies)} SELFIES strings successfully.")

# Step 5: Build SELFIES vocabulary
selfies_alphabet = set()
for s in all_selfies:
    selfies_alphabet.update(sf.split_selfies(s))

token_to_idx = {token: idx for idx, token in enumerate(sorted(selfies_alphabet))}
idx_to_token = {idx: token for token, idx in token_to_idx.items()}

# Add special tokens if needed
special_tokens = ['<pad>', '<unk>', '<start>', '<end>']
for token in special_tokens:
    if token not in token_to_idx:
        token_to_idx[token] = len(token_to_idx)
        idx_to_token[len(idx_to_token)] = token

vocab_size = len(token_to_idx)
PRETRAIN_MAX_LEN = min(100, max(len(list(sf.split_selfies(s))) for s in all_selfies) if all_selfies else 0)
SUPERVISED_MAX_LEN = max(len(list(sf.split_selfies(s))) + 2 for s in all_selfies) if all_selfies else 0

logging.info(f"SELFIES vocabulary size: {vocab_size}")
logging.info(f"Pretrain MAX_LEN: {PRETRAIN_MAX_LEN}, Supervised MAX_LEN: {SUPERVISED_MAX_LEN}")

# Define encode_selfies function
def encode_selfies(selfies_str, max_len):
    """Encode a SELFIES string into a list of token indices, padded/truncated to max_len."""
    tokens = list(sf.split_selfies(selfies_str))
    token_indices = [token_to_idx.get(token, token_to_idx['<unk>']) for token in tokens]
    token_indices = [token_to_idx['<start>']] + token_indices + [token_to_idx['<end>']]
    if len(token_indices) > max_len:
        token_indices = token_indices[:max_len]
    padding_idx = token_to_idx['<pad>']
    token_indices += [padding_idx] * (max_len - len(token_indices))
    return token_indices

# Save results
pd.DataFrame({'SELFIES': all_selfies}).to_csv('massspecgym_selfies.csv', index=False)
with open('selfies_vocab.json', 'w') as f:
    import json
    json.dump({'token_to_idx': token_to_idx, 'idx_to_token': idx_to_token}, f)

2025-11-19 04:00:54,287 INFO Dataset size: (562533, 20)
2025-11-19 04:00:54,288 INFO Columns in dataset: ['identifier', 'mzs', 'intensities', 'smiles', 'inchikey', 'formula', 'precursor_formula', 'parent_mass', 'precursor_mz', 'adduct', 'instrument_type', 'collision_energy', 'fold', 'simulation_challenge', 'smiles', 'ion_mode', 'precursor_bin', 'adduct_idx', 'binned', 'graph_data']
2025-11-19 04:00:54,368 INFO First 5 rows:
             identifier                                                mzs  \
0  MassSpecGymID0000001  91.0542,125.0233,154.0499,155.0577,185.0961,20...   
1  MassSpecGymID0000002  91.0542,125.0233,155.0577,185.0961,229.0859,24...   
2  MassSpecGymID0000003  69.0343,91.0542,125.0233,127.039,153.0699,154....   
3  MassSpecGymID0000004  69.0343,91.0542,110.06,111.0441,112.0393,120.0...   
4  MassSpecGymID0000005  91.0542,125.0233,185.0961,229.0859,246.1125,28...   

                                         intensities  \
0  0.24524524524524524,1.0,0.08008008008008008,

In [17]:
#cell 17
# Precompute Morgan Fingerprints for All Unique SMILES

import logging
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Verify datasets
if 'df_massspecgym' not in globals():
    raise NameError("df_massspecgym is not defined. Please load the dataset in Cell 3.")
if 'df_external' not in globals():
    raise NameError("df_external is not defined. Please load the dataset.")

# Step 1: Inspect datasets
logging.info(f"df_massspecgym columns: {df_massspecgym.columns.tolist()}")
logging.info(f"df_external columns: {df_external.columns.tolist()}")

# Step 2: Select SMILES column from df_massspecgym
massspecgym_smiles_cols = [i for i, col in enumerate(df_massspecgym.columns) if col == 'smiles']
if not massspecgym_smiles_cols:
    raise ValueError("No 'smiles' column found in df_massspecgym.")
if len(massspecgym_smiles_cols) > 1:
    logging.warning(f"Multiple 'smiles' columns found in df_massspecgym at indices: {massspecgym_smiles_cols}")
    for idx in massspecgym_smiles_cols:
        logging.info(f"Sample data from df_massspecgym 'smiles' at index {idx}:\n{df_massspecgym.iloc[:, idx].head(5)}")
massspecgym_smiles_index = massspecgym_smiles_cols[0]
logging.info(f"Using df_massspecgym 'smiles' column at index {massspecgym_smiles_index}")

# Step 3: Select SMILES column from df_external
external_smiles_cols = [i for i, col in enumerate(df_external.columns) if col == 'smiles']
if not external_smiles_cols:
    raise ValueError("No 'smiles' column found in df_external.")
if len(external_smiles_cols) > 1:
    logging.warning(f"Multiple 'smiles' columns found in df_external at indices: {external_smiles_cols}")
    for idx in external_smiles_cols:
        logging.info(f"Sample data from df_external 'smiles' at index {idx}:\n{df_external.iloc[:, idx].head(5)}")
external_smiles_index = external_smiles_cols[0]
logging.info(f"Using df_external 'smiles' column at index {external_smiles_index}")

# Step 4: Extract and combine unique SMILES
massspecgym_smiles = df_massspecgym.iloc[:, massspecgym_smiles_index].dropna().tolist()
external_smiles = df_external.iloc[:, external_smiles_index].dropna().tolist()
all_smiles = list(set(massspecgym_smiles + external_smiles))
logging.info(f"Extracted {len(all_smiles)} unique SMILES strings.")

# Step 5: Precompute Morgan fingerprints for all unique SMILES
all_fingerprints = {}
morgan_gen = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048)
for smiles in all_smiles:
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            all_fingerprints[smiles] = morgan_gen.GetFingerprint(mol)
        else:
            logging.warning(f"Invalid SMILES '{smiles}' skipped during fingerprint generation.")
    except Exception as e:
        logging.warning(f"Failed to process SMILES '{smiles}': {e}")
logging.info(f"Generated fingerprints for {len(all_fingerprints)} SMILES strings.")

2025-11-19 04:05:05,177 INFO df_massspecgym columns: ['identifier', 'mzs', 'intensities', 'smiles', 'inchikey', 'formula', 'precursor_formula', 'parent_mass', 'precursor_mz', 'adduct', 'instrument_type', 'collision_energy', 'fold', 'simulation_challenge', 'smiles', 'ion_mode', 'precursor_bin', 'adduct_idx', 'binned', 'graph_data']
2025-11-19 04:05:05,177 INFO df_external columns: ['identifier', 'mzs', 'intensities', 'smiles', 'inchikey', 'formula', 'precursor_formula', 'parent_mass', 'precursor_mz', 'adduct', 'instrument_type', 'collision_energy', 'fold', 'simulation_challenge', 'ion_mode', 'precursor_bin', 'adduct_idx', 'binned', 'graph_data']
2025-11-19 04:05:05,179 INFO Sample data from df_massspecgym 'smiles' at index 3:
0    COc1cc([C@H](Cc2ccccc2)NC(C)=O)oc(=O)c1
1    COc1cc([C@H](Cc2ccccc2)NC(C)=O)oc(=O)c1
2    COc1cc([C@H](Cc2ccccc2)NC(C)=O)oc(=O)c1
3    COc1cc([C@H](Cc2ccccc2)NC(C)=O)oc(=O)c1
4    COc1cc([C@H](Cc2ccccc2)NC(C)=O)oc(=O)c1
Name: smiles, dtype: object
2025-11-19 0

In [18]:
#cell 18
# Dataset class for MS/MS data

class MSMSDataset(Dataset):
    def __init__(self, dataframe, max_len=PRETRAIN_MAX_LEN, is_ssl=False):
        self.spectra = np.stack(dataframe['binned'].values)
        self.graph_data = dataframe['graph_data'].values
        self.ion_modes = dataframe['ion_mode'].values
        self.precursor_bins = dataframe['precursor_bin'].values
        self.adduct_indices = dataframe['adduct_idx'].values
        self.raw_smiles = dataframe['smiles'].values

        self.is_ssl = is_ssl
        self.ssl_max_len = max_len                 # for self-supervised (pretrain)
        self.sup_max_len = SUPERVISED_MAX_LEN      # for supervised fine-tune

        if is_ssl:
            # self-supervised: create original & masked SELFIES encodings
            self.smiles = []
            self.masked_smiles = []
            for s in self.raw_smiles:
                selfies = sf.encoder(s)
                masked_s, orig_s = self.mask_selfies(selfies)
                self.smiles.append(encode_selfies(orig_s, self.ssl_max_len))
                self.masked_smiles.append(encode_selfies(masked_s, self.ssl_max_len))
        else:
            # supervised: only original SELFIES with supervised max length
            self.smiles = [
                encode_selfies(sf.encoder(s), max_len=self.sup_max_len)
                for s in self.raw_smiles
            ]

    def mask_selfies(self, selfies, mask_ratio=0.10):
        """Randomly mask a fraction of SELFIES tokens for SSL pretraining."""
        max_len = self.ssl_max_len
        try:
            # split_selfies returns a generator → convert to list, then truncate
            tokens = list(sf.split_selfies(selfies))[:max_len - 2]

            masked_tokens = tokens.copy()
            n_mask = int(mask_ratio * len(tokens))

            if n_mask > 0:
                mask_indices = np.random.choice(len(tokens), n_mask, replace=False)
                for idx in mask_indices:
                    masked_tokens[idx] = config.MASK_TOKEN

            # return masked version and original (both as SELFIES strings)
            return ''.join(masked_tokens), ''.join(tokens)

        except Exception as e:
            logging.error(
                f"mask_selfies failed for {selfies}: {e}\n{traceback.format_exc()}"
            )
            # Fallback: no masking so dataset creation doesn't crash
            return selfies, selfies

    def __len__(self):
        return len(self.spectra)

    def __getitem__(self, idx):
        if self.is_ssl:
            return (
                torch.tensor(self.spectra[idx], dtype=torch.float),
                self.graph_data[idx],
                torch.tensor(self.smiles[idx], dtype=torch.long),
                torch.tensor(self.masked_smiles[idx], dtype=torch.long),
                torch.tensor(self.ion_modes[idx], dtype=torch.long),
                torch.tensor(self.precursor_bins[idx], dtype=torch.long),
                torch.tensor(self.adduct_indices[idx], dtype=torch.long),
                self.raw_smiles[idx],
            )

        return (
            torch.tensor(self.spectra[idx], dtype=torch.float),
            self.graph_data[idx],
            torch.tensor(self.smiles[idx], dtype=torch.long),
            torch.tensor(self.ion_modes[idx], dtype=torch.long),
            torch.tensor(self.precursor_bins[idx], dtype=torch.long),
            torch.tensor(self.adduct_indices[idx], dtype=torch.long),
            self.raw_smiles[idx],
        )



In [19]:
#cell 19
# Positional encoding and model encoder/decoder classes

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# Neural Network Models
class SpectrumTransformerEncoder(nn.Module):
    def __init__(self, d_model=512, nhead=8, num_layers=6, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.input_proj = nn.Linear(1, d_model)
        self.pos_encoding = PositionalEncoding(d_model)
        self.dropout = nn.Dropout(dropout)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dropout=dropout, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)

    def forward(self, x):
        x = self.input_proj(x.unsqueeze(-1))
        x = self.pos_encoding(x)
        x = self.dropout(x)
        return self.transformer(x)

class SpectrumGNNEncoder(MessagePassing):
    def __init__(self, d_model=512):
        super().__init__(aggr='mean')
        self.d_model = d_model
        self.lin = nn.Linear(1, d_model)
        self.mlp = nn.Sequential(nn.Linear(d_model, d_model), nn.ReLU(), nn.Linear(d_model, d_model))

    def forward(self, x, edge_index, batch):
        x = self.lin(x)
        x = self.propagate(edge_index, x=x)
        return global_mean_pool(x, batch)

    def message(self, x_j):
        return self.mlp(x_j)

class SmilesTransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True)
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers)
        self.output_proj = nn.Linear(d_model, vocab_size)

    def forward(self, tgt, memory, tgt_mask=None):
        tgt = self.embedding(tgt) * math.sqrt(self.d_model)
        tgt = self.pos_encoding(tgt)
        output = self.transformer(tgt, memory, tgt_mask=tgt_mask)
        return self.output_proj(output)

class MSMS2SmilesHybrid(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, **kwargs):
        super().__init__()
        self.transformer_encoder = SpectrumTransformerEncoder(d_model, nhead, num_layers)
        self.gnn_encoder = SpectrumGNNEncoder(d_model)
        self.decoder = SmilesTransformerDecoder(vocab_size, d_model, nhead, num_layers)
        self.fusion = nn.Linear(d_model * 2, d_model)

    def forward(self, spectrum, graph_data, tgt, tgt_mask=None):
        transformer_out = self.transformer_encoder(spectrum)
        gnn_out = self.gnn_encoder(graph_data.x, graph_data.edge_index, graph_data.batch)
        memory = self.fusion(torch.cat([transformer_out.mean(1), gnn_out], dim=1)).unsqueeze(1)
        return self.decoder(tgt, memory, tgt_mask)


In [20]:
# cell 20
# Training and evaluation functions

from torch_geometric.data import Batch as GeoBatch

def make_graph_batch(graph_data, device):
    """
    Handle graph_data coming from either:
    - torch_geometric.loader.DataLoader  -> already a Batch / tupleBatch
    - torch.utils.data.DataLoader        -> list/tuple of Data
    - a single Data object
    """
    # Already a PyG Batch (common with GeoDataLoader)
    if isinstance(graph_data, GeoBatch):
        return graph_data.to(device)

    # Likely a sequence (list/tuple/tupleBatch) of Data
    try:
        return GeoBatch.from_data_list(list(graph_data)).to(device)
    except Exception:
        # Fallback: assume single Data
        return GeoBatch.from_data_list([graph_data]).to(device)


def ssl_pretrain(model, dataloader, epochs=3, lr=1e-4):
    from torch.cuda.amp import autocast, GradScaler
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scaler = GradScaler()
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(dataloader, desc=f'SSL Epoch {epoch+1}'):
            spectrum, graph_data, target, masked, _, _, _, _ = batch
            spectrum, target = spectrum.to(device, non_blocking=True), target.to(device, non_blocking=True)
            graph_batch = make_graph_batch(graph_data, device)
            
            optimizer.zero_grad()
            with autocast():
                # For “true” SSL you might want masked[:, :-1] here instead of target[:, :-1]
                output = model(spectrum, graph_batch, target[:, :-1])
                loss = F.cross_entropy(
                    output.reshape(-1, output.size(-1)),
                    target[:, 1:].reshape(-1),
                    ignore_index=0
                )
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()
        print(f'SSL Epoch {epoch+1} Loss: {total_loss/len(dataloader):.4f}')


def supervised_train(model, train_loader, val_loader, epochs=30, lr=1e-4, patience=5):
    from torch.cuda.amp import autocast, GradScaler
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    scaler = GradScaler()
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch in tqdm(train_loader, desc=f'Train Epoch {epoch+1}'):
            spectrum, graph_data, target, _, _, _, _ = batch
            spectrum, target = spectrum.to(device, non_blocking=True), target.to(device, non_blocking=True)
            graph_batch = make_graph_batch(graph_data, device)
            
            # Create attention (causal) mask
            tgt_mask = torch.triu(
                torch.ones(target.size(1)-1, target.size(1)-1, device=device),
                diagonal=1
            ).bool()
            
            optimizer.zero_grad()
            with autocast():
                output = model(spectrum, graph_batch, target[:, :-1], tgt_mask=tgt_mask)
                loss = F.cross_entropy(
                    output.reshape(-1, output.size(-1)),
                    target[:, 1:].reshape(-1),
                    ignore_index=0
                )
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            train_loss += loss.item()
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                spectrum, graph_data, target, _, _, _, _ = batch
                spectrum, target = spectrum.to(device, non_blocking=True), target.to(device, non_blocking=True)
                graph_batch = make_graph_batch(graph_data, device)
                with autocast():
                    output = model(spectrum, graph_batch, target[:, :-1])
                    loss = F.cross_entropy(
                        output.reshape(-1, output.size(-1)),
                        target[:, 1:].reshape(-1),
                        ignore_index=0
                    )
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        scheduler.step()
        print(f'Epoch {epoch+1}: Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss:.4f}')
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'Early stopping at epoch {epoch+1}')
                break
    return best_val_loss


def beam_search(model, spectrum, graph_data, ion_mode, precursor_bin, adduct_idx,
                true_smiles, beam_width=5, max_len=100, device='cpu'):
    model.eval()
    with torch.no_grad():
        spectrum = spectrum.unsqueeze(0).to(device)
        # graph_data here is usually a single Data, so make_graph_batch will handle it
        graph_batch = make_graph_batch(graph_data, device)
        
        # Start with SOS token
        sequences = [[token_to_idx[config.SOS_TOKEN]]]
        scores = [0.0]
        
        for _ in range(max_len):
            candidates = []
            for i, seq in enumerate(sequences):
                if seq[-1] == token_to_idx[config.EOS_TOKEN]:
                    candidates.append((seq, scores[i]))
                    continue
                
                tgt = torch.tensor([seq]).to(device)
                output = model(spectrum, graph_batch, tgt)
                probs = F.softmax(output[0, -1], dim=-1)
                
                top_probs, top_indices = torch.topk(probs, beam_width)
                for prob, idx in zip(top_probs, top_indices):
                    new_seq = seq + [idx.item()]
                    new_score = scores[i] + torch.log(prob).item()
                    candidates.append((new_seq, new_score))
            
            candidates.sort(key=lambda x: x[1], reverse=True)
            sequences = [seq for seq, _ in candidates[:beam_width]]
            scores = [score for _, score in candidates[:beam_width]]
        
        results = []
        for seq, score in zip(sequences, scores):
            smiles = decode_selfies(seq)
            if smiles:
                results.append((smiles, score))
        return results[:beam_width]


# Missing evaluation functions
def mw_difference(smiles1, smiles2):
    try:
        mol1, mol2 = Chem.MolFromSmiles(smiles1), Chem.MolFromSmiles(smiles2)
        if mol1 and mol2:
            return abs(Descriptors.MolWt(mol1) - Descriptors.MolWt(mol2))
    except:
        pass
    return float('inf')


def logp_difference(smiles1, smiles2):
    try:
        mol1, mol2 = Chem.MolFromSmiles(smiles1), Chem.MolFromSmiles(smiles2)
        if mol1 and mol2:
            return abs(Descriptors.MolLogP(mol1) - Descriptors.MolLogP(mol2))
    except:
        pass
    return float('inf')


def substructure_match(smiles1, smiles2, substructures=None):
    return 0.5  # Placeholder


def error_analysis(pred_list, true_list, adduct_list, fingerprints):
    print('Error analysis completed')


def plot_attention_weights(weights, title='Attention'):
    print(f'Attention visualization: {title}')


def plot_gnn_edge_weights(weights, edges, title='GNN'):
    print(f'GNN visualization: {title}')


def calculate_bleu(predicted_smiles, true_smiles):
    try:
        pred_tokens = list(predicted_smiles)
        true_tokens = list(true_smiles)
        return sentence_bleu([true_tokens], pred_tokens, weights=(0.25, 0.25, 0.25, 0.25))
    except:
        return 0.0


def tanimoto_similarity(smiles1, smiles2, fingerprint_dict):
    if smiles1 in fingerprint_dict and smiles2 in fingerprint_dict:
        return DataStructs.TanimotoSimilarity(fingerprint_dict[smiles1], fingerprint_dict[smiles2])
    return 0.0


def validity_rate(smiles_list):
    valid = sum(1 for s in smiles_list if Chem.MolFromSmiles(s) is not None)
    return (valid / len(smiles_list)) * 100 if smiles_list else 0


def objective(trial, train_data, val_data):
    lr = trial.suggest_float('lr', 1e-5, 1e-3, log=True)
    return lr  # Simplified for demo


# Additional metrics and visualization
def dice_similarity(smiles1, smiles2):
    try:
        mol1, mol2 = Chem.MolFromSmiles(smiles1), Chem.MolFromSmiles(smiles2)
        if mol1 and mol2:
            fp1 = Chem.RDKFingerprint(mol1)
            fp2 = Chem.RDKFingerprint(mol2)
            return DataStructs.DiceSimilarity(fp1, fp2)
    except:
        pass
    return 0.0


def mcs_similarity(smiles1, smiles2):
    try:
        mol1, mol2 = Chem.MolFromSmiles(smiles1), Chem.MolFromSmiles(smiles2)
        if mol1 and mol2:
            mcs = rdFMCS.FindMCS([mol1, mol2])
            return mcs.numAtoms / max(mol1.GetNumAtoms(), mol2.GetNumAtoms())
    except:
        pass
    return 0.0


def prediction_diversity(smiles_list):
    unique_smiles = set(smiles_list)
    return len(unique_smiles) / len(smiles_list) if smiles_list else 0


def plot_molecular_comparison(true_smiles, pred_smiles, title='Comparison'):
    try:
        true_mol = Chem.MolFromSmiles(true_smiles)
        pred_mol = Chem.MolFromSmiles(pred_smiles)
        if true_mol and pred_mol:
            img = Draw.MolsToGridImage(
                [true_mol, pred_mol],
                molsPerRow=2,
                subImgSize=(300, 300),
                legends=['True', 'Predicted']
            )
            plt.figure(figsize=(10, 5))
            plt.imshow(np.array(img))
            plt.axis('off')
            plt.title(title)
            plt.show()
    except Exception as e:
        print(f'Visualization error: {e}')


# Model checkpointing
def save_checkpoint(model, optimizer, epoch, loss, filepath):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'vocab_size': vocab_size,
        'token_to_idx': token_to_idx,
        'idx_to_token': idx_to_token
    }, filepath)


def load_checkpoint(filepath, model, optimizer=None):
    checkpoint = torch.load(filepath, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['loss']


# Data validation functions
def validate_spectrum_quality(mzs, intensities, min_peaks=5, max_mz_range=2000):
    if len(mzs) < min_peaks or max(mzs) > max_mz_range:
        return False
    return True


def validate_molecular_properties(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if not mol:
            return False
        mw = Descriptors.MolWt(mol)
        return 50 <= mw <= 1000
    except:
        return False


def remove_duplicates(df, subset=['smiles', 'precursor_mz']):
    return df.drop_duplicates(subset=subset, keep='first')


# Memory management
def clear_memory():
    import gc
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


In [21]:
#cell 21
# Integration: XGBoost + RAG + Deep Learning
class HybridPredictor:
    def __init__(self, dl_model, xgb_model, rag_system, label_encoder):
        self.dl_model = dl_model
        self.xgb_model = xgb_model
        self.rag_system = rag_system
        self.label_encoder = label_encoder

    def predict_ensemble(self, spectrum, graph_data, features, query_text=None, weights=[0.5, 0.3, 0.2]):
        # Ensure spectrum is a torch tensor on the correct device
        if not isinstance(spectrum, torch.Tensor):
            spectrum = torch.tensor(spectrum, dtype=torch.float32, device=device)
        else:
            spectrum = spectrum.to(device)

        # Validate weights (normalize if slightly off)
        if abs(sum(weights) - 1.0) > 0.01:
            weights = [w / sum(weights) for w in weights]
        
        predictions = []
        
        # Deep learning prediction with error handling
        try:
            dl_results = beam_search(
                self.dl_model,
                spectrum,
                graph_data,
                ion_mode=0,
                precursor_bin=0,
                adduct_idx=0,
                true_smiles='',
                beam_width=5,
                device=device
            )
            if dl_results and dl_results[0][0]:
                predictions.append(('DL', dl_results[0][0], weights[0]))
        except Exception as e:
            print(f'DL prediction failed: {e}')
        
        # XGBoost prediction with error handling
        try:
            xgb_pred = self.xgb_model.predict([features])[0]
            xgb_smiles = self.label_encoder.inverse_transform([xgb_pred])[0]
            predictions.append(('XGB', xgb_smiles, weights[1]))
        except Exception as e:
            print(f'XGBoost prediction failed: {e}')
        
        # RAG prediction with error handling
        if query_text:
            try:
                rag_results = self.rag_system.semantic_search(query_text, k=1)
                if rag_results and len(rag_results) > 0:
                    predictions.append(('RAG', rag_results[0]['smiles'], weights[2]))
            except Exception as e:
                print(f'RAG prediction failed: {e}')
        
        return predictions

    def evaluate_ensemble(self, test_data, n_samples=10):
        results = {'dl': [], 'xgb': [], 'rag': [], 'ensemble': []}
        
        for i in range(min(n_samples, len(test_data))):
            row = test_data.iloc[i]
            true_smiles = row['smiles']
            
            # Extract features
            spectrum = row['binned']          # numpy array
            graph_data = row['graph_data']    # PyG Data object
            features = [
                np.mean(spectrum),
                np.std(spectrum),
                np.max(spectrum),
                np.sum(spectrum > 0.1),
                row['precursor_mz'],
                row['ion_mode'],
                row['adduct_idx'],
                len(row['mzs'])
            ]
            
            # Get ensemble predictions
            preds = self.predict_ensemble(
                spectrum,
                graph_data,
                features,
                query_text=f"molecule with MW {row['precursor_mz']:.1f}"
            )

            if not preds:
                continue  # nothing to evaluate for this sample
            
            # Evaluate each method
            for method, pred_smiles, weight in preds:
                similarity = tanimoto_similarity(pred_smiles, true_smiles, all_fingerprints)
                results[method.lower()].append(similarity)
            
            # Weighted ensemble score
            total_weight = sum(w for _, _, w in preds)
            if total_weight > 0:
                ensemble_score = sum(
                    tanimoto_similarity(pred, true_smiles, all_fingerprints) * w
                    for _, pred, w in preds
                ) / total_weight
                results['ensemble'].append(ensemble_score)
        
        return {k: np.mean(v) if v else 0 for k, v in results.items()}

print('Integration system ready')


Integration system ready


In [22]:
# cell 22
import logging
import torch
import numpy as np
from sklearn.model_selection import KFold
import optuna
from IPython.display import FileLink

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Verify required variables
required_vars = ['MSMSDataset', 'config', 'objective', 'ssl_pretrain', 'supervised_train', 
                'token_to_idx', 'idx_to_token', 'vocab_size', 'PRETRAIN_MAX_LEN', 'SUPERVISED_MAX_LEN', 'device']
for var in required_vars:
    if var not in globals():
        raise NameError(f"{var} is not defined. Ensure it is defined in previous cells.")

# Verify datasets
if 'df_massspecgym' not in globals():
    raise NameError("df_massspecgym is not defined. Please load the dataset in Cell 3.")
if 'df_external' not in globals():
    raise NameError("df_external is not defined. Please load the dataset.")

# Remove duplicate columns
df_massspecgym = df_massspecgym.loc[:, ~df_massspecgym.columns.duplicated(keep='first')]
df_external = df_external.loc[:, ~df_external.columns.duplicated(keep='first')]
logging.info(f"df_massspecgym columns after removing duplicates: {df_massspecgym.columns.tolist()}")
logging.info(f"df_external columns after removing duplicates: {df_external.columns.tolist()}")

# Verify 'smiles' column exists
if 'smiles' not in df_massspecgym.columns:
    raise ValueError("No 'smiles' column in df_massspecgym after removing duplicates.")
if 'smiles' not in df_external.columns:
    raise ValueError("No 'smiles' column in df_external after removing duplicates.")

# Cross-validation setup
kf = KFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []

# Debug/interactive safety
FAST_DEBUG = globals().get('FAST_DEBUG', False)
if FAST_DEBUG:
    logging.info('FAST_DEBUG mode enabled: using small subsets, fewer epochs and trials, single-worker loaders')

# Create external dataset
external_dataset = MSMSDataset(df_external, max_len=SUPERVISED_MAX_LEN, is_ssl=False)
logging.info(f"External dataset size: {len(external_dataset)}")
external_loader = None  # Set to None as in original code; update if needed

# Prefer torch_geometric DataLoader
use_geo_loader = False
try:
    from torch_geometric.loader import DataLoader as GeoDataLoader
    use_geo_loader = True
except Exception:
    from torch.utils.data import DataLoader as TorchDataLoader
    GeoDataLoader = None
    logging.info("Using torch.utils.data.DataLoader as fallback.")

for fold, (train_idx, val_idx) in enumerate(kf.split(df_massspecgym)):
    logging.info(f"\nFold {fold+1}/5")
    train_data = df_massspecgym.iloc[train_idx].copy()
    val_data = df_massspecgym.iloc[val_idx].copy()
    ssl_data = train_data.sample(frac=0.3, random_state=42).copy()

    # FAST_DEBUG sampling
    if FAST_DEBUG:
        train_data = train_data.sample(n=min(512, len(train_data)), random_state=42)
        val_data = val_data.sample(n=min(128, len(val_data)), random_state=42)
        ssl_data = ssl_data.sample(n=min(256, len(ssl_data)), random_state=42)
        optuna_trials = 2
        ssl_epochs = 1
        supervised_epochs = 1
    else:
        optuna_trials = 10
        ssl_epochs = config.SSL_EPOCHS
        supervised_epochs = config.SUPERVISED_EPOCHS

    # DataLoader parameters
    workers = 0  # Single worker for Jupyter stability
    pin_memory = False  # Disable for notebook safety

    # Batch sizes
    if torch.cuda.is_available():
        train_bs = max(4, int(config.BATCH_SIZE // 8))
        val_bs = max(4, int(config.BATCH_SIZE // 8))
        ssl_bs = max(8, int(config.BATCH_SIZE // 4))
    else:
        train_bs = max(8, config.BATCH_SIZE)
        val_bs = max(8, config.BATCH_SIZE)
        ssl_bs = max(32, config.BATCH_SIZE)

    # Build datasets and loaders
    train_dataset = MSMSDataset(train_data, max_len=SUPERVISED_MAX_LEN, is_ssl=False)
    val_dataset = MSMSDataset(val_data, max_len=SUPERVISED_MAX_LEN, is_ssl=False)
    ssl_dataset = MSMSDataset(ssl_data, max_len=PRETRAIN_MAX_LEN, is_ssl=True)

    if use_geo_loader and GeoDataLoader is not None:
        train_loader = GeoDataLoader(train_dataset, batch_size=train_bs, shuffle=True)
        val_loader = GeoDataLoader(val_dataset, batch_size=val_bs, shuffle=False)
        ssl_loader = GeoDataLoader(ssl_dataset, batch_size=ssl_bs, shuffle=True)
    else:
        train_loader = TorchDataLoader(train_dataset, batch_size=train_bs, shuffle=True, num_workers=workers, pin_memory=pin_memory)
        val_loader = TorchDataLoader(val_dataset, batch_size=val_bs, shuffle=False, num_workers=workers, pin_memory=pin_memory)
        ssl_loader = TorchDataLoader(ssl_dataset, batch_size=ssl_bs, shuffle=True, num_workers=workers, pin_memory=pin_memory)

    # Hyperparameter tuning
    study = optuna.create_study(direction='minimize')
    study.optimize(lambda trial: objective(trial, train_data, val_data), n_trials=optuna_trials)
    best_lr = study.best_params.get('lr', config.LEARNING_RATE)
    logging.info(f"Best learning rate for fold {fold+1}: {best_lr:.6f}")

    # Initialize and train model with multiple OOM retries
    max_retries = 3
    retry_count = 0
    while retry_count < max_retries:
        try:
            model = MSMS2SmilesHybrid(vocab_size=vocab_size, d_model=config.D_MODEL, nhead=config.NHEAD, num_layers=config.NUM_LAYERS).to(device)
            logging.info(f"Starting SSL pretraining for fold {fold+1} with train_bs={train_bs}, ssl_bs={ssl_bs}...")
            ssl_pretrain(model, ssl_loader, epochs=ssl_epochs, lr=best_lr)
            logging.info(f"Starting supervised training for fold {fold+1}...")
            best_val_loss = supervised_train(model, train_loader, val_loader, epochs=supervised_epochs, lr=best_lr, patience=config.PATIENCE)
            break
        except RuntimeError as e:
            if 'out of memory' in str(e).lower() and retry_count < max_retries - 1 and torch.cuda.is_available():
                logging.warning(f"CUDA OOM detected (retry {retry_count+1}/{max_retries}). Reducing batch sizes...")
                try:
                    torch.cuda.empty_cache()
                except Exception:
                    pass
                train_bs = max(2, train_bs // 2)
                val_bs = max(2, val_bs // 2)
                ssl_bs = max(4, ssl_bs // 2)
                if use_geo_loader and GeoDataLoader is not None:
                    train_loader = GeoDataLoader(train_dataset, batch_size=train_bs, shuffle=True)
                    val_loader = GeoDataLoader(val_dataset, batch_size=val_bs, shuffle=False)
                    ssl_loader = GeoDataLoader(ssl_dataset, batch_size=ssl_bs, shuffle=True)
                else:
                    train_loader = TorchDataLoader(train_dataset, batch_size=train_bs, shuffle=True, num_workers=0, pin_memory=False)
                    val_loader = TorchDataLoader(val_dataset, batch_size=val_bs, shuffle=False, num_workers=0, pin_memory=False)
                    ssl_loader = TorchDataLoader(ssl_dataset, batch_size=ssl_bs, shuffle=True, num_workers=0, pin_memory=False)
                logging.info(f"New batch sizes -> train: {train_bs}, val: {val_bs}, ssl: {ssl_bs}. Retrying fold...")
                retry_count += 1
            else:
                raise

    fold_results.append(best_val_loss)
    model_path = f'best_msms_hybrid_fold_{fold+1}.pt'
    torch.save({
        'model_state_dict': model.state_dict(),
        'token_to_idx': token_to_idx,
        'idx_to_token': idx_to_token
    }, model_path)
    
    try:
        display(FileLink(model_path))
    except Exception as e:
        logging.warning(f"Failed to create FileLink for {model_path}: {e}")

logging.info(f"Cross-validation results: {fold_results}")
logging.info(f"Average validation loss: {np.mean(fold_results):.4f}")

# Training time estimation
total_samples = len(df_massspecgym)
samples_per_epoch = total_samples // max(1, train_bs)
total_epochs = config.N_FOLDS * (ssl_epochs + supervised_epochs)
estimated_hours = (samples_per_epoch * total_epochs * 0.5) / 3600
logging.info(f"Estimated training time: {estimated_hours:.1f} hours ({estimated_hours/24:.1f} days)")
logging.info("With RTX 3080 Ti optimizations, expect 8-12 hours total.")



2025-11-19 04:05:12,214 INFO df_massspecgym columns after removing duplicates: ['identifier', 'mzs', 'intensities', 'smiles', 'inchikey', 'formula', 'precursor_formula', 'parent_mass', 'precursor_mz', 'adduct', 'instrument_type', 'collision_energy', 'fold', 'simulation_challenge', 'ion_mode', 'precursor_bin', 'adduct_idx', 'binned', 'graph_data']
2025-11-19 04:05:12,214 INFO df_external columns after removing duplicates: ['identifier', 'mzs', 'intensities', 'smiles', 'inchikey', 'formula', 'precursor_formula', 'parent_mass', 'precursor_mz', 'adduct', 'instrument_type', 'collision_energy', 'fold', 'simulation_challenge', 'ion_mode', 'precursor_bin', 'adduct_idx', 'binned', 'graph_data']
2025-11-19 04:05:19,131 INFO External dataset size: 23111
2025-11-19 04:05:19,142 INFO 
Fold 1/5
[I 2025-11-19 04:08:47,983] A new study created in memory with name: no-name-46d1e456-7aae-4c15-be2a-7a8ec35f18b3
[I 2025-11-19 04:08:47,984] Trial 0 finished with value: 0.0002792285899920133 and parameters:

SSL Epoch 1:   0%|          | 0/8438 [00:00<?, ?it/s]

  with autocast():


SSL Epoch 1 Loss: 0.8453


SSL Epoch 2:   0%|          | 0/8438 [00:00<?, ?it/s]

SSL Epoch 2 Loss: 0.6898


SSL Epoch 3:   0%|          | 0/8438 [00:00<?, ?it/s]

2025-11-19 04:48:04,518 INFO Starting supervised training for fold 1...


SSL Epoch 3 Loss: 0.5377


  scaler = GradScaler()


Train Epoch 1:   0%|          | 0/56254 [00:00<?, ?it/s]

  with autocast():
  with autocast():


Epoch 1: Train Loss: 0.3470, Val Loss: 1.8403


Train Epoch 2:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 2: Train Loss: 0.2281, Val Loss: 2.3627


Train Epoch 3:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 3: Train Loss: 0.1886, Val Loss: 2.8418


Train Epoch 4:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 4: Train Loss: 0.1665, Val Loss: 2.5303


Train Epoch 5:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 5: Train Loss: 0.1519, Val Loss: 2.0756


Train Epoch 6:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 6: Train Loss: 0.1415, Val Loss: 1.9566
Early stopping at epoch 6


2025-11-19 11:11:02,703 INFO 
Fold 2/5
[I 2025-11-19 11:14:31,674] A new study created in memory with name: no-name-ec7f8262-2608-4a76-89eb-8b16e19d51c1
[I 2025-11-19 11:14:31,674] Trial 0 finished with value: 6.194087750668737e-05 and parameters: {'lr': 6.194087750668737e-05}. Best is trial 0 with value: 6.194087750668737e-05.
[I 2025-11-19 11:14:31,675] Trial 1 finished with value: 0.00021306179025333818 and parameters: {'lr': 0.00021306179025333818}. Best is trial 0 with value: 6.194087750668737e-05.
[I 2025-11-19 11:14:31,675] Trial 2 finished with value: 1.5359791733001554e-05 and parameters: {'lr': 1.5359791733001554e-05}. Best is trial 2 with value: 1.5359791733001554e-05.
[I 2025-11-19 11:14:31,676] Trial 3 finished with value: 0.0007827674805337467 and parameters: {'lr': 0.0007827674805337467}. Best is trial 2 with value: 1.5359791733001554e-05.
[I 2025-11-19 11:14:31,676] Trial 4 finished with value: 0.0003959116260564962 and parameters: {'lr': 0.0003959116260564962}. Best is

SSL Epoch 1:   0%|          | 0/8438 [00:00<?, ?it/s]

SSL Epoch 1 Loss: 0.8225


SSL Epoch 2:   0%|          | 0/8438 [00:00<?, ?it/s]

SSL Epoch 2 Loss: 0.6171


SSL Epoch 3:   0%|          | 0/8438 [00:00<?, ?it/s]

2025-11-19 11:53:56,408 INFO Starting supervised training for fold 2...


SSL Epoch 3 Loss: 0.4077


Train Epoch 1:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 1: Train Loss: 0.3164, Val Loss: 1.6130


Train Epoch 2:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 2: Train Loss: 0.2032, Val Loss: 1.9953


Train Epoch 3:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 3: Train Loss: 0.1686, Val Loss: 2.2239


Train Epoch 4:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 4: Train Loss: 0.1496, Val Loss: 1.9556


Train Epoch 5:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 5: Train Loss: 0.1372, Val Loss: 2.2566


Train Epoch 6:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 6: Train Loss: 0.1283, Val Loss: 2.1675
Early stopping at epoch 6


2025-11-19 18:18:58,154 INFO 
Fold 3/5
[I 2025-11-19 18:22:27,001] A new study created in memory with name: no-name-72521e91-9284-4ec9-9762-91314a402c07
[I 2025-11-19 18:22:27,002] Trial 0 finished with value: 0.0001032367669556582 and parameters: {'lr': 0.0001032367669556582}. Best is trial 0 with value: 0.0001032367669556582.
[I 2025-11-19 18:22:27,003] Trial 1 finished with value: 5.966805509951177e-05 and parameters: {'lr': 5.966805509951177e-05}. Best is trial 1 with value: 5.966805509951177e-05.
[I 2025-11-19 18:22:27,003] Trial 2 finished with value: 3.213029237082538e-05 and parameters: {'lr': 3.213029237082538e-05}. Best is trial 2 with value: 3.213029237082538e-05.
[I 2025-11-19 18:22:27,003] Trial 3 finished with value: 2.586835394494521e-05 and parameters: {'lr': 2.586835394494521e-05}. Best is trial 3 with value: 2.586835394494521e-05.
[I 2025-11-19 18:22:27,004] Trial 4 finished with value: 6.295667178538748e-05 and parameters: {'lr': 6.295667178538748e-05}. Best is trial

SSL Epoch 1:   0%|          | 0/8438 [00:00<?, ?it/s]

SSL Epoch 1 Loss: 0.7675


SSL Epoch 2:   0%|          | 0/8438 [00:00<?, ?it/s]

SSL Epoch 2 Loss: 0.4179


SSL Epoch 3:   0%|          | 0/8438 [00:00<?, ?it/s]

2025-11-19 19:01:52,299 INFO Starting supervised training for fold 3...


SSL Epoch 3 Loss: 0.2390


Train Epoch 1:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 1: Train Loss: 0.2722, Val Loss: 1.6226


Train Epoch 2:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 2: Train Loss: 0.1717, Val Loss: 1.6672


Train Epoch 3:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 3: Train Loss: 0.1440, Val Loss: 1.5602


Train Epoch 4:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 4: Train Loss: 0.1292, Val Loss: 1.8181


Train Epoch 5:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 5: Train Loss: 0.1197, Val Loss: 1.7263


Train Epoch 6:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 6: Train Loss: 0.1130, Val Loss: 1.8249


Train Epoch 7:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 7: Train Loss: 0.1077, Val Loss: 2.2292


Train Epoch 8:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 8: Train Loss: 0.1035, Val Loss: 2.2150
Early stopping at epoch 8


2025-11-20 03:33:59,431 INFO 
Fold 4/5
[I 2025-11-20 03:37:33,627] A new study created in memory with name: no-name-45c52553-aafd-4ceb-87e2-cf656ea0f5bf
[I 2025-11-20 03:37:33,633] Trial 0 finished with value: 0.0009244023700960408 and parameters: {'lr': 0.0009244023700960408}. Best is trial 0 with value: 0.0009244023700960408.
[I 2025-11-20 03:37:33,633] Trial 1 finished with value: 3.964538104980318e-05 and parameters: {'lr': 3.964538104980318e-05}. Best is trial 1 with value: 3.964538104980318e-05.
[I 2025-11-20 03:37:33,634] Trial 2 finished with value: 1.0294526142437247e-05 and parameters: {'lr': 1.0294526142437247e-05}. Best is trial 2 with value: 1.0294526142437247e-05.
[I 2025-11-20 03:37:33,634] Trial 3 finished with value: 7.527120538990596e-05 and parameters: {'lr': 7.527120538990596e-05}. Best is trial 2 with value: 1.0294526142437247e-05.
[I 2025-11-20 03:37:33,635] Trial 4 finished with value: 1.1238236740539683e-05 and parameters: {'lr': 1.1238236740539683e-05}. Best is

SSL Epoch 1:   0%|          | 0/8438 [00:00<?, ?it/s]

SSL Epoch 1 Loss: 0.8571


SSL Epoch 2:   0%|          | 0/8438 [00:00<?, ?it/s]

SSL Epoch 2 Loss: 0.7143


SSL Epoch 3:   0%|          | 0/8438 [00:00<?, ?it/s]

2025-11-20 04:17:00,880 INFO Starting supervised training for fold 4...


SSL Epoch 3 Loss: 0.5702


Train Epoch 1:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 1: Train Loss: 0.3572, Val Loss: 1.6442


Train Epoch 2:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 2: Train Loss: 0.2352, Val Loss: 1.6139


Train Epoch 3:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 3: Train Loss: 0.1944, Val Loss: 1.6089


Train Epoch 4:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 4: Train Loss: 0.1714, Val Loss: 1.5430


Train Epoch 5:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 5: Train Loss: 0.1564, Val Loss: 1.6183


Train Epoch 6:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 6: Train Loss: 0.1454, Val Loss: 1.5469


Train Epoch 7:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 7: Train Loss: 0.1372, Val Loss: 1.5390


Train Epoch 8:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 8: Train Loss: 0.1307, Val Loss: 1.5371


Train Epoch 9:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 9: Train Loss: 0.1255, Val Loss: 1.4400


Train Epoch 10:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 10: Train Loss: 0.1211, Val Loss: 1.6763


Train Epoch 11:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 11: Train Loss: 0.1174, Val Loss: 1.4413


Train Epoch 12:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 12: Train Loss: 0.1142, Val Loss: 1.4727


Train Epoch 13:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 13: Train Loss: 0.1115, Val Loss: 1.3508


Train Epoch 14:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 14: Train Loss: 0.1091, Val Loss: 1.4924


Train Epoch 15:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 15: Train Loss: 0.1070, Val Loss: 1.4519


Train Epoch 16:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 16: Train Loss: 0.1052, Val Loss: 1.4772


Train Epoch 17:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 17: Train Loss: 0.1035, Val Loss: 1.4576


Train Epoch 18:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 18: Train Loss: 0.1021, Val Loss: 1.7320
Early stopping at epoch 18


2025-11-20 23:32:06,361 INFO 
Fold 5/5
[I 2025-11-20 23:35:35,163] A new study created in memory with name: no-name-eac48373-4469-4e18-aba9-2856d7c877b9
[I 2025-11-20 23:35:35,164] Trial 0 finished with value: 0.00013692890054654416 and parameters: {'lr': 0.00013692890054654416}. Best is trial 0 with value: 0.00013692890054654416.
[I 2025-11-20 23:35:35,164] Trial 1 finished with value: 5.131721849901983e-05 and parameters: {'lr': 5.131721849901983e-05}. Best is trial 1 with value: 5.131721849901983e-05.
[I 2025-11-20 23:35:35,164] Trial 2 finished with value: 1.0814978119940094e-05 and parameters: {'lr': 1.0814978119940094e-05}. Best is trial 2 with value: 1.0814978119940094e-05.
[I 2025-11-20 23:35:35,165] Trial 3 finished with value: 1.3455377067737503e-05 and parameters: {'lr': 1.3455377067737503e-05}. Best is trial 2 with value: 1.0814978119940094e-05.
[I 2025-11-20 23:35:35,165] Trial 4 finished with value: 9.961255485368533e-05 and parameters: {'lr': 9.961255485368533e-05}. Best

SSL Epoch 1:   0%|          | 0/8438 [00:00<?, ?it/s]

SSL Epoch 1 Loss: 0.8478


SSL Epoch 2:   0%|          | 0/8438 [00:00<?, ?it/s]

SSL Epoch 2 Loss: 0.7081


SSL Epoch 3:   0%|          | 0/8438 [00:00<?, ?it/s]

2025-11-21 00:15:02,410 INFO Starting supervised training for fold 5...


SSL Epoch 3 Loss: 0.5749


Train Epoch 1:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 1: Train Loss: 0.3542, Val Loss: 2.6217


Train Epoch 2:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 2: Train Loss: 0.2327, Val Loss: 2.1373


Train Epoch 3:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 3: Train Loss: 0.1923, Val Loss: 2.9632


Train Epoch 4:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 4: Train Loss: 0.1699, Val Loss: 3.6648


Train Epoch 5:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 5: Train Loss: 0.1550, Val Loss: 3.3032


Train Epoch 6:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 6: Train Loss: 0.1444, Val Loss: 3.5895


Train Epoch 7:   0%|          | 0/56254 [00:00<?, ?it/s]

Epoch 7: Train Loss: 0.1362, Val Loss: 3.4068
Early stopping at epoch 7


2025-11-21 07:44:25,403 INFO Cross-validation results: [1.8402759010932144, 1.6129989732794285, 1.5601578291889124, 1.3507649380442464, 2.137286630874464]
2025-11-21 07:44:25,404 INFO Average validation loss: 1.7003
2025-11-21 07:44:25,404 INFO Estimated training time: 1611.4 hours (67.1 days)
2025-11-21 07:44:25,405 INFO With RTX 3080 Ti optimizations, expect 8-12 hours total.


Train Epoch 2:   0%|          | 0/56254 [00:00<?, ?it/s]

In [23]:
#cell 23
# Load the best trained model
model = MSMS2SmilesHybrid(vocab_size=vocab_size, d_model=config.D_MODEL, nhead=config.NHEAD, num_layers=config.NUM_LAYERS).to(device)
checkpoint = torch.load('best_msms_hybrid_fold_1.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print('Model loaded successfully')


Model loaded successfully


In [24]:
#cell 24
# External dataset evaluation and visualization
model.eval()
external_metrics = {'tanimoto': [], 'dice': [], 'mcs': [], 'mw_diff': [], 'logp_diff': [], 'substructure': []}
pred_smiles_list = []
true_smiles_list = []
adducts_list = []
num_samples = min(5, len(external_dataset))

for sample_idx in range(num_samples):
    sample_spectrum = external_dataset[sample_idx][0]
    sample_graph = external_dataset[sample_idx][1]
    sample_ion_mode = external_dataset[sample_idx][3]
    sample_precursor_bin = external_dataset[sample_idx][4]
    sample_adduct_idx = external_dataset[sample_idx][5]
    true_smiles = external_dataset[sample_idx][6]

    predicted_results = beam_search(model, sample_spectrum, sample_graph, sample_ion_mode, sample_precursor_bin, sample_adduct_idx, true_smiles, beam_width=10, max_len=SUPERVISED_MAX_LEN, device=device)
    pred_smiles_list.extend([smiles for smiles, _ in predicted_results])
    true_smiles_list.extend([true_smiles] * len(predicted_results))
    adducts_list.extend([df_external.iloc[sample_idx]['adduct']] * len(predicted_results))

    print(f"\nExternal Sample {sample_idx} - True SMILES: {true_smiles}")
    print("Top Predicted SMILES:")
    for smiles, confidence in predicted_results[:3]:
        external_metrics['tanimoto'].append(tanimoto_similarity(smiles, true_smiles, all_fingerprints))
        external_metrics['dice'].append(dice_similarity(smiles, true_smiles))
        external_metrics['mcs'].append(mcs_similarity(smiles, true_smiles))
        external_metrics['mw_diff'].append(mw_difference(smiles, true_smiles))
        external_metrics['logp_diff'].append(logp_difference(smiles, true_smiles))
        external_metrics['substructure'].append(substructure_match(smiles, true_smiles, model.gnn_encoder.substructures))
        print(f"SMILES: {smiles}, Confidence: {confidence:.4f}, Tanimoto: {external_metrics['tanimoto'][-1]:.4f}, Dice: {external_metrics['dice'][-1]:.4f}, MCS: {external_metrics['mcs'][-1]:.4f}")
        if len(smiles) > 100 and smiles.count('C') > len(smiles) * 0.8:
            print("Warning: Predicted SMILES is a long carbon chain, indicating potential model underfitting.")
        if smiles != "Invalid SMILES":
            mol = Chem.MolFromSmiles(smiles, sanitize=True)
            if mol:
                print(f"Molecular Weight: {Descriptors.MolWt(mol):.2f}, LogP: {Descriptors.MolLogP(mol):.2f}")

    # Visualize molecules
    if predicted_results[0][0] != "Invalid SMILES":
        pred_mol = Chem.MolFromSmiles(predicted_results[0][0], sanitize=True)
        true_mol = Chem.MolFromSmiles(true_smiles, sanitize=True)
        if pred_mol and true_mol:
            img = Draw.MolsToGridImage([true_mol, pred_mol], molsPerRow=2, subImgSize=(300, 300), legends=['True', 'Predicted'])
            img_array = np.array(img.convert('RGB'))
            plt.figure(figsize=(10, 5))
            plt.imshow(img_array)
            plt.axis('off')
            plt.title(f"External Sample {sample_idx} - Tanimoto: {external_metrics['tanimoto'][0]:.4f}")
            plt.show()

    # Visualize attention and GNN weights for first sample
    if sample_idx == 0:
        with torch.no_grad():
            spectrum = sample_spectrum.unsqueeze(0).to(device)
            graph_data = Batch.from_data_list([sample_graph]).to(device)
            ion_mode_idx = torch.tensor([sample_ion_mode], dtype=torch.long).to(device)
            precursor_idx = torch.tensor([sample_precursor_bin], dtype=torch.long).to(device)
            adduct_idx = torch.tensor([sample_adduct_idx], dtype=torch.long).to(device)
            _, attn_weights = model.transformer_encoder(spectrum, ion_mode_idx, precursor_idx, adduct_idx)
            _, _, edge_weights = model.gnn_encoder(graph_data, ion_mode_idx, precursor_idx, adduct_idx)
            plot_attention_weights(attn_weights, title=f"External Fold Transformer Attention Weights")
            plot_gnn_edge_weights(edge_weights, sample_graph.edge_index, title=f"External Fold GNN Edge Importance")

# Final Evaluation
print(f"External Validity Rate: {validity_rate(pred_smiles_list):.2f}%")
print(f"External Prediction Diversity: {prediction_diversity(pred_smiles_list):.4f}")
print("External Metrics Summary:")
print(f"Avg Tanimoto: {np.mean(external_metrics['tanimoto']):.4f}")
print(f"Avg Dice: {np.mean(external_metrics['dice']):.4f}")
print(f"Avg MCS: {np.mean(external_metrics['mcs']):.4f}")
print(f"Avg MW Difference: {np.mean([x for x in external_metrics['mw_diff'] if x != float('inf')]):.2f}")
print(f"Avg LogP Difference: {np.mean([x for x in external_metrics['logp_diff'] if x != float('inf')]):.2f}")
print(f"Avg Substructure Match: {np.mean(external_metrics['substructure']):.4f}")
error_analysis(pred_smiles_list, true_smiles_list, adducts_list, all_fingerprints)


KeyError: '< SOS >'

backup code: 

In [None]:
# Additional helpers merged from notebook-for-pc.ipynb - added without removing any existing features
# 1) SMILES/SELFIES validators and plausibility checks
from rdkit import Chem
from rdkit.Chem import Descriptors, rdFMCS
from rdkit import DataStructs
import math
import numpy as np
import logging

# Define a conservative set of valid atom symbols used for basic syntactic checks
valid_atoms = set([
    'H','B','C','N','O','F','P','S','Cl','Br','I',
    'c','n','o','s','p'  # aromatic/lowercase tokens sometimes used in SMILES-like checks
])

def is_valid_smiles_syntax(smiles):
    "Basic SMILES syntax validator: bracket and paren matching + simple token checks."
    if not isinstance(smiles, str) or len(smiles) == 0:
        return False
    stack = []
    for c in smiles:
        if c in '([':
            stack.append(c)
        elif c == ')':
            if not stack or stack[-1] != '(':
                return False
            stack.pop()
        elif c == ']':
            if not stack or stack[-1] != '[':
                return False
            stack.pop()
    if stack:
        return False
    i = 0
    while i < len(smiles):
        if smiles[i] == '[':
            j = smiles.find(']', i)
            if j == -1:
                return False
            atom = smiles[i+1:j]
            # simple check: at least one valid atom symbol is present
            if not any(a in atom for a in valid_atoms):
                return False
            i = j + 1
        else:
            if smiles[i] in valid_atoms or smiles[i] in '()=#/\\@.:+-0123456789%[]':
                i += 1
            else:
                # allow SELFIES tokens and other chars; fallback to RDKit for final check
                i += 1
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        return mol is not None
    except Exception:
        return False

def is_plausible_molecule(smiles, true_mol=None, max_mw=1500, min_logp=-7, max_logp=7):
    "Basic RDKit plausibility check used for filtering beam search outputs."
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        if not mol or not is_valid_smiles_syntax(smiles):
            return False
        mw = Descriptors.MolWt(mol)
        logp = Descriptors.MolLogP(mol)
        true_mw = Descriptors.MolWt(true_mol) if true_mol else None
        if mw > max_mw:
            return False
        if not (min_logp <= logp <= max_logp):
            return False
        if true_mw is not None and abs(mw - true_mw) > 300:
            return False
        return True
    except Exception as e:
        logging.debug(f'is_plausible_molecule error: {e}')
        return False

# 2) Safe adduct mapping helper (builds mapping if missing and fills NaNs)
def ensure_adduct_mapping(df_massspecgym, df_external=None):
    "Builds adduct_types and adduct_to_idx if not present, fills NaNs in adduct_idx."
    global adduct_types, adduct_to_idx
    try:
        if 'adduct_to_idx' in globals() and 'adduct_types' in globals():
            return adduct_to_idx
    except Exception:
        pass
    adduct_types = df_massspecgym['adduct'].dropna().unique().tolist()
    adduct_to_idx = {adduct: i for i, adduct in enumerate(adduct_types)}
    df_massspecgym['adduct_idx'] = df_massspecgym['adduct'].map(adduct_to_idx).fillna(0).astype(int)
    if df_external is not None:
        df_external['adduct_idx'] = df_external['adduct'].map(adduct_to_idx).fillna(0).astype(int)
    return adduct_to_idx

# attempt to ensure mapping if dataframes exist in notebook globals
try:
    if 'df_massspecgym' in globals():
        ensure_adduct_mapping(df_massspecgym, df_external if 'df_external' in globals() else None)
except Exception as e:
    logging.debug(f'ensure_adduct_mapping failed: {e}')

# 3) Enhanced beam search variant (keeps original beam_search intact)
def beam_search_enhanced(model, spectrum, graph_data, ion_mode_idx, precursor_idx, adduct_idx, true_smiles=None, beam_width=8, max_len=150, nucleus_p=0.9, device='cpu'):
    "Beam-search that applies SMILES/SELFIES syntax checks, stereochemistry boosting, and plausibility filters."
    model.eval()
    true_mol = Chem.MolFromSmiles(true_smiles) if true_smiles else None
    with torch.no_grad():
        spectrum = spectrum.unsqueeze(0).to(device) if isinstance(spectrum, torch.Tensor) else torch.tensor(spectrum, dtype=torch.float).unsqueeze(0).to(device)
        try:
            graph_batch = Batch.from_data_list([graph_data]).to(device)
        except Exception:
            graph_batch = graph_data if hasattr(graph_data, 'batch') else graph_data
        ion_mode_idx = torch.tensor([int(ion_mode_idx)], dtype=torch.long).to(device)
        precursor_idx = torch.tensor([int(precursor_idx)], dtype=torch.long).to(device)
        adduct_idx = torch.tensor([int(adduct_idx)], dtype=torch.long).to(device)
        sequences = [([token_to_idx.get(config.SOS_TOKEN, token_to_idx.get(PAD_TOKEN, 0))], 0.0)]

        for _ in range(max_len):
            all_candidates = []
            for seq, score in sequences:
                if seq[-1] == token_to_idx.get(config.EOS_TOKEN, token_to_idx.get(EOS_TOKEN, 0)):
                    all_candidates.append((seq, score))
                    continue
                tgt_input = torch.tensor([seq], dtype=torch.long).to(device)
                # generate mask if model provides helper, otherwise create causal mask
                tgt_mask = None
                try:
                    outputs = model.decoder(tgt_input, model.combine_layer(torch.cat([model.transformer_encoder(spectrum), model.gnn_encoder(graph_batch)], dim=-1)).unsqueeze(1), None)
                    logits = outputs[0, -1]
                except Exception:
                    # Fallback to calling model forward if decoder signature differs
                    outputs = model(spectrum, graph_batch, tgt_input)
                    logits = outputs[0][0, -1] if isinstance(outputs, tuple) else outputs[0, -1]
                log_probs = F.log_softmax(logits, dim=-1).cpu().numpy()
                # boost stereochemistry tokens if present
                for tok in ['@','/','\\']:
                    if tok in token_to_idx:
                        log_probs[token_to_idx[tok]] += 0.25
                topk = np.argsort(log_probs)[-min(len(log_probs), beam_width*4):][::-1]
                for tok in topk[:beam_width]:
                    new_seq = seq + [int(tok)]
                    new_score = score + float(log_probs[tok])
                    # quick syntax check on partial reconstruction
                    partial = ''.join([idx_to_token.get(i, '') for i in new_seq[1:]])
                    if not is_valid_smiles_syntax(partial):
                        # allow but penalize
                        new_score -= 1.0
                    all_candidates.append((new_seq, new_score))
            sequences = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]
            if all(seq[-1] == token_to_idx.get(config.EOS_TOKEN, token_to_idx.get(EOS_TOKEN, 0)) for seq, _ in sequences):
                break

        results = []
        for seq, score in sequences:
            toks = [idx_to_token.get(i, '') for i in seq[1:]]
            cand_str = ''.join([t for t in toks if t not in {config.PAD_TOKEN, config.SOS_TOKEN, config.EOS_TOKEN}])
            # attempt to decode either as SELFIES or SMILES depending on available decoders
            smiles = None
            try:
                # prefer SELFIES decode if token set is SELFIES-like
                if 'sf' in globals():
                    s = ''.join(toks)
                    smiles = sf.decoder(s) if s else ''
            except Exception:
                smiles = None
            if not smiles:
                smiles = cand_str
            if smiles and is_plausible_molecule(smiles, true_mol):
                results.append((smiles, math.exp(score / max(1, len(seq)))))
        return results if results else [(, 0.0)]

# Small runtime check to show new helpers are present
print('Merged helpers: is_valid_smiles_syntax, is_plausible_molecule, ensure_adduct_mapping, beam_search_enhanced available')
