<a href="https://colab.research.google.com/github/rachitt-t/AI-in-healthcare/blob/main/lab13_e22cseu0118.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

RACHIT TAYAL E22CSEU0118

In [None]:
 !pip install rdkit-pypi torch torchvision pandas tqdm

[31mERROR: Could not find a version that satisfies the requirement rdkit-pypi (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for rdkit-pypi[0m[31m
[0m

In [None]:
# Full end-to-end SMILES generative pipeline (PyTorch + RDKit)
# Runs on Google Colab or local environment.


# -------------------------
# 0) Install dependencies (Colab)
# -------------------------
!pip install rdkit torch torchvision pandas tqdm


# -------------------------
# 1) Imports
# -------------------------
import re
import random
from collections import Counter
from typing import List

import pandas as pd
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# RDKit
from rdkit import Chem




# -------------------------
# 2) Configuration & helper tokens
# -------------------------
PAD_TOKEN = "<pad>"
START_TOKEN = "<s>"
END_TOKEN = "</s>"
UNK_TOKEN = "<unk>"

EMBED_DIM = 128
HIDDEN_DIM = 512
NUM_LAYERS = 2
BATCH_SIZE = 256
LR = 1e-3
EPOCHS = 8
MAX_SEQ_LEN = 120

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# -------------------------
# 3) Tokenizer (regex)
# -------------------------
SMILES_TOKENIZER_REGEX = re.compile(
    r"(\%\d{2}|Br|Cl|Si|Se|\[[^\]]+\]|@@?|==|=|#|:|~|\/|\\|\(|\)|\.|\+|\-|\d+|[A-Za-z])"
)

def tokenize_smiles(smiles: str) -> List[str]:
    tokens = SMILES_TOKENIZER_REGEX.findall(smiles)
    return [t for t in tokens if t.strip()]


# -------------------------
# -------------------------
# 4) Load dataset from TXT file (using fixed path)
# -------------------------

TXT_PATH = "/content/test.txt"   # your file path

def load_txt_smiles(path):
    smiles_list = []
    with open(path, "r") as f:
        for line in f:
            s = line.strip()
            if len(s) == 0:
                continue
            if len(s) > MAX_SEQ_LEN - 2:
                continue
            smiles_list.append(s)
    return smiles_list

print("Loading SMILES from:", TXT_PATH)
smiles_all = load_txt_smiles(TXT_PATH)
print("Loaded", len(smiles_all), "SMILES from txt file")



# -------------------------
# 5) Build vocabulary
# -------------------------
counter = Counter()
for s in smiles_all:
    counter.update(tokenize_smiles(s))

tokens = sorted(counter.keys(), key=lambda x: (-counter[x], x))
vocab = [PAD_TOKEN, START_TOKEN, END_TOKEN, UNK_TOKEN] + tokens

token2idx = {t: i for i, t in enumerate(vocab)}
idx2token = {i: t for t, i in token2idx.items()}

vocab_size = len(vocab)
print("Vocab size:", vocab_size)


# -------------------------
# 6) Encode + decode
# -------------------------
def encode(tokens: List[str], token2idx: dict, max_len: int):
    seq = [token2idx[START_TOKEN]]
    for t in tokens:
        seq.append(token2idx.get(t, token2idx[UNK_TOKEN]))
        if len(seq) >= max_len - 1:
            break
    seq.append(token2idx[END_TOKEN])
    if len(seq) < max_len:
        seq = seq + [token2idx[PAD_TOKEN]] * (max_len - len(seq))
    return seq[:max_len]

def decode(indices: List[int], idx2token: dict) -> str:
    result = []
    for idx in indices:
        tok = idx2token[idx]
        if tok == END_TOKEN:
            break
        if tok in (PAD_TOKEN, START_TOKEN):
            continue
        result.append(tok)
    return "".join(result)


# -------------------------
# 7) PyTorch Dataset
# -------------------------
class SmilesDataset(Dataset):
    def __init__(self, smiles_list, token2idx, max_len):
        self.smiles = smiles_list
        self.token2idx = token2idx
        self.max_len = max_len

    def __len__(self):
        return len(self.smiles)

    def __getitem__(self, idx):
        s = self.smiles[idx]
        toks = tokenize_smiles(s)
        seq = encode(toks, self.token2idx, self.max_len)
        x = torch.tensor(seq[:-1], dtype=torch.long)  # input
        y = torch.tensor(seq[1:], dtype=torch.long)   # next token target
        return x, y

dataset = SmilesDataset(smiles_all, token2idx, MAX_SEQ_LEN)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)


# -------------------------
# 8) Model (LSTM)
# -------------------------
class SmilesRNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, pad_idx):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        emb = self.embed(x)
        out, hidden = self.lstm(emb, hidden)
        logits = self.fc(out)
        return logits, hidden

model = SmilesRNN(
    vocab_size=vocab_size,
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    num_layers=NUM_LAYERS,
    pad_idx=token2idx[PAD_TOKEN]
).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss(ignore_index=token2idx[PAD_TOKEN])

print(model)


# -------------------------
# 9) Train model
# -------------------------
def train_one_epoch(model, dataloader):
    model.train()
    total_loss = 0.0
    pbr = tqdm(dataloader)
    for x, y in pbr:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        optimizer.zero_grad()
        logits, _ = model(x)
        loss = criterion(logits.reshape(-1, vocab_size), y.reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        total_loss += loss.item()
        pbr.set_postfix(loss=total_loss/(pbr.n+1))
    return total_loss / len(dataloader)

for epoch in range(1, EPOCHS+1):
    loss = train_one_epoch(model, dataloader)
    print(f"Epoch {epoch}/{EPOCHS} \u2014 loss: {loss:.4f}")


# -------------------------
# 10) Generate SMILES
# -------------------------
def sample_next_token(logits, temperature=1.0):
    logits = logits / temperature
    probs = torch.softmax(torch.tensor(logits), dim=0).numpy()
    return np.random.choice(len(probs), p=probs)

def generate_smiles(model, temperature=0.9):
    model.eval()
    idx = token2idx[START_TOKEN]
    seq = [idx]
    hidden = None
    inp = torch.tensor([[idx]], dtype=torch.long).to(DEVICE)

    with torch.no_grad():
        for _ in range(MAX_SEQ_LEN):
            logits, hidden = model(inp, hidden)
            logits = logits[0, -1].cpu().numpy()
            next_idx = sample_next_token(logits, temperature)
            seq.append(next_idx)

            if next_idx == token2idx[END_TOKEN]:
                break

            inp = torch.tensor([[next_idx]], dtype=torch.long).to(DEVICE)

    return decode(seq, idx2token)

print("Example Generated SMILES:")
for _ in range(10):
    print(generate_smiles(model))

Collecting rdkit
  Downloading rdkit-2025.9.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (4.1 kB)
Downloading rdkit-2025.9.1-cp312-cp312-manylinux_2_28_x86_64.whl (36.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.2/36.2 MB[0m [31m64.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-2025.9.1
Loading SMILES from: /content/test.txt
Loaded 176075 SMILES from txt file
Vocab size: 52
SmilesRNN(
  (embed): Embedding(52, 128, padding_idx=0)
  (lstm): LSTM(128, 512, num_layers=2, batch_first=True)
  (fc): Linear(in_features=512, out_features=52, bias=True)
)


100%|██████████| 687/687 [02:15<00:00,  5.05it/s, loss=0.807]


Epoch 1/8 — loss: 0.8070


100%|██████████| 687/687 [02:24<00:00,  4.77it/s, loss=0.568]


Epoch 2/8 — loss: 0.5683


100%|██████████| 687/687 [02:24<00:00,  4.77it/s, loss=0.536]


Epoch 3/8 — loss: 0.5360


100%|██████████| 687/687 [02:24<00:00,  4.76it/s, loss=0.517]


Epoch 4/8 — loss: 0.5172


100%|██████████| 687/687 [02:24<00:00,  4.77it/s, loss=0.504]


Epoch 5/8 — loss: 0.5040


100%|██████████| 687/687 [02:24<00:00,  4.76it/s, loss=0.494]


Epoch 6/8 — loss: 0.4938


100%|██████████| 687/687 [02:24<00:00,  4.76it/s, loss=0.485]


Epoch 7/8 — loss: 0.4855


100%|██████████| 687/687 [02:24<00:00,  4.77it/s, loss=0.478]


Epoch 8/8 — loss: 0.4785
Example Generated SMILES:
Cc1csc(NC(=O)Cn2nnc3c(cnn3C)c2=O)n1test
CCNC(=O)c1cccc(NC(=O)C(C)(C)C)c1Ctest
O=C(Nc1ccc2c(c1)OCO2)c1c[nH]nc1-c1ccc(Cl)cc1test
COc1ccc(NC(=O)c2sc(-n3cccc3)nc2C)cc1test
CCCS(=O)(=O)N1CCC(C(=O)c2ccccc2C)CC1test
CC1CCCCC1NC(=O)c1c[nH]nc1-c1ccccc1test
COCCN(Cc1ccccc1)C(=O)Nc1ccccc1Ntest
CCOC1CCN(C(=O)COc2cccc(Cl)c2C)CC1test
COc1c(F)cc(NCC2CCCN(C(=O)C(C)C)C2)c(Cl)c1test
CCOc1cccc2c1OCC(C(=O)Nc1ccc(C)c(F)c1)C2test
