In [None]:
!pip install torch torchvision torchaudio rdkit-pypi datasets selfies tokenizers tqdm -q

In [None]:
# nature_msms

#SOLVED variable issue

# Import packages
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import pandas as pd
import numpy as np
import selfies as sf
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from tqdm import tqdm
import math
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Define token variables early
PAD_TOKEN = "<PAD>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"

# Load and preprocess dataset
dataset = load_dataset('/kaggle/input/massdata', split='train')  # Use full dataset
df = pd.DataFrame(dataset)

# Inspect dataset
print("Dataset Columns:", df.columns.tolist())
print("\nFirst few rows of the dataset:")
print(df[['identifier', 'mzs', 'intensities', 'smiles', 'adduct', 'precursor_mz']].head())
print("\nUnique adduct values:", df['adduct'].unique())

# Binning spectra
def bin_spectrum(mzs, intensities, n_bins=1000, max_mz=1000):
    spectrum = np.zeros(n_bins)
    for mz, intensity in zip(mzs, intensities):
        try:
            mz = float(mz)
            intensity = float(intensity)
            if mz < max_mz:
                bin_idx = int((mz / max_mz) * n_bins)
                spectrum[bin_idx] += intensity
        except (ValueError, TypeError):
            continue
    if spectrum.max() > 0:
        spectrum = spectrum / spectrum.max()
    return spectrum

df['binned'] = df.apply(lambda row: bin_spectrum(row['mzs'], row['intensities']), axis=1)

# Convert SMILES to SELFIES
df['selfies'] = df['smiles'].apply(lambda s: sf.encoder(s) if Chem.MolFromSmiles(s) else None)
df = df.dropna(subset=['selfies'])  # Drop invalid SMILES/SELFIES

# Preprocess ion mode (from adduct) and precursor m/z
df['ion_mode'] = df['adduct'].apply(lambda x: 0 if '+' in str(x) else 1 if '-' in str(x) else 0).fillna(0)
df['precursor_bin'] = pd.qcut(df['precursor_mz'], q=100, labels=False, duplicates='drop')

# Train-validation split
df_train, df_val = train_test_split(df, test_size=0.1, random_state=42)

# Verify preprocessing
print("\nFirst few rows of preprocessed data:")
print(df[['identifier', 'binned', 'selfies', 'ion_mode', 'precursor_bin']].head())

# Check maximum SELFIES length
max_len = max(len(list(sf.split_selfies(s)) + [SOS_TOKEN, EOS_TOKEN]) for s in df['selfies'])
print(f"Maximum SELFIES length: {max_len}")

# Tokenization
all_selfies = df_train['selfies'].tolist()
unique_tokens = set()
for sf_str in all_selfies:
    tokens = list(sf.split_selfies(sf_str))  # Convert generator to list
    unique_tokens.update(tokens)

tokens = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN] + sorted(unique_tokens)
token_to_idx = {tok: i for i, tok in enumerate(tokens)}
idx_to_token = {i: tok for tok, i in token_to_idx.items()}
vocab_size = len(tokens)
MAX_LEN = max(150, max_len)  # Adjust MAX_LEN based on dataset

def encode_selfies(sf_string):
    tokens = [SOS_TOKEN] + list(sf.split_selfies(sf_string)) + [EOS_TOKEN]
    token_ids = [token_to_idx[tok] for tok in tokens if tok in token_to_idx]
    if len(token_ids) > MAX_LEN:
        token_ids = token_ids[:MAX_LEN]
    else:
        token_ids += [token_to_idx[PAD_TOKEN]] * (MAX_LEN - len(token_ids))
    return token_ids

# Dataset class
class MSMSDataset(Dataset):
    def __init__(self, dataframe):
        self.spectra = np.stack(dataframe['binned'].values)
        self.selfies = [encode_selfies(s) for s in dataframe['selfies']]
        self.ion_modes = dataframe['ion_mode'].values
        self.precursor_bins = dataframe['precursor_bin'].values

    def __len__(self):
        return len(self.spectra)

    def __getitem__(self, idx):
        return (
            torch.tensor(self.spectra[idx], dtype=torch.float),
            torch.tensor(self.selfies[idx], dtype=torch.long),
            torch.tensor(self.ion_modes[idx], dtype=torch.long),
            torch.tensor(self.precursor_bins[idx], dtype=torch.long)
        )

train_dataset = MSMSDataset(df_train)
val_dataset = MSMSDataset(df_val)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=64, num_workers=2)

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# Metadata Embedding
class SpectrumMetadataEmbedding(nn.Module):
    def __init__(self, emb_dim=64, ion_mode_dim=2, precursor_bins=100):
        super().__init__()
        self.ion_emb = nn.Embedding(ion_mode_dim, emb_dim)
        self.prec_emb = nn.Embedding(precursor_bins, emb_dim)
        self.linear = nn.Linear(2 * emb_dim, emb_dim)

    def forward(self, ion_mode_idx, precursor_idx):
        ion_vec = self.ion_emb(ion_mode_idx)
        prec_vec = self.prec_emb(precursor_idx)
        combined = torch.cat([ion_vec, prec_vec], dim=-1)
        return self.linear(combined)

# Transformer Encoder
class SpectrumTransformerEncoder(nn.Module):
    def __init__(self, input_dim=1000, d_model=512, nhead=8, num_layers=6, dim_feedforward=1024, dropout=0.2):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, d_model)
        self.metadata_emb = SpectrumMetadataEmbedding(emb_dim=64)
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc = nn.Linear(d_model + 64, d_model)

    def forward(self, src, ion_mode_idx, precursor_idx):
        src = self.input_proj(src).unsqueeze(1)  # Shape: (batch_size, 1, d_model)
        metadata = self.metadata_emb(ion_mode_idx, precursor_idx)  # Shape: (batch_size, emb_dim)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src).squeeze(1)  # Shape: (batch_size, d_model)
        output = torch.cat([output, metadata], dim=-1)  # Concat metadata
        output = self.fc(output)  # Shape: (batch_size, d_model)
        return output

# Transformer Decoder
class SelfiesTransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=1024, dropout=0.2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.output_layer = nn.Linear(d_model, vocab_size)
        self.d_model = d_model

    def forward(self, tgt, memory, tgt_mask=None, memory_key_padding_mask=None):
        embedded = self.embedding(tgt) * math.sqrt(self.d_model)
        embedded = self.pos_encoder(embedded)
        output = self.transformer_decoder(embedded, memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_key_padding_mask)
        return self.output_layer(output)

# Full Transformer Model
class MSMS2SelfiesTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=1024, dropout=0.2):
        super().__init__()
        self.encoder = SpectrumTransformerEncoder(input_dim=1000, d_model=d_model, nhead=nhead, num_layers=num_layers, dim_feedforward=dim_feedforward, dropout=dropout)
        self.decoder = SelfiesTransformerDecoder(vocab_size, d_model, nhead, num_layers, dim_feedforward, dropout)

    def generate_square_subsequent_mask(self, tgt_len):
        mask = torch.triu(torch.ones(tgt_len, tgt_len), diagonal=1)
        mask = mask.float().masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, float(0.0))
        return mask

    def forward(self, src, tgt, ion_mode_idx, precursor_idx, tgt_mask=None, memory_key_padding_mask=None):
        memory = self.encoder(src, ion_mode_idx, precursor_idx).unsqueeze(1)  # Shape: (batch_size, 1, d_model)
        output = self.decoder(tgt, memory, tgt_mask, memory_key_padding_mask)
        return output

# SSL Pretraining for Encoder
def mask_spectrum(spectrum, mask_ratio=0.15):
    spectrum = spectrum.clone()
    n_mask = int(mask_ratio * spectrum.size(0))
    mask_indices = torch.randperm(spectrum.size(0))[:n_mask]
    spectrum[mask_indices] = 0
    return spectrum

def ssl_pretrain_encoder(encoder, dataloader, epochs=3, lr=1e-4):
    encoder.train()
    optimizer = optim.Adam(encoder.parameters(), lr=lr)
    criterion = nn.MSELoss()
    for epoch in range(epochs):
        total_loss = 0
        for spectra, _, ion_modes, precursor_bins in tqdm(dataloader, desc=f"SSL Epoch {epoch+1}/{epochs}"):
            spectra = spectra.to(device)
            ion_modes = ion_modes.to(device)
            precursor_bins = precursor_bins.to(device)
            masked_spectra = torch.stack([mask_spectrum(s) for s in spectra]).to(device)
            optimizer.zero_grad()
            reconstructed = encoder(masked_spectra, ion_modes, precursor_bins)
            loss = criterion(reconstructed, encoder(spectra, ion_modes, precursor_bins))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"SSL Epoch {epoch+1}/{epochs} - Reconstruction Loss: {total_loss/len(dataloader):.4f}")

# Supervised Training with Early Stopping
def supervised_train(model, train_loader, val_loader, epochs=10, lr=1e-4, patience=3):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=token_to_idx[PAD_TOKEN])
    best_val_loss = float('inf')
    no_improve = 0

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        for spectra, selfies_tokens, ion_modes, precursor_bins in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            spectra, selfies_tokens = spectra.to(device), selfies_tokens.to(device)
            ion_modes, precursor_bins = ion_modes.to(device), precursor_bins.to(device)
            tgt_input = selfies_tokens[:, :-1]
            tgt_output = selfies_tokens[:, 1:]
            tgt_mask = model.generate_square_subsequent_mask(tgt_input.size(1)).to(device)
            memory_key_padding_mask = None  # No padding in memory since src_len=1
            optimizer.zero_grad()
            output = model(spectra, tgt_input, ion_modes, precursor_bins, tgt_mask, memory_key_padding_mask)
            loss = criterion(output.reshape(-1, vocab_size), tgt_output.reshape(-1))
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        avg_train_loss = total_train_loss / len(train_loader)

        # Validation
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for spectra, selfies_tokens, ion_modes, precursor_bins in val_loader:
                spectra, selfies_tokens = spectra.to(device), selfies_tokens.to(device)
                ion_modes, precursor_bins = ion_modes.to(device), precursor_bins.to(device)
                tgt_input = selfies_tokens[:, :-1]
                tgt_output = selfies_tokens[:, 1:]
                tgt_mask = model.generate_square_subsequent_mask(tgt_input.size(1)).to(device)
                memory_key_padding_mask = None
                output = model(spectra, tgt_input, ion_modes, precursor_bins, tgt_mask, memory_key_padding_mask)
                loss = criterion(output.reshape(-1, vocab_size), tgt_output.reshape(-1))
                total_val_loss += loss.item()
        avg_val_loss = total_val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            no_improve = 0
            torch.save({
                'model_state_dict': model.state_dict(),
                'token_to_idx': token_to_idx,
                'idx_to_token': idx_to_token
            }, 'best_msms_transformer.pt')
        else:
            no_improve += 1
        if no_improve >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    return best_val_loss

# Beam Search Inference with Diversity Penalty
def beam_search(model, spectrum, ion_mode_idx, precursor_idx, beam_width=10, max_len=150, device='cpu'):
    model.eval()
    with torch.no_grad():
        spectrum = spectrum.unsqueeze(0).to(device)
        ion_mode_idx = torch.tensor([ion_mode_idx], dtype=torch.long).to(device)
        precursor_idx = torch.tensor([precursor_idx], dtype=torch.long).to(device)
        memory = model.encoder(spectrum, ion_mode_idx, precursor_idx).unsqueeze(1)  # Shape: (1, 1, d_model)
        sequences = [([token_to_idx[SOS_TOKEN]], 0.0)]  # (sequence, log_prob)

        for _ in range(max_len):
            all_candidates = []
            for seq, score in sequences:
                if seq[-1] == token_to_idx[EOS_TOKEN]:
                    all_candidates.append((seq, score))
                    continue
                tgt_input = torch.tensor([seq], dtype=torch.long).to(device)
                tgt_mask = model.generate_square_subsequent_mask(len(seq)).to(device)
                outputs = model.decoder(tgt_input, memory, tgt_mask)
                log_probs = F.log_softmax(outputs[0, -1], dim=-1).cpu().numpy()
                top_tokens = np.argsort(log_probs)[-beam_width:]
                for tok in top_tokens:
                    diversity_penalty = 0.1 * sum(1 for s, _ in sequences if tok in s[1:-1])  # Penalize repeated tokens
                    candidate = (seq + [tok], score + log_probs[tok] - diversity_penalty)
                    all_candidates.append(candidate)
            sequences = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]
            if all(seq[-1] == token_to_idx[EOS_TOKEN] for seq, _ in sequences):
                break

        results = []
        for seq, score in sequences:
            sf_str = ''.join([idx_to_token.get(idx, '') for idx in seq[1:-1]])
            try:
                smiles = sf.decoder(sf_str)
                if Chem.MolFromSmiles(smiles):
                    confidence = np.exp(score / len(seq))  # Normalized score
                    results.append((smiles, confidence))
            except:
                continue
        return results if results else [("Invalid SMILES", 0.0)]

# Tanimoto Similarity
def tanimoto_similarity(smiles1, smiles2):
    mol1 = Chem.MolFromSmiles(smiles1)
    mol2 = Chem.MolFromSmiles(smiles2)
    if mol1 and mol2:
        fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2, 2048)
        fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, 2, 2048)
        return Chem.DataStructs.TanimotoSimilarity(fp1, fp2)
    return 0.0

# Initialize and Train Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MSMS2SelfiesTransformer(vocab_size=vocab_size, d_model=512, num_layers=6, dim_feedforward=1024, dropout=0.2).to(device)

# SSL Pretraining
print("Starting SSL pretraining...")
ssl_pretrain_encoder(model.encoder, train_loader, epochs=3)

# Supervised Training
print("Starting supervised training...")
best_val_loss = supervised_train(model, train_loader, val_loader, epochs=10, patience=3)

print(f"Training complete. Best validation loss: {best_val_loss:.4f}")
print("Model saved as 'best_msms_transformer.pt'")

# Inference and Visualization
sample_idx = 0
sample_spectrum = torch.tensor(df_val['binned'].iloc[sample_idx], dtype=torch.float)
sample_ion_mode = df_val['ion_mode'].iloc[sample_idx]
sample_precursor_bin = df_val['precursor_bin'].iloc[sample_idx]
true_smiles = df_val['smiles'].iloc[sample_idx]

predicted_results = beam_search(model, sample_spectrum, sample_ion_mode, sample_precursor_bin, beam_width=10, device=device)
print(f"True SMILES: {true_smiles}")
print("Top Predicted SMILES:")
for smiles, confidence in predicted_results[:3]:
    print(f"SMILES: {smiles}, Confidence: {confidence:.4f}")
    similarity = tanimoto_similarity(true_smiles, smiles)
    print(f"Tanimoto Similarity: {similarity:.4f}")
    if len(smiles) > 100 and smiles.count('C') > len(smiles) * 0.8:
        print("Warning: Predicted SMILES is a long carbon chain, indicating potential model underfitting.")

# Visualize molecules
if predicted_results[0][0] != "Invalid SMILES":
    pred_mol = Chem.MolFromSmiles(predicted_results[0][0])
    true_mol = Chem.MolFromSmiles(true_smiles)
    if pred_mol and true_mol:
        img = Draw.MolsToGridImage([true_mol, pred_mol], molsPerRow=2, subImgSize=(300, 300), legends=['True', 'Predicted'])
        img_array = np.array(img.convert('RGB'))  # Convert PIL Image to RGB and then to NumPy array
        plt.figure(figsize=(10, 5))
        plt.imshow(img_array)
        plt.axis('off')
        plt.show()
