<a href="https://colab.research.google.com/github/yamenfargaly-maker/TRANSFORMER/blob/main/MOLECULAR_TRANSER.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
## Cell 1 — Code — Install dependencies

# 2. Downgrade numpy before installing rdkit
!pip install "numpy<2.0"

# 3. Install rdkit
!pip install rdkit

# 4. Install rdchiral properly
!pip install git+https://github.com/connorcoley/rdchiral.git

# 5. Other dependencies
!pip install torch torchvision torchaudio pandas tqdm

Collecting git+https://github.com/connorcoley/rdchiral.git
  Cloning https://github.com/connorcoley/rdchiral.git to /tmp/pip-req-build-1gs1mxwz
  Running command git clone --filter=blob:none --quiet https://github.com/connorcoley/rdchiral.git /tmp/pip-req-build-1gs1mxwz
  Resolved https://github.com/connorcoley/rdchiral.git to commit da174b921b921f6547e46f32812b5d4af937cc94
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:
## Cell 2 — Code — Create folders / scaffolding

import os

os.makedirs("retroB/data", exist_ok=True)
os.makedirs("retroB/mt", exist_ok=True)
os.makedirs("retroB/runs", exist_ok=True)
os.makedirs("retroB/out", exist_ok=True)


In [None]:
## Cell 3 — Code — Helper to write modules

def write_script(path, code):
    with open(path, 'w') as f:
        f.write(code)

In [None]:
## Cell 4 - Code - Cleans, canonicalizes, and randomly splits a CSV of chemical reactions into 80% training, 10% validation, and 10% test sets, saved as text files
## Takes in "reaction_smiles" (reactant>>reagents>>products) as column input and outputs "processed" train, validation, and test text files with reaction id, product, and reactant for each, separated by tab-separated
## Saves in prepare_upsto.py to be run later on

prep_code = """
import pandas as pd, re, argparse, random, os
from rdkit import Chem

def canon(smi: str):
    if pd.isna(smi): return None
    mol = Chem.MolFromSmiles(str(smi))
    if not mol: return None
    Chem.SanitizeMol(mol)
    return Chem.MolToSmiles(mol, canonical=True)

def unmap(rx: str):
    # remove atom-map labels like :1 :2 etc
    return re.sub(r':\\d+', '', str(rx))

def split_df(df: pd.DataFrame, seed=1337):
    ids = list(df.index)
    random.Random(seed).shuffle(ids)
    n = len(ids); n_tr, n_va = int(0.8*n), int(0.1*n)
    tr = df.loc[ids[:n_tr]]; va = df.loc[ids[n_tr:n_tr+n_va]]; te = df.loc[ids[n_tr+n_va:]]
    return tr, va, te

def to_lines(df: pd.DataFrame, out_path: str):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    with open(out_path, "w") as f:
        for rid, r in df.iterrows():
            p, q = r["product_canon"], r["reactants_canon"]
            if pd.isna(p) or pd.isna(q):
                continue
            f.write(f"{rid}\\t{p}\\t{q}\\n")

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--csv", required=True, help="CSV with column 'reactants>reagents>production' (reactants>>product only)")
    ap.add_argument("--outdir", default="data/processed")
    args = ap.parse_args()
    os.makedirs(args.outdir, exist_ok=True)

    df = pd.read_csv(args.csv)

    # get the reactions column
    colname = "reactants>reagents>production"
    if colname not in df.columns:
        raise ValueError(f"Expected column '{colname}' not found")

    rx = df[colname].astype(str).apply(unmap)

    # split into reactants and product
    df["reactants"] = rx.str.split(">>").str[0]
    df["product"]   = rx.str.split(">>").str[1]

    # canonicalize both sides
    df["reactants_canon"] = df["reactants"].apply(canon)
    df["product_canon"]   = df["product"].apply(canon)
    df = df.dropna(subset=["reactants_canon","product_canon"]).reset_index(drop=True)

    # split and save
    tr, va, te = split_df(df)
    to_lines(tr, f"{args.outdir}/train.txt")
    to_lines(va, f"{args.outdir}/valid.txt")
    to_lines(te, f"{args.outdir}/test.txt")
"""
write_script("retroB/data/prepare_uspto.py", prep_code)

In [None]:
## Cell 5
## Tokenzes SMILES strings into characters, builds a vocab of allowed characters, which are then converted into integer IDs
## PAD - Makes all sequences in a batch the same length
## BOS - Marks where decoding starts
## EOS - Marks where decoding ends.
## Saves in tokenizer.py

tok_code = """
PAD, BOS, EOS = "<pad>", "<bos>", "<eos>"

def tok_smiles(s: str):
    return list(s.strip())

def build_vocab(paths, min_freq=1):
    from collections import Counter
    cnt = Counter()
    for p in paths:
        for line in open(p, "r"):
            _, prod, react = line.rstrip("\\n").split("\\t")
            cnt.update(tok_smiles(prod)); cnt.update(tok_smiles(react))
    itos = [PAD, BOS, EOS] + sorted([t for t,c in cnt.items() if c >= min_freq])
    stoi = {t:i for i,t in enumerate(itos)}
    return stoi, itos

def encode(tokens, stoi):
    return [stoi[t] for t in tokens]

def decode(ids, itos):
    # join chars, skipping specials
    specials = {PAD, BOS, EOS}
    toks = [itos[i] for i in ids if itos[i] not in specials]
    return "".join(toks)
"""
write_script("retroB/mt/tokenizer.py", tok_code)

In [None]:
## Cell 6
## Takes in tokenized SMILES strings and learns to "translate" products --> reactants
## Adds positional information so the model understands the SMILES structures
## Uses a custmoo learning rate schedule called Noam Decay: INcreasings the learning rate for the first few thousand steps and then decays over time to stablize the model's learning capabilities


model_code = """
import math, torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=4096):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0)/d_model))
        pe[:,0::2] = torch.sin(pos*div); pe[:,1::2] = torch.cos(pos*div)
        self.register_buffer("pe", pe.unsqueeze(0))  # [1, L, D]
    def forward(self, x):  # x: [B,L,D]
        return x + self.pe[:, :x.size(1)]

class MT(nn.Module):
    def __init__(self, vocab, d_model=512, nhead=8, num_layers=6, d_ff=2048, dropout=0.1):
        super().__init__()
        V = len(vocab)
        self.src_emb = nn.Embedding(V, d_model)
        self.tgt_emb = nn.Embedding(V, d_model)
        self.pos     = PositionalEncoding(d_model)
        self.tf = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=d_ff, dropout=dropout, batch_first=True)
        self.proj = nn.Linear(d_model, V)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, tgt):
        src = self.pos(self.dropout(self.src_emb(src)))
        tgt = self.pos(self.dropout(self.tgt_emb(tgt)))
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        out = self.tf(src, tgt, tgt_mask=tgt_mask)
        return self.proj(out)  # [B,L,V]

def noam_lr(opt, d_model, warmup=8000):
    step = 0
    def _rate():
        nonlocal step
        step += 1
        return (d_model**-0.5) * min(step**-0.5, step*(warmup**-1.5))
    return torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda _: _rate())
"""
write_script("retroB/mt/model.py", model_code)

In [None]:
##Cell 7
##Baseline with standard cross-entropy
train_code = """
import argparse, torch, os
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pandas as pd
from tokenizer import build_vocab, tok_smiles, encode, decode, PAD, BOS, EOS
from model import MT, noam_lr
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog("rdApp.*")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# -------------------------
# Dataset
# -------------------------
class RxDataset(Dataset):
    def __init__(self, path):
        # Each line: reaction_id \\t product \\t reactants
        self.rows = [l.rstrip("\\n").split("\\t") for l in open(path)]
    def __len__(self): return len(self.rows)
    def __getitem__(self, i):
        rid, p, r = self.rows[i]
        return rid, p, r

def collate(batch, stoi, max_len=256):
    rids, src, tgt = zip(*batch)
    src = [[BOS]+tok_smiles(s)+[EOS] for s in src]
    tgt = [[BOS]+tok_smiles(s)+[EOS] for s in tgt]
    src_ids = [encode(s, stoi)[:max_len] for s in src]
    tgt_ids = [encode(t, stoi)[:max_len] for t in tgt]
    pad_id = stoi[PAD]
    Ls = max(map(len, src_ids)); Lt = max(map(len, tgt_ids))
    pad = lambda seq, L: seq + [pad_id]*(L-len(seq))
    src_t = torch.tensor([pad(s,Ls) for s in src_ids]).long()
    tgt_in = torch.tensor([pad(t[:-1],Lt-1) for t in tgt_ids]).long()
    tgt_y  = torch.tensor([pad(t[1:], Lt-1) for t in tgt_ids]).long()
    return rids, src_t, tgt_in, tgt_y

# -------------------------
# Loss
# -------------------------
def ce_loss(logits, target, ignore_index=0):
    return F.cross_entropy(
        logits.transpose(1,2),
        target,
        ignore_index=ignore_index,
        reduction="mean"
    )

# -------------------------
# Beam Search Decoder
# -------------------------
@torch.no_grad()
def beam_search(model, src, stoi, itos, beam=10, max_len=256):
    model.eval()
    pad_id, bos_id, eos_id = stoi[PAD], stoi[BOS], stoi[EOS]
    src = src.to(DEVICE)

    beams = [([bos_id], 0.0)]  # sequence, score
    finished = []

    for _ in range(max_len):
        new_beams = []
        for seq, score in beams:
            if seq[-1] == eos_id:
                finished.append((seq, score))
                continue
            tgt = torch.tensor(seq).unsqueeze(0).to(DEVICE)
            logits = model(src.unsqueeze(0), tgt)[:,-1]  # last step
            probs = F.log_softmax(logits, dim=-1).squeeze(0)
            topk = torch.topk(probs, beam)
            for idx, s in zip(topk.indices.tolist(), topk.values.tolist()):
                new_beams.append((seq+[idx], score+s))
        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam]

    finished.extend(beams)
    finished = sorted(finished, key=lambda x: x[1], reverse=True)
    decoded = [decode(seq, itos) for seq, _ in finished]
    return decoded[:beam]


def canonicalize(smi):
    try:
        mol = Chem.MolFromSmiles(smi)
        if mol:
            return Chem.MolToSmiles(mol, canonical=True)
    except:
        return None
    return None

@torch.no_grad()
def validate(model, loader, stoi, itos, beam=10):
    model.eval()
    tot_loss, n = 0.0, 0
    top1, top5, top10 = 0, 0, 0

    for _, src, tgt_in, tgt_y in loader:
        src, tgt_in, tgt_y = src.to(DEVICE), tgt_in.to(DEVICE), tgt_y.to(DEVICE)

        # Loss
        logits = model(src, tgt_in)
        tot_loss += ce_loss(logits, tgt_y, ignore_index=stoi[PAD]).item()
        n += 1

        # Decode & compare
        for i in range(src.size(0)):
            beams = beam_search(model, src[i], stoi, itos, beam=beam)

            # decode target
            target = decode(tgt_y[i].tolist(), itos)

            # canonicalize safely
            can_target = canonicalize(target)
            can_beams  = [canonicalize(b) for b in beams if canonicalize(b) is not None]

            if can_target and can_target in can_beams[:1]:
                top1 += 1
            if can_target and can_target in can_beams[:5]:
                top5 += 1
            if can_target and can_target in can_beams[:10]:
                top10 += 1

    return (
        tot_loss/max(1,n),
        top1/len(loader.dataset),
        top5/len(loader.dataset),
        top10/len(loader.dataset)
    )


# -------------------------
# Main
# -------------------------
def main(args):
    stoi, itos = build_vocab([args.train, args.valid])
    os.makedirs(args.outdir, exist_ok=True)
    torch.save({"stoi":stoi, "itos":itos}, os.path.join(args.outdir,"vocab.pt"))

    tr = RxDataset(args.train); va = RxDataset(args.valid)
    C = lambda b: collate(b, stoi)
    dl_tr = DataLoader(tr, batch_size=args.bsz, shuffle=True, collate_fn=C)
    dl_va = DataLoader(va, batch_size=1, shuffle=False, collate_fn=C)  # beam search needs batch=1

    model = MT(stoi).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=1.0, betas=(0.9,0.98), weight_decay=0.01)
    sch = noam_lr(opt, d_model=512, warmup=8000)

    best, bad = 1e9, 0
    patience = args.patience
    for epoch in range(args.max_epochs):
        # Training
        model.train()
        for _, src, tgt_in, tgt_y in dl_tr:
            src, tgt_in, tgt_y = src.to(DEVICE), tgt_in.to(DEVICE), tgt_y.to(DEVICE)
            logits = model(src, tgt_in)
            loss = ce_loss(logits, tgt_y, ignore_index=stoi[PAD])
            opt.zero_grad(); loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step(); sch.step()

        # Validation
        vloss, acc1, acc5, acc10 = validate(model, dl_va, stoi, itos)
        print(f"epoch {epoch} | val_loss {vloss:.4f} | top1 {acc1:.3f} | top5 {acc5:.3f} | top10 {acc10:.3f}")

        if vloss < best:
            best, bad = vloss, 0
            torch.save({"model":model.state_dict(),"stoi":stoi,"itos":itos},
                       os.path.join(args.outdir,"best.pt"))
        else:
            bad += 1
            if bad >= patience: break

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--train", default="data/processed/train.txt")
    ap.add_argument("--valid", default="data/processed/valid.txt")
    ap.add_argument("--outdir", default="runs/mt_standardce")
    ap.add_argument("--bsz", type=int, default=64)
    ap.add_argument("--max_epochs", type=int, default=8)
    ap.add_argument("--patience", type=int, default=5, help="Early stopping patience (epochs without improvement)")

    args = ap.parse_args(); main(args)
"""
write_script("retroB/mt/train.py", train_code)

In [None]:
##Baseline + Label Smoothing Cross-Entropy
train_code = """
import argparse, torch, os
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pandas as pd
from tokenizer import build_vocab, tok_smiles, encode, PAD, BOS, EOS
from model import MT, noam_lr

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class RxDataset(Dataset):
    def __init__(self, path): self.rows = [l.rstrip("\\n").split("\\t") for l in open(path)]
    def __len__(self): return len(self.rows)
    def __getitem__(self,i): rid, p, r = self.rows[i]; return rid, p, r

def collate(batch, stoi, max_len=256):
    rids, src, tgt = zip(*batch)
    src = [[BOS]+tok_smiles(s)+[EOS] for s in src]
    tgt = [[BOS]+tok_smiles(s)+[EOS] for s in tgt]
    src_ids = [encode(s, stoi)[:max_len] for s in src]
    tgt_ids = [encode(t, stoi)[:max_len] for t in tgt]
    pad_id = stoi[PAD]
    Ls = max(map(len, src_ids)); Lt = max(map(len, tgt_ids))
    pad = lambda seq, L: seq + [pad_id]*(L-len(seq))
    src_t = torch.tensor([pad(s,Ls) for s in src_ids]).long()
    tgt_in = torch.tensor([pad(t[:-1],Lt-1) for t in tgt_ids]).long()
    tgt_y  = torch.tensor([pad(t[1:], Lt-1) for t in tgt_ids]).long()
    return rids, src_t, tgt_in, tgt_y

def label_smoothing_loss(logits, target, eps=0.1, ignore_index=0, reduction="none"):
    n_class = logits.size(-1)
    logp = F.log_softmax(logits, -1)
    with torch.no_grad():
        true = torch.zeros_like(logp)
        true.fill_(eps/(n_class-1))
        true.scatter_(2, target.unsqueeze(-1), 1-eps)
        true.masked_fill_(target.eq(ignore_index).unsqueeze(-1), 0)
    loss = -(true*logp).sum(-1)
    mask = (~target.eq(ignore_index)).float()
    loss = loss * mask
    if reduction == "none":
        return loss.sum(-1) / mask.sum(-1)  # per-example loss
    else:
        return loss.sum()/mask.sum()

# -------------------------
# Load precomputed UQ weights
# -------------------------
def load_uq_weights(path):
    df = pd.read_csv(path)
    top1 = df[df["rank"] == 1][["reaction_id", "UQ_score"]]
    # normalize: low UQ = high weight
    uq = 1.0 - (top1["UQ_score"] - top1["UQ_score"].min()) / (top1["UQ_score"].max() - top1["UQ_score"].min() + 1e-8)
    weights = dict(zip(top1["reaction_id"], uq))
    return weights

def main(args):
    stoi, itos = build_vocab([args.train, args.valid])
    os.makedirs(args.outdir, exist_ok=True)
    torch.save({"stoi":stoi, "itos":itos}, os.path.join(args.outdir, "vocab.pt"))

    tr = RxDataset(args.train); va = RxDataset(args.valid)
    C = lambda b: collate(b, stoi)
    dl_tr = DataLoader(tr, batch_size=args.bsz, shuffle=True,  collate_fn=C)
    dl_va = DataLoader(va, batch_size=args.bsz, shuffle=False, collate_fn=C)

    # load UQ weights if provided
    uq_weights = load_uq_weights(args.uq_csv) if args.uq_csv else {}

    model = MT(stoi).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=1.0, betas=(0.9,0.98), weight_decay=0.01)
    sch = noam_lr(opt, d_model=512, warmup=8000)

    best, bad = 1e9, 0
    patience = args.patience
    for epoch in range(args.max_epochs):   # <-- capped by CLI arg
        model.train()
        for rids, src, tgt_in, tgt_y in dl_tr:
            src, tgt_in, tgt_y = src.to(DEVICE), tgt_in.to(DEVICE), tgt_y.to(DEVICE)
            logits = model(src, tgt_in)
            per_ex_loss = label_smoothing_loss(logits, tgt_y, ignore_index=stoi[PAD], reduction="none")

            # apply UQ weighting
            if uq_weights:
                w = torch.tensor([uq_weights.get(rid, 1.0) for rid in rids], device=DEVICE)
                loss = (per_ex_loss * w).mean()
            else:
                loss = per_ex_loss.mean()

            opt.zero_grad(); loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step(); sch.step()

        # validation
        model.eval(); vloss, cnt = 0.0, 0
        with torch.no_grad():
            for _, src, tgt_in, tgt_y in dl_va:
                src, tgt_in, tgt_y = src.to(DEVICE), tgt_in.to(DEVICE), tgt_y.to(DEVICE)
                logits = model(src, tgt_in)
                vloss += label_smoothing_loss(logits, tgt_y, ignore_index=stoi[PAD]).mean().item(); cnt += 1
        vloss /= max(cnt,1)
        print(f"epoch {epoch} | val_loss {vloss:.4f}")

        if vloss < best:
            best, bad = vloss, 0
            torch.save({"model":model.state_dict(),"stoi":stoi,"itos":itos},
                       os.path.join(args.outdir,"best.pt"))
        else:
            bad += 1
            if bad >= patience: break

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--train",  default="data/processed/train.txt")
    ap.add_argument("--valid",  default="data/processed/valid.txt")
    ap.add_argument("--outdir", default="runs/mt")
    ap.add_argument("--bsz", type=int, default=64)
    ap.add_argument("--uq_csv", default="", help="CSV of precomputed UQ scores for training set")
    ap.add_argument("--max_epochs", type=int, default=8, help="maximum number of epochs to train")  # <-- added arg
    args = ap.parse_args(); main(args)
"""

write_script("retroB/mt/train.py", train_code)

In [None]:
##Baseline + UQ + Standard Cross Entropy
train_code = """
import argparse, os, torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pandas as pd
from tokenizer import build_vocab, tok_smiles, encode, PAD, BOS, EOS
from model import MT, noam_lr

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# -------------------------
# Dataset / Collate
# -------------------------
class RxDataset(Dataset):
    def __init__(self, path): self.rows = [l.rstrip("\\n").split("\\t") for l in open(path)]
    def __len__(self): return len(self.rows)
    def __getitem__(self,i): rid, p, r = self.rows[i]; return rid, p, r

def collate(batch, stoi, max_len=256):
    rids, src, tgt = zip(*batch)
    src = [[BOS] + tok_smiles(s) + [EOS] for s in src]
    tgt = [[BOS] + tok_smiles(s) + [EOS] for s in tgt]
    src_ids = [encode(s, stoi)[:max_len] for s in src]
    tgt_ids = [encode(t, stoi)[:max_len] for t in tgt]
    pad_id = stoi[PAD]
    Ls = max(map(len, src_ids)); Lt = max(map(len, tgt_ids))
    pad = lambda seq, L: seq + [pad_id] * (L - len(seq))
    src_t = torch.tensor([pad(s, Ls) for s in src_ids]).long()
    tgt_in = torch.tensor([pad(t[:-1], Lt - 1) for t in tgt_ids]).long()
    tgt_y  = torch.tensor([pad(t[1:],  Lt - 1) for t in tgt_ids]).long()
    return rids, src_t, tgt_in, tgt_y

# -------------------------
# Pure cross-entropy loss
# -------------------------
def ce_loss(logits, target, ignore_index=0, reduction="none"):
    loss = F.cross_entropy(
        logits.transpose(1, 2),  # [B, V, T]
        target,
        ignore_index=ignore_index,
        reduction="none"
    )  # [B, T]
    if reduction == "none":
        mask = (~target.eq(ignore_index)).float()
        loss = (loss * mask).sum(-1) / mask.sum(-1).clamp_min(1.0)
        return loss
    else:
        return loss.mean()

# -------------------------
# Uncertainty helpers
# -------------------------
@torch.no_grad()
def seq_entropy(logits):
    probs = F.softmax(logits, dim=-1)
    log_probs = F.log_softmax(logits, dim=-1)
    ent = -(probs * log_probs).sum(-1)     # [B, T]
    return ent.mean(-1)                     # [B]

@torch.no_grad()
def mc_dropout_uq(model, src, tgt_in, n_samples=6):
    model.train()  # enable dropout
    preds = []
    for _ in range(max(1, n_samples)):
        logits = model(src, tgt_in)
        preds.append(F.softmax(logits, dim=-1).detach())
    stack = torch.stack(preds, dim=0)  # [S, B, T, V]
    var = stack.var(dim=0)             # [B, T, V]
    epi = var.mean(dim=-1).mean(dim=-1) # [B]
    return epi

def combine_uq(aleatoric, epistemic, lam=0.5):
    return lam * aleatoric + (1 - lam) * epistemic

def make_weights_from_uq(uq_scores, alpha=1.0, scheme="exp", min_w=0.1):
    u = (uq_scores - uq_scores.min()) / (uq_scores.max() - uq_scores.min() + 1e-8)
    if scheme == "exp":
        w = torch.exp(-alpha * u)
    elif scheme == "linear":
        w = 1.0 - alpha * u
    else:  # invert
        w = 1.0 - u
    return torch.clamp(w, min=min_w, max=1.0)

def load_uq_weights(path):
    df = pd.read_csv(path)
    col_id = "reaction_id" if "reaction_id" in df.columns else df.columns[0]
    col_uq = "UQ_score" if "UQ_score" in df.columns else df.columns[-1]
    top = df[[col_id, col_uq]].dropna()
    uq = (top[col_uq] - top[col_uq].min()) / (top[col_uq].max() - top[col_uq].min() + 1e-8)
    weights = 1.0 - uq
    return dict(zip(top[col_id].astype(str), weights.astype(float)))

# -------------------------
# Training
# -------------------------
def train_one_epoch(model, loader, stoi, opt, sch, uq_map, args):
    model.train()
    total = 0.0
    for rids, src, tgt_in, tgt_y in loader:
        src, tgt_in, tgt_y = src.to(DEVICE), tgt_in.to(DEVICE), tgt_y.to(DEVICE)
        logits = model(src, tgt_in)
        per_ex_loss = ce_loss(logits, tgt_y, ignore_index=stoi[PAD], reduction="none")

        if uq_map:  # offline UQ
            w = torch.tensor([uq_map.get(str(r), 1.0) for r in rids], device=DEVICE, dtype=per_ex_loss.dtype)
        else:  # online UQ
            alea = seq_entropy(logits)
            epi  = mc_dropout_uq(model, src, tgt_in, n_samples=args.mc_samples) if args.mc_samples > 0 else torch.zeros_like(alea)
            uq   = combine_uq(alea, epi, lam=args.lambda_mix)
            w    = make_weights_from_uq(uq, alpha=args.alpha, scheme=args.weighting, min_w=args.min_w)

        loss = (per_ex_loss * w).mean()

        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step(); sch.step()

        total += loss.item()
    return total / max(1, len(loader))

@torch.no_grad()
def validate(model, loader, stoi):
    model.eval()
    tot, n = 0.0, 0
    for _, src, tgt_in, tgt_y in loader:
        src, tgt_in, tgt_y = src.to(DEVICE), tgt_in.to(DEVICE), tgt_y.to(DEVICE)
        logits = model(src, tgt_in)
        loss = ce_loss(logits, tgt_y, ignore_index=stoi[PAD]).mean().item()
        tot += loss; n += 1
    return tot / max(1, n)

def main(args):
    stoi, itos = build_vocab([args.train, args.valid])
    os.makedirs(args.outdir, exist_ok=True)
    torch.save({"stoi": stoi, "itos": itos}, os.path.join(args.outdir, "vocab.pt"))

    tr = RxDataset(args.train); va = RxDataset(args.valid)
    C  = lambda b: collate(b, stoi)
    dl_tr = DataLoader(tr, batch_size=args.bsz, shuffle=True,  collate_fn=C, num_workers=0)
    dl_va = DataLoader(va, batch_size=args.bsz, shuffle=False, collate_fn=C, num_workers=0)

    uq_map = load_uq_weights(args.uq_csv) if args.uq_csv else {}

    model = MT(stoi).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=1.0, betas=(0.9,0.98), weight_decay=0.01)
    sch = noam_lr(opt, d_model=512, warmup=8000)

    best, patience, bad = 1e9, args.patience, 0
    for epoch in range(args.max_epochs):
        tr_loss = train_one_epoch(model, dl_tr, stoi, opt, sch, uq_map, args)
        va_loss = validate(model, dl_va, stoi)
        print(f"epoch {epoch} | train_loss {tr_loss:.4f} | val_loss {va_loss:.4f}")
        if va_loss < best:
            best, bad = va_loss, 0
            torch.save({"model": model.state_dict(), "stoi": stoi, "itos": itos},
                       os.path.join(args.outdir, "best.pt"))
        else:
            bad += 1
            if bad >= patience: break

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--train",  default="data/processed/train.txt")
    ap.add_argument("--valid",  default="data/processed/valid.txt")
    ap.add_argument("--outdir", default="runs/mt_uq_ce")
    ap.add_argument("--bsz", type=int, default=64)
    # UQ controls
    ap.add_argument("--uq_csv", default="", help="Optional CSV with columns [reaction_id,UQ_score] for offline weights")
    ap.add_argument("--mc_samples", type=int, default=0, help="If >0, do MC-dropout passes for epistemic UQ (online).")
    ap.add_argument("--lambda_mix", type=float, default=0.5, help="Mix weight for aleatoric vs epistemic (0..1).")
    ap.add_argument("--alpha", type=float, default=1.0, help="Strength of down-weighting uncertain samples.")
    ap.add_argument("--weighting", choices=["exp","linear","invert"], default="exp")
    ap.add_argument("--min_w", type=float, default=0.1)
    # Training control
    ap.add_argument("--max_epochs", type=int, default=8)
    ap.add_argument("--patience", type=int, default=5)
    args = ap.parse_args(); main(args)
    """
write_script("retroB/mt/train.py", train_code)


In [None]:
##Baseline + UQ + Label Smoothing Cross Entropy
train_code = """
import argparse, os, torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pandas as pd
from tokenizer import build_vocab, tok_smiles, encode, PAD, BOS, EOS
from model import MT, noam_lr

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# -------------------------
# Dataset / Collate
# -------------------------
class RxDataset(Dataset):
    def __init__(self, path): self.rows = [l.rstrip("\\n").split("\\t") for l in open(path)]
    def __len__(self): return len(self.rows)
    def __getitem__(self,i): rid, p, r = self.rows[i]; return rid, p, r

def collate(batch, stoi, max_len=256):
    rids, src, tgt = zip(*batch)
    src = [[BOS]+tok_smiles(s)+[EOS] for s in src]
    tgt = [[BOS]+tok_smiles(s)+[EOS] for s in tgt]
    src_ids = [encode(s, stoi)[:max_len] for s in src]
    tgt_ids = [encode(t, stoi)[:max_len] for t in tgt]
    pad_id = stoi[PAD]
    Ls = max(map(len, src_ids)); Lt = max(map(len, tgt_ids))
    pad = lambda seq, L: seq + [pad_id]*(L-len(seq))
    src_t = torch.tensor([pad(s,Ls) for s in src_ids]).long()
    tgt_in = torch.tensor([pad(t[:-1],Lt-1) for t in tgt_ids]).long()
    tgt_y  = torch.tensor([pad(t[1:], Lt-1) for t in tgt_ids]).long()
    return rids, src_t, tgt_in, tgt_y

# -------------------------
# Label smoothing loss
# -------------------------
def label_smoothing_loss(logits, target, eps=0.1, ignore_index=0, reduction="none"):
    n_class = logits.size(-1)
    logp = F.log_softmax(logits, -1)
    with torch.no_grad():
        true = torch.zeros_like(logp)
        true.fill_(eps/(n_class-1))
        true.scatter_(2, target.unsqueeze(-1), 1-eps)
        true.masked_fill_(target.eq(ignore_index).unsqueeze(-1), 0)
    loss = -(true*logp).sum(-1)
    mask = (~target.eq(ignore_index)).float()
    loss = loss * mask
    if reduction == "none":
        return loss.sum(-1) / mask.sum(-1)
    else:
        return loss.sum()/mask.sum()

# -------------------------
# Uncertainty helpers
# -------------------------
@torch.no_grad()
def seq_entropy(logits):
    probs = F.softmax(logits, dim=-1)
    log_probs = F.log_softmax(logits, dim=-1)
    ent = -(probs * log_probs).sum(-1)
    return ent.mean(-1)

@torch.no_grad()
def mc_dropout_uq(model, src, tgt_in, n_samples=6):
    model.train()
    preds = []
    for _ in range(max(1, n_samples)):
        logits = model(src, tgt_in)
        preds.append(F.softmax(logits, dim=-1).detach())
    stack = torch.stack(preds, dim=0)
    var = stack.var(dim=0)
    epi = var.mean(dim=-1).mean(dim=-1)
    return epi

def combine_uq(aleatoric, epistemic, lam=0.5):
    return lam * aleatoric + (1 - lam) * epistemic

def make_weights_from_uq(uq_scores, alpha=1.0, scheme="exp", min_w=0.1):
    u = (uq_scores - uq_scores.min()) / (uq_scores.max() - uq_scores.min() + 1e-8)
    if scheme == "exp":
        w = torch.exp(-alpha * u)
    elif scheme == "linear":
        w = 1.0 - alpha * u
    else:
        w = 1.0 - u
    return torch.clamp(w, min=min_w, max=1.0)

def load_uq_weights(path):
    df = pd.read_csv(path)
    col_id = "reaction_id" if "reaction_id" in df.columns else df.columns[0]
    col_uq = "UQ_score" if "UQ_score" in df.columns else df.columns[-1]
    top = df[[col_id, col_uq]].dropna()
    uq = (top[col_uq] - top[col_uq].min()) / (top[col_uq].max() - top[col_uq].min() + 1e-8)
    weights = 1.0 - uq
    return dict(zip(top[col_id].astype(str), weights.astype(float)))

# -------------------------
# Training
# -------------------------
def train_one_epoch(model, loader, stoi, opt, sch, uq_map, args):
    model.train()
    total = 0.0
    for rids, src, tgt_in, tgt_y in loader:
        src, tgt_in, tgt_y = src.to(DEVICE), tgt_in.to(DEVICE), tgt_y.to(DEVICE)
        logits = model(src, tgt_in)
        per_ex_loss = label_smoothing_loss(logits, tgt_y, ignore_index=stoi[PAD], reduction="none")

        if uq_map:
            w = torch.tensor([uq_map.get(str(r), 1.0) for r in rids], device=DEVICE)
        else:
            alea = seq_entropy(logits)
            epi  = mc_dropout_uq(model, src, tgt_in, n_samples=args.mc_samples) if args.mc_samples > 0 else torch.zeros_like(alea)
            uq   = combine_uq(alea, epi, lam=args.lambda_mix)
            w    = make_weights_from_uq(uq, alpha=args.alpha, scheme=args.weighting, min_w=args.min_w)

        loss = (per_ex_loss * w).mean()
        opt.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step(); sch.step()
        total += loss.item()
    return total / max(1, len(loader))

@torch.no_grad()
def validate(model, loader, stoi):
    model.eval()
    tot, n = 0.0, 0
    for _, src, tgt_in, tgt_y in loader:
        src, tgt_in, tgt_y = src.to(DEVICE), tgt_in.to(DEVICE), tgt_y.to(DEVICE)
        logits = model(src, tgt_in)
        loss = label_smoothing_loss(logits, tgt_y, ignore_index=stoi[PAD]).mean().item()
        tot += loss; n += 1
    return tot / max(1, n)

def main(args):
    stoi, itos = build_vocab([args.train, args.valid])
    os.makedirs(args.outdir, exist_ok=True)
    torch.save({"stoi": stoi, "itos": itos}, os.path.join(args.outdir, "vocab.pt"))

    tr = RxDataset(args.train); va = RxDataset(args.valid)
    C  = lambda b: collate(b, stoi)
    dl_tr = DataLoader(tr, batch_size=args.bsz, shuffle=True, collate_fn=C)
    dl_va = DataLoader(va, batch_size=args.bsz, shuffle=False, collate_fn=C)

    uq_map = load_uq_weights(args.uq_csv) if args.uq_csv else {}

    model = MT(stoi).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=1.0, betas=(0.9,0.98), weight_decay=0.01)
    sch = noam_lr(opt, d_model=512, warmup=8000)

    best, patience, bad = 1e9, args.patience, 0
    for epoch in range(args.max_epochs):
        tr_loss = train_one_epoch(model, dl_tr, stoi, opt, sch, uq_map, args)
        va_loss = validate(model, dl_va, stoi)
        print(f"epoch {epoch} | train_loss {tr_loss:.4f} | val_loss {va_loss:.4f}")
        if va_loss < best:
            best, bad = va_loss, 0
            torch.save({"model": model.state_dict(), "stoi": stoi, "itos": itos},
                       os.path.join(args.outdir, "best.pt"))
        else:
            bad += 1
            if bad >= patience: break

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--train",  default="data/processed/train.txt")
    ap.add_argument("--valid",  default="data/processed/valid.txt")
    ap.add_argument("--outdir", default="runs/mt")
    ap.add_argument("--bsz", type=int, default=64)
    ap.add_argument("--uq_csv", default="")
    ap.add_argument("--mc_samples", type=int, default=0)
    ap.add_argument("--lambda_mix", type=float, default=0.5)
    ap.add_argument("--alpha", type=float, default=1.0)
    ap.add_argument("--weighting", choices=["exp","linear","invert"], default="exp")
    ap.add_argument("--min_w", type=float, default=0.1)
    ap.add_argument("--max_epochs", type=int, default=8)
    ap.add_argument("--patience", type=int, default=5)
    args = ap.parse_args(); main(args)
    """
write_script("retroB/mt/train.py", train_code)


In [None]:
infer_code = """
import argparse, torch, numpy as np, pandas as pd
import torch.nn.functional as F
from tokenizer import PAD, BOS, EOS, decode
from model import MT

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ------------------------------
# Combine Aleatoric + Epistemic
# ------------------------------
def compute_combined_uq(aleatoric, epistemic, method="weighted_sum", alea_weight=0.5, epis_weight=0.5):
    if method == "weighted_sum":
        combined = alea_weight * aleatoric + epis_weight * epistemic
    elif method == "geometric_mean":
        total = alea_weight + epis_weight
        w_alea = alea_weight / total
        w_epis = epis_weight / total
        combined = np.power(aleatoric, w_alea) * np.power(epistemic, w_epis)
    elif method == "rss":
        combined = np.sqrt(np.square(aleatoric) + np.square(epistemic))
    else:
        raise ValueError(f"Unsupported combination method: {method}")
    return combined

# ------------------------------
# Token entropy = Aleatoric proxy
# ------------------------------
def token_entropy(step_logits):  # [T,V]
    p = F.softmax(step_logits, dim=-1)
    return -(p * (p.clamp_min(1e-12)).log()).sum(-1)  # [T]

# ------------------------------
# Beam search decoder
# ------------------------------
@torch.no_grad()
def beam_search(model, src, beam=10, max_len=256, pad_id=0, bos_id=1, eos_id=2):
    beams = [([bos_id], 0.0, [])]
    finished = []
    for _ in range(max_len):
        new = []
        for tokens, lp, slogits in beams:
            inp = torch.tensor([tokens], device=src.device)
            logits = model(src, inp)[:,-1,:]                 # [1,V]
            logp = F.log_softmax(logits, -1).squeeze(0)      # [V]
            vals, idxs = torch.topk(logp, k=beam)
            for addlp, tok in zip(vals.tolist(), idxs.tolist()):
                ntoks = tokens+[tok]; nlp = lp+addlp
                nslog = slogits+[logits.squeeze(0)]
                if tok==eos_id: finished.append((ntoks, nlp, nslog))
                else: new.append((ntoks, nlp, nslog))
        beams = sorted(new, key=lambda x: x[1], reverse=True)[:beam]
        if len(finished)>=beam: break
    if not finished: finished = beams
    finished = sorted(finished, key=lambda x: x[1], reverse=True)[:beam]
    return finished

# ------------------------------
# Epistemic proxy via MC Dropout
# ------------------------------
def mc_dropout_var(model, src, bos_id, eos_id, passes=5, beam=1):
    model.train()  # keep dropout active
    vals = []
    with torch.no_grad():
        for _ in range(passes):
            beams = beam_search(model, src, beam=beam, bos_id=bos_id, eos_id=eos_id)
            vals.append(beams[0][1])   # take seq logprob of top-1
    model.eval()
    return np.var(vals)

# ------------------------------
# Main Inference
# ------------------------------
def main(args):
    ckpt = torch.load(args.ckpt, map_location=DEVICE)
    stoi, itos = ckpt["stoi"], ckpt["itos"]
    model = MT(stoi).to(DEVICE); model.load_state_dict(ckpt["model"]); model.eval()

    pad_id, bos_id, eos_id = stoi[PAD], stoi[BOS], stoi[EOS]

    def enc(prod_str):
        ids = [bos_id] + [stoi[c] for c in list(prod_str)] + [eos_id]
        return torch.tensor([ids], device=DEVICE)

    rows = [l.rstrip("\\n").split("\t") for l in open(args.test)]
    out = []
    for rid, prod, gold in rows:
        src = enc(prod)

        # epistemic once per product
        epistemic = mc_dropout_var(model, src, bos_id, eos_id, passes=5)

        beams = beam_search(model, src, beam=args.beam, pad_id=pad_id, bos_id=bos_id, eos_id=eos_id)
        scores = np.array([b[1] for b in beams])
        probs = np.exp(scores - scores.max()); probs = probs / probs.sum()

        for rank, (tokens, logp, slogits) in enumerate(beams, start=1):
            step_logits = torch.stack(slogits)             # [T,V]
            ents = token_entropy(step_logits).cpu().numpy()
            pred = decode(tokens[1:-1], itos)              # strip BOS/EOS

            # Aleatoric proxy = mean entropy
            aleatoric = float(ents.mean())

            # Combined UQ using chosen method
            UQ_score = compute_combined_uq(
                aleatoric, float(epistemic),
                method=args.uq_method,
                alea_weight=args.alea_weight,
                epis_weight=args.epis_weight
            )

            row = {
                "reaction_id": rid,
                "product": prod,
                "rank": rank,
                "pred_reactants": pred,
                "seq_logprob": float(logp),
                "aleatoric_entropy": aleatoric,
                "epistemic_var": float(epistemic),
                "UQ_score": float(UQ_score),
                "confidence": float(probs[rank-1]),
                "gold_reactants": gold
            }

            if rank == 1:  # save logits only for top beam
                row["logits"] = step_logits.cpu().numpy().tolist()

            out.append(row)

    pd.DataFrame(out).to_json(args.outcsv.replace(".csv", "_with_logits.json"), orient="records")
    pd.DataFrame(out).to_csv(args.outcsv, index=False)

# ------------------------------
# CLI
# ------------------------------
if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt",  default="runs/mt/best.pt")
    ap.add_argument("--test",  default="data/processed/test.txt")
    ap.add_argument("--beam",  type=int, default=10)
    ap.add_argument("--outcsv", default="out/mt_preds_with_uq.csv")
    ap.add_argument("--uq_method", type=str, default="weighted_sum",
                    choices=["weighted_sum", "geometric_mean", "rss"],
                    help="Method for combining aleatoric & epistemic UQ")
    ap.add_argument("--alea_weight", type=float, default=0.5,
                    help="Weight for aleatoric (for weighted_sum/geometric_mean)")
    ap.add_argument("--epis_weight", type=float, default=0.5,
                    help="Weight for epistemic (for weighted_sum/geometric_mean)")
    args = ap.parse_args(); main(args)
"""
write_script("retroB/mt/infer.py", infer_code)


In [None]:
metrics_code = """
import argparse, pandas as pd, numpy as np, json, torch, os
import matplotlib.pyplot as plt
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rdkit import Chem
from rdkit import RDLogger

# --------------------
# Suppress RDKit warnings (helper)
# --------------------
RDLogger.DisableLog("rdApp.*")

# --------------------
# Canonicalize SMILES (helper)
# --------------------
def canonicalize(smiles):
    \"\"\"Convert SMILES to canonical form, return None if invalid.\"\"\"
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        return Chem.MolToSmiles(mol, canonical=True)
    except Exception:
        return None

# --------------------
# Top-k Accuracy (metric)
# --------------------
def topk(df):
    def acc_at(k):
        d = df[df["rank"] <= k] \\
             .assign(correct=lambda x: (
                 x["pred_reactants"].apply(canonicalize) ==
                 x["gold_reactants"].apply(canonicalize)
             ).astype(int)) \\
             .groupby("reaction_id")["correct"].max()
        return d.mean()
    return {f"top{k}": float(acc_at(k)) for k in [1,5,10]}

# --------------------
# Expected Calibration Error (metric)
# --------------------
def ece(conf, corr, n_bins=15):
    bins = np.linspace(0.,1.,n_bins+1); e=0.0
    for i in range(n_bins):
        m = (conf > bins[i]) & (conf <= bins[i+1])
        if m.any(): e += m.mean() * abs(corr[m].mean() - conf[m].mean())
    return float(e)

# --------------------
# Brier Score (metric)
# --------------------
def brier(conf, corr):
    return float(((conf - corr.astype(float))**2).mean())

# --------------------
# Reliability Plot (helper)
# --------------------
def reliability_plot(conf, corr, n_bins=10, title="Reliability Diagram", outpath=None):
    bins = np.linspace(0, 1, n_bins+1)
    binids = np.digitize(conf, bins) - 1
    accuracies, confidences = [], []

    for b in range(n_bins):
        idx = binids == b
        if np.any(idx):
            accuracies.append(corr[idx].mean())
            confidences.append(conf[idx].mean())
        else:
            accuracies.append(0.0)
            confidences.append(0.0)

    plt.figure(figsize=(5,5))
    plt.plot([0,1],[0,1],"--",color="gray")
    plt.bar(bins[:-1], accuracies, width=1/n_bins, alpha=0.6, label="Accuracy")
    plt.plot(bins[:-1]+0.05, confidences, "o-", color="red", label="Confidence")
    plt.xlabel("Confidence")
    plt.ylabel("Accuracy")
    plt.title(title)
    plt.legend()
    if outpath:
        os.makedirs(os.path.dirname(outpath), exist_ok=True)
        plt.savefig(outpath, bbox_inches="tight")
    else:
        plt.show()
    plt.close()

# --------------------
# Temperature Scaling (calibration)
# --------------------
def set_temperature(logits, labels):
    nll = torch.nn.CrossEntropyLoss()
    T = torch.nn.Parameter(torch.ones(1) * 1.5)
    optimizer = torch.optim.LBFGS([T], lr=0.01, max_iter=50)

    def eval():
        optimizer.zero_grad()
        loss = nll(logits / T, labels)
        loss.backward()
        return loss

    optimizer.step(eval)
    return T.item()

# --------------------
# Retro-BLEU (plausibility metric)
# --------------------
def retro_bleu(preds, refs, n_gram=4):
    smoothie = SmoothingFunction().method1
    scores = []
    for p, r in zip(preds, refs):
        if p is None or r is None:
            continue
        p_tokens, r_tokens = list(p), list(r)
        try:
            score = sentence_bleu([r_tokens], p_tokens,
                                  weights=tuple([1/n_gram]*n_gram),
                                  smoothing_function=smoothie)
        except ZeroDivisionError:
            score = 0.0
        scores.append(score)
    return float(np.mean(scores)) if scores else 0.0

# --------------------
# Plausibility Score (valid SMILES fraction)
# --------------------
def plausibility_score(smiles_list):
    valid = 0
    total = 0
    for s in smiles_list:
        if s is None:
            continue
        total += 1
        mol = Chem.MolFromSmiles(s)
        if mol is not None:
            valid += 1
    return valid / total if total > 0 else 0.0

# --------------------
# Main Script
# --------------------
if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--csv", default="out/mt_preds_with_uq.csv")
    ap.add_argument("--json_logits", default="out/mt_preds_with_uq_with_logits.json")
    ap.add_argument("--out", default="out/metrics.json")
    args = ap.parse_args()

    # Load data (helper)
    df = pd.read_csv(args.csv)
    with open(args.json_logits, "r") as f:
        data_logits = json.load(f)

    res = topk(df)

    # Top-1 correctness (metric)
    top1 = df[df["rank"] == 1].copy()
    top1["correct"] = (
        top1["pred_reactants"].apply(canonicalize) ==
        top1["gold_reactants"].apply(canonicalize)
    ).astype(int)

    # Before calibration (metrics + plot)
    conf = top1["confidence"].to_numpy()
    corr = top1["correct"].to_numpy()
    res["ece_top1_before"]   = ece(conf, corr)
    res["brier_top1_before"] = brier(conf, corr)
    reliability_plot(conf, corr, n_bins=10,
                     title="Reliability (Before Calibration)",
                     outpath="out/reliability_before.png")

    # Prepare logits + labels (helper)
    logits_list, labels = [], []
    for row in data_logits:
        if row["rank"] == 1 and "logits" in row:
            step_logits = np.array(row["logits"])  # [T,V]
            final_logits = step_logits[-1]
            logits_list.append(final_logits)
            labels.append(1 if canonicalize(row["pred_reactants"]) ==
                             canonicalize(row["gold_reactants"]) else 0)

    logits_tensor = torch.tensor(np.stack(logits_list), dtype=torch.float32)
    labels_tensor = torch.tensor(labels, dtype=torch.long)

    # Temperature scaling (calibration)
    T = set_temperature(logits_tensor, labels_tensor)
    scaled_probs = torch.softmax(logits_tensor / T, dim=1).detach().numpy()
    calibrated_conf = scaled_probs.max(axis=1)

    # After calibration (metrics + plot)
    res["ece_top1_after"]   = ece(calibrated_conf, np.array(labels))
    res["brier_top1_after"] = brier(calibrated_conf, np.array(labels))
    reliability_plot(calibrated_conf, np.array(labels), n_bins=10,
                     title=f"Reliability (After Calibration, T={T:.2f})",
                     outpath="out/reliability_after.png")

    # Plausibility metrics (metrics)
    preds = [canonicalize(s) for s in top1["pred_reactants"].tolist()]
    refs  = [canonicalize(s) for s in top1["gold_reactants"].tolist()]
    res["retro_bleu_top1"]   = retro_bleu(preds, refs)
    res["plausibility_top1"] = plausibility_score(preds)

    # Save metrics (helper)
    os.makedirs(os.path.dirname(args.out), exist_ok=True)
    with open(args.out, "w") as f: json.dump(res, f, indent=2)
    print(json.dumps(res, indent=2))
"""

write_script("retroB/mt/metrics.py", metrics_code)

In [None]:
from google.colab import files
uploaded = files.upload()

Saving raw_test.csv to raw_test.csv
Saving raw_train.csv to raw_train.csv
Saving raw_val.csv to raw_val.csv


In [None]:
import shutil

shutil.move("raw_train.csv", "retroB/data/raw_train.csv")
shutil.move("raw_val.csv", "retroB/data/raw_val.csv")
shutil.move("raw_test.csv", "retroB/data/raw_test.csv")

'retroB/data/raw_test.csv'

In [None]:
!python3 retroB/data/prepare_uspto.py \
    --csv retroB/data/raw_train.csv \
    --outdir retroB/data/processed

!python3 retroB/data/prepare_uspto.py \
    --csv retroB/data/raw_val.csv \
    --outdir retroB/data/processed

!python3 retroB/data/prepare_uspto.py \
    --csv retroB/data/raw_test.csv \
    --outdir retroB/data/processed

In [None]:
!head -n 190 retroB/data/processed/train.txt > retroB/data/processed/train_debug.txt

In [None]:
!head -n 190 retroB/data/processed/test.txt > retroB/data/processed/test_debug.txt

In [None]:
!head -n 190 retroB/data/processed/valid.txt > retroB/data/processed/valid_debug.txt

In [None]:
with open("retroB/data/processed/train.txt") as f:
    for i in range(5):
        print(f.readline().strip())

849	COc1ncc(-c2ccc3ncc4nnc(C)n4c3c2)cc1N	COc1ncc(B2OC(C)(C)C(C)(C)O2)cc1N.Cc1nnc2cnc3ccc(Br)cc3n12
4185	CCCCCCCCCCCCCCCCCCCCCCOc1ccc(CO)c(OCCCCCCCCCCCCCCCCCCCCCC)c1	CCCCCCCCCCCCCCCCCCCCCCOc1ccc(C=O)c(OCCCCCCCCCCCCCCCCCCCCCC)c1
4373	CCCn1c(COCC)nc2c(N)nc3cc(OC4CCN(C(=O)NC(C)C)CC4)ccc3c21	CC(C)N=C=O.CCCn1c(COCC)nc2c(N)nc3cc(OC4CCNCC4)ccc3c21
4715	Oc1ccc(-c2nnc(CSCCOc3ccccc3)o2)cc1	O=C(CSCCOc1ccccc1)NNC(=O)c1ccc(O)cc1
654	CC(C)(C)OC(=O)N1CCC(C(=O)Nc2cc(OCc3ccccc3)ccc2Br)CC1	CC(C)(C)OC(=O)N1CCC(C(=O)O)CC1.Nc1cc(OCc2ccccc2)ccc1Br


In [None]:
from rdkit import Chem

def canonicalize(smi):
    try:
        mol = Chem.MolFromSmiles(smi)
        if mol: return Chem.MolToSmiles(mol, canonical=True)
    except:
        return None
    return None

print(canonicalize("OCC"))   # should print "CCO"
print(canonicalize("CCO"))   # should also print "CCO"


CCO
CCO


In [None]:
import pandas as pd
raw = pd.read_csv("retroB/data/raw_train.csv")
print(raw.head())

                id class                      reactants>reagents>production
0       US05849732   UNK  O=C(OCc1ccccc1)[NH:1][CH2:2][CH2:3][CH2:4][CH2...
1  US20120114765A1   UNK  O[C:1](=[O:2])[c:3]1[cH:4][c:5]([N+:6](=[O:7])...
2     US08003648B2   UNK  O=[CH:1][c:2]1[cH:3][cH:4][c:5](-[c:6]2[n:7][c...
3     US09045475B2   UNK  O=[C:1]([CH2:2][F:3])[CH2:4][F:5].[CH3:6][C:7]...
4     US08188098B2   UNK  Cl[C:1](=[O:2])[O:3][CH:4]1[CH2:5][CH2:6][CH2:...


In [None]:
import sys
sys.path.append("./retroB/mt")

from tokenizer import decode
from model import MT
import torch

ckpt = torch.load("retroB/runs/mt_standardce/best.pt", map_location="cpu")
stoi, itos = ckpt["stoi"], ckpt["itos"]
model = MT(stoi); model.load_state_dict(ckpt["model"]); model.eval()

# Inspect one sample from train.txt
with open("retroB/data/processed/train.txt") as f:
    rid, product, gold_reactants = f.readline().strip().split("\t")

print("Product:", product)
print("Gold reactants:", gold_reactants)

Product: COc1ncc(-c2ccc3ncc4nnc(C)n4c3c2)cc1N
Gold reactants: COc1ncc(B2OC(C)(C)C(C)(C)O2)cc1N.Cc1nnc2cnc3ccc(Br)cc3n12


In [None]:
# Baseline with standard cross-entropy
!python3 retroB/mt/train.py \
  --train retroB/data/processed/train_debug.txt \
  --valid retroB/data/processed/valid_debug.txt \
  --outdir retroB/runs/mt_standardce \
  --bsz 8 \
  --max_epochs 5 \
  --patience 5


python3: can't open file '/content/retroB/mt/train.py': [Errno 2] No such file or directory


In [None]:
import torch
from tokenizer import decode, PAD, BOS, EOS
from model import MT

ckpt = torch.load("retroB/runs/mt_standardce/best.pt", map_location="cpu")
stoi, itos = ckpt["stoi"], ckpt["itos"]
model = MT(stoi); model.load_state_dict(ckpt["model"]); model.eval()

def enc(prod_str):
    ids = [stoi[BOS]] + [stoi[c] for c in prod_str] + [stoi[EOS]]
    return torch.tensor([ids])

product = "CCO"  # ethanol
src = enc(product)
with torch.no_grad():
    logits = model(src, src[:,:-1])  # greedy test
    pred_ids = logits.argmax(-1).squeeze().tolist()
    pred = decode(pred_ids, itos)

print("Predicted:", pred)


Predicted: 3333


In [None]:
## Baseline with labelsmoothing cross-entropy
!python3 retroB/mt/train.py \
  --train retroB/data/processed/train.txt \
  --valid retroB/data/processed/valid.txt \
  --outdir retroB/runs/mt_labelsmce \
  --bsz 64 \
  --max_epochs 5

epoch 0 | val_loss 1.2139
epoch 1 | val_loss 0.9564
epoch 2 | val_loss 0.8720
epoch 3 | val_loss 0.8320
epoch 4 | val_loss 0.7991


In [None]:
## Baseline with UQ and standard cross-entropy (online UQ)
!python3 retroB/mt/train.py \
  --train retroB/data/processed/train_debug.txt \
  --valid retroB/data/processed/valid_debug.txt \
  --outdir retroB/runs/mt_uq_online_ce \
  --bsz 64 \
  --max_epochs 5 \
  --mc_samples 6 \
  --lambda_mix 0.5 \
  --alpha 1.0 \
  --weighting exp \
  --min_w 0.1

epoch 0 | train_loss 2.5840 | val_loss 4.0902
epoch 1 | train_loss 2.5009 | val_loss 4.0549
epoch 2 | train_loss 2.6223 | val_loss 3.9992
epoch 3 | train_loss 2.5906 | val_loss 3.9243
epoch 4 | train_loss 2.4536 | val_loss 3.8320


In [None]:
## Baseline with UQ and labelsmoothing cross-entropy
!python3 retroB/mt/train.py \
  --train retroB/data/processed/train.txt \
  --valid retroB/data/processed/valid.txt \
  --outdir retroB/runs/mt_uq_labelsmce \
  --bsz 64 \
  --max_epochs 5 \
  --uq_csv retroB/data/processed/train_uq_scores.csv

In [None]:
## Baseline and Standard Cross-Entropy
!python3 retroB/mt/infer.py \
    --ckpt retroB/runs/mt_uq_online_ce/best.pt \
    --test retroB/data/processed/test_debug.txt \
    --beam 1 \
    --uq_method geometric_mean \
    --alea_weight 0.7 \
    --epis_weight 0.3 \
    --outcsv retroB/out/mt_uq_preds.csv

Traceback (most recent call last):
  File "/content/retroB/mt/infer.py", line 150, in <module>
    args = ap.parse_args(); main(args)
                            ^^^^^^^^^^
  File "/content/retroB/mt/infer.py", line 75, in main
    ckpt = torch.load(args.ckpt, map_location=DEVICE)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 1484, in load
    with _open_file_like(f, "rb") as opened_file:
         ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 759, in _open_file_like
    return _open_file(name_or_buffer, mode)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 740, in __init__
    super().__init__(open(name, mode))
                     ^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: 'retroB/runs/mt_uq_online_ce/best.pt'


In [None]:
## Baseline and Label Smoothing Cross-Entropy
!python3 retroB/mt/infer.py \
  --ckpt retroB/runs/mt_labelsmce/best.pt \
  --test retroB/data/processed/test_debug.txt \
  --beam 1 \
  --outcsv retroB/out/mt_labelsmce_preds.csv

In [None]:
## Baseline and UQ and Standard Cross-Entropy
!python3 retroB/mt/infer.py \
  --ckpt retroB/runs/mt_uq_online_ce/best.pt \
  --test retroB/data/processed/test_debug.txt \
  --beam 1 \
  --outcsv retroB/out/mt_uq_ce_preds.csv

In [None]:
## Baseline and UQ and Label Smoothing Cross-Entropy
!python3 retroB/mt/infer.py \
  --ckpt retroB/runs/mt_uq_labelsmce/best.pt \
  --test retroB/data/processed/test_debug.txt \
  --beam 1 \
  --outcsv retroB/out/mt_uq_labelsmce_preds.csv

In [None]:
##METRICS PART BELOW SECTIONED IN 4 PARTS

In [None]:
## Baseline and Standard Cross-Entropy
!python3 retroB/mt/metrics.py \
  --csv retroB/out/mt_standardce_preds.csv \
  --json_logits retroB/out/mt_standardce_preds_with_logits.json \
  --out retroB/out/mt_standardce_metrics.json

Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  loss = float(closure())
{
  "top1": 0.0,
  "top5": 0.0,
  "top10": 0.0,
  "ece_top1_before": 1.0,
  "brier_top1_before": 1.0,
  "ece_top1_after": 0.8438736224174499,
  "brier_top1_after": 0.750165443642741,
  "retro_bleu_top1": 0.7864217469357003,
  "plausibility_top1": 1.0
}


In [None]:
## Baseline and Labeling Smoothing Cross-Entropy
!python3 retroB/mt/metrics.py \
  --csv retroB/out/mt_labelsmce_preds.csv \
  --json_logits retroB/out/mt_labelsmce_preds_with_logits.json \
  --out retroB/out/mt_labelsmce_metrics.json

Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  loss = float(closure())
{
  "top1": 0.0,
  "top5": 0.0,
  "top10": 0.0,
  "ece_top1_before": 1.0,
  "brier_top1_before": 1.0,
  "ece_top1_after": 0.40727545738220217,
  "brier_top1_after": 0.1715270954949243,
  "retro_bleu_top1": 0.6462103208837113,
  "plausibility_top1": 1.0
}


In [None]:
## Baseline and UQ and Standard Cross-Entropy
!python3 retroB/mt/metrics.py \
  --csv retroB/out/mt_uq_ce_preds.csv \
  --json_logits retroB/out/mt_uq_ce_preds_with_logits.json \
  --out retroB/out/mt_uq_ce_metrics.json

Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  loss = float(closure())
{
  "top1": 0.0,
  "top5": 0.0,
  "top10": 0.0,
  "ece_top1_before": 1.0,
  "brier_top1_before": 1.0,
  "ece_top1_after": 0.047933172434568405,
  "brier_top1_after": 0.0023018148618795965,
  "retro_bleu_top1": 0.00023626519801865363,
  "plausibility_top1": 1.0
}


In [None]:
## Baseline and UQ and Label Smoothing Cross-Entropy
!python3 retroB/mt/metrics.py \
  --csv retroB/out/mt_uq_labelsmce_preds.csv \
  --json_logits retroB/out/mt_uq_labelsmce_preds.json \
  --out retroB/out/mt_uq_labelsmce_metrics.json

In [None]:
## mt_pred_with_uq contains rows pertaining to one test example (set of reactants) and columns include true product, model's top prediction, two other likely alternatives, score 1 to 3, entropy(low is better), and overall log-likelihood
## metrics.json contains aggregate performance metrics computed over the entire test set, based on the predictions in mt_preds_with_uq.csv. Top1 accuracy, top2-3 as well as top5-10, average log likelihood, average entropy, n_test examples evaluated, calibration error (predicted confidence matches the actual accuracy of the model)

from google.colab import files
files.download('/content/retroB/out/mt_preds_with_uq_debug.csv')
files.download('/content/retroB/out/metrics.json')