# TEM-1 Target Finder (Demo)
Predict binding to TEM-1 β-lactamase using small pretrained embeddings (ESM-2 for protein, ChemBERTa for ligands) + lightweight heads (XGBoost, LogisticRegression).

Outputs **pAff (−log10 Kd)**, calibrated **P(binder)**, and simple graphs/heatmaps; runs on CPU.

**Why this matters:** fast, docking-free triage of small molecules against a resistance enzyme.

## 1) Setup/Install

In [16]:
!pip -q install --no-cache-dir \
  torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu
!pip -q install --no-cache-dir "transformers==4.44.2" "tokenizers==0.19.1" \
  pandas==2.2.2 scikit-learn xgboost gradio requests

In [17]:
import sys, platform, numpy, torch, sklearn, transformers, xgboost, pandas, gradio
print("Python:", sys.version.split()[0], "|", platform.platform())
print("NumPy:", numpy.__version__)
print("PyTorch:", torch.__version__, "| CUDA available:", torch.cuda.is_available())
print("Transformers:", transformers.__version__, "| Tokenizers:", transformers.__version__)
print("scikit-learn:", sklearn.__version__, "| XGBoost:", xgboost.__version__)
print("pandas:", pandas.__version__, "| Gradio:", gradio.__version__)

Python: 3.11.13 | Linux-6.1.123+-x86_64-with-glibc2.35
NumPy: 2.0.2
PyTorch: 2.3.1+cpu | CUDA available: False
Transformers: 4.44.2 | Tokenizers: 4.44.2
scikit-learn: 1.6.1 | XGBoost: 3.0.3
pandas: 2.2.2 | Gradio: 5.41.0


## 2) Fetch protein sequence & generate ESM-2 embedding

Fetch and clean the TEM-1 sequence from UniProt

In [18]:
# Fetch TEM-1 beta-lactamase protein sequence from UniProt
import requests, re

UNIPROT_ID = "P62593"
fasta = requests.get(f"https://rest.uniprot.org/uniprotkb/{UNIPROT_ID}.fasta").text

# Remove FASTA header and non-amino acid characters
TEM1_SEQ = "".join(line.strip() for line in fasta.splitlines() if not line.startswith(">"))
TEM1_SEQ = re.sub(r"[^ACDEFGHIKLMNPQRSTVWY]", "", TEM1_SEQ.upper())

print(f"TEM-1 length: {len(TEM1_SEQ)} amino acids")


TEM-1 length: 286 amino acids


Generate a protein embedding with ESM-2 (35M)

In [19]:
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np

# Select device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the lightweight ESM-2 model (12 layers, 35M parameters)
tok_p = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
mdl_p = AutoModel.from_pretrained("facebook/esm2_t12_35M_UR50D").to(device).eval()

# Encode sequence into a mean-pooled embedding
with torch.inference_mode():
    toks = tok_p(TEM1_SEQ, return_tensors="pt", add_special_tokens=True).to(device)
    rep = mdl_p(**toks).last_hidden_state[0, 1:-1, :].mean(dim=0).cpu().numpy()

# Store embedding as float32 (~480-D vector)
prot_vec = rep.astype(np.float32)

print("Protein embedding shape:", prot_vec.shape)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Protein embedding shape: (480,)


## 3) Load ligand dataset & prepare scaffolds

In [20]:
# Upload CSV file from local machine
from google.colab import files
uploaded = files.upload()  # pick tem1_clean.csv

csv_path = [k for k in uploaded.keys() if k.endswith(".csv")][0]
df = pd.read_csv(csv_path)

# Compute Bemis–Murcko scaffolds if missing
if "scaffold" not in df.columns:
    def murcko_scaffold(smi):
        m = Chem.MolFromSmiles(str(smi))
        return MurckoScaffold.MurckoScaffoldSmiles(mol=m) if m else None
    df["scaffold"] = df["smiles"].apply(murcko_scaffold)

# Data sanity check
needed = {"smiles","pAff"}
missing = needed - set(df.columns)
assert not missing, f"Missing columns: {missing}"
df = df.dropna(subset=["smiles","pAff"]).reset_index(drop=True)
print(df.shape)
df.head(3)

Saving tem1_clean.csv to tem1_clean.csv
(316, 11)


Unnamed: 0,smiles,affinity_nM,pAff,affinity_type,uniprot_id,target_name,organism,ligand_name,ligand_id,pmid,scaffold
0,[O-]C(=O)C1=CS[C@H]2N1C(=O)\C2=C/c1cn2CCOCc2n1,0.4,9.39794,IC50,P62593,Beta-lactamase TEM,,"CHEMBL212163::sodium (R,E)-6-((6,8-dihydro-5H-...",50191378,16854068.0,O=C1C(=Cc2cn3c(n2)COCC3)C2SC=CN12
1,[O-]C(=O)C1=CS[C@H]2N1C(=O)\C2=C\c1cnc2COCCn12,0.4,9.39794,IC50,P62593,Beta-lactamase TEM,,"CHEMBL263746::Sodium; (R)-6-[1-(5,6-dihydro-8H...",50149468,15214794.0,O=C1C(=Cc2cnc3n2CCOC3)C2SC=CN12
2,CC1=C[C@H](N2C[C@@H]1N(OC(F)(F)C(O)=O)C2=O)C(N)=O,0.47,9.327902,IC50,P62593,Beta-lactamase TEM,,"2-(((2S,5R)-2-carbamoyl-4-methyl-7-oxo-1,6-dia...",467000,,O=C1NC2C=CCN1C2


## 4) Ligand Embedding with ChemBERTa & Feature Combination

Load ChemBERTa model

In [21]:
import numpy as np, torch, time
from transformers import AutoTokenizer, AutoModel

# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load ChemBERTa (ligand encoder)
tok_l = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
mdl_l = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(device).eval()



Define embedding function for SMILES

In [22]:
def chemberta_embed(smiles_list, batch_size=64, max_length=256):
    """
    Generate ligand embeddings from SMILES strings using ChemBERTa.
    - Uses CLS token representation as molecule embedding.
    - Processes molecules in batches for speed.
    """
    vecs = []
    for i in range(0, len(smiles_list), batch_size):
        batch = smiles_list[i:i+batch_size]
        enc = tok_l(batch, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device)
        with torch.inference_mode():
            out = mdl_l(**enc).last_hidden_state  # Shape: [B, L, D]
            cls = out[:, 0, :].detach().cpu().numpy().astype(np.float32)  # CLS token vector
        vecs.append(cls)
    return np.vstack(vecs)


Embed ligands & combine with protein embedding

In [23]:
# Generate ligand embeddings (~768-D)
t0 = time.time()
lig_X = chemberta_embed(df["smiles"].tolist(), batch_size=64)
print("Ligand embed:", lig_X.shape, f"in {time.time()-t0:.1f}s")

# Repeat protein embedding for each ligand (~480-D)
prot_X = np.repeat(prot_vec.reshape(1, -1), len(df), axis=0)

# Combine protein and ligand embeddings (~1248-D feature space)
X = np.hstack([prot_X, lig_X]).astype(np.float32)

# Labels: continuous pAff and binary binder classification
y = df["pAff"].values.astype(np.float32)
y_bin = (y >= 6.0).astype(np.int32)
print("X:", X.shape, "| binders:", int(y_bin.sum()))


Ligand embed: (316, 768) in 40.7s
X: (316, 1248) | binders: 170


## 5) Scaffold-/Cluster-Aware Data Splitting & Model Training

Why:
To ensure realistic evaluation, we split the dataset so that structurally similar ligands (same scaffold or same cluster) do not appear in both train and test sets. This prevents data leakage and ensures that the model is tested on truly novel chemical structures.

We then train two models:

*   XGBoost Regressor → Predict continuous binding
*   Logistic Regression → Predict binary "binder vs non-binder" labels

Group-wise train/test split

In [24]:
import numpy as np, sklearn
from sklearn.cluster import KMeans

def groupwise_split(groups, test_frac=0.2, seed=7):
    """
    Splits dataset into train/test groups based on scaffolds or clusters.
    Ensures no scaffold appears in both sets.
    """
    rng = np.random.default_rng(seed)
    buckets = {}
    for i, g in enumerate(groups):
        key = str(g) if (g and str(g).strip()) else f"None_{i}"
        buckets.setdefault(key, []).append(i)
    keys = list(buckets.keys())
    rng.shuffle(keys)

    test_idx, taken = [], 0
    N = len(groups)
    for k in keys:
        if taken / N < test_frac:
            test_idx.extend(buckets[k])
            taken += len(buckets[k])
        else:
            break

    train_idx = sorted(set(range(N)) - set(test_idx))
    return np.array(train_idx), np.array(test_idx)

# Group by scaffold (if available) or cluster in embedding space
if "scaffold" in df.columns and df["scaffold"].notna().sum() > 0:
    groups = df["scaffold"].fillna("").tolist()
else:
    k = max(5, min(50, len(df) // 50))  # adaptive # clusters
    km = KMeans(n_clusters=k, random_state=7, n_init=10)
    groups = km.fit_predict(lig_X).tolist()

# Perform group-wise split
tr_idx, te_idx = groupwise_split(groups, test_frac=0.2, seed=7)

# Create split datasets
X_tr, X_te = X[tr_idx], X[te_idx]
y_tr, y_te = y[tr_idx], y[te_idx]
yb_tr, yb_te = y_bin[tr_idx], y_bin[te_idx]

print(f"train={len(tr_idx)} (binders {yb_tr.sum()}) | test={len(te_idx)} (binders {yb_te.sum()})")

train=241 (binders 141) | test=75 (binders 29)


Train models & evaluate

In [26]:
from xgboost import XGBRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import mean_squared_error, r2_score, roc_auc_score, average_precision_score

reg = XGBRegressor(
    n_estimators=600, max_depth=6, learning_rate=0.05,
    subsample=0.8, colsample_bytree=0.8, n_jobs=-1
).fit(X_tr, y_tr)

pred = reg.predict(X_te)
try:
    rmse = mean_squared_error(y_te, pred, squared=False)
except TypeError:
    rmse = mean_squared_error(y_te, pred) ** 0.5
r2 = r2_score(y_te, pred)

clf = LogisticRegression(max_iter=2000).fit(X_tr, yb_tr)
p_bin = clf.predict_proba(X_te)[:, 1]
roc = roc_auc_score(yb_te, p_bin)
pr  = average_precision_score(yb_te, p_bin)

print({"RMSE": round(float(rmse),3),
       "R2": round(float(r2),3),
       "ROC-AUC": round(float(roc),3),
       "PR-AUC": round(float(pr),3)})

{'RMSE': 1.996, 'R2': 0.227, 'ROC-AUC': 0.968, 'PR-AUC': 0.933}


## 6) Final Gradio UI — Prediction logic, uncertainty, calibration

Helpers — units & prediction-interval lookup

In [27]:
# Helpers: pAff ↔ concentration (nM) formatting
import numpy as np

def pAff_to_nM(p):
    return 1e9 * (10 ** (-p))

def fmt_conc(nM):
    if nM < 1e-2:   return f"{nM*1e3:.2f} pM"
    if nM < 1:      return f"{nM:.2f} nM"
    if nM < 1e3:    return f"{nM/1e3:.2f} µM"
    return f"{nM/1e6:.2f} mM"

# Ensure test preds exist (used to build conditional 90% intervals)
try:
    pred
except NameError:
    pred = reg.predict(X_te)

# Conditional 90% absolute error by predicted pAff bin
bins = np.linspace(float(pred.min()), float(pred.max()), 8)
bin_idx = np.digitize(pred, bins)
global_q90 = float(np.quantile(np.abs(y_te - pred), 0.90))

q90_table = []
for b in range(len(bins)+1):
    m = bin_idx == b
    if m.sum() >= 15:
        q90_table.append(float(np.quantile(np.abs(y_te[m] - pred[m]), 0.90)))
    else:
        q90_table.append(global_q90)

def q90_for(p):
    i = int(np.digitize([p], bins)[0])
    i = max(0, min(i, len(q90_table)-1))
    return q90_table[i]

Calibration & distribution-shift check

In [28]:
# Calibrate classifier probabilities (isotonic) on training data
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics.pairwise import cosine_similarity
import torch

HI, MID = 0.80, 0.60  # Likely / Uncertain thresholds

# Avoid name shadowing: keep original 'clf' and create a calibrated wrapper
clf_cal = CalibratedClassifierCV(clf, method="isotonic", cv=3).fit(X_tr, yb_tr)

def conf_label(p): return "Likely" if p >= HI else ("Uncertain" if p >= MID else "Unlikely")
def conf_emoji(p): return "🟢" if p >= HI else ("🟡" if p >= MID else "🔴")

# In-distribution check: nearest ligand similarity in training set (ChemBERTa space)
prot_dim = prot_vec.shape[0]
lig_tr   = X_tr[:, prot_dim:]  # ligand features only

def train_similarity(smiles):
    enc = tok_l([smiles], padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
    with torch.inference_mode():
        lig = mdl_l(**enc).last_hidden_state[:,0,:].cpu().numpy().astype(np.float32)
    sim = cosine_similarity(lig, lig_tr)[0]
    return float(sim.max())

Predictor used by the UI

In [29]:
# Predict for a single SMILES: pAff + 90% PI + calibrated P(binder) + similarity note
def predict_smiles(smiles: str):
    if not smiles:
        return "Please enter a SMILES", ""

    # 1) ChemBERTa ligand embedding
    enc = tok_l([smiles], padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
    with torch.inference_mode():
        out = mdl_l(**enc).last_hidden_state
        lig = out[:, 0, :].detach().cpu().numpy().astype(np.float32)

    # 2) Build joint feature with TEM-1 (WT) protein embedding
    fx = np.hstack([prot_vec.reshape(1, -1), lig]).astype(np.float32)

    # 3) Regression → pAff and conditional 90% interval
    p_aff = float(reg.predict(fx)[0])
    q90   = q90_for(p_aff)
    p_lo, p_hi = p_aff - q90, p_aff + q90

    # Pretty concentration readouts
    nM_center = pAff_to_nM(p_aff)
    nM_hi, nM_lo = pAff_to_nM(p_hi), pAff_to_nM(p_lo)

    # 4) Calibrated binder probability and label
    p_cal = float(clf_cal.predict_proba(fx)[:, 1])
    label = conf_label(p_cal); mark = conf_emoji(p_cal)
    badge = " (≤1 µM)" if p_aff >= 6 else ""

    # 5) Training-set similarity (OOD cue)
    sim = train_similarity(smiles)
    sim_note = (f"\nNearest-set similarity: {sim:.2f}"
                if sim >= 0.60 else
                f"\n⚠️ Low training similarity (cosine={sim:.2f}) — higher uncertainty.")

    # 6) Final message (compact but informative)
    msg = (
        f"{mark} **pAff={p_aff:.2f}** (≈ {fmt_conc(nM_center)}) • "
        f"90%≈[{p_lo:.2f}, {p_hi:.2f}] (≈ {fmt_conc(nM_hi)}–{fmt_conc(nM_lo)})\n"
        f"**P(binder)={p_cal:.2f} → {label}**{badge} "
        f"[Likely≥{HI:.2f}, Uncertain {MID:.2f}–{HI:.2f}, Unlikely<{MID:.2f}]"
        + sim_note
    )
    return msg, smiles

Helpers for Batch Prediction & Parsing

In [30]:
import re, numpy as np, matplotlib.pyplot as plt
import torch

def _parse_smiles_block(text, limit=100):
    smi = [s.strip() for s in re.split(r'[\n,;]+', str(text or "")) if s.strip()]
    return smi[:limit]

def _embed_ligands(smiles_list):
    enc = tok_l(smiles_list, padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
    with torch.inference_mode():
        out = mdl_l(**enc).last_hidden_state
        ligs = out[:, 0, :].detach().cpu().numpy().astype(np.float32)
    return ligs  # (L, D)

def batch_predict(smiles_text):
    smi = _parse_smiles_block(smiles_text)
    if not smi:
        return [], np.array([]), np.array([])
    lig = _embed_ligands(smi)                              # (L, Dl)
    P   = np.repeat(prot_vec.reshape(1, -1), len(smi), 0)  # (L, Dp)
    X   = np.hstack([P, lig]).astype(np.float32)           # (L, Dp+Dl)
    p_aff  = reg.predict(X)
    p_bind = clf.predict_proba(X)[:, 1]
    return smi, p_aff, p_bind


Plotting Functions

In [31]:
def plot_paff_bars(smiles_text, top_n=20, paff_threshold=6.0):
    names, paff, pbind = batch_predict(smiles_text)
    fig = plt.figure(figsize=(10, max(3, 0.35 * max(1, len(names)))))
    if len(names) == 0:
        plt.text(0.5, 0.5, "Provide one or more SMILES", ha="center", va="center")
        plt.axis("off"); return fig

    idx = np.argsort(-paff)[:top_n]
    names = [names[i] for i in idx]; paff = paff[idx]; pbind = pbind[idx]

    y = np.arange(len(names))
    plt.barh(y, paff)
    plt.axvline(paff_threshold, linestyle="--")  # ≈ 1 µM
    for i, (x, pb) in enumerate(zip(paff, pbind)):
        plt.text(x, i, f"  p={pb:.2f}", va="center")
    plt.yticks(y, [n[:45] + ("…" if len(n) > 45 else "") for n in names])
    plt.xlabel("Predicted pAff  (−log10 M) — higher = tighter")
    plt.gca().invert_yaxis()
    plt.tight_layout()
    return fig

def plot_paff_vs_pbind(smiles_text, hi=0.80, mid=0.60, paff_thr=6.0):
    names, paff, pbind = batch_predict(smiles_text)
    fig = plt.figure(figsize=(7, 5))
    if len(names) == 0:
        plt.text(0.5, 0.5, "Provide one or more SMILES", ha="center", va="center")
        plt.axis("off"); return fig

    plt.scatter(paff, pbind, s=36)
    plt.axvline(paff_thr, linestyle="--"); plt.axhline(hi, linestyle="--"); plt.axhline(mid, linestyle="--")
    # annotate a few promising points
    top = np.argsort(-(paff + pbind))[:10]
    for i in top:
        lbl = names[i][:18] + ("…" if len(names[i]) > 18 else "")
        plt.annotate(lbl, (paff[i], pbind[i]), xytext=(4, 4), textcoords="offset points")
    plt.xlabel("Predicted pAff (−log10 M)"); plt.ylabel("Calibrated P(binder)")
    plt.title("Batch predictions"); plt.tight_layout()
    return fig

def plot_eval_true_vs_pred():
    try:
        y_true = y_te
        y_pred = reg.predict(X_te)
    except Exception:
        fig = plt.figure(figsize=(6, 2))
        plt.text(0.5, 0.5, "No held-out set available in this session.", ha="center", va="center")
        plt.axis("off"); return fig
    fig = plt.figure(figsize=(5, 5))
    plt.scatter(y_true, y_pred, s=20, alpha=0.7)
    lo = float(min(y_true.min(), y_pred.min())); hi = float(max(y_true.max(), y_pred.max()))
    plt.plot([lo, hi], [lo, hi], linestyle="--")
    plt.xlabel("True pAff"); plt.ylabel("Predicted pAff"); plt.title("Held-out evaluation")
    plt.tight_layout(); return fig

Heatmap Predictor: Batch SMILES Affinity Map (WT TEM-1)

In [10]:
import re, numpy as np, torch, matplotlib.pyplot as plt

def heatmap_predict(smiles_text):
    smi_list = [s.strip() for s in re.split(r'[\n,;]+', str(smiles_text)) if s.strip()]
    smi_list = smi_list[:20]
    if not smi_list:
        return None

    # Batch ChemBERTa embeddings
    enc = tok_l(smi_list, padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
    with torch.inference_mode():
        out = mdl_l(**enc).last_hidden_state
        ligs = out[:, 0, :].detach().cpu().numpy().astype(np.float32)  # (L, D)

    pv = prot_vec.reshape(1, -1)
    pv_rep = np.repeat(pv, len(smi_list), axis=0)
    fx = np.hstack([pv_rep, ligs]).astype(np.float32)
    p_affs = reg.predict(fx)  # (L,)

    M = p_affs.reshape(1, -1)  # single row: WT only

    fig, ax = plt.subplots(figsize=(max(6, len(smi_list)*0.8), 2.8))
    im = ax.imshow(M, aspect="auto")
    ax.set_xticks(range(len(smi_list)))
    ax.set_xticklabels([s[:14] + ("…" if len(s) > 14 else "") for s in smi_list], rotation=45, ha="right")
    ax.set_yticks([0]); ax.set_yticklabels(["TEM-1 (WT)"])
    cbar = fig.colorbar(im, ax=ax); cbar.set_label("Predicted pAff")

    for j in range(M.shape[1]):
        if M[0, j] >= 6.0:
            ax.text(j, 0, "★", ha="center", va="center", fontsize=8)

    ax.set_xlabel("Ligands"); ax.set_ylabel("Variant")
    fig.tight_layout()
    return fig

## 7) Interactive Gradio App (Demo)

In [32]:
import gradio as gr, numpy as np, torch

# --- tolerant access to thresholds/metrics so UI renders even if they don't exist yet ---
HI_T  = float(globals().get("HI", 0.80))
MID_T = float(globals().get("MID", 0.60))
_rmse = globals().get("rmse", None)
_r2   = globals().get("r2", None)
_roc  = globals().get("roc", None)
_pr   = globals().get("pr", None)

metrics_md = (
    f"**Eval (held-out)** — RMSE: {_rmse:.2f} pAff (≈×{10**_rmse:.1f}), "
    f"R²: {_r2:.2f}, ROC-AUC: {_roc:.2f}, PR-AUC: {_pr:.2f}"
    if all(v is not None for v in [_rmse, _r2, _roc, _pr])
    else "*(Train a model or run evaluation to populate metrics here.)*"
)

# --- a tiny helper to clear inputs/outputs (nice for live demos) ---
def _clear_inputs():
    return "", "", "",""

with gr.Blocks(title="Antibiotic Resistance Target Finder — TEM-1") as demo:
    # ===== Header / quick intro =====
    gr.Markdown(
        """
# Antibiotic Resistance Target Finder — TEM-1
**Goal:** Predict how tightly a small molecule binds **TEM-1 β-lactamase** variants (antibiotic resistance enzyme).

**How to use (2 steps):**
1) Paste a **SMILES** string and click **Submit** to get a prediction.
2) Explore **heatmaps** and **binding-affinity graphs** for batches of SMILES.

*Protein embeddings:* ESM-2 (35M) • *Ligand embeddings:* ChemBERTa • *Models:* small XGBoost + LogisticRegression
        """
    )

    # ===== INPUT + PREDICTION =====
    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("### 1) Enter a molecule (SMILES)")
            inp = gr.Textbox(
                label="SMILES",
                placeholder="e.g., CCO    (ethanol)",
                lines=1,
            )
            gr.Examples(
                label="Try examples",
                examples=[
                    ["CCO"],  # unlikely binder
                    ["[O-]C(=O)C1=CS[C@H]2N1C(=O)\\C2=C/c1cn2CCOCc2n1"],  # strong-ish example
                    ["CC1=C[C@H](N2C[C@@H]1N(OC(F)(F)C(O)=O)C2=O)C(N)=O"], # moderate example
                ],
                inputs=[inp],
            )

            with gr.Row():
                btn = gr.Button("🚀 Submit", variant="primary")
                clr = gr.Button("Clear")

        with gr.Column(scale=3):
            gr.Markdown("### 2) Model output")
            out_pred = gr.Markdown(label="Prediction")
            out_smi  = gr.Textbox(label="Echoed SMILES", interactive=False)

            # Quick legend so judges instantly understand labels
            gr.Markdown(
                f"""
**How to read this:**
- **pAff** = −log10(Kd) in molar (higher ⇒ tighter).
  6 ≈ 1 µM, 7 ≈ 100 nM, 8 ≈ 10 nM, 9 ≈ 1 nM.
- **P(binder)** is a calibrated probability.
  We label as **Likely** (≥ {HI_T:.2f}), **Uncertain** ({MID_T:.2f}–{HI_T:.2f}), **Unlikely** (< {MID_T:.2f}).
- We also show a nearest-neighbor similarity to flag **distribution shift** (low similarity ⇒ higher uncertainty).
                """
            )

    # wire prediction + clear
    btn.click(predict_smiles, [inp], [out_pred, out_smi])
    clr.click(_clear_inputs, outputs=[inp, out_smi, ], inputs=None)

    # ===== EXPLANATION ACCORDIONS =====
    with gr.Accordion("What is pAff and why −log10?", open=False):
        gr.Markdown(
            """
**pAff** is the negative log10 of affinity in molar (e.g., Kd).
We use −log10(Kd) because it’s easy to compare (bigger is better), and it keeps values in a compact 5–10 range.

**Examples**
- 1 µM → pAff=6
- 100 nM → pAff=7
- 10 nM → pAff=8
            """
        )

    with gr.Accordion("How this model works (1-paragraph)", open=False):
        gr.Markdown(
            """
**Embeddings:** ESM-2 (35M) encodes the protein; ChemBERTa encodes the ligand.
We concatenate embeddings and train (a) **XGBoost** for pAff and (b) **LogisticRegression** for P(binder).
This keeps compute tiny while leveraging powerful pre-training.
            """
        )

    # ===== HEATMAP (WT only, quick batch view) =====
    with gr.Accordion("Ligand heatmap (WT only)", open=False):
        gr.Markdown(
            "Paste multiple SMILES to see a quick heatmap of predicted pAff (TEM-1 wild-type)."
        )
        smi_multi = gr.Textbox(
            label="SMILES list (comma or newline separated)",
            lines=4,
            placeholder="CCO\nc1ccccc1\n[O-]C(=O)C1=CS[C@H]2N1C(=O)\\C2=C/c1cn2CCOCc2n1",
            value="CCO\nc1ccccc1\n[O-]C(=O)C1=CS[C@H]2N1C(=O)\\C2=C/c1cn2CCOCc2n1",
        )
        hm_btn = gr.Button("Build heatmap")
        hm_plot = gr.Plot(label="pAff heatmap")
        hm_btn.click(heatmap_predict, [smi_multi], hm_plot)

    # ===== BINDING-AFFINITY GRAPHS (bar / scatter / eval) =====
    with gr.Accordion("Binding-affinity graphs", open=True):
        gr.Markdown(
            """
**Why these graphs?**
- **Bar chart:** rank molecules by predicted pAff (quick top-N).
- **Scatter:** shows pAff vs **P(binder)** (agreement helps trust; disagreements signal uncertainty).
- **Eval plot (optional):** if ground truth is loaded, compare predicted vs true.
            """
        )
        smi_batch = gr.Textbox(
            label="SMILES list (newline or comma separated)",
            lines=4,
            placeholder="CCO\nc1ccccc1\n[O-]C(=O)C1=CS[C@H]2N1C(=O)\\C2=C/c1cn2CCOCc2n1",
            value="CCO\nc1ccccc1\n[O-]C(=O)C1=CS[C@H]2N1C(=O)\\C2=C/c1cn2CCOCc2n1",
        )
        with gr.Row():
            bar_btn  = gr.Button("Top-N pAff bar chart")
            scat_btn = gr.Button("pAff vs P(binder) scatter")
            eval_btn = gr.Button("Eval: true vs predicted")
        out_plot = gr.Plot(label="Graph")

        # wire actions (your plotting fns must already exist)
        bar_btn.click(lambda s: plot_paff_bars(s, top_n=20), [smi_batch], out_plot)
        scat_btn.click(plot_paff_vs_pbind, [smi_batch], out_plot)
        eval_btn.click(lambda: plot_eval_true_vs_pred(), [], out_plot)

    # ===== Model card / metrics / limitations =====
    with gr.Accordion("Model card: assumptions, metrics & limits", open=False):
        gr.Markdown(
            f"""
**Compute footprint:** small (≤50M embeddings + lightweight heads). Runs on CPU in Colab/Spaces.
{metrics_md}

**Assumptions / caveats**
- Trained on **TEM-1** datasets; predictions for very dissimilar chemotypes are less certain.
- Reported “confidence” is **calibrated** on a held-out set; not a substitute for wet-lab validation.
- Use as a **ranking/triage** tool, not as a definitive activity claim.
            """
        )

demo.launch(share=True)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://cda9106cdd3e369ee0.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


