In [None]:
!pip install torch torchvision torchaudio rdkit datasets tokenizers tqdm

In [1]:

#final_version
# stereochemistry_fixed

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.data import Data, Batch
from datasets import load_dataset
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import Draw, Descriptors, rdFMCS, EnumerateStereoisomers
from rdkit import DataStructs
from rdkit.Chem import rdFingerprintGenerator
from tqdm import tqdm
import math
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
from torch.cuda.amp import GradScaler, autocast
import optuna
from nltk.translate.bleu_score import sentence_bleu
from Levenshtein import distance
from pandarallel import pandarallel #Added by Pawan
import time  #Added by Pawan
import os #Added by Pawan
import glob #Added by Pawan
%matplotlib inline

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Define token variables early
PAD_TOKEN = "<PAD>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"
MASK_TOKEN = "[MASK]"

In [2]:
# Load and preprocess dataset
dataset = load_dataset('roman-bushuiev/MassSpecGym', split='val')
df = pd.DataFrame(dataset)

In [3]:
# Simulate external dataset (e.g., NIST-like) by splitting
df_massspecgym, df_external = df.iloc[:int(0.9*len(df))], df.iloc[int(0.9*len(df)):]
print("MassSpecGym size:", len(df_massspecgym), "External test size:", len(df_external))


MassSpecGym size: 207993 External test size: 23111


In [4]:
# Inspect dataset
print("Dataset Columns:", df_massspecgym.columns.tolist())
print("\nFirst few rows of MassSpecGym dataset:")
print(df_massspecgym[['identifier', 'mzs', 'intensities', 'smiles', 'adduct', 'precursor_mz']].head())
print("\nUnique adduct values:", df_massspecgym['adduct'].unique())


Dataset Columns: ['identifier', 'mzs', 'intensities', 'smiles', 'inchikey', 'formula', 'precursor_formula', 'parent_mass', 'precursor_mz', 'adduct', 'instrument_type', 'collision_energy', 'fold', 'simulation_challenge']

First few rows of MassSpecGym dataset:
             identifier                                                mzs  \
0  MassSpecGymID0000001  91.0542,125.0233,154.0499,155.0577,185.0961,20...   
1  MassSpecGymID0000002  91.0542,125.0233,155.0577,185.0961,229.0859,24...   
2  MassSpecGymID0000003  69.0343,91.0542,125.0233,127.039,153.0699,154....   
3  MassSpecGymID0000004  69.0343,91.0542,110.06,111.0441,112.0393,120.0...   
4  MassSpecGymID0000005  91.0542,125.0233,185.0961,229.0859,246.1125,28...   

                                         intensities  \
0  0.24524524524524524,1.0,0.08008008008008008,0....   
1  0.0990990990990991,0.28128128128128127,0.04004...   
2  0.03403403403403404,0.31431431431431434,1.0,0....   
3  0.17917917917917917,0.47347347347347346,0.03

In [None]:
identifier 	mzs 	intensities 	inchikey 	formula 	precursor_formula 	parent_mass 	precursor_mz 	adduct 	instrument_type 	collision_energy 	fold 	simulation_challenge 	smiles

In [2]:
# Data augmentation: SMILES enumeration and spectral noise
def augment_smiles(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            stereoisomers = EnumerateStereoisomers.EnumerateStereoisomers(mol)
            return [Chem.MolToSmiles(m, canonical=True, doRandom=True) for m in stereoisomers]
        return [smiles]
    except:
        return [smiles]

def bin_spectrum_to_graph(mzs, intensities, ion_mode, precursor_mz, adduct, n_bins=1000, max_mz=1000, noise_level=0.05):
    spectrum = np.zeros(n_bins)
    for mz, intensity in zip(mzs, intensities):
        try:
            mz = float(mz)
            intensity = float(intensity)
            if mz < max_mz:
                bin_idx = int((mz / max_mz) * n_bins)
                spectrum[bin_idx] += intensity
        except (ValueError, TypeError):
            continue
    if spectrum.max() > 0:
        spectrum = spectrum / spectrum.max()
    spectrum += np.random.normal(0, noise_level, spectrum.shape).clip(0, 1)
    x = torch.tensor(spectrum, dtype=torch.float).unsqueeze(-1)
    edge_index = []
    for i in range(n_bins-1):
        edge_index.append([i, i+1])
        edge_index.append([i+1, i])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t()
    ion_mode = torch.tensor([ion_mode], dtype=torch.float)
    precursor_mz = torch.tensor([precursor_mz], dtype=torch.float)
    adduct_idx = adduct_to_idx.get(adduct, 0)
    return spectrum, Data(x=x, edge_index=edge_index, ion_mode=ion_mode, precursor_mz=precursor_mz, adduct_idx=adduct_idx)

In [6]:
# Canonicalize SMILES and augment
pandarallel.initialize(nb_workers=16, progress_bar=True) #Added by Pawan
start_time = time.time()
def canonicalize_smiles(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        if mol:
            return Chem.MolToSmiles(mol, canonical=True)
        return None
    except:
        return None

df_massspecgym['smiles'] = df_massspecgym['smiles'].parallel_apply(canonicalize_smiles) # Added by Pawan
df_external['smiles'] = df_external['smiles'].parallel_apply(canonicalize_smiles) #Added by Pawan
df_massspecgym = df_massspecgym.dropna(subset=['smiles']) 
df_external = df_external.dropna(subset=['smiles'])
df_massspecgym['smiles_list'] = df_massspecgym['smiles'].parallel_apply(augment_smiles)
df_massspecgym = df_massspecgym.explode('smiles_list').dropna(subset=['smiles_list'])
df_massspecgym = df_massspecgym.drop(columns=['smiles']) # Drop original 'smiles' to prevent duplicates; added by Pawan
df_massspecgym = df_massspecgym.rename(columns={'smiles_list': 'smiles'}) # Rename exploded list column to 'smiles'; added by Pawan
df_massspecgym.to_parquet("df_massspecgym.parquet")
df_external.to_parquet("df_external.parquet")
print("Completed in {:.2f} seconds".format(time.time() - start_time)) #Added by Pawan


INFO: Pandarallel will run on 16 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=13000), Label(value='0 / 13000')))…

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_massspecgym['smiles'] = df_massspecgym['smiles'].parallel_apply(canonicalize_smiles) # Added by Pawan


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=1445), Label(value='0 / 1445'))), …

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_external['smiles'] = df_external['smiles'].parallel_apply(canonicalize_smiles) #Added by Pawan


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=13000), Label(value='0 / 13000')))…

Completed in 1823.35 seconds


In [10]:
print(df_massspecgym.shape) #Added by Pawan
print(df_massspecgym.index.nunique()) #Added by Pawan
print(df_external.shape) #Added by Pawan
print(df_external.index.nunique()) #Added by Pawan

(19329189, 14)
207993
(23111, 14)
23111


In [11]:
df_massspecgym.reset_index(drop=True, inplace=True) #Added by Pawan

In [12]:
print(df_massspecgym.shape) #Added by Pawan
print(df_massspecgym.index.nunique()) #Added by Pawan
print(df_external.shape) #Added by Pawan
print(df_external.index.nunique()) #Added by Pawan

(19329189, 14)
19329189
(23111, 14)
23111


In [13]:
df_massspecgym.head(5)

Unnamed: 0,identifier,mzs,intensities,inchikey,formula,precursor_formula,parent_mass,precursor_mz,adduct,instrument_type,collision_energy,fold,simulation_challenge,smiles
0,MassSpecGymID0000001,"91.0542,125.0233,154.0499,155.0577,185.0961,20...","0.24524524524524524,1.0,0.08008008008008008,0....",VFMQMACUYWGDOJ,C16H17NO4,C16H18NO4,287.115224,288.1225,[M+H]+,Orbitrap,30.0,train,True,COc1cc(oc(c1)=O)[C@@H](NC(C)=O)Cc1ccccc1
1,MassSpecGymID0000002,"91.0542,125.0233,155.0577,185.0961,229.0859,24...","0.0990990990990991,0.28128128128128127,0.04004...",VFMQMACUYWGDOJ,C16H17NO4,C16H18NO4,287.115224,288.1225,[M+H]+,Orbitrap,20.0,train,True,c1ccc(C[C@H](NC(=O)C)c2cc(cc(=O)o2)OC)cc1
2,MassSpecGymID0000003,"69.0343,91.0542,125.0233,127.039,153.0699,154....","0.03403403403403404,0.31431431431431434,1.0,0....",VFMQMACUYWGDOJ,C16H17NO4,C16H18NO4,287.115224,288.1225,[M+H]+,Orbitrap,40.0,train,True,c1cccc(c1)C[C@@H](c1cc(cc(o1)=O)OC)NC(C)=O
3,MassSpecGymID0000004,"69.0343,91.0542,110.06,111.0441,112.0393,120.0...","0.17917917917917917,0.47347347347347346,0.0380...",VFMQMACUYWGDOJ,C16H17NO4,C16H18NO4,287.115224,288.1225,[M+H]+,Orbitrap,55.0,train,True,c1cc(ccc1)C[C@@H](c1oc(=O)cc(OC)c1)NC(=O)C
4,MassSpecGymID0000005,"91.0542,125.0233,185.0961,229.0859,246.1125,28...","0.07807807807807808,0.1841841841841842,0.03503...",VFMQMACUYWGDOJ,C16H17NO4,C16H18NO4,287.115224,288.1225,[M+H]+,Orbitrap,10.0,train,True,N([C@H](c1oc(cc(OC)c1)=O)Cc1ccccc1)C(C)=O


#Preprocess ion mode, precursor m/z, and adducts
from pandarallel import pandarallel #Added by Pawan
pandarallel.initialize(nb_workers=16, progress_bar=True) #Added by Pawan
import time # Added by Pawan
start_time = time.time()
df_massspecgym['ion_mode'] = df_massspecgym['adduct'].parallel_apply(lambda x: 0 if '+' in str(x) else 1 if '-' in str(x) else 0).fillna(0)
df_massspecgym['precursor_bin'] = pd.qcut(df_massspecgym['precursor_mz'], q=100, labels=False, duplicates='drop')
df_external['ion_mode'] = df_external['adduct'].parallel_apply(lambda x: 0 if '+' in str(x) else 1 if '-' in str(x) else 0).fillna(0)
df_external['precursor_bin'] = pd.qcut(df_external['precursor_mz'], q=100, labels=False, duplicates='drop')
adduct_types = df_massspecgym['adduct'].unique()
adduct_to_idx = {adduct: i for i, adduct in enumerate(adduct_types)}
df_massspecgym['adduct_idx'] = df_massspecgym['adduct'].map(adduct_to_idx)
df_external['adduct_idx'] = df_external['adduct'].map(adduct_to_idx)

df_massspecgym[['binned', 'graph_data']] = df_massspecgym.parallel_apply(
    lambda row: pd.Series(bin_spectrum_to_graph(row['mzs'], row['intensities'], row['ion_mode'], row['precursor_mz'], row['adduct'])),
    axis=1
)
df_external[['binned', 'graph_data']] = df_external.parallel_apply(
    lambda row: pd.Series(bin_spectrum_to_graph(row['mzs'], row['intensities'], row['ion_mode'], row['precursor_mz'], row['adduct'])),
    axis=1
)
print("Completed in {:.2f} seconds".format(time.time() - start_time)) #Added by Pawan

In [3]:
#Preprocess ion mode, precursor m/z, and adducts
#Setup and Load
import os
import time
import pandas as pd
from pandarallel import pandarallel
import pyarrow as pa

pandarallel.initialize(nb_workers=16, progress_bar=True)

# Load datasets
df_massspecgym = pd.read_parquet("df_massspecgym.parquet")
df_external = pd.read_parquet("df_external.parquet")

# Build adduct mapping from df_massspecgym
adduct_types = df_massspecgym['adduct'].unique()
adduct_to_idx = {adduct: i for i, adduct in enumerate(adduct_types)}

# Create output directory
os.makedirs("processed_chunks", exist_ok=True)


INFO: Pandarallel will run on 16 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [4]:
#Preprocess ion mode, precursor m/z, and adducts
#Define Processing Function
def preprocess_chunk(df_chunk, chunk_idx):
    start_time = time.time()

    df_chunk['ion_mode'] = df_chunk['adduct'].parallel_apply(
        lambda x: 0 if '+' in str(x) else 1 if '-' in str(x) else 0
    ).fillna(0)

    df_chunk['precursor_bin'] = pd.qcut(
        df_chunk['precursor_mz'], q=100, labels=False, duplicates='drop'
    )

    df_chunk['adduct_idx'] = df_chunk['adduct'].map(adduct_to_idx)

    df_chunk[['binned', 'graph_data']] = df_chunk.parallel_apply(
        lambda row: pd.Series(bin_spectrum_to_graph(
            row['mzs'], row['intensities'], row['ion_mode'],
            row['precursor_mz'], row['adduct']
        )),
        axis=1
    )

    # Drop graph_data column before saving to avoid pyarrow error
    df_chunk.drop(columns=['graph_data'], inplace=True)

    df_chunk.to_parquet(f"processed_chunks/df_massspecgym_chunk_{chunk_idx:03}.parquet")
    print(f"✅ Saved chunk {chunk_idx} | Rows: {len(df_chunk)} | Time: {time.time() - start_time:.2f} sec")


In [5]:
#Preprocess ion mode, precursor m/z, and adducts
#Chunk df_massspecgym
chunk_size = 100_000
n_chunks = (len(df_massspecgym) + chunk_size - 1) // chunk_size

for i in range(n_chunks):
    output_file = f"processed_chunks/df_massspecgym_chunk_{i:03}.parquet"
    if os.path.exists(output_file):
        print(f"⏩ Skipping chunk {i} (already exists)")
        continue

    df_chunk = df_massspecgym.iloc[i * chunk_size : (i + 1) * chunk_size].copy()
    preprocess_chunk(df_chunk, i)


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 0 | Rows: 100000 | Time: 52.61 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 1 | Rows: 100000 | Time: 64.91 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 2 | Rows: 100000 | Time: 55.39 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 3 | Rows: 100000 | Time: 54.43 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 4 | Rows: 100000 | Time: 56.63 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 5 | Rows: 100000 | Time: 56.80 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 6 | Rows: 100000 | Time: 55.22 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 7 | Rows: 100000 | Time: 55.11 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 8 | Rows: 100000 | Time: 54.97 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 9 | Rows: 100000 | Time: 53.41 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 10 | Rows: 100000 | Time: 53.27 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 11 | Rows: 100000 | Time: 53.84 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 12 | Rows: 100000 | Time: 53.49 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 13 | Rows: 100000 | Time: 53.63 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 14 | Rows: 100000 | Time: 55.54 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 15 | Rows: 100000 | Time: 53.54 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 16 | Rows: 100000 | Time: 53.66 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 17 | Rows: 100000 | Time: 55.17 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 18 | Rows: 100000 | Time: 54.15 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 19 | Rows: 100000 | Time: 54.73 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 20 | Rows: 100000 | Time: 56.42 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 21 | Rows: 100000 | Time: 54.24 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 22 | Rows: 100000 | Time: 54.60 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 23 | Rows: 100000 | Time: 53.51 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 24 | Rows: 100000 | Time: 52.96 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 25 | Rows: 100000 | Time: 54.11 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 26 | Rows: 100000 | Time: 54.78 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 27 | Rows: 100000 | Time: 53.36 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 28 | Rows: 100000 | Time: 53.02 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 29 | Rows: 100000 | Time: 53.29 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 30 | Rows: 100000 | Time: 53.08 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 31 | Rows: 100000 | Time: 53.84 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 32 | Rows: 100000 | Time: 53.15 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 33 | Rows: 100000 | Time: 53.13 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 34 | Rows: 100000 | Time: 56.98 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 35 | Rows: 100000 | Time: 56.03 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 36 | Rows: 100000 | Time: 54.57 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 37 | Rows: 100000 | Time: 53.01 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 38 | Rows: 100000 | Time: 57.84 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 39 | Rows: 100000 | Time: 55.42 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 40 | Rows: 100000 | Time: 55.70 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 41 | Rows: 100000 | Time: 57.48 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 42 | Rows: 100000 | Time: 58.72 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 43 | Rows: 100000 | Time: 57.97 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 44 | Rows: 100000 | Time: 59.25 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 45 | Rows: 100000 | Time: 59.27 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 46 | Rows: 100000 | Time: 58.32 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 47 | Rows: 100000 | Time: 53.31 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 48 | Rows: 100000 | Time: 55.91 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 49 | Rows: 100000 | Time: 57.57 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 50 | Rows: 100000 | Time: 56.63 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 51 | Rows: 100000 | Time: 58.24 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 52 | Rows: 100000 | Time: 58.92 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 53 | Rows: 100000 | Time: 57.24 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 54 | Rows: 100000 | Time: 57.62 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 55 | Rows: 100000 | Time: 59.16 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 56 | Rows: 100000 | Time: 57.05 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 57 | Rows: 100000 | Time: 58.11 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 58 | Rows: 100000 | Time: 60.19 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 59 | Rows: 100000 | Time: 59.55 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 60 | Rows: 100000 | Time: 55.67 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 61 | Rows: 100000 | Time: 57.70 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 62 | Rows: 100000 | Time: 57.26 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 63 | Rows: 100000 | Time: 58.27 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 64 | Rows: 100000 | Time: 55.41 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 65 | Rows: 100000 | Time: 54.48 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 66 | Rows: 100000 | Time: 53.52 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 67 | Rows: 100000 | Time: 54.81 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 68 | Rows: 100000 | Time: 53.88 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 69 | Rows: 100000 | Time: 54.76 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 70 | Rows: 100000 | Time: 54.96 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 71 | Rows: 100000 | Time: 55.29 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 72 | Rows: 100000 | Time: 54.90 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 73 | Rows: 100000 | Time: 55.23 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 74 | Rows: 100000 | Time: 55.60 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 75 | Rows: 100000 | Time: 53.83 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 76 | Rows: 100000 | Time: 54.99 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 77 | Rows: 100000 | Time: 54.22 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 78 | Rows: 100000 | Time: 54.21 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 79 | Rows: 100000 | Time: 53.47 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 80 | Rows: 100000 | Time: 54.33 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 81 | Rows: 100000 | Time: 54.28 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 82 | Rows: 100000 | Time: 54.08 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 83 | Rows: 100000 | Time: 55.78 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 84 | Rows: 100000 | Time: 53.51 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 85 | Rows: 100000 | Time: 55.01 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 86 | Rows: 100000 | Time: 55.19 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 87 | Rows: 100000 | Time: 54.78 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 88 | Rows: 100000 | Time: 53.62 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 89 | Rows: 100000 | Time: 55.40 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 90 | Rows: 100000 | Time: 54.04 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 91 | Rows: 100000 | Time: 56.18 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 92 | Rows: 100000 | Time: 55.13 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 93 | Rows: 100000 | Time: 55.88 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 94 | Rows: 100000 | Time: 55.27 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 95 | Rows: 100000 | Time: 55.54 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 96 | Rows: 100000 | Time: 54.58 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 97 | Rows: 100000 | Time: 53.99 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 98 | Rows: 100000 | Time: 55.62 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 99 | Rows: 100000 | Time: 53.97 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 100 | Rows: 100000 | Time: 55.82 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 101 | Rows: 100000 | Time: 53.20 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 102 | Rows: 100000 | Time: 56.08 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 103 | Rows: 100000 | Time: 55.22 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 104 | Rows: 100000 | Time: 55.63 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 105 | Rows: 100000 | Time: 54.24 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 106 | Rows: 100000 | Time: 55.91 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 107 | Rows: 100000 | Time: 55.11 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 108 | Rows: 100000 | Time: 54.94 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 109 | Rows: 100000 | Time: 54.34 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 110 | Rows: 100000 | Time: 54.43 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 111 | Rows: 100000 | Time: 55.51 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 112 | Rows: 100000 | Time: 54.67 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 113 | Rows: 100000 | Time: 54.32 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 114 | Rows: 100000 | Time: 55.08 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 115 | Rows: 100000 | Time: 54.82 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 116 | Rows: 100000 | Time: 54.97 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 117 | Rows: 100000 | Time: 55.36 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 118 | Rows: 100000 | Time: 55.54 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 119 | Rows: 100000 | Time: 54.93 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 120 | Rows: 100000 | Time: 53.91 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 121 | Rows: 100000 | Time: 55.60 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 122 | Rows: 100000 | Time: 56.46 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 123 | Rows: 100000 | Time: 54.73 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 124 | Rows: 100000 | Time: 55.00 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 125 | Rows: 100000 | Time: 55.30 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 126 | Rows: 100000 | Time: 53.83 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 127 | Rows: 100000 | Time: 53.61 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 128 | Rows: 100000 | Time: 52.53 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 129 | Rows: 100000 | Time: 55.85 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 130 | Rows: 100000 | Time: 55.56 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 131 | Rows: 100000 | Time: 55.16 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 132 | Rows: 100000 | Time: 53.82 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 133 | Rows: 100000 | Time: 55.20 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 134 | Rows: 100000 | Time: 55.19 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 135 | Rows: 100000 | Time: 54.82 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 136 | Rows: 100000 | Time: 56.06 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 137 | Rows: 100000 | Time: 54.97 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 138 | Rows: 100000 | Time: 54.48 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 139 | Rows: 100000 | Time: 54.61 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 140 | Rows: 100000 | Time: 53.74 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 141 | Rows: 100000 | Time: 53.52 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 142 | Rows: 100000 | Time: 55.65 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 143 | Rows: 100000 | Time: 53.97 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 144 | Rows: 100000 | Time: 54.40 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 145 | Rows: 100000 | Time: 55.60 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 146 | Rows: 100000 | Time: 54.78 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 147 | Rows: 100000 | Time: 54.40 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 148 | Rows: 100000 | Time: 54.88 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 149 | Rows: 100000 | Time: 56.27 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 150 | Rows: 100000 | Time: 53.74 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 151 | Rows: 100000 | Time: 55.23 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 152 | Rows: 100000 | Time: 55.26 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 153 | Rows: 100000 | Time: 54.96 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 154 | Rows: 100000 | Time: 56.22 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 155 | Rows: 100000 | Time: 54.70 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 156 | Rows: 100000 | Time: 55.76 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 157 | Rows: 100000 | Time: 53.87 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 158 | Rows: 100000 | Time: 53.97 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 159 | Rows: 100000 | Time: 55.37 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 160 | Rows: 100000 | Time: 55.76 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 161 | Rows: 100000 | Time: 55.58 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 162 | Rows: 100000 | Time: 55.78 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 163 | Rows: 100000 | Time: 55.21 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 164 | Rows: 100000 | Time: 56.27 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 165 | Rows: 100000 | Time: 56.30 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 166 | Rows: 100000 | Time: 56.02 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 167 | Rows: 100000 | Time: 55.57 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 168 | Rows: 100000 | Time: 55.55 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 169 | Rows: 100000 | Time: 56.81 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 170 | Rows: 100000 | Time: 55.74 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 171 | Rows: 100000 | Time: 56.58 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 172 | Rows: 100000 | Time: 56.46 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 173 | Rows: 100000 | Time: 55.59 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 174 | Rows: 100000 | Time: 53.90 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 175 | Rows: 100000 | Time: 55.13 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 176 | Rows: 100000 | Time: 53.81 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 177 | Rows: 100000 | Time: 56.33 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 178 | Rows: 100000 | Time: 55.36 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 179 | Rows: 100000 | Time: 54.97 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 180 | Rows: 100000 | Time: 56.63 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 181 | Rows: 100000 | Time: 54.41 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 182 | Rows: 100000 | Time: 56.11 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 183 | Rows: 100000 | Time: 56.42 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 184 | Rows: 100000 | Time: 56.36 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 185 | Rows: 100000 | Time: 57.40 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 186 | Rows: 100000 | Time: 55.19 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 187 | Rows: 100000 | Time: 55.36 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 188 | Rows: 100000 | Time: 57.28 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 189 | Rows: 100000 | Time: 55.09 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 190 | Rows: 100000 | Time: 56.96 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 191 | Rows: 100000 | Time: 54.87 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6250), Label(value='0 / 6250'))), …

✅ Saved chunk 192 | Rows: 100000 | Time: 57.44 sec


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=1825), Label(value='0 / 1825'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=1825), Label(value='0 / 1825'))), …

✅ Saved chunk 193 | Rows: 29189 | Time: 20.75 sec


In [None]:
#Preprocess ion mode, precursor m/z, and adducts
#Process and save each chunk safely — minimal RAM use
import pandas as pd
import pickle
import glob
import gc
import os
from pandarallel import pandarallel
from tqdm import tqdm
import time

start_time = time.time()
pandarallel.initialize(nb_workers=8, progress_bar=False)

# Load adduct mapping
df_massspecgym = pd.read_parquet("df_massspecgym.parquet", columns=["adduct"])
adduct_types = df_massspecgym['adduct'].unique()
adduct_to_idx = {adduct: i for i, adduct in enumerate(adduct_types)}
del df_massspecgym

chunk_files = sorted(glob.glob("processed_chunks/df_massspecgym_chunk_*.parquet"))
output_dir = "graph_data_chunks"
os.makedirs(output_dir, exist_ok=True)

for i, chunk_file in enumerate(tqdm(chunk_files, desc="Processing chunks")):
    df = pd.read_parquet(chunk_file)

    graph_data = df.parallel_apply(
        lambda row: bin_spectrum_to_graph(
            row['mzs'], row['intensities'], row['ion_mode'],
            row['precursor_mz'], row['adduct']
        )[1],
        axis=1
    )

    # Save per chunk — no accumulation
    out_path = os.path.join(output_dir, f"graph_data_chunk_{i:03}.pkl")
    with open(out_path, "wb") as f:
        pickle.dump(graph_data.tolist(), f)

    del df
    del graph_data
    gc.collect()
    print(f"✅ Saved {out_path}")

print("🎉 All chunks saved individually.")
print("🕒 Completed in {:.2f} seconds".format(time.time() - start_time))


INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


Processing chunks:   1%|                      | 1/194 [01:13<3:55:18, 73.15s/it]

✅ Saved graph_data_chunks/graph_data_chunk_000.pkl


Processing chunks:   1%|▏                     | 2/194 [02:23<3:49:14, 71.64s/it]

✅ Saved graph_data_chunks/graph_data_chunk_001.pkl


Processing chunks:   2%|▎                     | 3/194 [03:33<3:45:52, 70.95s/it]

✅ Saved graph_data_chunks/graph_data_chunk_002.pkl


Processing chunks:   2%|▍                     | 4/194 [04:43<3:43:25, 70.55s/it]

✅ Saved graph_data_chunks/graph_data_chunk_003.pkl


Processing chunks:   3%|▌                     | 5/194 [05:56<3:44:16, 71.20s/it]

✅ Saved graph_data_chunks/graph_data_chunk_004.pkl


Processing chunks:   3%|▋                     | 6/194 [07:08<3:44:44, 71.73s/it]

✅ Saved graph_data_chunks/graph_data_chunk_005.pkl


Processing chunks:   4%|▊                     | 7/194 [08:20<3:43:40, 71.76s/it]

✅ Saved graph_data_chunks/graph_data_chunk_006.pkl


Processing chunks:   4%|▉                     | 8/194 [09:31<3:41:19, 71.39s/it]

✅ Saved graph_data_chunks/graph_data_chunk_007.pkl


Processing chunks:   5%|█                     | 9/194 [10:43<3:40:31, 71.52s/it]

✅ Saved graph_data_chunks/graph_data_chunk_008.pkl


Processing chunks:   5%|█                    | 10/194 [11:53<3:37:58, 71.08s/it]

✅ Saved graph_data_chunks/graph_data_chunk_009.pkl


Processing chunks:   6%|█▏                   | 11/194 [13:03<3:35:45, 70.74s/it]

✅ Saved graph_data_chunks/graph_data_chunk_010.pkl


Processing chunks:   6%|█▎                   | 12/194 [14:13<3:34:25, 70.69s/it]

✅ Saved graph_data_chunks/graph_data_chunk_011.pkl


Processing chunks:   7%|█▍                   | 13/194 [15:24<3:32:57, 70.59s/it]

✅ Saved graph_data_chunks/graph_data_chunk_012.pkl


Processing chunks:   7%|█▌                   | 14/194 [16:35<3:32:09, 70.72s/it]

✅ Saved graph_data_chunks/graph_data_chunk_013.pkl


Processing chunks:   8%|█▌                   | 15/194 [17:46<3:31:23, 70.86s/it]

✅ Saved graph_data_chunks/graph_data_chunk_014.pkl


Processing chunks:   8%|█▋                   | 16/194 [18:57<3:30:15, 70.87s/it]

✅ Saved graph_data_chunks/graph_data_chunk_015.pkl


Processing chunks:   9%|█▊                   | 17/194 [20:07<3:28:44, 70.76s/it]

✅ Saved graph_data_chunks/graph_data_chunk_016.pkl


Processing chunks:   9%|█▉                   | 18/194 [21:18<3:27:15, 70.66s/it]

✅ Saved graph_data_chunks/graph_data_chunk_017.pkl


Processing chunks:  10%|██                   | 19/194 [22:29<3:26:39, 70.85s/it]

✅ Saved graph_data_chunks/graph_data_chunk_018.pkl


Processing chunks:  10%|██▏                  | 20/194 [23:40<3:25:52, 70.99s/it]

✅ Saved graph_data_chunks/graph_data_chunk_019.pkl


Processing chunks:  11%|██▎                  | 21/194 [24:54<3:26:47, 71.72s/it]

✅ Saved graph_data_chunks/graph_data_chunk_020.pkl


Processing chunks:  11%|██▍                  | 22/194 [26:06<3:26:24, 72.00s/it]

✅ Saved graph_data_chunks/graph_data_chunk_021.pkl


Processing chunks:  12%|██▍                  | 23/194 [27:18<3:24:54, 71.90s/it]

✅ Saved graph_data_chunks/graph_data_chunk_022.pkl


Processing chunks:  12%|██▌                  | 24/194 [28:29<3:22:51, 71.60s/it]

✅ Saved graph_data_chunks/graph_data_chunk_023.pkl


Processing chunks:  13%|██▋                  | 25/194 [29:39<3:20:39, 71.24s/it]

✅ Saved graph_data_chunks/graph_data_chunk_024.pkl


Processing chunks:  13%|██▊                  | 26/194 [30:49<3:18:33, 70.92s/it]

✅ Saved graph_data_chunks/graph_data_chunk_025.pkl


Processing chunks:  14%|██▉                  | 27/194 [32:00<3:16:54, 70.74s/it]

✅ Saved graph_data_chunks/graph_data_chunk_026.pkl


Processing chunks:  14%|███                  | 28/194 [33:10<3:15:38, 70.71s/it]

✅ Saved graph_data_chunks/graph_data_chunk_027.pkl


Processing chunks:  15%|███▏                 | 29/194 [34:21<3:14:18, 70.66s/it]

✅ Saved graph_data_chunks/graph_data_chunk_028.pkl


Processing chunks:  15%|███▏                 | 30/194 [35:32<3:13:11, 70.68s/it]

✅ Saved graph_data_chunks/graph_data_chunk_029.pkl


Processing chunks:  16%|███▎                 | 31/194 [36:42<3:11:54, 70.64s/it]

✅ Saved graph_data_chunks/graph_data_chunk_030.pkl


Processing chunks:  16%|███▍                 | 32/194 [37:53<3:10:23, 70.51s/it]

✅ Saved graph_data_chunks/graph_data_chunk_031.pkl


Processing chunks:  17%|███▌                 | 33/194 [39:04<3:10:09, 70.87s/it]

✅ Saved graph_data_chunks/graph_data_chunk_032.pkl


Processing chunks:  18%|███▋                 | 34/194 [40:15<3:08:56, 70.85s/it]

✅ Saved graph_data_chunks/graph_data_chunk_033.pkl


Processing chunks:  18%|███▊                 | 35/194 [41:27<3:08:37, 71.18s/it]

✅ Saved graph_data_chunks/graph_data_chunk_034.pkl


Processing chunks:  19%|███▉                 | 36/194 [42:41<3:09:25, 71.93s/it]

✅ Saved graph_data_chunks/graph_data_chunk_035.pkl


Processing chunks:  19%|████                 | 37/194 [43:51<3:07:08, 71.52s/it]

✅ Saved graph_data_chunks/graph_data_chunk_036.pkl


Processing chunks:  20%|████                 | 38/194 [45:02<3:05:43, 71.43s/it]

✅ Saved graph_data_chunks/graph_data_chunk_037.pkl


Processing chunks:  20%|████▏                | 39/194 [46:16<3:06:18, 72.12s/it]

✅ Saved graph_data_chunks/graph_data_chunk_038.pkl


Processing chunks:  21%|████▎                | 40/194 [47:28<3:05:09, 72.14s/it]

✅ Saved graph_data_chunks/graph_data_chunk_039.pkl


Processing chunks:  21%|████▍                | 41/194 [48:41<3:04:02, 72.17s/it]

✅ Saved graph_data_chunks/graph_data_chunk_040.pkl


Processing chunks:  22%|████▌                | 42/194 [49:54<3:03:53, 72.59s/it]

✅ Saved graph_data_chunks/graph_data_chunk_041.pkl


Processing chunks:  22%|████▋                | 43/194 [51:09<3:04:26, 73.29s/it]

✅ Saved graph_data_chunks/graph_data_chunk_042.pkl


Processing chunks:  23%|████▊                | 44/194 [52:24<3:04:29, 73.80s/it]

✅ Saved graph_data_chunks/graph_data_chunk_043.pkl


Processing chunks:  23%|████▊                | 45/194 [53:42<3:06:21, 75.05s/it]

✅ Saved graph_data_chunks/graph_data_chunk_044.pkl


Processing chunks:  24%|████▉                | 46/194 [54:59<3:06:12, 75.49s/it]

✅ Saved graph_data_chunks/graph_data_chunk_045.pkl


Processing chunks:  24%|█████                | 47/194 [56:13<3:03:54, 75.06s/it]

✅ Saved graph_data_chunks/graph_data_chunk_046.pkl


Processing chunks:  25%|█████▏               | 48/194 [57:24<3:00:11, 74.05s/it]

✅ Saved graph_data_chunks/graph_data_chunk_047.pkl


Processing chunks:  25%|█████▎               | 49/194 [58:36<2:57:19, 73.37s/it]

✅ Saved graph_data_chunks/graph_data_chunk_048.pkl


Processing chunks:  26%|█████▍               | 50/194 [59:50<2:56:32, 73.56s/it]

✅ Saved graph_data_chunks/graph_data_chunk_049.pkl


Processing chunks:  26%|████▉              | 51/194 [1:01:04<2:55:22, 73.58s/it]

✅ Saved graph_data_chunks/graph_data_chunk_050.pkl


Processing chunks:  27%|█████              | 52/194 [1:02:18<2:54:26, 73.71s/it]

✅ Saved graph_data_chunks/graph_data_chunk_051.pkl


Processing chunks:  27%|█████▏             | 53/194 [1:03:34<2:55:20, 74.61s/it]

✅ Saved graph_data_chunks/graph_data_chunk_052.pkl


Processing chunks:  28%|█████▎             | 54/194 [1:04:49<2:54:21, 74.73s/it]

✅ Saved graph_data_chunks/graph_data_chunk_053.pkl


Processing chunks:  28%|█████▍             | 55/194 [1:06:06<2:54:40, 75.40s/it]

✅ Saved graph_data_chunks/graph_data_chunk_054.pkl


Processing chunks:  29%|█████▍             | 56/194 [1:07:22<2:53:23, 75.39s/it]

✅ Saved graph_data_chunks/graph_data_chunk_055.pkl


Processing chunks:  29%|█████▌             | 57/194 [1:08:38<2:52:26, 75.52s/it]

✅ Saved graph_data_chunks/graph_data_chunk_056.pkl


Processing chunks:  30%|█████▋             | 58/194 [1:09:55<2:52:35, 76.15s/it]

✅ Saved graph_data_chunks/graph_data_chunk_057.pkl


Processing chunks:  30%|█████▊             | 59/194 [1:11:10<2:50:43, 75.87s/it]

✅ Saved graph_data_chunks/graph_data_chunk_058.pkl


Processing chunks:  31%|█████▉             | 60/194 [1:12:27<2:49:49, 76.04s/it]

✅ Saved graph_data_chunks/graph_data_chunk_059.pkl


Processing chunks:  31%|█████▉             | 61/194 [1:13:40<2:46:38, 75.17s/it]

✅ Saved graph_data_chunks/graph_data_chunk_060.pkl


Processing chunks:  32%|██████             | 62/194 [1:14:56<2:45:37, 75.28s/it]

✅ Saved graph_data_chunks/graph_data_chunk_061.pkl


Processing chunks:  32%|██████▏            | 63/194 [1:16:11<2:44:30, 75.35s/it]

✅ Saved graph_data_chunks/graph_data_chunk_062.pkl


Processing chunks:  33%|██████▎            | 64/194 [1:17:27<2:43:39, 75.53s/it]

✅ Saved graph_data_chunks/graph_data_chunk_063.pkl


Processing chunks:  34%|██████▎            | 65/194 [1:18:39<2:40:08, 74.48s/it]

✅ Saved graph_data_chunks/graph_data_chunk_064.pkl


Processing chunks:  34%|██████▍            | 66/194 [1:19:51<2:37:03, 73.62s/it]

✅ Saved graph_data_chunks/graph_data_chunk_065.pkl


Processing chunks:  35%|██████▌            | 67/194 [1:21:02<2:34:16, 72.88s/it]

✅ Saved graph_data_chunks/graph_data_chunk_066.pkl


Processing chunks:  35%|██████▋            | 68/194 [1:22:13<2:31:53, 72.33s/it]

✅ Saved graph_data_chunks/graph_data_chunk_067.pkl


Processing chunks:  36%|██████▊            | 69/194 [1:23:25<2:30:15, 72.13s/it]

✅ Saved graph_data_chunks/graph_data_chunk_068.pkl


Processing chunks:  36%|██████▊            | 70/194 [1:24:35<2:28:13, 71.72s/it]

✅ Saved graph_data_chunks/graph_data_chunk_069.pkl


Processing chunks:  37%|██████▉            | 71/194 [1:25:46<2:26:08, 71.29s/it]

✅ Saved graph_data_chunks/graph_data_chunk_070.pkl


Processing chunks:  37%|███████            | 72/194 [1:26:56<2:24:40, 71.15s/it]

✅ Saved graph_data_chunks/graph_data_chunk_071.pkl


Processing chunks:  38%|███████▏           | 73/194 [1:28:08<2:23:36, 71.21s/it]

✅ Saved graph_data_chunks/graph_data_chunk_072.pkl


Processing chunks:  38%|███████▏           | 74/194 [1:29:20<2:23:12, 71.61s/it]

✅ Saved graph_data_chunks/graph_data_chunk_073.pkl


Processing chunks:  39%|███████▎           | 75/194 [1:30:31<2:21:25, 71.31s/it]

✅ Saved graph_data_chunks/graph_data_chunk_074.pkl


Processing chunks:  39%|███████▍           | 76/194 [1:31:42<2:19:48, 71.09s/it]

✅ Saved graph_data_chunks/graph_data_chunk_075.pkl


Processing chunks:  40%|███████▌           | 77/194 [1:32:53<2:18:55, 71.25s/it]

✅ Saved graph_data_chunks/graph_data_chunk_076.pkl


Processing chunks:  40%|███████▋           | 78/194 [1:34:04<2:17:26, 71.09s/it]

✅ Saved graph_data_chunks/graph_data_chunk_077.pkl


Processing chunks:  41%|███████▋           | 79/194 [1:35:14<2:15:59, 70.95s/it]

✅ Saved graph_data_chunks/graph_data_chunk_078.pkl


Processing chunks:  41%|███████▊           | 80/194 [1:36:25<2:14:41, 70.89s/it]

✅ Saved graph_data_chunks/graph_data_chunk_079.pkl


Processing chunks:  42%|███████▉           | 81/194 [1:37:36<2:13:36, 70.94s/it]

✅ Saved graph_data_chunks/graph_data_chunk_080.pkl


Processing chunks:  42%|████████           | 82/194 [1:38:47<2:12:04, 70.76s/it]

✅ Saved graph_data_chunks/graph_data_chunk_081.pkl


Processing chunks:  43%|████████▏          | 83/194 [1:39:58<2:11:13, 70.93s/it]

✅ Saved graph_data_chunks/graph_data_chunk_082.pkl


Processing chunks:  43%|████████▏          | 84/194 [1:41:08<2:09:38, 70.71s/it]

✅ Saved graph_data_chunks/graph_data_chunk_083.pkl


Processing chunks:  44%|████████▎          | 85/194 [1:42:18<2:08:12, 70.57s/it]

✅ Saved graph_data_chunks/graph_data_chunk_084.pkl


Processing chunks:  44%|████████▍          | 86/194 [1:43:30<2:07:51, 71.03s/it]

✅ Saved graph_data_chunks/graph_data_chunk_085.pkl


Processing chunks:  45%|████████▌          | 87/194 [1:44:41<2:06:29, 70.93s/it]

✅ Saved graph_data_chunks/graph_data_chunk_086.pkl


Processing chunks:  45%|████████▌          | 88/194 [1:45:52<2:05:17, 70.92s/it]

✅ Saved graph_data_chunks/graph_data_chunk_087.pkl


Processing chunks:  46%|████████▋          | 89/194 [1:47:03<2:03:55, 70.81s/it]

✅ Saved graph_data_chunks/graph_data_chunk_088.pkl


Processing chunks:  46%|████████▊          | 90/194 [1:48:13<2:02:31, 70.69s/it]

✅ Saved graph_data_chunks/graph_data_chunk_089.pkl


Processing chunks:  47%|████████▉          | 91/194 [1:49:23<2:01:04, 70.53s/it]

✅ Saved graph_data_chunks/graph_data_chunk_090.pkl


Processing chunks:  47%|█████████          | 92/194 [1:50:34<2:00:02, 70.61s/it]

✅ Saved graph_data_chunks/graph_data_chunk_091.pkl


Processing chunks:  48%|█████████          | 93/194 [1:51:45<1:58:52, 70.62s/it]

✅ Saved graph_data_chunks/graph_data_chunk_092.pkl


Processing chunks:  48%|█████████▏         | 94/194 [1:52:56<1:57:49, 70.70s/it]

✅ Saved graph_data_chunks/graph_data_chunk_093.pkl


Processing chunks:  49%|█████████▎         | 95/194 [1:54:06<1:56:46, 70.78s/it]

✅ Saved graph_data_chunks/graph_data_chunk_094.pkl


Processing chunks:  49%|█████████▍         | 96/194 [1:55:17<1:55:20, 70.62s/it]

✅ Saved graph_data_chunks/graph_data_chunk_095.pkl


Processing chunks:  50%|█████████▌         | 97/194 [1:56:29<1:54:47, 71.01s/it]

✅ Saved graph_data_chunks/graph_data_chunk_096.pkl


Processing chunks:  51%|█████████▌         | 98/194 [1:57:40<1:53:38, 71.03s/it]

✅ Saved graph_data_chunks/graph_data_chunk_097.pkl


Processing chunks:  51%|█████████▋         | 99/194 [1:58:51<1:52:25, 71.01s/it]

✅ Saved graph_data_chunks/graph_data_chunk_098.pkl


Processing chunks:  52%|█████████▎        | 100/194 [2:00:01<1:50:50, 70.75s/it]

✅ Saved graph_data_chunks/graph_data_chunk_099.pkl


Processing chunks:  52%|█████████▎        | 101/194 [2:01:11<1:49:29, 70.64s/it]

✅ Saved graph_data_chunks/graph_data_chunk_100.pkl


Processing chunks:  53%|█████████▍        | 102/194 [2:02:21<1:48:09, 70.53s/it]

✅ Saved graph_data_chunks/graph_data_chunk_101.pkl


Processing chunks:  53%|█████████▌        | 103/194 [2:03:32<1:46:54, 70.49s/it]

✅ Saved graph_data_chunks/graph_data_chunk_102.pkl


Processing chunks:  54%|█████████▋        | 104/194 [2:04:42<1:45:41, 70.46s/it]

✅ Saved graph_data_chunks/graph_data_chunk_103.pkl


Processing chunks:  54%|█████████▋        | 105/194 [2:05:53<1:44:28, 70.43s/it]

✅ Saved graph_data_chunks/graph_data_chunk_104.pkl


Processing chunks:  55%|█████████▊        | 106/194 [2:07:03<1:43:25, 70.51s/it]

✅ Saved graph_data_chunks/graph_data_chunk_105.pkl


Processing chunks:  55%|█████████▉        | 107/194 [2:08:14<1:42:28, 70.67s/it]

✅ Saved graph_data_chunks/graph_data_chunk_106.pkl


Processing chunks:  56%|██████████        | 108/194 [2:09:25<1:41:03, 70.51s/it]

✅ Saved graph_data_chunks/graph_data_chunk_107.pkl


Processing chunks:  56%|██████████        | 109/194 [2:10:35<1:39:51, 70.49s/it]

✅ Saved graph_data_chunks/graph_data_chunk_108.pkl


Processing chunks:  57%|██████████▏       | 110/194 [2:11:45<1:38:38, 70.46s/it]

✅ Saved graph_data_chunks/graph_data_chunk_109.pkl


Processing chunks:  57%|██████████▎       | 111/194 [2:12:56<1:37:43, 70.64s/it]

✅ Saved graph_data_chunks/graph_data_chunk_110.pkl


Processing chunks:  58%|██████████▍       | 112/194 [2:14:07<1:36:32, 70.64s/it]

✅ Saved graph_data_chunks/graph_data_chunk_111.pkl


Processing chunks:  58%|██████████▍       | 113/194 [2:15:18<1:35:20, 70.62s/it]

✅ Saved graph_data_chunks/graph_data_chunk_112.pkl


Processing chunks:  59%|██████████▌       | 114/194 [2:16:28<1:33:59, 70.49s/it]

✅ Saved graph_data_chunks/graph_data_chunk_113.pkl


Processing chunks:  59%|██████████▋       | 115/194 [2:17:39<1:32:55, 70.58s/it]

✅ Saved graph_data_chunks/graph_data_chunk_114.pkl


Processing chunks:  60%|██████████▊       | 116/194 [2:18:49<1:31:44, 70.57s/it]

✅ Saved graph_data_chunks/graph_data_chunk_115.pkl


Processing chunks:  60%|██████████▊       | 117/194 [2:20:00<1:30:34, 70.58s/it]

✅ Saved graph_data_chunks/graph_data_chunk_116.pkl


Processing chunks:  61%|██████████▉       | 118/194 [2:21:11<1:29:41, 70.81s/it]

✅ Saved graph_data_chunks/graph_data_chunk_117.pkl


Processing chunks:  61%|███████████       | 119/194 [2:22:21<1:28:19, 70.66s/it]

✅ Saved graph_data_chunks/graph_data_chunk_118.pkl


Processing chunks:  62%|███████████▏      | 120/194 [2:23:32<1:27:02, 70.57s/it]

✅ Saved graph_data_chunks/graph_data_chunk_119.pkl


Processing chunks:  62%|███████████▏      | 121/194 [2:24:42<1:25:47, 70.51s/it]

✅ Saved graph_data_chunks/graph_data_chunk_120.pkl


Processing chunks:  63%|███████████▎      | 122/194 [2:25:52<1:24:27, 70.38s/it]

✅ Saved graph_data_chunks/graph_data_chunk_121.pkl


Processing chunks:  63%|███████████▍      | 123/194 [2:27:03<1:23:19, 70.41s/it]

✅ Saved graph_data_chunks/graph_data_chunk_122.pkl


Processing chunks:  64%|███████████▌      | 124/194 [2:28:13<1:22:11, 70.45s/it]

✅ Saved graph_data_chunks/graph_data_chunk_123.pkl


Processing chunks:  64%|███████████▌      | 125/194 [2:29:24<1:20:59, 70.43s/it]

✅ Saved graph_data_chunks/graph_data_chunk_124.pkl


Processing chunks:  65%|███████████▋      | 126/194 [2:30:34<1:19:50, 70.45s/it]

✅ Saved graph_data_chunks/graph_data_chunk_125.pkl


Processing chunks:  65%|███████████▊      | 127/194 [2:31:45<1:18:43, 70.50s/it]

✅ Saved graph_data_chunks/graph_data_chunk_126.pkl


Processing chunks:  66%|███████████▉      | 128/194 [2:32:55<1:17:21, 70.33s/it]

✅ Saved graph_data_chunks/graph_data_chunk_127.pkl


Processing chunks:  66%|███████████▉      | 129/194 [2:34:05<1:16:14, 70.38s/it]

✅ Saved graph_data_chunks/graph_data_chunk_128.pkl


Processing chunks:  67%|████████████      | 130/194 [2:35:16<1:15:22, 70.66s/it]

✅ Saved graph_data_chunks/graph_data_chunk_129.pkl


Processing chunks:  68%|████████████▏     | 131/194 [2:36:27<1:14:00, 70.48s/it]

✅ Saved graph_data_chunks/graph_data_chunk_130.pkl


Processing chunks:  68%|████████████▏     | 132/194 [2:37:37<1:12:57, 70.61s/it]

✅ Saved graph_data_chunks/graph_data_chunk_131.pkl


Processing chunks:  69%|████████████▎     | 133/194 [2:38:48<1:11:42, 70.54s/it]

✅ Saved graph_data_chunks/graph_data_chunk_132.pkl


Processing chunks:  69%|████████████▍     | 134/194 [2:39:58<1:10:25, 70.43s/it]

✅ Saved graph_data_chunks/graph_data_chunk_133.pkl


Processing chunks:  70%|████████████▌     | 135/194 [2:41:10<1:09:42, 70.88s/it]

✅ Saved graph_data_chunks/graph_data_chunk_134.pkl


In [7]:
#Preprocess ion mode, precursor m/z, and adducts
#Full Processing of df_external
import time
import pickle

start_time = time.time()

# Compute ion_mode
df_external['ion_mode'] = df_external['adduct'].parallel_apply(
    lambda x: 0 if '+' in str(x) else 1 if '-' in str(x) else 0
).fillna(0)

# Compute precursor_bin
df_external['precursor_bin'] = pd.qcut(
    df_external['precursor_mz'], q=100, labels=False, duplicates='drop'
)

# Map adduct to index
df_external['adduct_idx'] = df_external['adduct'].map(adduct_to_idx)

# Generate binned and graph_data columns
df_external[['binned', 'graph_data']] = df_external.parallel_apply(
    lambda row: pd.Series(bin_spectrum_to_graph(
        row['mzs'], row['intensities'], row['ion_mode'],
        row['precursor_mz'], row['adduct']
    )),
    axis=1
)

# 🔒 Save graph_data separately (optional)
with open("df_external_graph_data.pkl", "wb") as f:
    pickle.dump(df_external['graph_data'].tolist(), f)

# ❌ Drop graph_data column before saving to parquet
df_external.drop(columns=['graph_data'], inplace=True)

# ✅ Save remaining data to Parquet
df_external.to_parquet("df_external_processed.parquet")

print("✅ df_external processed and saved in {:.2f} seconds".format(time.time() - start_time))


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=1445), Label(value='0 / 1445'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=1445), Label(value='0 / 1445'))), …

✅ df_external processed and saved in 33.56 seconds


In [3]:
#Preprocess ion mode, precursor m/z, and adducts
import pandas as pd
import glob
import pyarrow as pa
import pyarrow.parquet as pq
import time

output_path = "df_massspecgym_processed_full.parquet"
chunk_files = sorted(glob.glob("processed_chunks/df_massspecgym_chunk_*.parquet"))

first_df = pd.read_parquet(chunk_files[0])
first_df['precursor_bin'] = first_df['precursor_bin'].fillna(-1).astype('int64')

table = pa.Table.from_pandas(first_df)
writer = pq.ParquetWriter(output_path, table.schema)

writer.write_table(table)
print(f"✅ Wrote chunk 0 / {len(chunk_files)}")

for i, file in enumerate(chunk_files[1:], start=1):
    start = time.time()
    df = pd.read_parquet(file)

    df['precursor_bin'] = df['precursor_bin'].fillna(-1).astype('int64')
    df = df[first_df.columns]

    table = pa.Table.from_pandas(df)
    writer.write_table(table)
    print(f"✅ Wrote chunk {i} / {len(chunk_files)} in {time.time() - start:.2f}s")

writer.close()
print(f"🎉 Done merging {len(chunk_files)} chunks ➜ {output_path}")


✅ Wrote chunk 0 / 194
✅ Wrote chunk 1 / 194 in 6.07s
✅ Wrote chunk 2 / 194 in 6.17s
✅ Wrote chunk 3 / 194 in 6.10s
✅ Wrote chunk 4 / 194 in 6.05s
✅ Wrote chunk 5 / 194 in 6.10s
✅ Wrote chunk 6 / 194 in 5.94s
✅ Wrote chunk 7 / 194 in 5.91s
✅ Wrote chunk 8 / 194 in 6.04s
✅ Wrote chunk 9 / 194 in 5.93s
✅ Wrote chunk 10 / 194 in 5.93s
✅ Wrote chunk 11 / 194 in 6.01s
✅ Wrote chunk 12 / 194 in 6.34s
✅ Wrote chunk 13 / 194 in 7.31s
✅ Wrote chunk 14 / 194 in 6.02s
✅ Wrote chunk 15 / 194 in 6.03s
✅ Wrote chunk 16 / 194 in 5.95s
✅ Wrote chunk 17 / 194 in 5.95s
✅ Wrote chunk 18 / 194 in 6.70s
✅ Wrote chunk 19 / 194 in 6.01s
✅ Wrote chunk 20 / 194 in 6.07s
✅ Wrote chunk 21 / 194 in 5.99s
✅ Wrote chunk 22 / 194 in 5.92s
✅ Wrote chunk 23 / 194 in 6.16s
✅ Wrote chunk 24 / 194 in 5.94s
✅ Wrote chunk 25 / 194 in 5.85s
✅ Wrote chunk 26 / 194 in 5.92s
✅ Wrote chunk 27 / 194 in 5.90s
✅ Wrote chunk 28 / 194 in 5.87s
✅ Wrote chunk 29 / 194 in 7.08s
✅ Wrote chunk 30 / 194 in 5.94s
✅ Wrote chunk 31 / 194 in 5

In [5]:
# SMILES Tokenization with Stereochemistry

import pyarrow.parquet as pq
import pyarrow as pa
import time 
# Special tokens
PAD_TOKEN = "<PAD>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"
MASK_TOKEN = "<MASK>"

# Open the large Parquet file (91 GB)
parquet_file = pq.ParquetFile("df_massspecgym_processed_full.parquet")

start = time.time()
# Step 1: Build vocabulary from all SMILES across row groups
unique_chars = set()
for i in range(parquet_file.num_row_groups):
    table = parquet_file.read_row_group(i, columns=["smiles"])
    df = table.to_pandas()
    unique_chars |= set(''.join(df['smiles'].dropna().astype(str).tolist()))

# Define vocabulary
valid_atoms = {'C', 'N', 'O', 'S', 'P', 'F', 'Cl', 'Br', 'I', 'H'}
tokens = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, MASK_TOKEN] + sorted(unique_chars - {PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, MASK_TOKEN})
token_to_idx = {tok: i for i, tok in enumerate(tokens) if tok in valid_atoms or tok in {PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, MASK_TOKEN, '(', ')', '=', '#', '@', '[', ']', '/', '\\', '.', ':'}}
idx_to_token = {i: tok for tok, i in token_to_idx.items()}
vocab_size = len(token_to_idx)

# Determine supervised max length (longest SMILES + 2 for SOS and EOS)
SUPERVISED_MAX_LEN = 0
for i in range(parquet_file.num_row_groups):
    table = parquet_file.read_row_group(i, columns=["smiles"])
    df = table.to_pandas()
    max_len = max(len(s) for s in df['smiles'].dropna().astype(str))
    SUPERVISED_MAX_LEN = max(SUPERVISED_MAX_LEN, max_len + 2)

PRETRAIN_MAX_LEN = 100

print(f"✅ Vocabulary size: {vocab_size}, Supervised MAX_LEN: {SUPERVISED_MAX_LEN}, Pretrain MAX_LEN: {PRETRAIN_MAX_LEN}")
print(f"Completed in {time.time() - start:.2f}s")
# Step 2: Define encoder function
start = time.time()
def encode_smiles(smiles, max_len=PRETRAIN_MAX_LEN):
    tokens = [SOS_TOKEN] + [c for c in smiles[:max_len - 2] if c in token_to_idx] + [EOS_TOKEN]
    token_ids = [token_to_idx.get(tok, token_to_idx[PAD_TOKEN]) for tok in tokens]
    if len(token_ids) > max_len:
        return token_ids[:max_len]
    else:
        return token_ids + [token_to_idx[PAD_TOKEN]] * (max_len - len(token_ids))
print(f"Completed in {time.time() - start:.2f}s")


✅ Vocabulary size: 21, Supervised MAX_LEN: 272, Pretrain MAX_LEN: 100
Completed in 32.10s
Completed in 0.00s


In [3]:
# Precompute Morgan fingerprints
import pandas as pd
import pyarrow.parquet as pq
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator
from joblib import Parallel, delayed
import pickle
import time

start = time.time()

# Load df_massspecgym from large Parquet file (only SMILES column)
massspec_parquet = pq.ParquetFile("df_massspecgym_processed_full.parquet")
df_massspecgym = pd.concat([
    massspec_parquet.read_row_group(i, columns=["smiles"]).to_pandas()
    for i in range(massspec_parquet.num_row_groups)
], ignore_index=True)

# Load df_external from smaller Parquet file
df_external = pd.read_parquet("df_external_processed.parquet")

# Combine and deduplicate SMILES
all_smiles = list(set(df_massspecgym['smiles'].dropna().tolist() + df_external['smiles'].dropna().tolist()))

# Function that avoids unpicklable generator
def fingerprint_one(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            generator = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048)
            return smiles, generator.GetFingerprint(mol)
    except Exception as e:
        print(f"Failed: {smiles} → {e}")
    return smiles, None

# Run parallel computation with 12 CPUs
results = Parallel(n_jobs=12, verbose=5)(
    delayed(fingerprint_one)(s) for s in all_smiles
)

# Collect into dictionary
all_fingerprints = {s: fp for s, fp in results if fp is not None}

# Save to file
with open("all_morgan_fingerprints.pkl", "wb") as f:
    pickle.dump(all_fingerprints, f)

print(f"✅ Done in {time.time() - start:.2f}s — {len(all_fingerprints)} fingerprints saved to all_morgan_fingerprints.pkl")


[Parallel(n_jobs=12)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=12)]: Done  93 tasks      | elapsed:    0.5s
[Parallel(n_jobs=12)]: Done 1320 tasks      | elapsed:    0.9s
[Parallel(n_jobs=12)]: Done 49128 tasks      | elapsed:    2.5s
[Parallel(n_jobs=12)]: Done 215016 tasks      | elapsed:    7.7s
[Parallel(n_jobs=12)]: Done 417768 tasks      | elapsed:   13.6s
[Parallel(n_jobs=12)]: Done 657384 tasks      | elapsed:   20.5s
[Parallel(n_jobs=12)]: Done 933864 tasks      | elapsed:   28.4s
[Parallel(n_jobs=12)]: Done 1247208 tasks      | elapsed:   37.5s
[Parallel(n_jobs=12)]: Done 1597416 tasks      | elapsed:   47.1s
[Parallel(n_jobs=12)]: Done 1984488 tasks      | elapsed:   57.6s
[Parallel(n_jobs=12)]: Done 2408424 tasks      | elapsed:  1.2min
[Parallel(n_jobs=12)]: Done 2869224 tasks      | elapsed:  1.4min
[Parallel(n_jobs=12)]: Done 3366888 tasks      | elapsed:  1.6min
[Parallel(n_jobs=12)]: Done 3901416 tasks      | elapsed:  1.9min
[Parallel(n_

✅ Done in 667.19s — 19320594 fingerprints saved to all_morgan_fingerprints.pkl


In [6]:
#Dataset class
class MSMSDataset(Dataset):
    def __init__(self, dataframe, max_len=PRETRAIN_MAX_LEN, is_ssl=False):
        self.spectra = np.stack(dataframe['binned'].values)
        self.graph_data = dataframe['graph_data'].values
        self.ion_modes = dataframe['ion_mode'].values
        self.precursor_bins = dataframe['precursor_bin'].values
        self.adduct_indices = dataframe['adduct_idx'].values
        self.raw_smiles = dataframe['smiles'].values
        self.is_ssl = is_ssl
        if is_ssl:
            self.smiles = []
            self.masked_smiles = []
            for s in self.raw_smiles:
                masked_s, orig_s = self.mask_smiles(s)
                self.smiles.append(encode_smiles(orig_s, max_len))
                self.masked_smiles.append(encode_smiles(masked_s, max_len))
        else:
            self.smiles = [encode_smiles(s, max_len=SUPERVISED_MAX_LEN) for s in self.raw_smiles]

    def mask_smiles(self, smiles, mask_ratio=0.10):
        chars = list(smiles)[:PRETRAIN_MAX_LEN-2]
        masked_chars = chars.copy()
        n_mask = int(mask_ratio * len(chars))
        mask_indices = np.random.choice(len(chars), n_mask, replace=False)
        for idx in mask_indices:
            masked_chars[idx] = MASK_TOKEN
        return ''.join(masked_chars), ''.join(chars)

    def __len__(self):
        return len(self.spectra)

    def __getitem__(self, idx):
        if self.is_ssl:
            return (
                torch.tensor(self.spectra[idx], dtype=torch.float),
                self.graph_data[idx],
                torch.tensor(self.smiles[idx], dtype=torch.long),
                torch.tensor(self.masked_smiles[idx], dtype=torch.long),
                torch.tensor(self.ion_modes[idx], dtype=torch.long),
                torch.tensor(self.precursor_bins[idx], dtype=torch.long),
                torch.tensor(self.adduct_indices[idx], dtype=torch.long),
                self.raw_smiles[idx]
            )
        return (
            torch.tensor(self.spectra[idx], dtype=torch.float),
            self.graph_data[idx],
            torch.tensor(self.smiles[idx], dtype=torch.long),
            torch.tensor(self.ion_modes[idx], dtype=torch.long),
            torch.tensor(self.precursor_bins[idx], dtype=torch.long),
            torch.tensor(self.adduct_indices[idx], dtype=torch.long),
            self.raw_smiles[idx]
        )

In [7]:
# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]


In [8]:
# Transformer Encoder
class SpectrumTransformerEncoder(nn.Module):
    def __init__(self, input_dim=1000, d_model=768, nhead=12, num_layers=8, dim_feedforward=2048, dropout=0.2):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, d_model)
        self.metadata_emb = nn.Linear(2 + 32, 64)
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.norm = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model + 64, d_model // 2)
        self.adduct_emb = nn.Embedding(len(adduct_types), 32)

    def forward(self, src, ion_mode_idx, precursor_idx, adduct_idx):
        src = self.input_proj(src).unsqueeze(1)
        adduct_embed = self.adduct_emb(adduct_idx)
        metadata = self.metadata_emb(torch.cat([ion_mode_idx.unsqueeze(-1).float(), precursor_idx.unsqueeze(-1).float(), adduct_embed], dim=-1))
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src).squeeze(1)
        output = self.norm(output)
        output = torch.cat([output, metadata], dim=-1)
        output = self.fc(output)
        return output, self.transformer_encoder.layers[-1].self_attn(src, src, src)[1]

In [9]:
# GNN Encoder with Expanded Substructures
class SpectrumGNNEncoder(MessagePassing):
    def __init__(self, d_model=768, hidden_dim=256, num_layers=3, dropout=0.2):
        super().__init__(aggr='mean')
        self.d_model = d_model
        self.num_layers = num_layers
        self.input_proj = nn.Linear(1, hidden_dim)
        self.message_nets = nn.ModuleList([nn.Linear(2 * hidden_dim, hidden_dim) for _ in range(num_layers)])
        self.update_nets = nn.ModuleList([nn.Linear(2 * hidden_dim, hidden_dim) for _ in range(num_layers)])
        self.metadata_emb = nn.Linear(2 + 32, hidden_dim)
        self.norm = nn.LayerNorm(hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, d_model // 2)
        self.dropout = nn.Dropout(dropout)
        self.substructure_head = nn.Linear(hidden_dim, 30)  # 30 substructures
        self.adduct_emb = nn.Embedding(len(adduct_types), 32)
        self.substructures = ['C=O', 'C=C', 'c1ccccc1', 'C#N', 'C(=O)O', 'N=O', 'S=O', 'P=O', 'C#C', 'C-N-C',
                              'C-O-C', 'C-S-C', 'C(=O)N', 'C(=O)S', 'C=C-C', 'c1ccncc1', 'c1cncnc1', 'c1ccoc1',
                              'c1ccsc1', 'C(=O)C', 'N-C-N', 'S-C-S', 'P-C-P', 'C-F', 'C-Cl', 'C-Br', 'C-I', 'N-N',
                              'O-O', 'S-S']

    def forward(self, graph_data, ion_mode_idx, precursor_idx, adduct_idx):
        batch = Batch.from_data_list(graph_data).to(device)
        x, edge_index = batch.x, batch.edge_index
        ion_mode = batch.ion_mode
        precursor_mz = batch.precursor_mz
        adduct_embed = self.adduct_emb(adduct_idx)

        x = self.input_proj(x)
        metadata = self.metadata_emb(torch.cat([ion_mode.unsqueeze(-1), precursor_mz.unsqueeze(-1), adduct_embed], dim=-1))

        edge_weights = []
        for i in range(self.num_layers):
            self._propagate_layer = i
            x_before = x.clone()
            x = self.propagate(edge_index, x=x)
            x = self.update_nets[i](torch.cat([x, metadata], dim=-1))
            x = self.norm(x)
            x = F.relu(x)
            x = self.dropout(x)
            edge_weights.append((x - x_before).norm(dim=-1))

        x = global_mean_pool(x, batch.batch)
        substructure_pred = self.substructure_head(x)
        x = self.output_layer(x)
        return x, substructure_pred, edge_weights

    def message(self, x_i, x_j):
        return F.relu(self.message_nets[self._propagate_layer](torch.cat([x_i, x_j], dim=-1)))


In [10]:
# Novel Decoder with Stereochemistry and Substructure Guidance
class SmilesTransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=768, nhead=12, num_layers=8, dim_feedforward=2048, dropout=0.2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.norm = nn.LayerNorm(d_model)
        self.output_layer = nn.Linear(d_model, vocab_size)
        self.d_model = d_model
        self.valence_rules = {
            'C': 4, 'N': 3, 'O': 2, 'S': 2, 'P': 3, 'F': 1, 'Cl': 1, 'Br': 1, 'I': 1, 'H': 1
        }
        self.stereo_tokens = {'@', '/'}
        self.substructure_condition = nn.Linear(30, d_model)

    def compute_valence(self, smiles_tokens, batch_size):
        valence_counts = torch.zeros(batch_size, len(self.valence_rules)).to(smiles_tokens.device)
        atom_indices = {tok: i for i, tok in enumerate(self.valence_rules.keys())}
        bond_counts = torch.zeros(batch_size, device=smiles_tokens.device)
        for t in range(smiles_tokens.size(1)):
            for tok, idx in atom_indices.items():
                mask = smiles_tokens[:, t] == token_to_idx[tok]
                valence_counts[mask, idx] += self.valence_rules[tok]
            for tok in ['=', '#']:
                mask = smiles_tokens[:, t] == token_to_idx[tok]
                bond_counts[mask] += 2 if tok == '#' else 1
        valence_counts = valence_counts - bond_counts.unsqueeze(-1)
        return torch.relu(valence_counts - torch.tensor(list(self.valence_rules.values()), device=smiles_tokens.device)).sum(dim=-1)

    def forward(self, tgt, memory, substructure_pred, tgt_mask=None, memory_key_padding_mask=None):
        embedded = self.embedding(tgt) * math.sqrt(self.d_model)
        embedded = self.pos_encoder(embedded)
        substructure_emb = self.substructure_condition(substructure_pred).unsqueeze(1)
        embedded = embedded + substructure_emb
        output = self.transformer_decoder(embedded, memory, tgt_mask, memory_key_padding_mask)
        output = self.norm(output)
        logits = self.output_layer(output)
        valence_penalty = self.compute_valence(tgt, tgt.size(0))
        return logits, valence_penalty

In [11]:
# Full Model with RL Component
class MSMS2SmilesHybrid(nn.Module):
    def __init__(self, vocab_size, d_model=768, nhead=12, num_layers=8, dim_feedforward=2048, dropout=0.2, fp_size=2048):
        super().__init__()
        self.transformer_encoder = SpectrumTransformerEncoder(input_dim=1000, d_model=d_model, nhead=nhead, num_layers=num_layers, dim_feedforward=dim_feedforward, dropout=dropout)
        self.gnn_encoder = SpectrumGNNEncoder(d_model=d_model, hidden_dim=256, num_layers=3, dropout=dropout)
        self.decoder = SmilesTransformerDecoder(vocab_size, d_model, nhead, num_layers, dim_feedforward, dropout)
        self.combine_layer = nn.Linear(d_model, d_model)
        self.fp_head = nn.Linear(d_model, fp_size)
        self.fp_size = fp_size
        self.log_sigma_smiles = nn.Parameter(torch.zeros(1))
        self.log_sigma_fp = nn.Parameter(torch.zeros(1))
        self.log_sigma_sub = nn.Parameter(torch.zeros(1))

    def generate_square_subsequent_mask(self, tgt_len):
        mask = torch.triu(torch.ones(tgt_len, tgt_len), diagonal=1)
        mask = mask.float().masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, float(0.0))
        return mask

    def forward(self, spectrum, graph_data, tgt, ion_mode_idx, precursor_idx, adduct_idx, tgt_mask=None, memory_key_padding_mask=None):
        trans_output, attn_weights = self.transformer_encoder(spectrum, ion_mode_idx, precursor_idx, adduct_idx)
        gnn_output, substructure_pred, edge_weights = self.gnn_encoder(graph_data, ion_mode_idx, precursor_idx, adduct_idx)
        memory = self.combine_layer(torch.cat([trans_output, gnn_output], dim=-1)).unsqueeze(1)
        smiles_output, valence_penalty = self.decoder(tgt, memory, substructure_pred, tgt_mask, memory_key_padding_mask)
        fp_output = self.fp_head(memory.squeeze(1))
        return smiles_output, fp_output, valence_penalty, attn_weights, edge_weights, substructure_pred

In [12]:
# SSL Pretraining
def ssl_pretrain(model, dataloader, epochs=3, lr=1e-4):
    model.train()
    scaler = GradScaler()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=token_to_idx[PAD_TOKEN])
    for epoch in range(epochs):
        total_loss = 0
        for spectra, graph_data, smiles_tokens, masked_tokens, ion_modes, precursor_bins, adduct_indices, _ in tqdm(dataloader, desc=f"SSL Epoch {epoch+1}/{epochs}"):
            spectra = spectra.to(device)
            ion_modes = ion_modes.to(device)
            precursor_bins = precursor_bins.to(device)
            adduct_indices = adduct_indices.to(device)
            smiles_tokens = smiles_tokens.to(device)
            masked_tokens = masked_tokens.to(device)
            tgt_input = masked_tokens[:, :-1]
            tgt_output = smiles_tokens[:, 1:]
            tgt_mask = model.generate_square_subsequent_mask(tgt_input.size(1)).to(device)
            optimizer.zero_grad()
            with autocast():
                smiles_output, _, valence_penalty, _, _, _ = model(spectra, graph_data, tgt_input, ion_modes, precursor_bins, adduct_indices, tgt_mask)
                loss = criterion(smiles_output.reshape(-1, vocab_size), tgt_output.reshape(-1)) + 0.1 * valence_penalty.mean()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        print(f"SSL Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss
        }, f'ssl_checkpoint_epoch_{epoch+1}.pt')
        print(f"Saved SSL checkpoint: ssl_checkpoint_epoch_{epoch+1}.pt")


In [13]:
# Supervised Training with RL
def supervised_train(model, train_loader, val_loader, epochs=30, lr=1e-4, patience=5):
    model.train()
    scaler = GradScaler()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
    smiles_criterion = nn.CrossEntropyLoss(ignore_index=token_to_idx[PAD_TOKEN])
    fp_criterion = nn.BCEWithLogitsLoss()
    mw_criterion = nn.MSELoss()
    sub_criterion = nn.BCEWithLogitsLoss()
    best_val_loss = float('inf')
    no_improve = 0

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        for spectra, graph_data, smiles_tokens, ion_modes, precursor_bins, adduct_indices, raw_smiles in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            spectra = spectra.to(device)
            ion_modes = ion_modes.to(device)
            precursor_bins = precursor_bins.to(device)
            adduct_indices = adduct_indices.to(device)
            smiles_tokens = smiles_tokens.to(device)
            tgt_input = smiles_tokens[:, :-1]
            tgt_output = smiles_tokens[:, 1:]
            tgt_mask = model.generate_square_subsequent_mask(tgt_input.size(1)).to(device)
            optimizer.zero_grad()
            with autocast():
                smiles_output, fp_output, valence_penalty, _, _, substructure_pred = model(spectra, graph_data, tgt_input, ion_modes, precursor_bins, adduct_indices, tgt_mask)
                smiles_loss = smiles_criterion(smiles_output.reshape(-1, vocab_size), tgt_output.reshape(-1))
                fp_loss = 0
                mw_loss = 0
                sub_loss = 0
                valid_count = 0
                substructure_targets = torch.zeros(len(raw_smiles), 30, dtype=torch.float, device=device)
                for i, (smiles, fp) in enumerate(zip(raw_smiles, fp_output)):
                    mol = Chem.MolFromSmiles(smiles, sanitize=True)
                    if mol:
                        true_fp = morgan_gen.GetFingerprint(mol)
                        fp_loss += fp_criterion(fp, torch.tensor([int(b) for b in true_fp.ToBitString()], dtype=torch.float, device=device))
                        mw_loss += mw_criterion(torch.tensor(Descriptors.MolWt(mol), dtype=torch.float, device=device), torch.tensor(500.0, dtype=torch.float, device=device))
                        for j, smarts in enumerate(model.gnn_encoder.substructures):
                            if mol.HasSubstructMatch(Chem.MolFromSmarts(smarts)):
                                substructure_targets[i, j] = 1
                        valid_count += 1
                fp_loss = fp_loss / valid_count if valid_count > 0 else torch.tensor(0.0, device=device)
                mw_loss = mw_loss / valid_count if valid_count > 0 else torch.tensor(0.0, device=device)
                sub_loss = sub_criterion(substructure_pred, substructure_targets)
                sigma_smiles = torch.clamp(torch.exp(model.log_sigma_smiles), 0.1, 10.0)
                sigma_fp = torch.clamp(torch.exp(model.log_sigma_fp), 0.1, 10.0)
                sigma_sub = torch.clamp(torch.exp(model.log_sigma_sub), 0.1, 10.0)
                supervised_loss = (smiles_loss / (2 * sigma_smiles**2) + model.log_sigma_smiles) + \
                                 (0.1 * fp_loss / (2 * sigma_fp**2) + model.log_sigma_fp) + \
                                 (0.1 * sub_loss / (2 * sigma_sub**2) + model.log_sigma_sub) + \
                                 0.1 * valence_penalty.mean() + 0.1 * mw_loss
                # RL component: Tanimoto reward
                rl_loss = 0
                if epoch >= 5:  # Start RL after initial training
                    pred_smiles = beam_search(model, spectra[0], graph_data[0], ion_modes[0], precursor_bins[0], adduct_indices[0], raw_smiles[0], beam_width=5, max_len=SUPERVISED_MAX_LEN, device=device)
                    if pred_smiles[0][0] != "Invalid SMILES":
                        tanimoto = tanimoto_similarity(pred_smiles[0][0], raw_smiles[0], all_fingerprints)
                        rl_loss = -torch.log(torch.tensor(tanimoto + 1e-6, device=device))
                loss = supervised_loss + 0.1 * rl_loss
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            total_train_loss += loss.item()
        avg_train_loss = total_train_loss / len(train_loader)

        # Validation
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for spectra, graph_data, smiles_tokens, ion_modes, precursor_bins, adduct_indices, raw_smiles in val_loader:
                spectra = spectra.to(device)
                ion_modes = ion_modes.to(device)
                precursor_bins = precursor_bins.to(device)
                adduct_indices = adduct_indices.to(device)
                smiles_tokens = smiles_tokens.to(device)
                tgt_input = smiles_tokens[:, :-1]
                tgt_output = smiles_tokens[:, 1:]
                tgt_mask = model.generate_square_subsequent_mask(tgt_input.size(1)).to(device)
                with autocast():
                    smiles_output, fp_output, valence_penalty, _, _, substructure_pred = model(spectra, graph_data, tgt_input, ion_modes, precursor_bins, adduct_indices, tgt_mask)
                    smiles_loss = smiles_criterion(smiles_output.reshape(-1, vocab_size), tgt_output.reshape(-1))
                    fp_loss = 0
                    mw_loss = 0
                    sub_loss = 0
                    valid_count = 0
                    substructure_targets = torch.zeros(len(raw_smiles), 30, dtype=torch.float, device=device)
                    for i, (smiles, fp) in enumerate(zip(raw_smiles, fp_output)):
                        mol = Chem.MolFromSmiles(smiles, sanitize=True)
                        if mol:
                            true_fp = morgan_gen.GetFingerprint(mol)
                            fp_loss += fp_criterion(fp, torch.tensor([int(b) for b in true_fp.ToBitString()], dtype=torch.float, device=device))
                            mw_loss += mw_criterion(torch.tensor(Descriptors.MolWt(mol), dtype=torch.float, device=device), torch.tensor(500.0, dtype=torch.float, device=device))
                            for j, smarts in enumerate(model.gnn_encoder.substructures):
                                if mol.HasSubstructMatch(Chem.MolFromSmarts(smarts)):
                                    substructure_targets[i, j] = 1
                            valid_count += 1
                    fp_loss = fp_loss / valid_count if valid_count > 0 else torch.tensor(0.0, device=device)
                    mw_loss = mw_loss / valid_count if valid_count > 0 else torch.tensor(0.0, device=device)
                    sub_loss = sub_criterion(substructure_pred, substructure_targets)
                    sigma_smiles = torch.clamp(torch.exp(model.log_sigma_smiles), 0.1, 10.0)
                    sigma_fp = torch.clamp(torch.exp(model.log_sigma_fp), 0.1, 10.0)
                    sigma_sub = torch.clamp(torch.exp(model.log_sigma_sub), 0.1, 10.0)
                    loss = (smiles_loss / (2 * sigma_smiles**2) + model.log_sigma_smiles) + \
                           (0.1 * fp_loss / (2 * sigma_fp**2) + model.log_sigma_fp) + \
                           (0.1 * sub_loss / (2 * sigma_sub**2) + model.log_sigma_sub) + \
                           0.1 * valence_penalty.mean() + 0.1 * mw_loss
                total_val_loss += loss.item()
        avg_val_loss = total_val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")
        scheduler.step(avg_val_loss)

        if (epoch + 1) % 10 == 0:
            checkpoint_path = f'checkpoint_epoch_{epoch+1}.pt'
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_loss': avg_val_loss,
                'token_to_idx': token_to_idx,
                'idx_to_token': idx_to_token
            }, checkpoint_path)
            print(f"Saved checkpoint: {checkpoint_path}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            no_improve = 0
            torch.save({
                'model_state_dict': model.state_dict(),
                'token_to_idx': token_to_idx,
                'idx_to_token': idx_to_token
            }, 'best_msms_hybrid.pt')
        else:
            no_improve += 1
        if no_improve >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    return best_val_loss

In [14]:
# SMILES Syntax Validator
def is_valid_smiles_syntax(smiles):
    stack = []
    for c in smiles:
        if c in '([':
            stack.append(c)
        elif c == ')':
            if not stack or stack[-1] != '(':
                return False
            stack.pop()
        elif c == ']':
            if not stack or stack[-1] != '[':
                return False
            stack.pop()
    if stack:
        return False
    i = 0
    while i < len(smiles):
        if smiles[i] == '[':
            j = smiles.find(']', i)
            if j == -1:
                return False
            atom = smiles[i+1:j]
            if not any(a in atom for a in valid_atoms):
                return False
            i = j + 1
        else:
            if smiles[i] in valid_atoms or smiles[i] in '()=#/\\@.:':
                i += 1
            else:
                return False
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        return mol is not None
    except:
        return False

In [15]:
# RDKit-based Molecular Property Filter
def is_plausible_molecule(smiles, true_mol, max_mw=1500, min_logp=-7, max_logp=7):
    mol = Chem.MolFromSmiles(smiles, sanitize=True)
    if not mol or not is_valid_smiles_syntax(smiles):
        return False
    mw = Descriptors.MolWt(mol)
    logp = Descriptors.MolLogP(mol)
    true_mw = Descriptors.MolWt(true_mol) if true_mol else 500
    return mw <= max_mw and min_logp <= logp <= max_logp and abs(mw - true_mw) < 300

# Evaluation Metrics
def dice_similarity(smiles1, smiles2):
    mol1 = Chem.MolFromSmiles(smiles1)
    mol2 = Chem.MolFromSmiles(smiles2)
    if mol1 and mol2:
        fp1 = morgan_gen.GetFingerprint(mol1)
        fp2 = morgan_gen.GetFingerprint(mol2)
        return DataStructs.DiceSimilarity(fp1, fp2)
    return 0.0

def mcs_similarity(true_smiles, pred_smiles):
    mol1 = Chem.MolFromSmiles(true_smiles)
    mol2 = Chem.MolFromSmiles(pred_smiles)
    if mol1 and mol2:
        mcs = rdFMCS.FindMCS([mol1, mol2], timeout=30)
        return mcs.numAtoms / max(mol1.GetNumAtoms(), mol2.GetNumAtoms())
    return 0.0

def mw_difference(true_smiles, pred_smiles):
    mol1 = Chem.MolFromSmiles(true_smiles)
    mol2 = Chem.MolFromSmiles(pred_smiles)
    if mol1 and mol2:
        return abs(Descriptors.MolWt(mol1) - Descriptors.MolWt(mol2))
    return float('inf')

def logp_difference(true_smiles, pred_smiles):
    mol1 = Chem.MolFromSmiles(true_smiles)
    mol2 = Chem.MolFromSmiles(pred_smiles)
    if mol1 and mol2:
        return abs(Descriptors.MolLogP(mol1) - Descriptors.MolLogP(mol2))
    return float('inf')

def substructure_match(true_smiles, pred_smiles, substructures):
    mol1 = Chem.MolFromSmiles(true_smiles)
    mol2 = Chem.MolFromSmiles(pred_smiles)
    if not mol1 or not mol2:
        return 0
    matches = 0
    for smarts in substructures:
        pattern = Chem.MolFromSmarts(smarts)
        if mol1.HasSubstructMatch(pattern) and mol2.HasSubstructMatch(pattern):
            matches += 1
    return matches / len(substructures)

def validity_rate(pred_smiles_list):
    valid = sum(1 for smiles in pred_smiles_list if Chem.MolFromSmiles(smiles, sanitize=True) is not None)
    return valid / len(pred_smiles_list) * 100

def tanimoto_similarity(smiles1, smiles2, precomputed_fps=None):
    mol1 = Chem.MolFromSmiles(smiles1, sanitize=True)
    if not mol1:
        return 0.0
    fp1 = morgan_gen.GetFingerprint(mol1)
    if precomputed_fps and smiles2 in precomputed_fps:
        fp2 = precomputed_fps[smiles2]
    else:
        mol2 = Chem.MolFromSmiles(smiles2, sanitize=True)
        if not mol2:
            return 0.0
        fp2 = morgan_gen.GetFingerprint(mol2)
    return DataStructs.TanimotoSimilarity(fp1, fp2)

def prediction_diversity(pred_smiles_list):
    if len(pred_smiles_list) < 2:
        return 0.0
    total_tanimoto = 0
    count = 0
    for i in range(len(pred_smiles_list)):
        for j in range(i+1, len(pred_smiles_list)):
            total_tanimoto += tanimoto_similarity(pred_smiles_list[i], pred_smiles_list[j])
            count += 1
    return 1 - (total_tanimoto / count) if count > 0 else 0.0

In [16]:
# Beam Search with Stereochemistry
def beam_search(model, spectrum, graph_data, ion_mode_idx, precursor_idx, adduct_idx, true_smiles, beam_width=10, max_len=150, nucleus_p=0.9, device='cpu'):
    model.eval()
    true_mol = Chem.MolFromSmiles(true_smiles) if true_smiles else None
    with torch.no_grad():
        spectrum = spectrum.unsqueeze(0).to(device)
        graph_data = Batch.from_data_list([graph_data]).to(device)
        ion_mode_idx = torch.tensor([ion_mode_idx], dtype=torch.long).to(device)
        precursor_idx = torch.tensor([precursor_idx], dtype=torch.long).to(device)
        adduct_idx = torch.tensor([adduct_idx], dtype=torch.long).to(device)
        memory = model.transformer_encoder(spectrum, ion_mode_idx, precursor_idx, adduct_idx)[0]
        gnn_output, substructure_pred, _ = model.gnn_encoder(graph_data, ion_mode_idx, precursor_idx, adduct_idx)
        memory = model.combine_layer(torch.cat([memory, gnn_output], dim=-1)).unsqueeze(1)
        sequences = [([token_to_idx[SOS_TOKEN]], 0.0)]

        for _ in range(max_len):
            all_candidates = []
            for seq, score in sequences:
                if seq[-1] == token_to_idx[EOS_TOKEN]:
                    all_candidates.append((seq, score))
                    continue
                partial_smiles = ''.join([idx_to_token.get(idx, '') for idx in seq[1:]])
                if not is_valid_smiles_syntax(partial_smiles):
                    continue
                tgt_input = torch.tensor([seq], dtype=torch.long).to(device)
                tgt_mask = model.generate_square_subsequent_mask(len(seq)).to(device)
                outputs, valence_penalty = model.decoder(tgt_input, memory, substructure_pred, tgt_mask)
                log_probs = F.log_softmax(outputs[0, -1], dim=-1).cpu().numpy() - 0.1 * valence_penalty.cpu().numpy()
                # Boost stereochemistry tokens
                for tok in ['@', '/']:
                    if tok in token_to_idx:
                        log_probs[token_to_idx[tok]] += 0.5
                sorted_probs = np.sort(np.exp(log_probs))[::-1]
                cumulative_probs = np.cumsum(sorted_probs)
                cutoff_idx = np.searchsorted(cumulative_probs, nucleus_p)
                top_tokens = np.argsort(log_probs)[-cutoff_idx:] if cutoff_idx > 0 else np.argsort(log_probs)[-1:]
                top_probs = np.exp(log_probs[top_tokens]) / np.sum(np.exp(log_probs[top_tokens]))
                for tok in np.random.choice(top_tokens, size=min(beam_width, len(top_tokens)), p=top_probs):
                    new_smiles = partial_smiles + idx_to_token.get(int(tok), '')
                    if is_valid_smiles_syntax(new_smiles):
                        diversity_penalty = 0.2 * sum(1 for s, _ in sequences if tok in s[1:-1])
                        all_candidates.append((seq + [int(tok)], score + log_probs[tok] - diversity_penalty))
            sequences = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]
            if all(seq[-1] == token_to_idx[EOS_TOKEN] for seq, _ in sequences):
                break

        results = []
        for seq, score in sequences:
            smiles = ''.join([idx_to_token.get(idx, '') for idx in seq[1:-1]])
            try:
                mol = Chem.MolFromSmiles(smiles, sanitize=True)
                if mol and is_plausible_molecule(smiles, true_mol):
                    smiles = Chem.MolToSmiles(mol, canonical=True, doRandom=True)
                    confidence = np.exp(score / len(seq))
                    results.append((smiles, confidence))
            except:
                continue
        return results if results else [("Invalid SMILES", 0.0)]


In [17]:
# Visualization Functions
def plot_attention_weights(attn_weights, title="Transformer Attention Weights"):
    plt.figure(figsize=(10, 8))
    plt.imshow(attn_weights.squeeze().cpu().numpy(), cmap='viridis')
    plt.colorbar()
    plt.title(title)
    plt.xlabel("Key Tokens")
    plt.ylabel("Query Tokens")
    plt.show()

def plot_gnn_edge_weights(edge_weights, edge_index, title="GNN Edge Importance"):
    edge_scores = edge_weights[-1].cpu().numpy()
    plt.figure(figsize=(10, 8))
    plt.hist(edge_scores, bins=50)
    plt.title(title)
    plt.xlabel("Edge Weight Magnitude")
    plt.ylabel("Frequency")
    plt.show()

# Error Analysis
def error_analysis(pred_smiles_list, true_smiles_list, adducts, precomputed_fps):
    errors = {'small': 0, 'large': 0, 'aromatic': 0, 'aliphatic': 0}
    adduct_errors = {adduct: [] for adduct in adduct_types}
    for pred_smiles, true_smiles, adduct in zip(pred_smiles_list, true_smiles_list, adducts):
        tanimoto = tanimoto_similarity(pred_smiles, true_smiles, precomputed_fps)
        if tanimoto < 0.3:
            mol = Chem.MolFromSmiles(true_smiles)
            if mol:
                mw = Descriptors.MolWt(mol)
                is_aromatic = any(atom.GetIsAromatic() for atom in mol.GetAtoms())
                errors['small' if mw < 300 else 'large'] += 1
                errors['aromatic' if is_aromatic else 'aliphatic'] += 1
                adduct_errors[adduct].append(tanimoto)
    print("Error Analysis:")
    print(f"Small molecules (<300 Da) errors: {errors['small']}")
    print(f"Large molecules (≥300 Da) errors: {errors['large']}")
    print(f"Aromatic molecule errors: {errors['aromatic']}")
    print(f"Aliphatic molecule errors: {errors['aliphatic']}")
    for adduct, scores in adduct_errors.items():
        if scores:
            print(f"Adduct {adduct} - Avg Tanimoto: {np.mean(scores):.4f}, Count: {len(scores)}")

In [19]:
# Hyperparameter Tuning
def objective(trial, train_data, val_data):
    lr = trial.suggest_float('lr', 1e-5, 1e-3, log=True)
    train_dataset = MSMSDataset(train_data, max_len=SUPERVISED_MAX_LEN, is_ssl=False)
    val_dataset = MSMSDataset(val_data, max_len=SUPERVISED_MAX_LEN, is_ssl=False)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=32, num_workers=2)
    model = MSMS2SmilesHybrid(vocab_size=vocab_size, d_model=768, nhead=12, num_layers=8, dim_feedforward=2048, dropout=0.2, fp_size=2048).to(device)
    return supervised_train(model, train_loader, val_loader, epochs=10, lr=lr)

In [20]:
# Cross-Validation and Training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
kf = KFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []

external_dataset = MSMSDataset(df_external, max_len=SUPERVISED_MAX_LEN, is_ssl=False)
external_loader = DataLoader(external_dataset, batch_size=32, num_workers=2)

for fold, (train_idx, val_idx) in enumerate(kf.split(df_massspecgym)):
    print(f"\nFold {fold+1}/5")
    train_data = df_massspecgym.iloc[train_idx]
    val_data = df_massspecgym.iloc[val_idx]
    ssl_data = train_data.sample(frac=0.3, random_state=42)

    train_dataset = MSMSDataset(train_data, max_len=SUPERVISED_MAX_LEN, is_ssl=False)
    val_dataset = MSMSDataset(val_data, max_len=SUPERVISED_MAX_LEN, is_ssl=False)
    ssl_dataset = MSMSDataset(ssl_data, max_len=PRETRAIN_MAX_LEN, is_ssl=True)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=32, num_workers=2)
    ssl_loader = DataLoader(ssl_dataset, batch_size=128, shuffle=True, num_workers=2)

    # Hyperparameter tuning
    study = optuna.create_study(direction='minimize')
    study.optimize(lambda trial: objective(trial, train_data, val_data), n_trials=10)
    best_lr = study.best_params['lr']
    print(f"Best learning rate for fold {fold+1}: {best_lr:.6f}")

    # Initialize and train model
    model = MSMS2SmilesHybrid(vocab_size=vocab_size, d_model=768, nhead=12, num_layers=8, dim_feedforward=2048, dropout=0.2, fp_size=2048).to(device)
    print(f"Starting SSL pretraining for fold {fold+1}...")
    ssl_pretrain(model, ssl_loader, epochs=3, lr=best_lr)
    print(f"Starting supervised training for fold {fold+1}...")
    best_val_loss = supervised_train(model, train_loader, val_loader, epochs=30, lr=best_lr, patience=5)
    fold_results.append(best_val_loss)
    torch.save({
        'model_state_dict': model.state_dict(),
        'token_to_idx': token_to_idx,
        'idx_to_token': idx_to_token
    }, f'best_msms_hybrid_fold_{fold+1}.pt')

print(f"Cross-validation results: {fold_results}")
print(f"Average validation loss: {np.mean(fold_results):.4f}")

KeyError: 'graph_data'

In [None]:


# External Dataset Evaluation
model.eval()
external_metrics = {'tanimoto': [], 'dice': [], 'mcs': [], 'mw_diff': [], 'logp_diff': [], 'substructure': []}
pred_smiles_list = []
true_smiles_list = []
adducts_list = []
num_samples = min(5, len(external_dataset))

for sample_idx in range(num_samples):
    sample_spectrum = external_dataset[sample_idx][0]
    sample_graph = external_dataset[sample_idx][1]
    sample_ion_mode = external_dataset[sample_idx][3]
    sample_precursor_bin = external_dataset[sample_idx][4]
    sample_adduct_idx = external_dataset[sample_idx][5]
    true_smiles = external_dataset[sample_idx][6]

    predicted_results = beam_search(model, sample_spectrum, sample_graph, sample_ion_mode, sample_precursor_bin, sample_adduct_idx, true_smiles, beam_width=10, max_len=SUPERVISED_MAX_LEN, device=device)
    pred_smiles_list.extend([smiles for smiles, _ in predicted_results])
    true_smiles_list.extend([true_smiles] * len(predicted_results))
    adducts_list.extend([df_external.iloc[sample_idx]['adduct']] * len(predicted_results))

    print(f"\nExternal Sample {sample_idx} - True SMILES: {true_smiles}")
    print("Top Predicted SMILES:")
    for smiles, confidence in predicted_results[:3]:
        external_metrics['tanimoto'].append(tanimoto_similarity(smiles, true_smiles, all_fingerprints))
        external_metrics['dice'].append(dice_similarity(smiles, true_smiles))
        external_metrics['mcs'].append(mcs_similarity(smiles, true_smiles))
        external_metrics['mw_diff'].append(mw_difference(smiles, true_smiles))
        external_metrics['logp_diff'].append(logp_difference(smiles, true_smiles))
        external_metrics['substructure'].append(substructure_match(smiles, true_smiles, model.gnn_encoder.substructures))
        print(f"SMILES: {smiles}, Confidence: {confidence:.4f}, Tanimoto: {external_metrics['tanimoto'][-1]:.4f}, Dice: {external_metrics['dice'][-1]:.4f}, MCS: {external_metrics['mcs'][-1]:.4f}")
        if len(smiles) > 100 and smiles.count('C') > len(smiles) * 0.8:
            print("Warning: Predicted SMILES is a long carbon chain, indicating potential model underfitting.")
        if smiles != "Invalid SMILES":
            mol = Chem.MolFromSmiles(smiles, sanitize=True)
            if mol:
                print(f"Molecular Weight: {Descriptors.MolWt(mol):.2f}, LogP: {Descriptors.MolLogP(mol):.2f}")

    # Visualize molecules
    if predicted_results[0][0] != "Invalid SMILES":
        pred_mol = Chem.MolFromSmiles(predicted_results[0][0], sanitize=True)
        true_mol = Chem.MolFromSmiles(true_smiles, sanitize=True)
        if pred_mol and true_mol:
            img = Draw.MolsToGridImage([true_mol, pred_mol], molsPerRow=2, subImgSize=(300, 300), legends=['True', 'Predicted'])
            img_array = np.array(img.convert('RGB'))
            plt.figure(figsize=(10, 5))
            plt.imshow(img_array)
            plt.axis('off')
            plt.title(f"External Sample {sample_idx} - Tanimoto: {external_metrics['tanimoto'][0]:.4f}")
            plt.show()

    # Visualize attention and GNN weights for first sample
    if sample_idx == 0:
        with torch.no_grad():
            spectrum = sample_spectrum.unsqueeze(0).to(device)
            graph_data = Batch.from_data_list([sample_graph]).to(device)
            ion_mode_idx = torch.tensor([sample_ion_mode], dtype=torch.long).to(device)
            precursor_idx = torch.tensor([sample_precursor_bin], dtype=torch.long).to(device)
            adduct_idx = torch.tensor([sample_adduct_idx], dtype=torch.long).to(device)
            _, attn_weights = model.transformer_encoder(spectrum, ion_mode_idx, precursor_idx, adduct_idx)
            _, _, edge_weights = model.gnn_encoder(graph_data, ion_mode_idx, precursor_idx, adduct_idx)
            plot_attention_weights(attn_weights, title=f"External Fold Transformer Attention Weights")
            plot_gnn_edge_weights(edge_weights, sample_graph.edge_index, title=f"External Fold GNN Edge Importance")

# Final Evaluation
print(f"External Validity Rate: {validity_rate(pred_smiles_list):.2f}%")
print(f"External Prediction Diversity: {prediction_diversity(pred_smiles_list):.4f}")
print("External Metrics Summary:")
print(f"Avg Tanimoto: {np.mean(external_metrics['tanimoto']):.4f}")
print(f"Avg Dice: {np.mean(external_metrics['dice']):.4f}")
print(f"Avg MCS: {np.mean(external_metrics['mcs']):.4f}")
print(f"Avg MW Difference: {np.mean([x for x in external_metrics['mw_diff'] if x != float('inf')]):.2f}")
print(f"Avg LogP Difference: {np.mean([x for x in external_metrics['logp_diff'] if x != float('inf')]):.2f}")
print(f"Avg Substructure Match: {np.mean(external_metrics['substructure']):.4f}")
error_analysis(pred_smiles_list, true_smiles_list, adducts_list, all_fingerprints)
