In [None]:
#!pip install torch torchvision torchaudio rdkit datasets tokenizers tqdm

In [None]:

#final_version
# stereochemistry_fixed

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.data import Data, Batch
from datasets import load_dataset
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import Draw, Descriptors, rdFMCS, EnumerateStereoisomers
from rdkit import DataStructs
from rdkit.Chem import rdFingerprintGenerator
from tqdm import tqdm
import math
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
from torch.cuda.amp import GradScaler, autocast
import optuna
from nltk.translate.bleu_score import sentence_bleu
from Levenshtein import distance
from pandarallel import pandarallel #Added by Pawan
import time  #Added by Pawan
import os #Added by Pawan
import glob #Added by Pawan
%matplotlib inline

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Define token variables early
PAD_TOKEN = "<PAD>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"
MASK_TOKEN = "[MASK]"

In [None]:
# Load and preprocess dataset
# - Uses Hugging Face's `load_dataset` function (from `datasets` library) to fetch the "MassSpecGym" dataset 
#   uploaded by user 'roman-bushuiev' on the Hugging Face Hub.
# - The `split='val'` argument tells Hugging Face to load only the validation split, 
#   which is predefined by the dataset creator and stored in the dataset's metadata/split files.
# - The split boundaries (which examples belong to 'val') are not determined here; 
#   they come from the dataset repository's configuration on the Hugging Face Hub.
# - The returned object is a `Dataset` object, which we then convert to a pandas DataFrame for easier manipulation.
dataset = load_dataset('roman-bushuiev/MassSpecGym', split='val')
df = pd.DataFrame(dataset)

In [None]:
# Simulate an external dataset (e.g., NIST-like) by doing a manual 90/10 positional split
# NOTE: The original DataFrame (loaded from Hugging Face "val" split) contains a `fold` column
#       with the dataset creator‚Äôs own 'train', 'val', and 'test' labels, but these are NOT used here.
#       This manual split ignores those fold labels and simply splits by row position:
#         - df_massspecgym: first 90% of rows (position 0 up to position 0.9 * total number of rows)
#         - df_external: last 10% of rows (position 0.9 * total number of rows to the end)
#       As a result, both subsets may contain a mix of original fold labels.
df_massspecgym, df_external = df.iloc[:int(0.9*len(df))], df.iloc[int(0.9*len(df)):]
print("MassSpecGym size:", len(df_massspecgym), "External test size:", len(df_external))


In [None]:
# Inspect dataset
print("Dataset Columns:", df_massspecgym.columns.tolist())
print("\nFirst few rows of MassSpecGym dataset:")
print(df_massspecgym[['identifier', 'mzs', 'intensities', 'smiles', 'adduct', 'precursor_mz']].head())
print("\nUnique adduct values:", df_massspecgym['adduct'].unique())


In [None]:
# Canonicalize SMILES and augment
# Takes a SMILES string as input and attempts to return a canonicalized version.
# 
# Purpose:
# --------
# - Standardize SMILES representations so that different notations for the same molecule
#   become identical (e.g., "OC" and "CO" both ‚Üí "CO").
# - Remove invalid SMILES from the dataset by returning None when they cannot be parsed.
#
# Steps:
# ------
# 1. Parse SMILES into an RDKit Mol object:
#    mol = Chem.MolFromSmiles(smiles, sanitize=True)
#    - `Chem.MolFromSmiles` converts the SMILES string into RDKit's internal molecule object.
#    - `sanitize=True` means RDKit will run chemical sanitization:
#         * Assign aromaticity
#         * Set atom hybridizations
#         * Verify valences
#         * Detect and store stereochemistry
#      If the SMILES is chemically invalid or violates valence/aromaticity rules, 
#      this will fail (return None or raise an exception).
#
# 2. Check if parsing succeeded:
#    if mol:
#    - In Python, an RDKit Mol object is truthy, while None is falsy.
#    - If mol is None (invalid SMILES), skip canonicalization and return None.
#
# 3. Convert Mol object back to SMILES in canonical form:
#    return Chem.MolToSmiles(mol, canonical=True)
#    - `Chem.MolToSmiles` generates a SMILES string from the Mol object.
#    - `canonical=True` ensures that RDKit outputs a unique, standardized SMILES
#      for each molecule regardless of atom order in the original input.
#      This is the step that enables deduplication based on chemical structure.
#
# 4. Handle failures:
#    - If any error occurs in parsing or canonicalization (caught by `except:`),
#      return None to mark the SMILES as invalid.
#
# Return values:
# --------------
# - Valid SMILES ‚Üí canonicalized SMILES string.
# - Invalid SMILES or errors ‚Üí None.
def canonicalize_smiles(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        if mol:
            return Chem.MolToSmiles(mol, canonical=True)
        return None
    except:
        return None

# Data augmentation: SMILES enumeration and spectral noise
# This function attempts to augment a SMILES string by enumerating all possible stereoisomers
# and returning randomized canonical SMILES for each.
# - Step 1: Convert SMILES string into an RDKit Mol object using Chem.MolFromSmiles().
#           * If the SMILES is valid ‚Üí returns a Mol object (truthy in Python).
#           * If the SMILES is invalid (e.g., "C1CC" ‚Äî ring not closed properly) ‚Üí returns None (falsy).
# - Step 2: If `mol` is valid (not None), enumerate stereoisomers with RDKit's EnumerateStereoisomers function
#           and convert each stereoisomer back to SMILES (canonical form, randomized atom order).
# - Step 3: If `mol` is None (invalid SMILES), skip augmentation and just return the original SMILES in a list.
# - Step 4: If any exception occurs anywhere in the try block (e.g., during stereoisomer enumeration),
#           the except block will also return the original SMILES unchanged.
# In short:
#   * Valid SMILES ‚Üí converted to Mol ‚Üí augmented list of SMILES.
#   * Invalid SMILES ‚Üí returned unchanged.
#   * Any error ‚Üí returned unchanged.
def augment_smiles(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            stereoisomers = EnumerateStereoisomers.EnumerateStereoisomers(mol)
            return [Chem.MolToSmiles(m, canonical=True, doRandom=True) for m in stereoisomers]
        return [smiles]
    except:
        return [smiles]


# Spectrum ‚Üí graph featurization with simple binning + chain edges
# This function converts a (m/z, intensity) spectrum into:
#   1) A fixed-length binned vector (length = n_bins) normalized to [0, 1]
#   2) A PyTorch Geometric graph (Data) whose nodes are bins and edges connect adjacent bins
#
# Inputs:
#   - mzs, intensities: iterables of equal length containing peak positions (m/z) and their intensities
#   - ion_mode: numeric encoding of polarity (e.g., +1 for positive, -1 for negative); passed as a scalar feature
#   - precursor_mz: the precursor mass/charge; passed as a scalar feature
#   - adduct: string adduct label (e.g., "[M+H]+"); mapped to an index via `adduct_to_idx`
#   - n_bins: number of equal-width bins spanning [0, max_mz)
#   - max_mz: upper m/z bound for binning; peaks with m/z >= max_mz are ignored
#   - noise_level: std-dev of zero-mean Gaussian noise added to the spectrum AFTER normalization
#
# Processing overview:
#   1) Initialize zero vector `spectrum` of length n_bins
#   2) For each (mz, intensity):
#        * cast to float (skip if invalid)
#        * if mz < max_mz, compute bin index as floor((mz / max_mz) * n_bins) and accumulate intensity
#   3) Normalize `spectrum` by its max value (if any nonzero content exists)
#   4) Add Gaussian noise ~ N(0, noise_level) per bin; the noise itself is clipped to [0, 1] before addition
#        NOTE: This clips the NOISE array, not the final spectrum. If you want the final spectrum in [0,1],
#              you may also clip `spectrum` afterward (not done here to preserve original behavior).
#   5) Build PyG node features `x`: shape [n_bins, 1], one scalar per bin (the binned intensity)
#   6) Build `edge_index` as a bidirectional chain: 0‚Üî1‚Üî2‚Üî...‚Üî(n_bins-1)
#   7) Package auxiliary scalar features (ion_mode, precursor_mz) as 1D tensors and look up adduct index
#
# Returns:
#   - spectrum: the final NumPy array of length n_bins (after normalization + noise)
#   - Data(...): PyG graph with:
#        * x: [n_bins, 1] float tensor of node features
#        * edge_index: [2, 2*(n_bins-1)] long tensor of undirected chain edges
#        * ion_mode: [1] float tensor (aux scalar)
#        * precursor_mz: [1] float tensor (aux scalar)
#        * adduct_idx: Python int (category id from `adduct_to_idx`)
#
# Prereqs in your environment:
#   - `from torch_geometric.data import Data`
#   - a dict `adduct_to_idx` mapping adduct strings to integer indices
def bin_spectrum_to_graph(mzs, intensities, ion_mode, precursor_mz, adduct,
                          n_bins=1000, max_mz=1000, noise_level=0.05):
    # 1) Initialize empty spectrum
    spectrum = np.zeros(n_bins)

    # 2) Bin peaks by m/z (accumulating intensities into bins)
    for mz, intensity in zip(mzs, intensities):
        try:
            mz = float(mz)
            intensity = float(intensity)
            if mz < max_mz:
                # Map m/z ‚àà [0, max_mz) to integer bin ‚àà [0, n_bins)
                bin_idx = int((mz / max_mz) * n_bins)
                spectrum[bin_idx] += intensity
        except (ValueError, TypeError):
            # Skip malformed values
            continue

    # 3) Normalize to [0, 1] by the maximum nonzero intensity (if any)
    if spectrum.max() > 0:
        spectrum = spectrum / spectrum.max()

    # 4) Add clipped Gaussian noise to the spectrum
    #    NOTE: np.random.normal(...).clip(0, 1) clips the NOISE itself to [0, 1] before adding.
    #          This preserves shape but allows upward-only perturbations when noise is positive.
    spectrum += np.random.normal(0, noise_level, spectrum.shape).clip(0, 1)

    # 5) Node features: one scalar per bin ‚Üí shape [n_bins, 1]
    x = torch.tensor(spectrum, dtype=torch.float).unsqueeze(-1)

    # 6) Chain graph connectivity: bidirectional edges between adjacent bins
    edge_index = []
    for i in range(n_bins - 1):
        edge_index.append([i, i + 1])
        edge_index.append([i + 1, i])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t()  # shape [2, 2*(n_bins-1)]

    # 7) Auxiliary scalar features + adduct category index
    ion_mode = torch.tensor([ion_mode], dtype=torch.float)
    precursor_mz = torch.tensor([precursor_mz], dtype=torch.float)
    adduct_idx = adduct_to_idx.get(adduct, 0)  # default to 0 if unseen adduct

    return spectrum, Data(
        x=x,
        edge_index=edge_index,
        ion_mode=ion_mode,
        precursor_mz=precursor_mz,
        adduct_idx=adduct_idx
    )


In [None]:


pandarallel.initialize(nb_workers=16, progress_bar=True) #Added by Pawan
start_time = time.time()

df_massspecgym['smiles'] = df_massspecgym['smiles'].parallel_apply(canonicalize_smiles) # Added by Pawan
df_external['smiles'] = df_external['smiles'].parallel_apply(canonicalize_smiles) #Added by Pawan
df_massspecgym = df_massspecgym.dropna(subset=['smiles']) 
df_external = df_external.dropna(subset=['smiles'])
df_massspecgym['smiles_list'] = df_massspecgym['smiles'].parallel_apply(augment_smiles)
df_massspecgym = df_massspecgym.explode('smiles_list').dropna(subset=['smiles_list'])
df_massspecgym = df_massspecgym.drop(columns=['smiles']) # Drop original 'smiles' to prevent duplicates; added by Pawan
df_massspecgym = df_massspecgym.rename(columns={'smiles_list': 'smiles'}) # Rename exploded list column to 'smiles'; added by Pawan
df_massspecgym.to_parquet("df_massspecgym.parquet")
df_external.to_parquet("df_external.parquet")
print("Completed in {:.2f} seconds".format(time.time() - start_time)) #Added by Pawan


In [None]:
print(df_massspecgym.shape) #Added by Pawan
print(df_massspecgym.index.nunique()) #Added by Pawan
print(df_external.shape) #Added by Pawan
print(df_external.index.nunique()) #Added by Pawan

In [None]:
df_massspecgym.reset_index(drop=True, inplace=True) #Added by Pawan

In [None]:
print(df_massspecgym.shape) #Added by Pawan
print(df_massspecgym.index.nunique()) #Added by Pawan
print(df_external.shape) #Added by Pawan
print(df_external.index.nunique()) #Added by Pawan

In [None]:
df_massspecgym.head(5)

In [None]:
#Preprocess ion mode, precursor m/z, and adducts
#Setup and Load
import os
import time
import pandas as pd
from pandarallel import pandarallel
import pyarrow as pa

pandarallel.initialize(nb_workers=16, progress_bar=True)

# Load datasets
df_massspecgym = pd.read_parquet("df_massspecgym.parquet")
df_external = pd.read_parquet("df_external.parquet")

# Build adduct mapping from df_massspecgym
adduct_types = df_massspecgym['adduct'].unique()
adduct_to_idx = {adduct: i for i, adduct in enumerate(adduct_types)}

# Create output directory
os.makedirs("processed_chunks", exist_ok=True)


In [None]:
#Preprocess ion mode, precursor m/z, and adducts
#Define Processing Function
def preprocess_chunk(df_chunk, chunk_idx):
    start_time = time.time()

    df_chunk['ion_mode'] = df_chunk['adduct'].parallel_apply(
        lambda x: 0 if '+' in str(x) else 1 if '-' in str(x) else 0
    ).fillna(0)

    df_chunk['precursor_bin'] = pd.qcut(
        df_chunk['precursor_mz'], q=100, labels=False, duplicates='drop'
    )

    df_chunk['adduct_idx'] = df_chunk['adduct'].map(adduct_to_idx)

    df_chunk[['binned', 'graph_data']] = df_chunk.parallel_apply(
        lambda row: pd.Series(bin_spectrum_to_graph(
            row['mzs'], row['intensities'], row['ion_mode'],
            row['precursor_mz'], row['adduct']
        )),
        axis=1
    )

    # Drop graph_data column before saving to avoid pyarrow error
    df_chunk.drop(columns=['graph_data'], inplace=True)

    df_chunk.to_parquet(f"processed_chunks/df_massspecgym_chunk_{chunk_idx:03}.parquet")
    print(f"‚úÖ Saved chunk {chunk_idx} | Rows: {len(df_chunk)} | Time: {time.time() - start_time:.2f} sec")


In [None]:
#Preprocess ion mode, precursor m/z, and adducts
#Chunk df_massspecgym
chunk_size = 100_000
n_chunks = (len(df_massspecgym) + chunk_size - 1) // chunk_size

for i in range(n_chunks):
    output_file = f"processed_chunks/df_massspecgym_chunk_{i:03}.parquet"
    if os.path.exists(output_file):
        print(f"‚è© Skipping chunk {i} (already exists)")
        continue

    df_chunk = df_massspecgym.iloc[i * chunk_size : (i + 1) * chunk_size].copy()
    preprocess_chunk(df_chunk, i)


In [None]:
# ‚úÖ Preprocess ion mode, precursor m/z, and adducts
# ‚úÖ Process each chunk efficiently ‚Üí write to SSD ‚Üí move to external HDD

import pandas as pd
import pickle
import glob
import gc
import os
import shutil
from pandarallel import pandarallel
from tqdm import tqdm
import time

start_time = time.time()
pandarallel.initialize(nb_workers=8, progress_bar=False)

# ‚úÖ External HDD target directory (Seagate)
external_dir = "/media/onepaw/seagate_manual/graph_data_chunks"
os.makedirs(external_dir, exist_ok=True)

# ‚úÖ Temporary SSD write directory
temp_dir = "graph_data_tmp"
os.makedirs(temp_dir, exist_ok=True)

# ‚úÖ Load adduct mapping
df_massspecgym = pd.read_parquet("df_massspecgym.parquet", columns=["adduct"])
adduct_types = df_massspecgym['adduct'].unique()
adduct_to_idx = {adduct: i for i, adduct in enumerate(adduct_types)}
del df_massspecgym

# ‚úÖ Process each chunk from SSD
chunk_files = sorted(glob.glob("processed_chunks/df_massspecgym_chunk_*.parquet"))

for i, chunk_file in enumerate(tqdm(chunk_files, desc="Processing chunks")):
    df = pd.read_parquet(chunk_file)

    graph_data = df.parallel_apply(
        lambda row: bin_spectrum_to_graph(
            row['mzs'], row['intensities'], row['ion_mode'],
            row['precursor_mz'], row['adduct']
        )[1],
        axis=1
    )

    # ‚úÖ Save to SSD first (fast write)
    temp_path = os.path.join(temp_dir, f"graph_data_chunk_{i:03}.pkl")
    with open(temp_path, "wb") as f:
        pickle.dump(graph_data.tolist(), f)

    # ‚úÖ Then move to external HDD to free SSD space
    final_path = os.path.join(external_dir, f"graph_data_chunk_{i:03}.pkl")
    shutil.move(temp_path, final_path)

    del df
    del graph_data
    gc.collect()

    print(f"‚úÖ Processed and moved: graph_data_chunk_{i:03}.pkl")

print("üéâ All chunks processed and saved to external drive.")
print("üïí Completed in {:.2f} seconds".format(time.time() - start_time))


In [None]:
# Stream-load each chunk and write them one by one to HDD to avoid RAM issues
import pickle
import glob
import os
from tqdm import tqdm

# ‚úÖ Directory where pickled graph_data chunks are stored (on external HDD)
external_dir = "/media/onepaw/seagate_manual/graph_data_chunks"
chunk_files = sorted(glob.glob(os.path.join(external_dir, "graph_data_chunk_*.pkl")))

# ‚úÖ Output path on external HDD
merged_path = "/media/onepaw/seagate_manual/df_massspecgym_graph_data_streamed.pkl"

# ‚úÖ Open final output in append-binary mode
with open(merged_path, "wb") as out_f:
    for chunk_file in tqdm(chunk_files, desc="Merging (streamed)"):
        with open(chunk_file, "rb") as in_f:
            data = pickle.load(in_f)
            # Stream-write this chunk to output
            pickle.dump(data, out_f)
            del data  # ensure chunk gets garbage collected

print(f"‚úÖ Streamed merge completed to: {merged_path}")


In [None]:
#Preprocess ion mode, precursor m/z, and adducts
#Full Processing of df_external
import time
import pickle

start_time = time.time()

# Compute ion_mode
df_external['ion_mode'] = df_external['adduct'].parallel_apply(
    lambda x: 0 if '+' in str(x) else 1 if '-' in str(x) else 0
).fillna(0)

# Compute precursor_bin
df_external['precursor_bin'] = pd.qcut(
    df_external['precursor_mz'], q=100, labels=False, duplicates='drop'
)

# Map adduct to index
df_external['adduct_idx'] = df_external['adduct'].map(adduct_to_idx)

# Generate binned and graph_data columns
df_external[['binned', 'graph_data']] = df_external.parallel_apply(
    lambda row: pd.Series(bin_spectrum_to_graph(
        row['mzs'], row['intensities'], row['ion_mode'],
        row['precursor_mz'], row['adduct']
    )),
    axis=1
)

# üîí Save graph_data separately (optional)
with open("df_external_graph_data.pkl", "wb") as f:
    pickle.dump(df_external['graph_data'].tolist(), f)

# ‚ùå Drop graph_data column before saving to parquet
df_external.drop(columns=['graph_data'], inplace=True)

# ‚úÖ Save remaining data to Parquet
df_external.to_parquet("df_external_processed.parquet")

print("‚úÖ df_external processed and saved in {:.2f} seconds".format(time.time() - start_time))


In [None]:
#Preprocess ion mode, precursor m/z, and adducts
import pandas as pd
import glob
import pyarrow as pa
import pyarrow.parquet as pq
import time

output_path = "df_massspecgym_processed_full.parquet"
chunk_files = sorted(glob.glob("processed_chunks/df_massspecgym_chunk_*.parquet"))

first_df = pd.read_parquet(chunk_files[0])
first_df['precursor_bin'] = first_df['precursor_bin'].fillna(-1).astype('int64')

table = pa.Table.from_pandas(first_df)
writer = pq.ParquetWriter(output_path, table.schema)

writer.write_table(table)
print(f"‚úÖ Wrote chunk 0 / {len(chunk_files)}")

for i, file in enumerate(chunk_files[1:], start=1):
    start = time.time()
    df = pd.read_parquet(file)

    df['precursor_bin'] = df['precursor_bin'].fillna(-1).astype('int64')
    df = df[first_df.columns]

    table = pa.Table.from_pandas(df)
    writer.write_table(table)
    print(f"‚úÖ Wrote chunk {i} / {len(chunk_files)} in {time.time() - start:.2f}s")

writer.close()
print(f"üéâ Done merging {len(chunk_files)} chunks ‚ûú {output_path}")


In [None]:
# SMILES Tokenization with Stereochemistry (Final Corrected Version)
import pyarrow.parquet as pq
import time
import re

# Special tokens
PAD_TOKEN = "<PAD>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"
MASK_TOKEN = "[MASK]"

# ‚úÖ FIX: Define a proper RegEx tokenizer for SMILES that handles multi-character elements
SMILES_TOKENIZER_PATTERN =  r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
smiles_regex = re.compile(SMILES_TOKENIZER_PATTERN)

def smiles_tokenizer(smiles):
    """Tokenize a SMILES string using the regular expression."""
    return [token for token in smiles_regex.findall(smiles)]

# Open the large Parquet file
parquet_file = pq.ParquetFile("df_massspecgym_processed_full.parquet")

start = time.time()
# Step 1: Build vocabulary from all SMILES *tokens*
all_tokens = set()
for i in range(parquet_file.num_row_groups):
    table = parquet_file.read_row_group(i, columns=["smiles"])
    df = table.to_pandas()
    for smiles in df['smiles'].dropna().astype(str):
        all_tokens.update(smiles_tokenizer(smiles))

# ‚úÖ FIX: Create the vocabulary from the tokenized list
special_tokens = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, MASK_TOKEN]
tokens = special_tokens + sorted(list(all_tokens))

# This now creates a correct, contiguous mapping from 0 to vocab_size-1
token_to_idx = {tok: i for i, tok in enumerate(tokens)}
idx_to_token = {i: tok for tok, i in token_to_idx.items()}
vocab_size = len(tokens)

# Determine supervised max length (based on number of TOKENS, not characters)
SUPERVISED_MAX_LEN = 0
for i in range(parquet_file.num_row_groups):
    table = parquet_file.read_row_group(i, columns=["smiles"])
    df = table.to_pandas()
    if not df.empty and 'smiles' in df.columns and not df['smiles'].dropna().empty:
        max_len = max(len(smiles_tokenizer(s)) for s in df['smiles'].dropna().astype(str))
        SUPERVISED_MAX_LEN = max(SUPERVISED_MAX_LEN, max_len + 2) # +2 for SOS and EOS

PRETRAIN_MAX_LEN = 100

print(f"‚úÖ Vocabulary size: {vocab_size}, Supervised MAX_LEN: {SUPERVISED_MAX_LEN}, Pretrain MAX_LEN: {PRETRAIN_MAX_LEN}")
print("Sample of token_to_idx to verify 'Cl' exists:", {k: v for k, v in token_to_idx.items() if 'Cl' in k or 'Br' in k})
print(f"Completed in {time.time() - start:.2f}s")


# Step 2: Define the NEW encoder function using the tokenizer
def encode_smiles(smiles, max_len=PRETRAIN_MAX_LEN):
    tokenized_smiles = smiles_tokenizer(smiles)
    tokens_with_specials = [SOS_TOKEN] + tokenized_smiles[:max_len - 2] + [EOS_TOKEN]
    
    # Use .get() with a default for any unknown tokens (though unlikely with regex)
    token_ids = [token_to_idx.get(tok, token_to_idx[PAD_TOKEN]) for tok in tokens_with_specials]
    
    # Padding
    if len(token_ids) < max_len:
        token_ids += [token_to_idx[PAD_TOKEN]] * (max_len - len(token_ids))
    
    return token_ids[:max_len]

In [None]:
# FINAL PREPROCESSING STEP: Pre-tokenize all SMILES strings
import pandas as pd
import glob
from tqdm import tqdm
import os

# Make sure pandarallel is initialized for this heavy task
pandarallel.initialize(nb_workers=16, progress_bar=True)

print("--- Starting SMILES pre-tokenization ---")
processed_chunk_dir = "processed_chunks"
chunk_files = sorted(glob.glob(os.path.join(processed_chunk_dir, "*.parquet")))

if not chunk_files:
    raise FileNotFoundError("No processed chunk files found. Please run the earlier preprocessing steps.")

for chunk_path in tqdm(chunk_files, desc="Pre-tokenizing SMILES chunks"):
    df = pd.read_parquet(chunk_path)
    
    # Check if this step has already been done to avoid re-running
    if 'token_ids' in df.columns:
        print(f"Skipping {os.path.basename(chunk_path)}, already tokenized.")
        continue

    # Apply the encode_smiles function to create the new column
    df['token_ids'] = df['smiles'].parallel_apply(lambda s: encode_smiles(s, max_len=SUPERVISED_MAX_LEN))
    
    # Overwrite the chunk file with the new tokenized version
    df.to_parquet(chunk_path)

print("‚úÖ All chunks have been pre-tokenized and saved.")

In [None]:
# Precompute Morgan fingerprints
import pandas as pd
import pyarrow.parquet as pq
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator
from joblib import Parallel, delayed
import pickle
import time

start = time.time()

# Load df_massspecgym from large Parquet file (only SMILES column)
massspec_parquet = pq.ParquetFile("df_massspecgym_processed_full.parquet")
df_massspecgym = pd.concat([
    massspec_parquet.read_row_group(i, columns=["smiles"]).to_pandas()
    for i in range(massspec_parquet.num_row_groups)
], ignore_index=True)

# Load df_external from smaller Parquet file
df_external = pd.read_parquet("df_external_processed.parquet")

# Combine and deduplicate SMILES
all_smiles = list(set(df_massspecgym['smiles'].dropna().tolist() + df_external['smiles'].dropna().tolist()))

# Function that avoids unpicklable generator
def fingerprint_one(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            generator = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048)
            return smiles, generator.GetFingerprint(mol)
    except Exception as e:
        print(f"Failed: {smiles} ‚Üí {e}")
    return smiles, None

# Run parallel computation with 12 CPUs
results = Parallel(n_jobs=12, verbose=5)(
    delayed(fingerprint_one)(s) for s in all_smiles
)

# Collect into dictionary
all_fingerprints = {s: fp for s, fp in results if fp is not None}

# Save to file
with open("all_morgan_fingerprints.pkl", "wb") as f:
    pickle.dump(all_fingerprints, f)

print(f"‚úÖ Done in {time.time() - start:.2f}s ‚Äî {len(all_fingerprints)} fingerprints saved to all_morgan_fingerprints.pkl")


In [None]:
# MSMSDataset Class (Final Optimized Version)
import pandas as pd
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data

class MSMSDataset(Dataset):
    def __init__(self, dataframe, graph_data_list):
        self.df = dataframe.reset_index(drop=True)
        self.graph_data_list = graph_data_list

        if len(self.df) != len(self.graph_data_list):
            raise ValueError(f"DataFrame length ({len(self.df)}) and graph_data_list length ({len(self.graph_data_list)}) must match.")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        spectrum = torch.tensor(row["binned"], dtype=torch.float)
        graph = self.graph_data_list[idx]

        if not isinstance(graph, Data):
            if isinstance(graph, tuple) and len(graph) > 0 and isinstance(graph[0], Data):
                graph = graph[0]
            else:
                # Create a placeholder for corrupted data to prevent crashes
                graph = Data(x=torch.zeros((1, 1), dtype=torch.float), edge_index=torch.empty((2, 0), dtype=torch.long))

        # Directly fetch the pre-tokenized list of integers. This is extremely fast.
        smiles_tensor = torch.tensor(row["token_ids"], dtype=torch.long)
        
        # Handle other potential NaN values safely
        precursor_bin_val = 0 if pd.isna(row["precursor_bin"]) else int(row["precursor_bin"])
        adduct_idx_val = 0 if pd.isna(row["adduct_idx"]) else int(row["adduct_idx"])
        
        ion_mode = torch.tensor(row["ion_mode"], dtype=torch.long)
        precursor_bin = torch.tensor(precursor_bin_val, dtype=torch.long)
        adduct_idx = torch.tensor(adduct_idx_val, dtype=torch.long)
        raw_smiles = row["smiles"]
            
        return (spectrum, graph, smiles_tensor, ion_mode, precursor_bin, adduct_idx, raw_smiles)

In [None]:
# 1. Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

In [None]:
# 2. Transformer Encoder
class SpectrumTransformerEncoder(nn.Module):
    # ‚úÖ FIX: Added num_adducts to the constructor
    def __init__(self, num_adducts, input_dim=1000, d_model=768, nhead=12, num_layers=8, dim_feedforward=2048, dropout=0.2):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, d_model)
        self.metadata_emb = nn.Linear(2 + 32, 64)
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.norm = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model + 64, d_model // 2)
        # ‚úÖ FIX: Use the num_adducts argument
        self.adduct_emb = nn.Embedding(num_adducts, 32)

    def forward(self, src, ion_mode_idx, precursor_idx, adduct_idx):
        src = self.input_proj(src).unsqueeze(1)
        adduct_embed = self.adduct_emb(adduct_idx)
        metadata = self.metadata_emb(torch.cat([ion_mode_idx.unsqueeze(-1).float(), precursor_idx.unsqueeze(-1).float(), adduct_embed], dim=-1))
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src).squeeze(1)
        output = self.norm(output)
        output = torch.cat([output, metadata], dim=-1)
        output = self.fc(output)
        # Return None for attn_weights to simplify
        return output, None


In [None]:
# 3. GNN Encoder (Corrected for Missing Attribute)
class SpectrumGNNEncoder(MessagePassing):
    def __init__(self, num_adducts, d_model=768, hidden_dim=256, num_layers=3, dropout=0.2):
        super().__init__(aggr='mean')
        # ‚úÖ FIX: Added the missing line to save num_layers as a class attribute
        self.num_layers = num_layers 
        
        self.input_proj = nn.Linear(1, hidden_dim)
        self.message_nets = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)])
        self.update_nets = nn.ModuleList([nn.GRUCell(hidden_dim, hidden_dim) for _ in range(num_layers)])
        self.metadata_emb = nn.Linear(2 + 32, hidden_dim)
        self.norm = nn.LayerNorm(hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, d_model // 2)
        self.dropout = nn.Dropout(dropout)
        self.adduct_emb = nn.Embedding(num_adducts, 32)

    def forward(self, batch_data, ion_mode_idx, precursor_idx, adduct_idx):
        x, edge_index, batch = batch_data.x, batch_data.edge_index, batch_data.batch
        adduct_embed = self.adduct_emb(adduct_idx)
        metadata_per_graph = self.metadata_emb(torch.cat([
            ion_mode_idx.unsqueeze(-1).float(), precursor_idx.unsqueeze(-1).float(), adduct_embed
        ], dim=-1))
        metadata = metadata_per_graph[batch]
        x = self.input_proj(x)
        h = F.relu(x)
        
        # This loop will now work correctly
        for i in range(self.num_layers):
            self._propagate_layer = i
            m = self.propagate(edge_index, x=h)
            m = m + metadata 
            h = self.update_nets[i](m, h)
            h = self.dropout(h)
            
        pooled_x = global_mean_pool(h, batch)
        pooled_x = self.norm(pooled_x)
        return self.output_layer(pooled_x)

    def message(self, x_j):
        layer_idx = getattr(self, '_propagate_layer', 0)
        return self.message_nets[layer_idx](x_j)

In [None]:
# 4. Transformer Decoder
class SmilesTransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=768, nhead=12, num_layers=8, dim_feedforward=2048, dropout=0.2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.norm = nn.LayerNorm(d_model)
        self.output_layer = nn.Linear(d_model, vocab_size)
        self.d_model = d_model
        
        valence_rules = {'C': 4, 'N': 3, 'O': 2, 'S': 2, 'P': 3, 'F': 1, 'Cl': 1, 'Br': 1, 'I': 1, 'H': 1}
        valence_map = torch.zeros(vocab_size)
        bond_map = torch.zeros(vocab_size)
        for token, idx in token_to_idx.items():
            if token in valence_rules: valence_map[idx] = valence_rules[token]
            elif token == '=': bond_map[idx] = 2
            elif token == '#': bond_map[idx] = 3
            elif token in ['-', '(', ')', '[', ']', '.', ':', '@', '/', '\\']: bond_map[idx] = 1
        self.register_buffer('valence_map', valence_map)
        self.register_buffer('bond_map', bond_map)

    def compute_valence(self, smiles_token_ids):
        atom_valences = self.valence_map[smiles_token_ids]
        bond_connections = self.bond_map[smiles_token_ids]
        total_valence_provided = torch.sum(atom_valences, dim=1)
        total_bonds_formed = torch.sum(bond_connections, dim=1)
        return torch.relu(total_bonds_formed - total_valence_provided)

    def forward(self, tgt, memory, tgt_mask=None, memory_key_padding_mask=None):
        embedded = self.embedding(tgt) * math.sqrt(self.d_model)
        embedded = self.pos_encoder(embedded)
        output = self.transformer_decoder(embedded, memory, tgt_mask, memory_key_padding_mask)
        output = self.norm(output)
        logits = self.output_layer(output)
        valence_penalty = self.compute_valence(tgt)
        return logits, valence_penalty

In [None]:

# 5. Full Hybrid Model
class MSMS2SmilesHybrid(nn.Module):
    # ‚úÖ FIX: Added num_adducts to the constructor
    def __init__(self, vocab_size, num_adducts, d_model=768, nhead=12, num_layers=8, dropout=0.2):
        super().__init__()
        # ‚úÖ FIX: Pass num_adducts to the child modules
        self.transformer_encoder = SpectrumTransformerEncoder(num_adducts=num_adducts, d_model=d_model, nhead=nhead, num_layers=num_layers, dropout=dropout)
        self.gnn_encoder = SpectrumGNNEncoder(num_adducts=num_adducts, d_model=d_model, num_layers=3, dropout=dropout)
        self.decoder = SmilesTransformerDecoder(vocab_size, d_model, nhead, num_layers, dropout=dropout)
        self.combine_layer = nn.Linear(d_model, d_model)

    def generate_square_subsequent_mask(self, tgt_len):
        mask = torch.triu(torch.ones(tgt_len, tgt_len), diagonal=1)
        mask = mask.float().masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, float(0.0))
        return mask.to(next(self.parameters()).device)

    def forward(self, spectrum, graph_data, tgt, ion_mode_idx, precursor_idx, adduct_idx, tgt_mask=None):
        trans_output, _ = self.transformer_encoder(spectrum, ion_mode_idx, precursor_idx, adduct_idx)
        gnn_output = self.gnn_encoder(graph_data, ion_mode_idx, precursor_idx, adduct_idx)
        memory = self.combine_layer(torch.cat([trans_output, gnn_output], dim=-1)).unsqueeze(1)
        smiles_output, valence_penalty = self.decoder(tgt, memory, tgt_mask)
        return smiles_output, valence_penalty

In [None]:
# SSL Pretraining
def ssl_pretrain(model, dataloader, epochs=3, lr=1e-4):
    model.train()
    scaler = GradScaler()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=token_to_idx[PAD_TOKEN])
    for epoch in range(epochs):
        total_loss = 0
        for spectra, graph_data, smiles_tokens, masked_tokens, ion_modes, precursor_bins, adduct_indices, _ in tqdm(dataloader, desc=f"SSL Epoch {epoch+1}/{epochs}"):
            spectra = spectra.to(device)
            ion_modes = ion_modes.to(device)
            precursor_bins = precursor_bins.to(device)
            adduct_indices = adduct_indices.to(device)
            smiles_tokens = smiles_tokens.to(device)
            masked_tokens = masked_tokens.to(device)
            tgt_input = masked_tokens[:, :-1]
            tgt_output = smiles_tokens[:, 1:]
            tgt_mask = model.generate_square_subsequent_mask(tgt_input.size(1)).to(device)
            optimizer.zero_grad()
            with autocast():
                smiles_output, _, valence_penalty, _, _, _ = model(spectra, graph_data, tgt_input, ion_modes, precursor_bins, adduct_indices, tgt_mask)
                loss = criterion(smiles_output.reshape(-1, vocab_size), tgt_output.reshape(-1)) + 0.1 * valence_penalty.mean()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        print(f"SSL Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss
        }, f'ssl_checkpoint_epoch_{epoch+1}.pt')
        print(f"Saved SSL checkpoint: ssl_checkpoint_epoch_{epoch+1}.pt")


In [None]:
# Supervised Training with RL
def supervised_train(model, train_loader, val_loader, epochs=30, lr=1e-4, patience=5):
    model.train()
    scaler = GradScaler()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
    smiles_criterion = nn.CrossEntropyLoss(ignore_index=token_to_idx[PAD_TOKEN])
    fp_criterion = nn.BCEWithLogitsLoss()
    mw_criterion = nn.MSELoss()
    sub_criterion = nn.BCEWithLogitsLoss()
    best_val_loss = float('inf')
    no_improve = 0

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        for spectra, graph_data, smiles_tokens, ion_modes, precursor_bins, adduct_indices, raw_smiles in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            spectra = spectra.to(device)
            ion_modes = ion_modes.to(device)
            precursor_bins = precursor_bins.to(device)
            adduct_indices = adduct_indices.to(device)
            smiles_tokens = smiles_tokens.to(device)
            tgt_input = smiles_tokens[:, :-1]
            tgt_output = smiles_tokens[:, 1:]
            tgt_mask = model.generate_square_subsequent_mask(tgt_input.size(1)).to(device)
            optimizer.zero_grad()
            with autocast():
                smiles_output, fp_output, valence_penalty, _, _, substructure_pred = model(spectra, graph_data, tgt_input, ion_modes, precursor_bins, adduct_indices, tgt_mask)
                smiles_loss = smiles_criterion(smiles_output.reshape(-1, vocab_size), tgt_output.reshape(-1))
                fp_loss = 0
                mw_loss = 0
                sub_loss = 0
                valid_count = 0
                substructure_targets = torch.zeros(len(raw_smiles), 30, dtype=torch.float, device=device)
                for i, (smiles, fp) in enumerate(zip(raw_smiles, fp_output)):
                    mol = Chem.MolFromSmiles(smiles, sanitize=True)
                    if mol:
                        true_fp = morgan_gen.GetFingerprint(mol)
                        fp_loss += fp_criterion(fp, torch.tensor([int(b) for b in true_fp.ToBitString()], dtype=torch.float, device=device))
                        mw_loss += mw_criterion(torch.tensor(Descriptors.MolWt(mol), dtype=torch.float, device=device), torch.tensor(500.0, dtype=torch.float, device=device))
                        for j, smarts in enumerate(model.gnn_encoder.substructures):
                            if mol.HasSubstructMatch(Chem.MolFromSmarts(smarts)):
                                substructure_targets[i, j] = 1
                        valid_count += 1
                fp_loss = fp_loss / valid_count if valid_count > 0 else torch.tensor(0.0, device=device)
                mw_loss = mw_loss / valid_count if valid_count > 0 else torch.tensor(0.0, device=device)
                sub_loss = sub_criterion(substructure_pred, substructure_targets)
                sigma_smiles = torch.clamp(torch.exp(model.log_sigma_smiles), 0.1, 10.0)
                sigma_fp = torch.clamp(torch.exp(model.log_sigma_fp), 0.1, 10.0)
                sigma_sub = torch.clamp(torch.exp(model.log_sigma_sub), 0.1, 10.0)
                supervised_loss = (smiles_loss / (2 * sigma_smiles**2) + model.log_sigma_smiles) + \
                                 (0.1 * fp_loss / (2 * sigma_fp**2) + model.log_sigma_fp) + \
                                 (0.1 * sub_loss / (2 * sigma_sub**2) + model.log_sigma_sub) + \
                                 0.1 * valence_penalty.mean() + 0.1 * mw_loss
                # RL component: Tanimoto reward
                rl_loss = 0
                if epoch >= 5:  # Start RL after initial training
                    pred_smiles = beam_search(model, spectra[0], graph_data[0], ion_modes[0], precursor_bins[0], adduct_indices[0], raw_smiles[0], beam_width=5, max_len=SUPERVISED_MAX_LEN, device=device)
                    if pred_smiles[0][0] != "Invalid SMILES":
                        tanimoto = tanimoto_similarity(pred_smiles[0][0], raw_smiles[0], all_fingerprints)
                        rl_loss = -torch.log(torch.tensor(tanimoto + 1e-6, device=device))
                loss = supervised_loss + 0.1 * rl_loss
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            total_train_loss += loss.item()
        avg_train_loss = total_train_loss / len(train_loader)

        # Validation
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for spectra, graph_data, smiles_tokens, ion_modes, precursor_bins, adduct_indices, raw_smiles in val_loader:
                spectra = spectra.to(device)
                ion_modes = ion_modes.to(device)
                precursor_bins = precursor_bins.to(device)
                adduct_indices = adduct_indices.to(device)
                smiles_tokens = smiles_tokens.to(device)
                tgt_input = smiles_tokens[:, :-1]
                tgt_output = smiles_tokens[:, 1:]
                tgt_mask = model.generate_square_subsequent_mask(tgt_input.size(1)).to(device)
                with autocast():
                    smiles_output, fp_output, valence_penalty, _, _, substructure_pred = model(spectra, graph_data, tgt_input, ion_modes, precursor_bins, adduct_indices, tgt_mask)
                    smiles_loss = smiles_criterion(smiles_output.reshape(-1, vocab_size), tgt_output.reshape(-1))
                    fp_loss = 0
                    mw_loss = 0
                    sub_loss = 0
                    valid_count = 0
                    substructure_targets = torch.zeros(len(raw_smiles), 30, dtype=torch.float, device=device)
                    for i, (smiles, fp) in enumerate(zip(raw_smiles, fp_output)):
                        mol = Chem.MolFromSmiles(smiles, sanitize=True)
                        if mol:
                            true_fp = morgan_gen.GetFingerprint(mol)
                            fp_loss += fp_criterion(fp, torch.tensor([int(b) for b in true_fp.ToBitString()], dtype=torch.float, device=device))
                            mw_loss += mw_criterion(torch.tensor(Descriptors.MolWt(mol), dtype=torch.float, device=device), torch.tensor(500.0, dtype=torch.float, device=device))
                            for j, smarts in enumerate(model.gnn_encoder.substructures):
                                if mol.HasSubstructMatch(Chem.MolFromSmarts(smarts)):
                                    substructure_targets[i, j] = 1
                            valid_count += 1
                    fp_loss = fp_loss / valid_count if valid_count > 0 else torch.tensor(0.0, device=device)
                    mw_loss = mw_loss / valid_count if valid_count > 0 else torch.tensor(0.0, device=device)
                    sub_loss = sub_criterion(substructure_pred, substructure_targets)
                    sigma_smiles = torch.clamp(torch.exp(model.log_sigma_smiles), 0.1, 10.0)
                    sigma_fp = torch.clamp(torch.exp(model.log_sigma_fp), 0.1, 10.0)
                    sigma_sub = torch.clamp(torch.exp(model.log_sigma_sub), 0.1, 10.0)
                    loss = (smiles_loss / (2 * sigma_smiles**2) + model.log_sigma_smiles) + \
                           (0.1 * fp_loss / (2 * sigma_fp**2) + model.log_sigma_fp) + \
                           (0.1 * sub_loss / (2 * sigma_sub**2) + model.log_sigma_sub) + \
                           0.1 * valence_penalty.mean() + 0.1 * mw_loss
                total_val_loss += loss.item()
        avg_val_loss = total_val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")
        scheduler.step(avg_val_loss)

        if (epoch + 1) % 10 == 0:
            checkpoint_path = f'checkpoint_epoch_{epoch+1}.pt'
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_loss': avg_val_loss,
                'token_to_idx': token_to_idx,
                'idx_to_token': idx_to_token
            }, checkpoint_path)
            print(f"Saved checkpoint: {checkpoint_path}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            no_improve = 0
            torch.save({
                'model_state_dict': model.state_dict(),
                'token_to_idx': token_to_idx,
                'idx_to_token': idx_to_token
            }, 'best_msms_hybrid.pt')
        else:
            no_improve += 1
        if no_improve >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    return best_val_loss

In [None]:
# SMILES Syntax Validator
def is_valid_smiles_syntax(smiles):
    stack = []
    for c in smiles:
        if c in '([':
            stack.append(c)
        elif c == ')':
            if not stack or stack[-1] != '(':
                return False
            stack.pop()
        elif c == ']':
            if not stack or stack[-1] != '[':
                return False
            stack.pop()
    if stack:
        return False
    i = 0
    while i < len(smiles):
        if smiles[i] == '[':
            j = smiles.find(']', i)
            if j == -1:
                return False
            atom = smiles[i+1:j]
            if not any(a in atom for a in valid_atoms):
                return False
            i = j + 1
        else:
            if smiles[i] in valid_atoms or smiles[i] in '()=#/\\@.:':
                i += 1
            else:
                return False
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        return mol is not None
    except:
        return False

In [None]:
# RDKit-based Molecular Property Filter
def is_plausible_molecule(smiles, true_mol, max_mw=1500, min_logp=-7, max_logp=7):
    mol = Chem.MolFromSmiles(smiles, sanitize=True)
    if not mol or not is_valid_smiles_syntax(smiles):
        return False
    mw = Descriptors.MolWt(mol)
    logp = Descriptors.MolLogP(mol)
    true_mw = Descriptors.MolWt(true_mol) if true_mol else 500
    return mw <= max_mw and min_logp <= logp <= max_logp and abs(mw - true_mw) < 300

# Evaluation Metrics
def dice_similarity(smiles1, smiles2):
    mol1 = Chem.MolFromSmiles(smiles1)
    mol2 = Chem.MolFromSmiles(smiles2)
    if mol1 and mol2:
        fp1 = morgan_gen.GetFingerprint(mol1)
        fp2 = morgan_gen.GetFingerprint(mol2)
        return DataStructs.DiceSimilarity(fp1, fp2)
    return 0.0

def mcs_similarity(true_smiles, pred_smiles):
    mol1 = Chem.MolFromSmiles(true_smiles)
    mol2 = Chem.MolFromSmiles(pred_smiles)
    if mol1 and mol2:
        mcs = rdFMCS.FindMCS([mol1, mol2], timeout=30)
        return mcs.numAtoms / max(mol1.GetNumAtoms(), mol2.GetNumAtoms())
    return 0.0

def mw_difference(true_smiles, pred_smiles):
    mol1 = Chem.MolFromSmiles(true_smiles)
    mol2 = Chem.MolFromSmiles(pred_smiles)
    if mol1 and mol2:
        return abs(Descriptors.MolWt(mol1) - Descriptors.MolWt(mol2))
    return float('inf')

def logp_difference(true_smiles, pred_smiles):
    mol1 = Chem.MolFromSmiles(true_smiles)
    mol2 = Chem.MolFromSmiles(pred_smiles)
    if mol1 and mol2:
        return abs(Descriptors.MolLogP(mol1) - Descriptors.MolLogP(mol2))
    return float('inf')

def substructure_match(true_smiles, pred_smiles, substructures):
    mol1 = Chem.MolFromSmiles(true_smiles)
    mol2 = Chem.MolFromSmiles(pred_smiles)
    if not mol1 or not mol2:
        return 0
    matches = 0
    for smarts in substructures:
        pattern = Chem.MolFromSmarts(smarts)
        if mol1.HasSubstructMatch(pattern) and mol2.HasSubstructMatch(pattern):
            matches += 1
    return matches / len(substructures)

def validity_rate(pred_smiles_list):
    valid = sum(1 for smiles in pred_smiles_list if Chem.MolFromSmiles(smiles, sanitize=True) is not None)
    return valid / len(pred_smiles_list) * 100

def tanimoto_similarity(smiles1, smiles2, precomputed_fps=None):
    mol1 = Chem.MolFromSmiles(smiles1, sanitize=True)
    if not mol1:
        return 0.0
    fp1 = morgan_gen.GetFingerprint(mol1)
    if precomputed_fps and smiles2 in precomputed_fps:
        fp2 = precomputed_fps[smiles2]
    else:
        mol2 = Chem.MolFromSmiles(smiles2, sanitize=True)
        if not mol2:
            return 0.0
        fp2 = morgan_gen.GetFingerprint(mol2)
    return DataStructs.TanimotoSimilarity(fp1, fp2)

def prediction_diversity(pred_smiles_list):
    if len(pred_smiles_list) < 2:
        return 0.0
    total_tanimoto = 0
    count = 0
    for i in range(len(pred_smiles_list)):
        for j in range(i+1, len(pred_smiles_list)):
            total_tanimoto += tanimoto_similarity(pred_smiles_list[i], pred_smiles_list[j])
            count += 1
    return 1 - (total_tanimoto / count) if count > 0 else 0.0

In [None]:
# Beam Search with Stereochemistry
def beam_search(model, spectrum, graph_data, ion_mode_idx, precursor_idx, adduct_idx, true_smiles, beam_width=10, max_len=150, nucleus_p=0.9, device='cpu'):
    model.eval()
    true_mol = Chem.MolFromSmiles(true_smiles) if true_smiles else None
    with torch.no_grad():
        spectrum = spectrum.unsqueeze(0).to(device)
        graph_data = Batch.from_data_list([graph_data]).to(device)
        ion_mode_idx = torch.tensor([ion_mode_idx], dtype=torch.long).to(device)
        precursor_idx = torch.tensor([precursor_idx], dtype=torch.long).to(device)
        adduct_idx = torch.tensor([adduct_idx], dtype=torch.long).to(device)
        memory = model.transformer_encoder(spectrum, ion_mode_idx, precursor_idx, adduct_idx)[0]
        gnn_output, substructure_pred, _ = model.gnn_encoder(graph_data, ion_mode_idx, precursor_idx, adduct_idx)
        memory = model.combine_layer(torch.cat([memory, gnn_output], dim=-1)).unsqueeze(1)
        sequences = [([token_to_idx[SOS_TOKEN]], 0.0)]

        for _ in range(max_len):
            all_candidates = []
            for seq, score in sequences:
                if seq[-1] == token_to_idx[EOS_TOKEN]:
                    all_candidates.append((seq, score))
                    continue
                partial_smiles = ''.join([idx_to_token.get(idx, '') for idx in seq[1:]])
                if not is_valid_smiles_syntax(partial_smiles):
                    continue
                tgt_input = torch.tensor([seq], dtype=torch.long).to(device)
                tgt_mask = model.generate_square_subsequent_mask(len(seq)).to(device)
                outputs, valence_penalty = model.decoder(tgt_input, memory, substructure_pred, tgt_mask)
                log_probs = F.log_softmax(outputs[0, -1], dim=-1).cpu().numpy() - 0.1 * valence_penalty.cpu().numpy()
                # Boost stereochemistry tokens
                for tok in ['@', '/']:
                    if tok in token_to_idx:
                        log_probs[token_to_idx[tok]] += 0.5
                sorted_probs = np.sort(np.exp(log_probs))[::-1]
                cumulative_probs = np.cumsum(sorted_probs)
                cutoff_idx = np.searchsorted(cumulative_probs, nucleus_p)
                top_tokens = np.argsort(log_probs)[-cutoff_idx:] if cutoff_idx > 0 else np.argsort(log_probs)[-1:]
                top_probs = np.exp(log_probs[top_tokens]) / np.sum(np.exp(log_probs[top_tokens]))
                for tok in np.random.choice(top_tokens, size=min(beam_width, len(top_tokens)), p=top_probs):
                    new_smiles = partial_smiles + idx_to_token.get(int(tok), '')
                    if is_valid_smiles_syntax(new_smiles):
                        diversity_penalty = 0.2 * sum(1 for s, _ in sequences if tok in s[1:-1])
                        all_candidates.append((seq + [int(tok)], score + log_probs[tok] - diversity_penalty))
            sequences = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]
            if all(seq[-1] == token_to_idx[EOS_TOKEN] for seq, _ in sequences):
                break

        results = []
        for seq, score in sequences:
            smiles = ''.join([idx_to_token.get(idx, '') for idx in seq[1:-1]])
            try:
                mol = Chem.MolFromSmiles(smiles, sanitize=True)
                if mol and is_plausible_molecule(smiles, true_mol):
                    smiles = Chem.MolToSmiles(mol, canonical=True, doRandom=True)
                    confidence = np.exp(score / len(seq))
                    results.append((smiles, confidence))
            except:
                continue
        return results if results else [("Invalid SMILES", 0.0)]


In [None]:
# Visualization Functions
def plot_attention_weights(attn_weights, title="Transformer Attention Weights"):
    plt.figure(figsize=(10, 8))
    plt.imshow(attn_weights.squeeze().cpu().numpy(), cmap='viridis')
    plt.colorbar()
    plt.title(title)
    plt.xlabel("Key Tokens")
    plt.ylabel("Query Tokens")
    plt.show()

def plot_gnn_edge_weights(edge_weights, edge_index, title="GNN Edge Importance"):
    edge_scores = edge_weights[-1].cpu().numpy()
    plt.figure(figsize=(10, 8))
    plt.hist(edge_scores, bins=50)
    plt.title(title)
    plt.xlabel("Edge Weight Magnitude")
    plt.ylabel("Frequency")
    plt.show()

# Error Analysis
def error_analysis(pred_smiles_list, true_smiles_list, adducts, precomputed_fps):
    errors = {'small': 0, 'large': 0, 'aromatic': 0, 'aliphatic': 0}
    adduct_errors = {adduct: [] for adduct in adduct_types}
    for pred_smiles, true_smiles, adduct in zip(pred_smiles_list, true_smiles_list, adducts):
        tanimoto = tanimoto_similarity(pred_smiles, true_smiles, precomputed_fps)
        if tanimoto < 0.3:
            mol = Chem.MolFromSmiles(true_smiles)
            if mol:
                mw = Descriptors.MolWt(mol)
                is_aromatic = any(atom.GetIsAromatic() for atom in mol.GetAtoms())
                errors['small' if mw < 300 else 'large'] += 1
                errors['aromatic' if is_aromatic else 'aliphatic'] += 1
                adduct_errors[adduct].append(tanimoto)
    print("Error Analysis:")
    print(f"Small molecules (<300 Da) errors: {errors['small']}")
    print(f"Large molecules (‚â•300 Da) errors: {errors['large']}")
    print(f"Aromatic molecule errors: {errors['aromatic']}")
    print(f"Aliphatic molecule errors: {errors['aliphatic']}")
    for adduct, scores in adduct_errors.items():
        if scores:
            print(f"Adduct {adduct} - Avg Tanimoto: {np.mean(scores):.4f}, Count: {len(scores)}")

In [None]:
# Hyperparameter Tuning
def objective(trial, train_data, val_data):
    lr = trial.suggest_float('lr', 1e-5, 1e-3, log=True)
    train_dataset = MSMSDataset(train_data, max_len=SUPERVISED_MAX_LEN, is_ssl=False)
    val_dataset = MSMSDataset(val_data, max_len=SUPERVISED_MAX_LEN, is_ssl=False)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=32, num_workers=2)
    model = MSMS2SmilesHybrid(vocab_size=vocab_size, d_model=768, nhead=12, num_layers=8, dim_feedforward=2048, dropout=0.2, fp_size=2048).to(device)
    return supervised_train(model, train_loader, val_loader, epochs=10, lr=lr)

In [None]:
# === Optimized Cross-Validation and Training ===
import torch
import pickle
import numpy as np
import pandas as pd
import time
from torch_geometric.loader import DataLoader
from sklearn.model_selection import KFold
from tqdm import tqdm
import glob
import gc
import os

# --- Setup ---
print("--- Initializing Setup ---")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

df_massspecgym_sample = pd.read_parquet("df_massspecgym.parquet", columns=["adduct"])
adduct_types = df_massspecgym_sample['adduct'].unique()
del df_massspecgym_sample
print(f"Adduct types loaded successfully. Count: {len(adduct_types)}")

processed_chunk_dir = "processed_chunks"
graph_chunk_dir = "/media/onepaw/seagate_manual/graph_data_chunks"
processed_files = np.array(sorted(glob.glob(os.path.join(processed_chunk_dir, "*.parquet"))))
graph_files = np.array(sorted(glob.glob(os.path.join(graph_chunk_dir, "*.pkl"))))
kf = KFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []
print(f"Found {len(processed_files)} chunk file pairs for 5-fold cross-validation.")

# === Training Parameters ===
NUM_EPOCHS = 10
NUM_WORKERS = 4
PATIENCE = 5
early_break = True  # Set to False to run all folds

for fold_idx, (train_indices, val_indices) in enumerate(kf.split(processed_files)):
    print(f"\n{'='*20} FOLD {fold_idx + 1}/5 {'='*20}")
    train_proc_files, train_graph_files = processed_files[train_indices], graph_files[train_indices]
    val_proc_files, val_graph_files = processed_files[val_indices], graph_files[val_indices]

    print(f"Initializing model for Fold {fold_idx + 1}...")
    model = MSMS2SmilesHybrid(vocab_size=vocab_size, num_adducts=len(adduct_types)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
    scaler = torch.cuda.amp.GradScaler()  # ‚úÖ fixed: no device_type here
    smiles_criterion = torch.nn.CrossEntropyLoss(ignore_index=token_to_idx[PAD_TOKEN])
    best_val_loss, epochs_no_improve = float('inf'), 0

    for epoch in range(NUM_EPOCHS):
        print(f"\n--- Epoch {epoch + 1}/{NUM_EPOCHS} ---")
        model.train()
        total_train_loss, train_batches = 0, 0
        start_time = time.time()

        for proc_file, graph_file in tqdm(zip(train_proc_files, train_graph_files), total=len(train_proc_files), desc="Training"):
            try:
                df_chunk = pd.read_parquet(proc_file)
                with open(graph_file, 'rb') as f:
                    graph_data_chunk = pickle.load(f)

                dataset = MSMSDataset(df_chunk, graph_data_chunk)
                loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)

                for spectra, graph_data, smiles_tokens, ion_modes, precursor_bins, adduct_indices, _ in loader:
                    graph_data = graph_data.to(device)
                    spectra, smiles_tokens = spectra.to(device), smiles_tokens.to(device)
                    ion_modes, precursor_bins, adduct_indices = ion_modes.to(device), precursor_bins.to(device), adduct_indices.to(device)
                    tgt_input, tgt_output = smiles_tokens[:, :-1], smiles_tokens[:, 1:]
                    tgt_mask = model.generate_square_subsequent_mask(tgt_input.size(1))

                    optimizer.zero_grad(set_to_none=True)
                    with torch.cuda.amp.autocast():  # ‚úÖ fixed
                        out, val_penalty = model(spectra, graph_data, tgt_input, ion_modes, precursor_bins, adduct_indices, tgt_mask)
                        loss = smiles_criterion(out.reshape(-1, vocab_size), tgt_output.reshape(-1)) + 0.1 * val_penalty.mean()

                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                    total_train_loss += loss.item()
                    train_batches += 1

                del df_chunk, graph_data_chunk, dataset, loader
                gc.collect()

            except Exception as e:
                print(f"‚ùå Error in {os.path.basename(proc_file)}: {e}")
                continue

        avg_train_loss = total_train_loss / max(train_batches, 1)
        print(f"‚úÖ Epoch {epoch+1} Train Loss: {avg_train_loss:.4f} | Time: {(time.time()-start_time)/60:.2f} mins")

        # --- VALIDATION ---
        model.eval()
        total_val_loss, val_batches = 0, 0

        with torch.no_grad():
            for proc_file, graph_file in tqdm(zip(val_proc_files, val_graph_files), total=len(val_proc_files), desc="Validating"):
                try:
                    df_chunk = pd.read_parquet(proc_file)
                    with open(graph_file, 'rb') as f:
                        graph_data_chunk = pickle.load(f)

                    dataset = MSMSDataset(df_chunk, graph_data_chunk)
                    loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

                    for spectra, graph_data, smiles_tokens, ion_modes, precursor_bins, adduct_indices, _ in loader:
                        graph_data = graph_data.to(device)
                        spectra = spectra.to(device)
                        smiles_tokens = smiles_tokens.to(device)
                        ion_modes = ion_modes.to(device)
                        precursor_bins = precursor_bins.to(device)
                        adduct_indices = adduct_indices.to(device)
                        tgt_input, tgt_output = smiles_tokens[:, :-1], smiles_tokens[:, 1:]
                        tgt_mask = model.generate_square_subsequent_mask(tgt_input.size(1))

                        with torch.cuda.amp.autocast():  # ‚úÖ fixed
                            out, val_penalty = model(spectra, graph_data, tgt_input, ion_modes, precursor_bins, adduct_indices, tgt_mask)
                            loss = smiles_criterion(out.reshape(-1, vocab_size), tgt_output.reshape(-1)) + 0.1 * val_penalty.mean()

                        total_val_loss += loss.item()
                        val_batches += 1

                    del df_chunk, graph_data_chunk, dataset, loader
                    gc.collect()

                except Exception as e:
                    print(f"‚ùå Validation error in {os.path.basename(proc_file)}: {e}")
                    continue

        avg_val_loss = total_val_loss / max(val_batches, 1)
        print(f"üìä Val Loss: {avg_val_loss:.4f}")
        scheduler.step(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            print(f"üì• Best model updated (Fold {fold_idx+1})")
            torch.save({'model_state_dict': model.state_dict(), 'token_to_idx': token_to_idx, 'idx_to_token': idx_to_token},
                       f"best_model_fold_{fold_idx+1}.pt")
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= PATIENCE:
            print("üõë Early stopping")
            break

    fold_results.append(best_val_loss)
    print(f"‚úÖ Fold {fold_idx+1} done | Best Val Loss: {best_val_loss:.4f}")

    if early_break:
        break

print("\nüéâ Training Complete")
print(f"Fold results: {fold_results}")
print(f"Avg best val loss: {np.mean(fold_results):.4f}")


In [None]:


# External Dataset Evaluation
model.eval()
external_metrics = {'tanimoto': [], 'dice': [], 'mcs': [], 'mw_diff': [], 'logp_diff': [], 'substructure': []}
pred_smiles_list = []
true_smiles_list = []
adducts_list = []
num_samples = min(5, len(external_dataset))

for sample_idx in range(num_samples):
    sample_spectrum = external_dataset[sample_idx][0]
    sample_graph = external_dataset[sample_idx][1]
    sample_ion_mode = external_dataset[sample_idx][3]
    sample_precursor_bin = external_dataset[sample_idx][4]
    sample_adduct_idx = external_dataset[sample_idx][5]
    true_smiles = external_dataset[sample_idx][6]

    predicted_results = beam_search(model, sample_spectrum, sample_graph, sample_ion_mode, sample_precursor_bin, sample_adduct_idx, true_smiles, beam_width=10, max_len=SUPERVISED_MAX_LEN, device=device)
    pred_smiles_list.extend([smiles for smiles, _ in predicted_results])
    true_smiles_list.extend([true_smiles] * len(predicted_results))
    adducts_list.extend([df_external.iloc[sample_idx]['adduct']] * len(predicted_results))

    print(f"\nExternal Sample {sample_idx} - True SMILES: {true_smiles}")
    print("Top Predicted SMILES:")
    for smiles, confidence in predicted_results[:3]:
        external_metrics['tanimoto'].append(tanimoto_similarity(smiles, true_smiles, all_fingerprints))
        external_metrics['dice'].append(dice_similarity(smiles, true_smiles))
        external_metrics['mcs'].append(mcs_similarity(smiles, true_smiles))
        external_metrics['mw_diff'].append(mw_difference(smiles, true_smiles))
        external_metrics['logp_diff'].append(logp_difference(smiles, true_smiles))
        external_metrics['substructure'].append(substructure_match(smiles, true_smiles, model.gnn_encoder.substructures))
        print(f"SMILES: {smiles}, Confidence: {confidence:.4f}, Tanimoto: {external_metrics['tanimoto'][-1]:.4f}, Dice: {external_metrics['dice'][-1]:.4f}, MCS: {external_metrics['mcs'][-1]:.4f}")
        if len(smiles) > 100 and smiles.count('C') > len(smiles) * 0.8:
            print("Warning: Predicted SMILES is a long carbon chain, indicating potential model underfitting.")
        if smiles != "Invalid SMILES":
            mol = Chem.MolFromSmiles(smiles, sanitize=True)
            if mol:
                print(f"Molecular Weight: {Descriptors.MolWt(mol):.2f}, LogP: {Descriptors.MolLogP(mol):.2f}")

    # Visualize molecules
    if predicted_results[0][0] != "Invalid SMILES":
        pred_mol = Chem.MolFromSmiles(predicted_results[0][0], sanitize=True)
        true_mol = Chem.MolFromSmiles(true_smiles, sanitize=True)
        if pred_mol and true_mol:
            img = Draw.MolsToGridImage([true_mol, pred_mol], molsPerRow=2, subImgSize=(300, 300), legends=['True', 'Predicted'])
            img_array = np.array(img.convert('RGB'))
            plt.figure(figsize=(10, 5))
            plt.imshow(img_array)
            plt.axis('off')
            plt.title(f"External Sample {sample_idx} - Tanimoto: {external_metrics['tanimoto'][0]:.4f}")
            plt.show()

    # Visualize attention and GNN weights for first sample
    if sample_idx == 0:
        with torch.no_grad():
            spectrum = sample_spectrum.unsqueeze(0).to(device)
            graph_data = Batch.from_data_list([sample_graph]).to(device)
            ion_mode_idx = torch.tensor([sample_ion_mode], dtype=torch.long).to(device)
            precursor_idx = torch.tensor([sample_precursor_bin], dtype=torch.long).to(device)
            adduct_idx = torch.tensor([sample_adduct_idx], dtype=torch.long).to(device)
            _, attn_weights = model.transformer_encoder(spectrum, ion_mode_idx, precursor_idx, adduct_idx)
            _, _, edge_weights = model.gnn_encoder(graph_data, ion_mode_idx, precursor_idx, adduct_idx)
            plot_attention_weights(attn_weights, title=f"External Fold Transformer Attention Weights")
            plot_gnn_edge_weights(edge_weights, sample_graph.edge_index, title=f"External Fold GNN Edge Importance")

# Final Evaluation
print(f"External Validity Rate: {validity_rate(pred_smiles_list):.2f}%")
print(f"External Prediction Diversity: {prediction_diversity(pred_smiles_list):.4f}")
print("External Metrics Summary:")
print(f"Avg Tanimoto: {np.mean(external_metrics['tanimoto']):.4f}")
print(f"Avg Dice: {np.mean(external_metrics['dice']):.4f}")
print(f"Avg MCS: {np.mean(external_metrics['mcs']):.4f}")
print(f"Avg MW Difference: {np.mean([x for x in external_metrics['mw_diff'] if x != float('inf')]):.2f}")
print(f"Avg LogP Difference: {np.mean([x for x in external_metrics['logp_diff'] if x != float('inf')]):.2f}")
print(f"Avg Substructure Match: {np.mean(external_metrics['substructure']):.4f}")
error_analysis(pred_smiles_list, true_smiles_list, adducts_list, all_fingerprints)


In [None]:
import torch
print("Torch version:", torch.__version__)
print("AMP GradScaler:", torch.amp.GradScaler)
