In [11]:
import pandas as pd
import selfies as sf
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, r2_score
from rdkit import Chem
from collections import Counter
import seaborn as sns
import torch
from torch.utils.data import Dataset
import torch.nn as nn
from tensorflow.keras.preprocessing.sequence import pad_sequences
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import csv
import selfies as sf
from collections import Counter
import torch.nn as nn
import csv
import os

### Create Embedding for SMILES
- **Input:** SMILES String, converted into SELFIES  
- **Encoder:** Compress to low-dimensional latent vector  
- **Decoder:** Reconstruct to original sequence  
- **Goal:** Learn compact representation of SMILES that is invertible


In [20]:
def tokenize_smiles(smiles, token2idx, max_len): 
    selfies_str = sf.encoder(smiles)
    tokens = list(sf.split_selfies(selfies_str))
    token_ids = [token2idx.get(tok, token2idx['[UNK]']) for tok in tokens]
    if len(token_ids) < max_len:
        token_ids += [token2idx['[PAD]']] * (max_len - len(token_ids))
    return token_ids[:max_len]

class SMILESAutoencoder(nn.Module): 
    def __init__(self, vocab_size, embedding_dim=64, hidden_dim=256, latent_dim=64, max_len=1801): 
        super().__init__()
        self.max_len = max_len
        self.latent_dim = latent_dim

        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)

        self.encoder = nn.Sequential(
            nn.Linear(max_len * embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, max_len * embedding_dim)
        )

        self.output = nn.Linear(embedding_dim, vocab_size)

    def forward(self, x):
        batch_size = x.size(0)
        embedded = self.embedding(x)  # (B, L, E)
        flat = embedded.view(batch_size, -1)
        latent = self.encoder(flat)  # (B, latent_dim)

        # Normalize latent and append mu and std
        mu = latent.mean(dim=1, keepdim=True)
        std = latent.std(dim=1, keepdim=True) + 1e-8
        norm_latent = (latent - mu) / std
        latent_with_stats = torch.cat([norm_latent, mu, std], dim=1)  # (B, latent_dim + 2)

        decoded_flat = self.decoder(latent)  # decode unnormalized latent
        decoded_emb = decoded_flat.view(batch_size, self.max_len, -1)
        logits = self.output(decoded_emb)  # (B, L, vocab_size)

        return logits, latent_with_stats

    def decode_from_latent(self, latent_with_stats, idx2token):
        self.eval()
        with torch.no_grad():
            norm_part = latent_with_stats[:-2]
            mu = latent_with_stats[-2]
            std = latent_with_stats[-1]
            latent = norm_part * std + mu

            decoded_flat = self.decoder(latent.unsqueeze(0))
            decoded_emb = decoded_flat.view(1, self.max_len, -1)
            logits = self.output(decoded_emb)
            token_ids = logits.argmax(-1).squeeze(0).tolist()
            tokens = [idx2token[i] for i in token_ids if idx2token[i] != '[PAD]']
            selfies_str = ''.join(tokens)
            return sf.decoder(selfies_str)

# --- Build vocabulary from dataset ---
token_counter = Counter()
with open("computed_spectra.csv", "r") as f:
    reader = csv.reader(f)
    for row in reader:
        smiles = row[0].strip()
        try:
            selfies_str = sf.encoder(smiles)
            tokens = list(sf.split_selfies(selfies_str))
            token_counter.update(tokens)
        except:
            continue

special_tokens = ['[PAD]', '[SOS]', '[EOS]', '[UNK]']
unique_tokens = sorted(token_counter.keys())
vocab = special_tokens + unique_tokens
token2idx = {tok: i for i, tok in enumerate(vocab)}
idx2token = {i: tok for tok, i in token2idx.items()}

# --- Tokenize and pad SMILES ---
smiles_list = []
padded_sequences = []

with open("computed_spectra.csv", "r") as f:
    reader = csv.reader(f)
    next(reader)  # skip header line
    for row in reader:
        smiles = row[0].strip()
        try:
            token_ids = tokenize_smiles(smiles, token2idx, max_len=1801)
            padded_sequences.append(token_ids)
            smiles_list.append(smiles)
        except:
            continue

max_len = 1801
X_tensor = torch.tensor(padded_sequences, dtype=torch.long)

# --- Initialize model ---
vocab_size = len(token2idx)
model = SMILESAutoencoder(vocab_size)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# --- Dataloader ---
from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(X_tensor, X_tensor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# --- Training loop with checkpointing and evaluation ---
for epoch in range(100):
    total_loss = 0
    for x_batch, y_batch in dataloader:
        optimizer.zero_grad()
        logits, _ = model(x_batch)
        loss = loss_fn(logits.view(-1, vocab_size), y_batch.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

    if (epoch + 1) % 5 == 0:  # every 5 epochs
        print(f"\n[Epoch {epoch+1}] SMILES Reconstruction Test:")
        test_smiles_list = smiles_list[:5]  # choose top 5 from dataset
        for test_smiles in test_smiles_list:
            try:
                token_ids = tokenize_smiles(test_smiles, token2idx, max_len=1801)
                input_tensor = torch.tensor([token_ids], dtype=torch.long)
                _, latent_with_stats = model(input_tensor)
                reconstructed_smiles = model.decode_from_latent(latent_with_stats.squeeze(0), idx2token)
                match = "✅" if reconstructed_smiles == test_smiles else "❌"
                print(f"{match} {test_smiles} → {reconstructed_smiles}")
            except Exception as e:
                print(f"⚠️ Skipped {test_smiles}: {e}")

# --- Final test on top 5 molecules from your dataset ---
print("\n=== Final SMILES Reconstructions ===")
for test_smiles in smiles_list[:5]:
    try:
        token_ids = tokenize_smiles(test_smiles, token2idx, max_len=1801)
        input_tensor = torch.tensor([token_ids], dtype=torch.long)
        model.eval()
        with torch.no_grad():
            _, latent_with_stats = model(input_tensor)
        reconstructed_smiles = model.decode_from_latent(latent_with_stats.squeeze(0), idx2token)
        match = "✅" if reconstructed_smiles == test_smiles else "❌"
        print(f"{match} {test_smiles} → {reconstructed_smiles}")
    except Exception as e:
        print(f"⚠️ Skipped {test_smiles}: {e}")


Epoch 1, Loss: 23.6473
Epoch 2, Loss: 5.8407
Epoch 3, Loss: 5.4546
Epoch 4, Loss: 5.1361
Epoch 5, Loss: 4.8057

[Epoch 5] SMILES Reconstruction Test:
❌ CN(c1ccc(NC(=O)Nc2ccccc2)cc1)S(=O)(=O)c1ccc(-c2ccn(CCO)n2)s1 → C=CC=CC=CC=C(CCNC=CC=CC=CC=CC=C)CCCCCC1=CC=CCCCCCCCC=2CC=21
❌ CC(O)CC(C)C#COC#CC(C)CC(C)O → CCCCCCCCCCCCCCCCCCC=C
❌ Cc1cc([N+](=O)[O-])ccc1NC(=O)c1ccc(OCC(C)C)c(Br)c1 → CC=CC=CCCNC=CC=CC=CCCNC=CC=CC=CC=CC=CC=CCCCC=C=CC#C
❌ COC12C(COC(N)=O)C3=C(C(=O)C(C)=C(N)C3=O)N1CC1NC12 → CCCCCNC=CC(C)C=CC=CC=CC=CC1CCCCCCCCCCCCCCC2CC21
❌ CC(C)(O)C(NC(=O)c1cnn2cc(C3CC3)cnc12)c1ccc(OC(F)(F)F)c(F)c1 → CCCCCCC=CC(N)C(C)C=CCCCCCCCCCCCCC=CC=CC=CC=CCCCCCCCC=C1CC1=CC
Epoch 6, Loss: 4.5090
Epoch 7, Loss: 4.2149
Epoch 8, Loss: 3.9452
Epoch 9, Loss: 3.6833
Epoch 10, Loss: 3.3847

[Epoch 10] SMILES Reconstruction Test:
❌ CN(c1ccc(NC(=O)Nc2ccccc2)cc1)S(=O)(=O)c1ccc(-c2ccn(CCO)n2)s1 → CCC=CC1=CC(=O)NC=CC=CC=CC1=C
❌ CC(O)CC(C)C#COC#CC(C)CC(C)O → CC(C)CCCCCCCCC(=O)CCCC(C)O
❌ Cc1cc([N+](=O)[O-])ccc1NC(=O)c