In [None]:
# Install required packages for MS-to-Structure pipeline
!pip install torch torch_geometric rdkit-pypi selfies datasets optuna nltk python-Levenshtein tqdm scikit-learn matplotlib

# MS-to-Structure Deep Learning Pipeline (Jupyter Version)

This notebook implements a robust mass spectrometry-to-structure (MS-to-structure) deep learning pipeline, adapted for interactive use. It includes data preprocessing, molecular string handling with SELFIES, model definition, training, and evaluation.

In [None]:
# Import libraries and set up logging for Jupyter compatibility
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.data import Data, Batch
from datasets import load_dataset
from rdkit import Chem
from rdkit.Chem import Draw, Descriptors, rdFMCS, EnumerateStereoisomers
from rdkit import DataStructs
from rdkit.Chem import rdFingerprintGenerator
from tqdm.notebook import tqdm
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import selfies as sf
import optuna
from nltk.translate.bleu_score import sentence_bleu
from Levenshtein import distance
import logging
import traceback
import math

# Setup logging for Jupyter (prints to stdout)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(levelname)s %(message)s'
)


In [None]:
# Set random seed for reproducibility and define global variables
np.random.seed(42)
torch.manual_seed(42)

PAD_TOKEN = "<PAD>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"
MASK_TOKEN = "[MASK]"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
# Load and preprocess dataset (replace path as needed)
dataset = load_dataset('/kaggle/input/tandem', split='train')
df = pd.DataFrame(dataset)

# Simulate external dataset (e.g., NIST-like) by splitting
df_massspecgym, df_external = df.iloc[:int(0.9*len(df))], df.iloc[int(0.9*len(df)):]
print("MassSpecGym size:", len(df_massspecgym), "External test size:", len(df_external))

# Inspect dataset
print("Dataset Columns:", df_massspecgym.columns.tolist())
print("\nFirst few rows of MassSpecGym dataset:")
print(df_massspecgym[['identifier', 'mzs', 'intensities', 'smiles', 'adduct', 'precursor_mz']].head())
print("\nUnique adduct values:", df_massspecgym['adduct'].unique())


In [None]:
# Canonicalize SMILES, augment, and bin spectra
def canonicalize_smiles(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        if mol:
            return Chem.MolToSmiles(mol, canonical=True)
        return None
    except Exception as e:
        logging.error(f"canonicalize_smiles failed for {smiles}: {e}\n{traceback.format_exc()}")
        return None

def augment_smiles(smiles, max_isomers=8):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            opts = EnumerateStereoisomers.EnumerateStereoisomersOptions()
            opts.maxIsomers = max_isomers
            stereoisomers = EnumerateStereoisomers.EnumerateStereoisomers(mol, options=opts)
            return [Chem.MolToSmiles(m, canonical=True, doRandom=True) for m in stereoisomers]
        return [smiles]
    except Exception as e:
        logging.error(f"augment_smiles failed for {smiles}: {e}\n{traceback.format_exc()}")
        return [smiles]

def bin_spectrum_to_graph(mzs, intensities, ion_mode, precursor_mz, adduct, n_bins=1000, max_mz=1000, noise_level=0.05):
    try:
        spectrum = np.zeros(n_bins)
        for mz, intensity in zip(mzs, intensities):
            try:
                mz = float(mz)
                intensity = float(intensity)
                if mz < max_mz:
                    bin_idx = int((mz / max_mz) * n_bins)
                    spectrum[bin_idx] += intensity
            except (ValueError, TypeError) as e:
                logging.warning(f"bin_spectrum_to_graph: Skipping value error: {e}")
                continue
        if spectrum.max() > 0:
            spectrum = spectrum / spectrum.max()
        spectrum += np.random.normal(0, noise_level, spectrum.shape).clip(0, 1)
        x = torch.tensor(spectrum, dtype=torch.float).unsqueeze(-1)
        edge_index = []
        for i in range(n_bins-1):
            edge_index.append([i, i+1])
            edge_index.append([i+1, i])
        edge_index = torch.tensor(edge_index, dtype=torch.long).t()
        ion_mode = torch.tensor([ion_mode], dtype=torch.float)
        precursor_mz = torch.tensor([precursor_mz], dtype=torch.float)
        adduct_idx = adduct_to_idx.get(adduct, 0)
        return spectrum, Data(x=x, edge_index=edge_index, ion_mode=ion_mode, precursor_mz=precursor_mz, adduct_idx=adduct_idx)
    except Exception as e:
        logging.error(f"bin_spectrum_to_graph failed: {e}\n{traceback.format_exc()}")
        return np.zeros(n_bins), Data(x=torch.zeros(n_bins, 1), edge_index=torch.zeros(2, 0, dtype=torch.long), ion_mode=torch.zeros(1), precursor_mz=torch.zeros(1), adduct_idx=0)


In [None]:
# Apply canonicalization, augmentation, and binning to the dataframe
# Preprocess ion mode, precursor m/z, and adducts
df_massspecgym['smiles'] = df_massspecgym['smiles'].apply(canonicalize_smiles)
df_external['smiles'] = df_external['smiles'].apply(canonicalize_smiles)
df_massspecgym = df_massspecgym.dropna(subset=['smiles'])
df_external = df_external.dropna(subset=['smiles'])
df_massspecgym['smiles_list'] = df_massspecgym['smiles'].apply(augment_smiles)
df_massspecgym = df_massspecgym.explode('smiles_list').dropna(subset=['smiles_list']).rename(columns={'smiles_list': 'smiles'})

df_massspecgym['ion_mode'] = df_massspecgym['adduct'].apply(lambda x: 0 if '+' in str(x) else 1 if '-' in str(x) else 0).fillna(0)
df_massspecgym['precursor_bin'] = pd.qcut(df_massspecgym['precursor_mz'], q=100, labels=False, duplicates='drop')
df_external['ion_mode'] = df_external['adduct'].apply(lambda x: 0 if '+' in str(x) else 1 if '-' in str(x) else 0).fillna(0)
df_external['precursor_bin'] = pd.qcut(df_external['precursor_mz'], q=100, labels=False, duplicates='drop')
adduct_types = df_massspecgym['adduct'].unique()
adduct_to_idx = {adduct: i for i, adduct in enumerate(adduct_types)}
df_massspecgym['adduct_idx'] = df_massspecgym['adduct'].map(adduct_to_idx)
df_external['adduct_idx'] = df_external['adduct'].map(adduct_to_idx)

df_massspecgym[['binned', 'graph_data']] = df_massspecgym.apply(
    lambda row: pd.Series(bin_spectrum_to_graph(row['mzs'], row['intensities'], row['ion_mode'], row['precursor_mz'], row['adduct'])),
    axis=1
)
df_external[['binned', 'graph_data']] = df_external.apply(
    lambda row: pd.Series(bin_spectrum_to_graph(row['mzs'], row['intensities'], row['ion_mode'], row['precursor_mz'], row['adduct'])),
    axis=1
)


In [None]:
# SELFIES tokenization and vocabulary setup
all_smiles = df_massspecgym['smiles'].tolist()
all_selfies = [sf.encoder(s) for s in all_smiles]
selfies_alphabet = set()
for s in all_selfies:
    selfies_alphabet.update(sf.split_selfies(s))
selfies_tokens = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, MASK_TOKEN] + sorted(selfies_alphabet)
token_to_idx = {tok: i for i, tok in enumerate(selfies_tokens)}
idx_to_token = {i: tok for tok, i in token_to_idx.items()}
vocab_size = len(token_to_idx)
PRETRAIN_MAX_LEN = 100
SUPERVISED_MAX_LEN = max(len(sf.split_selfies(s)) + 2 for s in all_selfies)
print(f"SELFIES vocabulary size: {vocab_size}, Supervised MAX_LEN: {SUPERVISED_MAX_LEN}, Pretrain MAX_LEN: {PRETRAIN_MAX_LEN}")

def encode_selfies(selfies, max_len=PRETRAIN_MAX_LEN):
    tokens = [SOS_TOKEN] + sf.split_selfies(selfies)[:max_len-2] + [EOS_TOKEN]
    token_ids = [token_to_idx.get(tok, token_to_idx[PAD_TOKEN]) for tok in tokens]
    if len(token_ids) > max_len:
        token_ids = token_ids[:max_len]
    else:
        token_ids += [token_to_idx[PAD_TOKEN]] * (max_len - len(token_ids))
    return token_ids

def decode_selfies(token_ids):
    tokens = [idx_to_token.get(idx, PAD_TOKEN) for idx in token_ids]
    tokens = [t for t in tokens if t not in {PAD_TOKEN, SOS_TOKEN, EOS_TOKEN}]
    selfies_str = ''.join(tokens)
    try:
        smiles = sf.decoder(selfies_str)
        return smiles
    except Exception:
        return ""


In [None]:
# Precompute Morgan fingerprints for all unique SMILES
all_smiles = list(set(df_massspecgym['smiles'].tolist() + df_external['smiles'].tolist()))
all_fingerprints = {}
morgan_gen = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048)
for smiles in all_smiles:
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        all_fingerprints[smiles] = morgan_gen.GetFingerprint(mol)


In [None]:
# Dataset class for MS/MS data
class MSMSDataset(Dataset):
    def __init__(self, dataframe, max_len=PRETRAIN_MAX_LEN, is_ssl=False):
        self.spectra = np.stack(dataframe['binned'].values)
        self.graph_data = dataframe['graph_data'].values
        self.ion_modes = dataframe['ion_mode'].values
        self.precursor_bins = dataframe['precursor_bin'].values
        self.adduct_indices = dataframe['adduct_idx'].values
        self.raw_smiles = dataframe['smiles'].values
        self.is_ssl = is_ssl
        if is_ssl:
            self.smiles = []
            self.masked_smiles = []
            for s in self.raw_smiles:
                selfies = sf.encoder(s)
                masked_s, orig_s = self.mask_selfies(selfies)
                self.smiles.append(encode_selfies(orig_s, max_len))
                self.masked_smiles.append(encode_selfies(masked_s, max_len))
        else:
            self.smiles = [encode_selfies(sf.encoder(s), max_len=SUPERVISED_MAX_LEN) for s in self.raw_smiles]

    def mask_selfies(self, selfies, mask_ratio=0.10):
        try:
            tokens = sf.split_selfies(selfies)[:PRETRAIN_MAX_LEN-2]
            masked_tokens = tokens.copy()
            n_mask = int(mask_ratio * len(tokens))
            if n_mask > 0:
                mask_indices = np.random.choice(len(tokens), n_mask, replace=False)
                for idx in mask_indices:
                    masked_tokens[idx] = MASK_TOKEN
            return ''.join(masked_tokens), ''.join(tokens)
        except Exception as e:
            logging.error(f"mask_selfies failed for {selfies}: {e}\n{traceback.format_exc()}")
            return selfies, selfies

    def __len__(self):
        return len(self.spectra)

    def __getitem__(self, idx):
        if self.is_ssl:
            return (
                torch.tensor(self.spectra[idx], dtype=torch.float),
                self.graph_data[idx],
                torch.tensor(self.smiles[idx], dtype=torch.long),
                torch.tensor(self.masked_smiles[idx], dtype=torch.long),
                torch.tensor(self.ion_modes[idx], dtype=torch.long),
                torch.tensor(self.precursor_bins[idx], dtype=torch.long),
                torch.tensor(self.adduct_indices[idx], dtype=torch.long),
                self.raw_smiles[idx]
            )
        return (
            torch.tensor(self.spectra[idx], dtype=torch.float),
            self.graph_data[idx],
            torch.tensor(self.smiles[idx], dtype=torch.long),
            torch.tensor(self.ion_modes[idx], dtype=torch.long),
            torch.tensor(self.precursor_bins[idx], dtype=torch.long),
            torch.tensor(self.adduct_indices[idx], dtype=torch.long),
            self.raw_smiles[idx]
        )


In [None]:
# Positional encoding and model encoder/decoder classes
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# (Add SpectrumTransformerEncoder, SpectrumGNNEncoder, SmilesTransformerDecoder, MSMS2SmilesHybrid classes here, as in the script)
# For brevity, you can copy the class definitions from the script into this cell or split them into multiple cells if preferred.


In [None]:
# Training and evaluation utilities (SSL pretrain, supervised train, metrics, beam search, etc)
# Copy the relevant functions from the script here, e.g.:
# ssl_pretrain, supervised_train, is_valid_smiles_syntax, is_plausible_molecule, dice_similarity, mcs_similarity, mw_difference, logp_difference, substructure_match, validity_rate, tanimoto_similarity, prediction_diversity, beam_search, batch_beam_search, plot_attention_weights, plot_gnn_edge_weights, error_analysis, objective
# For brevity, you can split these into multiple cells if needed.


In [None]:
# Cross-validation, training, and evaluation loop
kf = KFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []

external_dataset = MSMSDataset(df_external, max_len=SUPERVISED_MAX_LEN, is_ssl=False)
external_loader = DataLoader(external_dataset, batch_size=32, num_workers=2)

for fold, (train_idx, val_idx) in enumerate(kf.split(df_massspecgym)):
    print(f"\nFold {fold+1}/5")
    train_data = df_massspecgym.iloc[train_idx]
    val_data = df_massspecgym.iloc[val_idx]
    ssl_data = train_data.sample(frac=0.3, random_state=42)

    train_dataset = MSMSDataset(train_data, max_len=SUPERVISED_MAX_LEN, is_ssl=False)
    val_dataset = MSMSDataset(val_data, max_len=SUPERVISED_MAX_LEN, is_ssl=False)
    ssl_dataset = MSMSDataset(ssl_data, max_len=PRETRAIN_MAX_LEN, is_ssl=True)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=32, num_workers=2)
    ssl_loader = DataLoader(ssl_dataset, batch_size=128, shuffle=True, num_workers=2)

    # Hyperparameter tuning
    study = optuna.create_study(direction='minimize')
    study.optimize(lambda trial: objective(trial, train_data, val_data), n_trials=10)
    best_lr = study.best_params['lr']
    print(f"Best learning rate for fold {fold+1}: {best_lr:.6f}")

    # Initialize and train model
    model = MSMS2SmilesHybrid(vocab_size=vocab_size, d_model=768, nhead=12, num_layers=8, dim_feedforward=2048, dropout=0.2, fp_size=2048).to(device)
    print(f"Starting SSL pretraining for fold {fold+1}...")
    ssl_pretrain(model, ssl_loader, epochs=3, lr=best_lr)
    print(f"Starting supervised training for fold {fold+1}...")
    best_val_loss = supervised_train(model, train_loader, val_loader, epochs=30, lr=best_lr, patience=5)
    fold_results.append(best_val_loss)
    torch.save({
        'model_state_dict': model.state_dict(),
        'token_to_idx': token_to_idx,
        'idx_to_token': idx_to_token
    }, f'best_msms_hybrid_fold_{fold+1}.pt')

print(f"Cross-validation results: {fold_results}")
print(f"Average validation loss: {np.mean(fold_results):.4f}")


In [None]:
# External dataset evaluation and visualization
model.eval()
external_metrics = {'tanimoto': [], 'dice': [], 'mcs': [], 'mw_diff': [], 'logp_diff': [], 'substructure': []}
pred_smiles_list = []
true_smiles_list = []
adducts_list = []
num_samples = min(5, len(external_dataset))

for sample_idx in range(num_samples):
    sample_spectrum = external_dataset[sample_idx][0]
    sample_graph = external_dataset[sample_idx][1]
    sample_ion_mode = external_dataset[sample_idx][3]
    sample_precursor_bin = external_dataset[sample_idx][4]
    sample_adduct_idx = external_dataset[sample_idx][5]
    true_smiles = external_dataset[sample_idx][6]

    predicted_results = beam_search(model, sample_spectrum, sample_graph, sample_ion_mode, sample_precursor_bin, sample_adduct_idx, true_smiles, beam_width=10, max_len=SUPERVISED_MAX_LEN, device=device)
    pred_smiles_list.extend([smiles for smiles, _ in predicted_results])
    true_smiles_list.extend([true_smiles] * len(predicted_results))
    adducts_list.extend([df_external.iloc[sample_idx]['adduct']] * len(predicted_results))

    print(f"\nExternal Sample {sample_idx} - True SMILES: {true_smiles}")
    print("Top Predicted SMILES:")
    for smiles, confidence in predicted_results[:3]:
        external_metrics['tanimoto'].append(tanimoto_similarity(smiles, true_smiles, all_fingerprints))
        external_metrics['dice'].append(dice_similarity(smiles, true_smiles))
        external_metrics['mcs'].append(mcs_similarity(smiles, true_smiles))
        external_metrics['mw_diff'].append(mw_difference(smiles, true_smiles))
        external_metrics['logp_diff'].append(logp_difference(smiles, true_smiles))
        external_metrics['substructure'].append(substructure_match(smiles, true_smiles, model.gnn_encoder.substructures))
        print(f"SMILES: {smiles}, Confidence: {confidence:.4f}, Tanimoto: {external_metrics['tanimoto'][-1]:.4f}, Dice: {external_metrics['dice'][-1]:.4f}, MCS: {external_metrics['mcs'][-1]:.4f}")
        if len(smiles) > 100 and smiles.count('C') > len(smiles) * 0.8:
            print("Warning: Predicted SMILES is a long carbon chain, indicating potential model underfitting.")
        if smiles != "Invalid SMILES":
            mol = Chem.MolFromSmiles(smiles, sanitize=True)
            if mol:
                print(f"Molecular Weight: {Descriptors.MolWt(mol):.2f}, LogP: {Descriptors.MolLogP(mol):.2f}")

    # Visualize molecules
    if predicted_results[0][0] != "Invalid SMILES":
        pred_mol = Chem.MolFromSmiles(predicted_results[0][0], sanitize=True)
        true_mol = Chem.MolFromSmiles(true_smiles, sanitize=True)
        if pred_mol and true_mol:
            img = Draw.MolsToGridImage([true_mol, pred_mol], molsPerRow=2, subImgSize=(300, 300), legends=['True', 'Predicted'])
            img_array = np.array(img.convert('RGB'))
            plt.figure(figsize=(10, 5))
            plt.imshow(img_array)
            plt.axis('off')
            plt.title(f"External Sample {sample_idx} - Tanimoto: {external_metrics['tanimoto'][0]:.4f}")
            plt.show()

    # Visualize attention and GNN weights for first sample
    if sample_idx == 0:
        with torch.no_grad():
            spectrum = sample_spectrum.unsqueeze(0).to(device)
            graph_data = Batch.from_data_list([sample_graph]).to(device)
            ion_mode_idx = torch.tensor([sample_ion_mode], dtype=torch.long).to(device)
            precursor_idx = torch.tensor([sample_precursor_bin], dtype=torch.long).to(device)
            adduct_idx = torch.tensor([sample_adduct_idx], dtype=torch.long).to(device)
            _, attn_weights = model.transformer_encoder(spectrum, ion_mode_idx, precursor_idx, adduct_idx)
            _, _, edge_weights = model.gnn_encoder(graph_data, ion_mode_idx, precursor_idx, adduct_idx)
            plot_attention_weights(attn_weights, title=f"External Fold Transformer Attention Weights")
            plot_gnn_edge_weights(edge_weights, sample_graph.edge_index, title=f"External Fold GNN Edge Importance")

# Final Evaluation
print(f"External Validity Rate: {validity_rate(pred_smiles_list):.2f}%")
print(f"External Prediction Diversity: {prediction_diversity(pred_smiles_list):.4f}")
print("External Metrics Summary:")
print(f"Avg Tanimoto: {np.mean(external_metrics['tanimoto']):.4f}")
print(f"Avg Dice: {np.mean(external_metrics['dice']):.4f}")
print(f"Avg MCS: {np.mean(external_metrics['mcs']):.4f}")
print(f"Avg MW Difference: {np.mean([x for x in external_metrics['mw_diff'] if x != float('inf')]):.2f}")
print(f"Avg LogP Difference: {np.mean([x for x in external_metrics['logp_diff'] if x != float('inf')]):.2f}")
print(f"Avg Substructure Match: {np.mean(external_metrics['substructure']):.4f}")
error_analysis(pred_smiles_list, true_smiles_list, adducts_list, all_fingerprints)
