<a href="https://colab.research.google.com/github/vaishnavey/CALVADOS_poly/blob/main/RealKcat_Inference_Interface_Class_Prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Robust Prediction of Enzyme Variant Kinetics with RealKcat**

This notebook predicts enzyme kinetic parameters ranges ($k_{\text{cat}}$ and $K_M$) for given protein sequences and substrates using the RealKcat model, which leverages enzyme amino acid sequences and Isomeric SMILEs for substrates.

_No coding experience required. Collapse cells to keep the notebook clean._  

---

## **How to Use**
1. **Select Mode**: `Demo`, `Interactive`, `Bulk`, or `Bulk-large`  
2. **Select Mutation Pathway**:  
   - **Mechanistic Mutation-Aware (default)** – realistic graded mutation effects (recommended)  
   - **Binary Over-Simplified (poor and optional)** – legacy binary alanine handling for benchmarking only  
3. **Run the cells sequentially below**  
4. **Follow on-screen prompts** (enter input or upload CSV)  
5. **Download results** (CSV is auto-saved to your computer)  

> ⚠️ Each time you run the inference cell, Colab will:  
> - Install required dependencies  
> - Download pretrained model files (~2 minutes setup)  

---

## **Modes**
- **`Demo`** → Test run with predefined inputs (~2 min)  
- **`Interactive`** → Manually enter up to 10 enzyme–substrate pairs  
- **`Bulk`** → Upload CSV with ≤10 pairs  
  - [Sample CSV](https://drive.google.com/uc?export=download&id=1X9bR67NW-sTNHaKU4W1Htlfxhl-OKMmu)  
- **`Bulk-large`** → For >10 pairs (batch size = 20 by default, could be slower in speed. GPU recommended).  
  - Outputs extended results including ±1-class error (“e-accuracy”), spanning one order of magnitude around predictions to reflect physiological robustness (pH, cofactors, _in vitro_ vs _in vivo_).  
  - **If your sequences are long (~>500 aa), start with a very small BATCH_SIZE (1–2) to avoid running out of RAM in Google Colab, then increase gradually.**

---

## **Limits**
- Maximum sequence length: **1022 amino acids**  
- Maximum SMILES length: **512 characters**  
- Entries exceeding these will be marked as `"skipped"` in the output.  


In [None]:
#@title Select RealKcat Mode of Inference,  $k_{cat}$ and $K_{M}$  [Demo  takes ~4-10 minutes]
mode = "Bulk-large"  #@param ["Demo", "Interactive", "Bulk", "Bulk-large"]
#@title Select mutation handling pathway
mutation_mode = "Mechanistic Mutation-Aware (default)"  #@param ["Mechanistic Mutation-Aware (default)", "Binary Over-Simplified (poor and optional)"]
import pandas as pd, io, os, sys
CSV_OUT = "infer_input.csv"
BATCH_SIZE = 20
try:
    from google.colab import files as _colab_files
    IS_COLAB = True
except Exception:
    _colab_files = None
    IS_COLAB = False
def _clean(df):
    df = df.copy()
    df["sequence"] = df["sequence"].astype(str).str.replace(r"\s+", "", regex=True)
    df["Isomeric SMILES"] = df["Isomeric SMILES"].astype(str).str.replace(r"\s+", "", regex=True)
    return df
if mode == "Bulk":
    if IS_COLAB:
        uploaded = _colab_files.upload()
        if not uploaded: raise RuntimeError("no file")
        f = list(uploaded.keys())[0]
        df = pd.read_csv(io.BytesIO(uploaded[f]))
    else:
        p = input("Path to CSV with 'sequence' and 'Isomeric SMILES': ").strip()
        if not p: raise RuntimeError("no path")
        df = pd.read_csv(p)
    df = _clean(df)
    df.to_csv(CSV_OUT, index=False)
elif mode == "Bulk-large":
    if IS_COLAB:
        uploaded = _colab_files.upload()
        if not uploaded: raise RuntimeError("no file")
        f = list(uploaded.keys())[0]
        df = pd.read_csv(io.BytesIO(uploaded[f]))
    else:
        p = input("Path to CSV with 'sequence' and 'Isomeric SMILES': ").strip()
        if not p: raise RuntimeError("no path")
        df = pd.read_csv(p)
    df = _clean(df)
    df.to_csv(CSV_OUT, index=False)
    try:
        s = input("Batch size (default 20): ").strip()
        if s: BATCH_SIZE = int(s)
    except Exception:
        pass
elif mode in ("Interactive", "Demo"):
    pairs = []
    if mode == "Demo":
        pairs = [
            ("MKVAVLGAAGGIGQALALLLKTQLPSGSELSLYDIAPVTPGVAVDLSHIPTAVKIKGFSGEDATPALEGADVVLISAGVARKPGMDRSDLFNVNAGIVKNLVQQVAKTCPKACIGIITNPVNTTVAIAAEVLKKAGVYDKNKLFGVTTLDIIRSNTFVAELKGKQPGEVEVPVIGGHSGVTILPLLSQVPGVSFTEQEVADLTKRIQNAGTEVVEAKAGGGSATLSMGQAAARFGLSLVRALQGEQGVVECAYVEGDGQYARFFSQPLLLGKNGVEERKSIGTLSAFEQNALEGMLDTLKKDIALGEEFVNK","C(C(=O)C(=O)O)C(=O)O"),
            ("MGVEQILKRKTGVIVGEDVHNLFTYAKEHKFAIPAINVTSSSTAVAALEAARDSKSPIILQTSNGGAAYFAGKGISNEGQNASIKGAIAAAHYIRSIAPAYGIPVVLHSDHCAKKLLPWFDGMLEADEAYFKEHGEPLFSSHMLDLSEETDEENISTCVKYFKRMAAMDQWLEMEIGITGGEEDGVNNENADKEDLYTKPEQVYNVYKALHPISPNFSIAAAFGNCHGLYAGDIALRPEILAEHQKYTREQVGCKEEKPLFLVFHGGSGSTVQEFHTGIDNGVVKVNLDTDCQYAYLTGIRDYVLNKKDYIMSPVGNPEGPEKPNKKFFDPRVWVREGEKTMGAKITKSLETFRTTNTL","C([C@H](C=O)O)OP(=O)(O)O"),
            ("M"*1023,"C(C=O)"),
            ("MKKVAV","C"+"C"*512),
        ]
    else:
        MAX=10
        while len(pairs)<MAX:
            s=input(f"Sequence #{len(pairs)+1} (blank to stop): ").strip()
            if not s: break
            m=input("Isomeric SMILES: ").strip()
            if not m: continue
            pairs.append((s,m))
    if not pairs: raise RuntimeError("no inputs")
    df = pd.DataFrame(pairs, columns=["sequence","Isomeric SMILES"])
    df = _clean(df)
    df.to_csv(CSV_OUT, index=False)
else:
    raise ValueError(mode)
from IPython import get_ipython
ip = get_ipython()
ip.run_line_magic("env", f"MODE={mode}")
ip.run_line_magic("env", f"CSV_PATH={os.path.abspath(CSV_OUT)}")
ip.run_line_magic("env", f"BATCH_SIZE={BATCH_SIZE}")
ip.run_line_magic("env", f"MUTATION_MODE={mutation_mode}")


In [None]:
#@title Run RealKcat Inference (Mechanistic or Binary, depending on selection above)
%%bash
set -e
export DEBIAN_FRONTEND=noninteractive
export TF_CPP_MIN_LOG_LEVEL=0   # <-- suppress TF INFO/WARNING/ERROR logs
export XLA_FLAGS="--xla_cpu_enable_fast_math=false"
# pip uninstall -y tensorflow tensorflow-gpu tensorflow-cpu tensorflow-intel tf-nightly 2>/dev/null || true
if [ -z "${MODE:-}" ]; then
  echo "No mode selected. Please run the cell above to choose a mode of inference first."
  exit 1
fi
: "${CSV_PATH:?}"; : "${BATCH_SIZE:=20}"
MUTATION_MODE="${MUTATION_MODE:-Mechanistic}"
echo "[info] MODE=${MODE}  MUTATION_MODE=${MUTATION_MODE}  CSV_PATH=${CSV_PATH}  BATCH_SIZE=${BATCH_SIZE}"

# ---------------------------- MECHANISTIC BRANCH (ESM-C, v2-like) ----------------------------
if [ "${MUTATION_MODE}" = "Mechanistic Mutation-Aware (default)" ]; then
  echo "[RealKcat] Running Mechanistic Mutation-Aware mode (recommended, ESM-C)..."
  # optional clean-up
  # pip uninstall -y tensorflow tensorflow-gpu tensorflow-cpu tensorflow-intel tf-nightly 2>/dev/null || true
  echo "[setup] Installing system deps for Python 3.12…"
  sudo sed -i 's/^[[:space:]]*deb-src .*r2u\.stat\.illinois\.edu.*$/# &/' /etc/apt/sources.list /etc/apt/sources.list.d/*.list 2>/dev/null || true
  sudo apt-get update -qq > /dev/null
  sudo apt-get install -qq -y python3.12 python3.12-venv python3-distutils curl wget unzip > /dev/null
  curl -sS https://bootstrap.pypa.io/get-pip.py -o /tmp/get-pip.py
  python3.12 /tmp/get-pip.py --quiet
  echo "[setup] Installing Python deps…"
  python3.12 -m pip install -q \
    numpy==2.1 pandas==2.2.2 scikit-learn==1.7.1 imbalanced-learn==0.8.1 \
    seaborn==0.11.2 joblib==1.2.0 ipython==7.34.0 \
    matplotlib==3.10.5 \
    notebook==6.5.4 jupyterlab==3.6.1 openpyxl==3.1.2 xlrd==2.0.1 XlsxWriter==3.0.3
  python3.12 -m pip install -q \
    xgboost==2.1.4 torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 \
    tqdm esm==3.2.3 httpx==0.28.1 biotite==1.2.0 \
    # huggingface_hub==1.2.3 #transformers==4.46.3  \
    huggingface_hub==0.25.2 #transformers==4.46.3  \
  : > latex_queue.txt
  echo "[run] Starting mechanistic inference… MODE=${MODE}  CSV_PATH=${CSV_PATH}  BATCH_SIZE=${BATCH_SIZE}  ENCODER=${ENCODER:-esmc}  ALLOW_ENCODER_MISMATCH=${ALLOW_ENCODER_MISMATCH:-1}"
  MODE="$MODE" CSV_PATH="$CSV_PATH" BATCH_SIZE="$BATCH_SIZE" stdbuf -oL -eL python3.12 - <<'PY'
import os, sys, warnings, torch, argparse, random, math, joblib, numpy as np, pandas as pd, io, gc
warnings.filterwarnings("ignore")
import logging
logging.disable(logging.CRITICAL)
from torch.utils.data import Dataset, DataLoader
from torch.serialization import add_safe_globals
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
from huggingface_hub import snapshot_download
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig
# ---------------- basic runtime ----------------
os.environ["OMP_NUM_THREADS"]="1"; os.environ["MKL_NUM_THREADS"]="1"
try:
    torch.set_num_threads(1)
except Exception:
    pass
MODE = os.environ.get("MODE","Demo").strip()
CSV_PATH = os.environ["CSV_PATH"]
BATCH_SIZE = int(os.environ.get("BATCH_SIZE","20"))
ENCODER = os.environ.get("ENCODER","esmc").strip().lower()          # only "esmc" supported here
ALLOW_ENCODER_MISMATCH = int(os.environ.get("ALLOW_ENCODER_MISMATCH","1"))
try:
    from google.colab import files as _cf
    IS_COLAB=True
except Exception:
    _cf=None; IS_COLAB=False
def check_device(): return torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = check_device()
print("Using device:", device); sys.stdout.flush()
print(sys.version); sys.stdout.flush()
print("Starting inference setup…"); sys.stdout.flush()
MAX_SEQ_LENGTH=1022
MAX_SMILES_LENGTH=512
# ---------------- class ranges ----------------
#  Build class ranges for kcat and km
kcat_log_bins = np.array([0, 1e-08, 1e-02, 1e-01, 1e+00, 1e+01, 1e+02, 1e+03, 1e+08])
km_log_bins   = np.array([1e-14, 1e-05, 1e-04, 1e-03, 1e-02, 1e-01, 1e+04])
class_ranges_kcat = {i: {"low": float(kcat_log_bins[i]), "high": float(kcat_log_bins[i+1])}
                for i in range(len(kcat_log_bins)-1)}
class_ranges_km = {i: {"low": float(km_log_bins[i]), "high": float(km_log_bins[i+1])}
                for i in range(len(km_log_bins)-1)}
def format_sci(v):
    s=f"{v:.2e}"
    if "e" in s:
        b,e=s.split("e"); return f"{b}x10^{int(e)}"
    return s
warnings.filterwarnings("ignore"); add_safe_globals([argparse.Namespace])
# ---------------- data load & basic cleaning ----------------
df=pd.read_csv(CSV_PATH)
df["sequence"]=df["sequence"].astype(str).str.replace(r"\s+","",regex=True).str.upper() \
            .str.replace(r"[^ACDEFGHIKLMNPQRSTVWY]", "X", regex=True)
df["Isomeric SMILES"]=df["Isomeric SMILES"].astype(str).str.replace(r"\s+","",regex=True)
pairs=list(zip(df["sequence"],df["Isomeric SMILES"]))
# ---------------- ESM-C loader & embed ----------------
def load_esmc_model(device, work_dir="."):
    """
    Loads ESM-C 600M to work_dir/esmc_model/... if needed, then returns eval() model.
    """
    repo_id = "EvolutionaryScale/esmc-600m-2024-12"
    local_dir = os.path.join(work_dir, "esmc_model")
    weights_rel = "data/weights/esmc_600m_2024_12_v0.pth"
    weights_path = os.path.join(local_dir, weights_rel)
    if not os.path.exists(weights_path):
        print("[ESMC] Weights missing; downloading from HuggingFace…")
        from huggingface_hub import logging as hf_logging
        hf_logging.set_verbosity_error()
        os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
        snapshot_download(repo_id=repo_id, local_dir=local_dir, force_download=False,tqdm_class=None)
    model = ESMC.from_pretrained("esmc_600m").to(device)
    if os.path.exists(weights_path):
        try:
            sd = torch.load(weights_path, map_location=device)
            model.load_state_dict(sd, strict=False)
        except Exception as e:
            print(f"[ESMC] Warning: could not load local weights strictly: {e}")
    model.eval()
    return model
@torch.no_grad()
def esmc_mean_embed(model, sequence: str, device):
    """
    Returns mean-pooled embedding over residues (excl. special tokens).
    """
    protein = ESMProtein(sequence=sequence)
    toks = model.encode(protein)  # (B, L)
    out = model.logits(toks, LogitsConfig(sequence=True, structure=True, return_embeddings=True))
    L = len(sequence)
    reps = out.embeddings[0, 1:L-1]  # drop [CLS]/[EOS]-like
    return reps.mean(dim=0).float().to(device)  # (D_esmc,)
# ---------------- chem encoder (unchanged) ----------------
_ESMC=None; _TOK=None; _CHEM=None
def _get_models():
    global _ESMC,_TOK,_CHEM
    if ENCODER != "esmc":
        raise RuntimeError("This cell implements only ENCODER=esmc. (esm2 path omitted.)")
    if _ESMC is None:
        _ESMC = load_esmc_model(device)
    if _TOK is None or _CHEM is None:
        _TOK=AutoTokenizer.from_pretrained('seyonec/PubChem10M_SMILES_BPE_450k')
        _CHEM=AutoModel.from_pretrained('seyonec/PubChem10M_SMILES_BPE_450k'); _CHEM.eval().to(device)
    return _ESMC, _TOK, _CHEM
# ---------------- dataset/standardization (unchanged split) ----------------
class TensorDataset(Dataset):
    def __init__(self,d,l): self.d,self.l=d,l
    def __len__(self): return len(self.d)
    def __getitem__(self,i): return self.d[i],self.l[i]
def dataset_to_tensors(ds):
    ld=DataLoader(ds,batch_size=len(ds),shuffle=False); return next(iter(ld))
def standardize_x_global_separate(data,g1,s1,g2,s2):
    # legacy split: first 1280 dims are sequence, remainder is chem
    # X1,X2=data[:,:1280],data[:,1280:]          # ESM-2
    X1,X2 = data[:, :1152], data[:, 1152:]       # ESM-C
    s1=torch.clamp(s1,min=1e-7); s2=torch.clamp(s2,min=1e-7)
    return torch.cat(((X1-g1)/s1,(X2-g2)/s2),dim=1).squeeze(1)
class StandardizedDatasetGlobalSeparate(Dataset):
    def __init__(self,sub,g1,s1,g2,s2): self.sub=sub; self.g1=g1; self.s1=s1; self.g2=g2; self.s2=s2
    def __len__(self): return len(self.sub)
    def __getitem__(self,i):
        x,y=self.sub[i]
        if len(x.shape)==1: x=x.unsqueeze(1)
        return standardize_x_global_separate(x,self.g1,self.s1,self.g2,self.s2),y
def apply_global_standardization_separate(ds,g1,s1,g2,s2):
    return StandardizedDatasetGlobalSeparate(ds,g1,s1,g2,s2)
# legacy global stats for sequence (1152) and chem block
global_mean_1=torch.tensor(-0.0004980484955012798,device=device)
global_std_1=torch.tensor(0.027508573606610298,device=device)
global_mean_2=torch.tensor(-2.7928688723477535e-05,device=device)
global_std_2=torch.tensor(0.6221200823783875,device=device)
# ---------------- model weights (unchanged) ----------------
kcat_model_v1b_path="model_weights/kcat_model_v1b.pkl"
km_model_v1b_path="model_weights/km_model_v1b.pkl"
if not (os.path.exists(kcat_model_v1b_path) and os.path.exists(km_model_v1b_path)):
    os.system("wget -q https://github.com/TKAI-LAB-Mali/RealKcat/raw/main/model_weights.zip -O model_weights.zip")
    os.system("unzip -o -q model_weights.zip")
# ---------------- inference wrapper ----------------
class KcatInference:
    def __init__(self,model_path,device=None,verbose=True):
        self.device=device if device else device
        self.model=joblib.load(model_path); self.verbose=verbose
        self.X_test_tensor=None; self.y1_test_tensor=None; self.keep_local=[]
    def load_data_from_pairs(self,pairs):
        embs=[]; keep=[]
        esmc_model, tok, chem = _get_models()
        if self.verbose:
            print(f"Models loaded (ENCODER={ENCODER}). Embedding sequences and substrates..."); sys.stdout.flush()
        for i,(seq,smi) in enumerate(pairs,1):
            if not isinstance(seq,str) or not isinstance(smi,str): continue
            if len(seq)>MAX_SEQ_LENGTH: continue
            # ---- sequence (ESM-C) ----
            try:
                es = esmc_mean_embed(esmc_model, seq, device)   # (D_esmc,)
                # enforce legacy 1152-dim for downstream scaler/models
                target_d = 1152
                d = es.numel()
                if d != target_d:
                    if not ALLOW_ENCODER_MISMATCH:
                        raise RuntimeError(
                            f"ESM-C embedding dim {d} != {target_d} expected by trained models. "
                            f"Set ALLOW_ENCODER_MISMATCH=1 to pad/truncate (for plumbing only), "
                            f"or retrain RealKcat with ESM-C."
                        )
                    es = es[:target_d] if d > target_d else torch.cat([es, torch.zeros(target_d - d, device=es.device)])
            except Exception:
                continue
            # ---- substrate SMILES (PubChem tokenizer/encoder) ----
            try:
                inp=tok([smi],return_tensors='pt',padding=True,truncation=False)
                inp={k:v.to(device) for k,v in inp.items()}
                if inp['input_ids'].shape[1]>MAX_SMILES_LENGTH: continue
                with torch.no_grad(): out=chem(**inp)
                cs=out.last_hidden_state.mean(dim=1).squeeze(0).float()
            except Exception:
                continue
            if es.dim()>1: es=es.flatten()
            if cs.dim()>1: cs=cs.flatten()
            embs.append(torch.cat((es,cs))); keep.append(i-1)
        if not embs:
            self.X_test_tensor=None; self.y1_test_tensor=None; self.keep_local=[]
            return
        self.X_test_tensor=torch.stack(embs).to(device)
        self.y1_test_tensor=torch.zeros(len(embs),dtype=torch.long).to(device)
        self.keep_local=keep
    def standardize_test_data(self,g1,s1,g2,s2):
        self.test_dataset_std=apply_global_standardization_separate(
            TensorDataset(self.X_test_tensor,self.y1_test_tensor),g1,s1,g2,s2
        ) if self.X_test_tensor is not None else None
    def convert_to_numpy(self):
        if self.test_dataset_std is None: return None,None,[]
        X,y=dataset_to_tensors(self.test_dataset_std); return X.cpu().numpy(),y.cpu().numpy(),self.keep_local
    def predict(self,X): return self.model.predict(X)
    def display_prediction_ranges_kcat(self,preds,cr):
        # section added to analyze preds - by mariana
        print("\n" + "="*80)
        print("display_prediction_ranges_kcat: preds parameter contents")
        print(f"Type of preds: {type(preds)}")
        print(f"Length of preds: {len(preds)}")
        print(f"First 10 elements: {preds[:10]}")
        print(f"Full preds list: {preds}")
        print(preds)
        print("="*80 + "\n")
        # end of section
        for i,p in enumerate(preds):
            if p is None or p=="skipped": print(f"Sample {i+1}: skipped due to excessive length")
            else: print(f"Sample {i+1}: Predicted Class = {p}, kcat range = [{format_sci(cr[p]['low'])}, {format_sci(cr[p]['high'])}]")
    def display_prediction_ranges_km(self,preds,cr):
        # section added to analyze preds - by mariana
        print("\n" + "="*80)
        print("display_prediction_ranges_km: preds parameter contents")
        print(f"Type of preds: {type(preds)}")
        print(f"Length of preds: {len(preds)}")
        print(f"First 10 elements: {preds[:10]}")
        print(f"Full preds list: {preds}")
        print(preds)
        print("="*80 + "\n")
        # end of section
        for i,p in enumerate(preds):
            if p is None or p=="skipped": print(f"Sample {i+1}: skipped due to excessive length")
            else: print(f"Sample {i+1}: Predicted Class = {p}, km range = [{format_sci(cr[p]['low'])}, {format_sci(cr[p]['high'])}]")
# ---------------- driver ----------------
pairs_all=pairs
valid=[i for i,(s,m) in enumerate(pairs_all)
      if isinstance(s,str) and isinstance(m,str)
      and len(s)<=MAX_SEQ_LENGTH and len(m)<=MAX_SMILES_LENGTH]
if not valid:
    df["Predicted_Kcat_low"]="skipped"; df["Predicted_Kcat_high"]="skipped"
    df["Predicted_KM_low"]="skipped"; df["Predicted_KM_high"]="skipped"
    df.to_csv("inference_results.csv",index=False)
    print("Inference complete. Saved inference_results.csv")
    if MODE in ("Bulk","Bulk-large") and IS_COLAB: _cf.download("inference_results.csv")
    sys.exit(0)
def run_batch(idxs):
    b=[pairs_all[i] for i in idxs]
    inf=KcatInference(model_path=kcat_model_v1b_path,device=device,verbose=False)
    inf.load_data_from_pairs(b)
    if not inf.keep_local:
        return [],[],[]
    inf.standardize_test_data(global_mean_1,global_std_1,global_mean_2,global_std_2)
    X,_,keep_local=inf.convert_to_numpy()
    yk=KcatInference(model_path=kcat_model_v1b_path,device=device,verbose=False).predict(X)
    ym=KcatInference(model_path=km_model_v1b_path,device=device,verbose=False).predict(X)
    del inf, X; gc.collect()
    return yk,ym,keep_local
N=len(df)
kcat_low_full=['skipped']*N; kcat_high_full=['skipped']*N; kcat_low_m1=['skipped']*N; kcat_high_p1=['skipped']*N
km_low_full=['skipped']*N; km_high_full=['skipped']*N; km_low_m1=['skipped']*N; km_high_p1=['skipped']*N
# added section for predicted class numbers - by mariana
kcat_class_full=['skipped']*N
km_class_full=['skipped']*N
# end of added section
max_kcat_c=max(class_ranges_kcat.keys()); max_km_c=max(class_ranges_km.keys())
if MODE=="Bulk-large":
    total_batches = math.ceil(len(valid)/BATCH_SIZE)
    for b in tqdm(range(0, len(valid), BATCH_SIZE), desc="Batches", unit="batch", dynamic_ncols=True, file=sys.stdout, miniters=1):
        idxs=valid[b:b+BATCH_SIZE]; yk,ym,keep=run_batch(idxs)
        if not keep: continue
        for j,k in enumerate(keep):
            i0=idxs[k]; ck=int(yk[j]); cm=int(ym[j])
            # added section to store class numbers - by mariana
            kcat_class_full[i0]=ck
            km_class_full[i0]=cm
            # end of added section
            kcat_low_full[i0]=class_ranges_kcat[ck]["low"]; kcat_high_full[i0]=class_ranges_kcat[ck]["high"]
            kcat_low_m1[i0]=class_ranges_kcat[max(ck-1,0)]["low"]; kcat_high_p1[i0]=class_ranges_kcat[min(ck+1,max_kcat_c)]["high"]
            km_low_full[i0]=class_ranges_km[cm]["low"]; km_high_full[i0]=class_ranges_km[cm]["high"]
            km_low_m1[i0]=class_ranges_km[max(cm-1,0)]["low"]; km_high_p1[i0]=class_ranges_km[min(cm+1,max_km_c)]["high"]
        gc.collect(); sys.stdout.flush()
else:
    yk,ym,keep=run_batch(valid)
    if keep:
        for j,k in enumerate(keep):
            i0=valid[k]; ck=int(yk[j]); cm=int(ym[j])
            # added section to store class numbers - by mariana
            kcat_class_full[i0]=ck
            km_class_full[i0]=cm
            # end of added section
            kcat_low_full[i0]=class_ranges_kcat[ck]["low"]; kcat_high_full[i0]=class_ranges_kcat[ck]["high"]
            km_low_full[i0]=class_ranges_km[cm]["low"]; km_high_full[i0]=class_ranges_km[cm]["high"]
        if MODE in ("Demo","Interactive"):
            final_kcat_predictions=["skipped"]*N; final_km_predictions=["skipped"]*N
            for j,k in enumerate(keep):
                i0=valid[k]; final_kcat_predictions[i0]=int(yk[j]); final_km_predictions[i0]=int(ym[j])
            kcat_inference_print=KcatInference(model_path=kcat_model_v1b_path,device=device,verbose=False)
            km_inference_print=KcatInference(model_path=km_model_v1b_path,device=device,verbose=False)
            print("\n=== kcat Prediction Results ==="); sys.stdout.flush()
            kcat_inference_print.display_prediction_ranges_kcat(final_kcat_predictions,class_ranges_kcat); sys.stdout.flush()
            print("\n=== KM Prediction Results ==="); sys.stdout.flush()
            km_inference_print.display_prediction_ranges_km(final_km_predictions,class_ranges_km); sys.stdout.flush()
# ---------------- write results ----------------
# added section to add km predicted class numbers - by mariana
df['Predicted_Kcat_class']=kcat_class_full
# end of added section
df['Predicted_Kcat_low']=kcat_low_full
df['Predicted_Kcat_high']=kcat_high_full
df['Predicted_Kcat_low (-1 class error)']=kcat_low_m1
df['Predicted_Kcat_high (+1 class error)']=kcat_high_p1
# added section to add km predicted class numbers - by mariana
df['Predicted_KM_class']=km_class_full
# end of added section
df['Predicted_KM_low']=km_low_full
df['Predicted_KM_high']=km_high_full
df['Predicted_KM_low (-1 class error)']=km_low_m1
df['Predicted_KM_high (+1 class error)']=km_high_p1

df.to_csv("inference_results.csv",index=False)
print("Inference complete. Saved inference_results.csv"); sys.stdout.flush()
# try:
#     if MODE=="Bulk-large" and IS_COLAB:
#         _cf.download("inference_results.csv")
#     if MODE=="Bulk" and IS_COLAB and os.environ.get("DOWNLOAD_CSV","")=="1":
#         _cf.download("inference_results.csv")
# except Exception as e:
#     print(f"[warn] Auto-download failed in subprocess: {e}")
PY
# ---------------------------- BINARY BRANCH (ESM-2, v1-like) ----------------------------
else
  echo "[RealKcat] Running Binary Alanine Simplified mode (optional, ESM-2)…"
  echo "[setup] Installing system deps for Python 3.10 …"
  echo "Trying to: Install dependencies, download and unzip model weights, get dependencies  [~2 minute]"
  sudo sed -i 's/^[[:space:]]*deb-src .*r2u\.stat\.illinois\.edu.*$/# &/' /etc/apt/sources.list /etc/apt/sources.list.d/*.list 2>/dev/null || true
  sudo apt-get update -qq > /dev/null
  sudo apt-get install -qq -y python3.10 python3.10-distutils python3.10-venv curl wget unzip > /dev/null
  curl -sS https://bootstrap.pypa.io/get-pip.py -o /tmp/get-pip.py
  python3.10 /tmp/get-pip.py --quiet
  python3.10 -m pip install -q numpy==1.23.5 pandas==1.5.3 scikit-learn==1.1.3 imbalanced-learn==0.8.1 matplotlib==3.6.3 seaborn==0.11.2 joblib==1.2.0 ipython==7.33.0 notebook==6.5.4 jupyterlab==3.6.1 openpyxl==3.1.2 xlrd==2.0.1 XlsxWriter==3.0.3
  python3.10 -m pip install -q xgboost==2.1.4 torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 transformers==4.33.3 fair-esm==2.0.0 mkl==2022.1.0 mkl-service==2.4.0 intel-openmp==2022.1.0 tqdm
    : > latex_queue.txt
    MODE="$MODE" CSV_PATH="$CSV_PATH" BATCH_SIZE="$BATCH_SIZE" stdbuf -oL -eL python3.10 - <<'PY'
import os, sys, warnings, torch, argparse, random, math, joblib, numpy as np, pandas as pd, io, gc
from torch.utils.data import Dataset, DataLoader
from torch.serialization import add_safe_globals
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import esm
os.environ["OMP_NUM_THREADS"]="1"; os.environ["MKL_NUM_THREADS"]="1"
try:
    torch.set_num_threads(1)
except Exception:
    pass
MODE=os.environ.get("MODE","Demo").strip()
CSV_PATH=os.environ["CSV_PATH"]
BATCH_SIZE=int(os.environ.get("BATCH_SIZE","20"))
try:
    from google.colab import files as _cf
    IS_COLAB=True
except Exception:
    _cf=None; IS_COLAB=False
def check_device(): return torch.device("cuda" if torch.cuda.is_available() else "cpu")
device=check_device()
print("Using device:",device); sys.stdout.flush()
print(sys.version); sys.stdout.flush()
print("Starting inference setup…"); sys.stdout.flush()
MAX_SEQ_LENGTH=1022
MAX_SMILES_LENGTH=512
class_ranges_kcat={0:{"low":0.0,"high":3.32e-8},1:{"low":3.33e-8,"high":1.0e-2},2:{"low":1.01e-2,"high":1.0e-1},3:{"low":1.01e-1,"high":1.0},4:{"low":1.001,"high":10.0},5:{"low":1.004e1,"high":1.0e2},6:{"low":1.0025e2,"high":1.0e3},7:{"low":1.002e3,"high":7.0e7}}
class_ranges_km={0:{"low":1.0e-10,"high":1.0e-5},1:{"low":1.01e-5,"high":1.0e-4},2:{"low":1.002e-4,"high":1.0e-3},3:{"low":1.002e-3,"high":1.0e-2},4:{"low":1.008e-2,"high":1.0e-1},5:{"low":1.01e-1,"high":1.02e2}}
def format_sci(v):
    s=f"{v:.2e}"
    if "e" in s:
        b,e=s.split("e"); return f"{b}x10^{int(e)}"
    return s
warnings.filterwarnings("ignore"); add_safe_globals([argparse.Namespace])
df=pd.read_csv(CSV_PATH)
df["sequence"]=df["sequence"].astype(str).str.replace(r"\s+","",regex=True).str.upper().str.replace(r"[^ACDEFGHIKLMNPQRSTVWY]", "X", regex=True)
df["Isomeric SMILES"]=df["Isomeric SMILES"].astype(str).str.replace(r"\s+","",regex=True)
pairs=list(zip(df["sequence"],df["Isomeric SMILES"]))
def load_esm2_model(device,work_dir=".",verbose=True):
    url="https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt"
    fn=url.split("/")[-1]; p=os.path.join(work_dir,fn)
    if not os.path.exists(p):
        m=torch.hub.load_state_dict_from_url(url,progress=True,map_location=device); torch.save(m,p)
    else:
        if verbose: print("ESM2 model weights found locally. Loading from disk...")
        m=torch.load(p,map_location=device)
    model,alphabet=esm.pretrained.load_model_and_alphabet_core("esm2_t33_650M_UR50D",m)
    model.eval().to(device); return model,alphabet
class TensorDataset(Dataset):
    def __init__(self,d,l): self.d,self.l=d,l
    def __len__(self): return len(self.d)
    def __getitem__(self,i): return self.d[i],self.l[i]
def dataset_to_tensors(ds):
    ld=DataLoader(ds,batch_size=len(ds),shuffle=False); return next(iter(ld))
def standardize_x_global_separate(data,g1,s1,g2,s2):
    X1,X2=data[:,:1280],data[:,1280:]; s1=torch.clamp(s1,min=1e-7); s2=torch.clamp(s2,min=1e-7)
    return torch.cat(((X1-g1)/s1,(X2-g2)/s2),dim=1).squeeze(1)
class StandardizedDatasetGlobalSeparate(Dataset):
    def __init__(self,sub,g1,s1,g2,s2): self.sub=sub; self.g1=g1; self.s1=s1; self.g2=g2; self.s2=s2
    def __len__(self): return len(self.sub)
    def __getitem__(self,i):
        x,y=self.sub[i]
        if len(x.shape)==1: x=x.unsqueeze(1)
        return standardize_x_global_separate(x,self.g1,self.s1,self.g2,self.s2),y
def apply_global_standardization_separate(ds,g1,s1,g2,s2):
    return StandardizedDatasetGlobalSeparate(ds,g1,s1,g2,s2)
global_mean_1=torch.tensor(-0.0006011285004206002,device=device)
global_std_1=torch.tensor(0.18902993202209473,device=device)
global_mean_2=torch.tensor(-0.00015002528380136937,device=device)
global_std_2=torch.tensor(0.6113553047180176,device=device)
kcat_model_path="model_weights/kcat_model.pkl"
km_model_path="model_weights/km_model.pkl"
if not (os.path.exists(kcat_model_path) and os.path.exists(km_model_path)):
    os.system("wget -q https://github.com/TKAI-LAB-Mali/RealKcat/raw/main/model_weights.zip -O model_weights.zip")
    os.system("unzip -o -q model_weights.zip")
_ESM=None; _ALPH=None; _TOK=None; _CHEM=None
def _get_models():
    global _ESM,_ALPH,_TOK,_CHEM
    if _ESM is None or _ALPH is None:
        _ESM,_ALPH=load_esm2_model(device,verbose=False)
    if _TOK is None or _CHEM is None:
        _TOK=AutoTokenizer.from_pretrained('seyonec/PubChem10M_SMILES_BPE_450k')
        _CHEM=AutoModel.from_pretrained('seyonec/PubChem10M_SMILES_BPE_450k'); _CHEM.eval().to(device)
    return _ESM,_ALPH,_TOK,_CHEM
class KcatInference:
    def __init__(self,model_path,device=None,verbose=True):
        self.device=device if device else device
        self.model=joblib.load(model_path); self.verbose=verbose
        self.X_test_tensor=None; self.y1_test_tensor=None; self.keep_local=[]
    def load_data_from_pairs(self,pairs):
        embs=[]; keep=[]
        esm_model,alphabet,tok,chem=_get_models(); bc=alphabet.get_batch_converter()
        if self.verbose: print("Models loaded. Embedding sequences and substrates..."); sys.stdout.flush()
        for i,(seq,smi) in enumerate(pairs,1):
            if not isinstance(seq,str) or not isinstance(smi,str): continue
            if len(seq)>MAX_SEQ_LENGTH: continue
            try:
                _,_,bt=bc([(f"sample_{i}",seq)]); bt=bt.to(device); bl=(bt!=alphabet.padding_idx).sum(1)
                with torch.no_grad(): r=esm_model(bt,repr_layers=[33],return_contacts=False)
                es=r["representations"][33][0,1:bl-1].mean(dim=0).float()
            except Exception:
                continue
            try:
                inp=tok([smi],return_tensors='pt',padding=True,truncation=False); inp={k:v.to(device) for k,v in inp.items()}
                if inp['input_ids'].shape[1]>MAX_SMILES_LENGTH: continue
                with torch.no_grad(): out=chem(**inp)
                cs=out.last_hidden_state.mean(dim=1).squeeze(0).float()
            except Exception:
                continue
            if es.dim()>1: es=es.flatten()
            if cs.dim()>1: cs=cs.flatten()
            embs.append(torch.cat((es,cs))); keep.append(i-1)
        if not embs:
            self.X_test_tensor=None; self.y1_test_tensor=None; self.keep_local=[]
            return
        self.X_test_tensor=torch.stack(embs).to(device); self.y1_test_tensor=torch.zeros(len(embs),dtype=torch.long).to(device); self.keep_local=keep
    def standardize_test_data(self,g1,s1,g2,s2):
        self.test_dataset_std=apply_global_standardization_separate(TensorDataset(self.X_test_tensor,self.y1_test_tensor),g1,s1,g2,s2) if self.X_test_tensor is not None else None
    def convert_to_numpy(self):
        if self.test_dataset_std is None: return None,None,[]
        X,y=dataset_to_tensors(self.test_dataset_std); return X.cpu().numpy(),y.cpu().numpy(),self.keep_local
    def predict(self,X): return self.model.predict(X)
    def display_prediction_ranges_kcat(self,preds,cr):
        for i,p in enumerate(preds):
            if p is None or p=="skipped": print(f"Sample {i+1}: skipped due to excessive length")
            else: print(f"Sample {i+1}: Predicted Class = {p}, kcat range = [{format_sci(cr[p]['low'])}, {format_sci(cr[p]['high'])}]")
    def display_prediction_ranges_km(self,preds,cr):
        for i,p in enumerate(preds):
            if p is None or p=="skipped": print(f"Sample {i+1}: skipped due to excessive length")
            else: print(f"Sample {i+1}: Predicted Class = {p}, km range = [{format_sci(cr[p]['low'])}, {format_sci(cr[p]['high'])}]")
pairs_all=pairs
valid=[i for i,(s,m) in enumerate(pairs_all) if isinstance(s,str) and isinstance(m,str) and len(s)<=MAX_SEQ_LENGTH and len(m)<=MAX_SMILES_LENGTH]
if not valid:
    df["Predicted_Kcat_low"]="skipped"; df["Predicted_Kcat_high"]="skipped"; df["Predicted_KM_low"]="skipped"; df["Predicted_KM_high"]="skipped"
    df.to_csv("inference_results.csv",index=False); print("Inference complete. Saved inference_results.csv")
    if MODE in ("Bulk","Bulk-large") and IS_COLAB: _cf.download("inference_results.csv")
    sys.exit(0)
def run_batch(idxs):
    b=[pairs_all[i] for i in idxs]
    inf=KcatInference(model_path=kcat_model_path,device=device,verbose=False)
    inf.load_data_from_pairs(b)
    if not inf.keep_local:
        return [],[],[]
    inf.standardize_test_data(global_mean_1,global_std_1,global_mean_2,global_std_2)
    X,_,keep_local=inf.convert_to_numpy()
    yk=KcatInference(model_path=kcat_model_path,device=device,verbose=False).predict(X)
    ym=KcatInference(model_path=km_model_path,device=device,verbose=False).predict(X)
    del inf, X; gc.collect()
    return yk,ym,keep_local
N=len(df)
kcat_low_full=['skipped']*N; kcat_high_full=['skipped']*N; kcat_low_m1=['skipped']*N; kcat_high_p1=['skipped']*N
km_low_full=['skipped']*N; km_high_full=['skipped']*N; km_low_m1=['skipped']*N; km_high_p1=['skipped']*N
# added section for predicted class numbers - by mariana
kcat_class_full=['skipped']*N
km_class_full=['skipped']*N
# end of added section
max_kcat_c=max(class_ranges_kcat.keys()); max_km_c=max(class_ranges_km.keys())
if MODE=="Bulk-large":
    total_batches = math.ceil(len(valid)/BATCH_SIZE)
    for b in tqdm(range(0, len(valid), BATCH_SIZE), desc="Batches", unit="batch", dynamic_ncols=True, file=sys.stdout, miniters=1):
        idxs=valid[b:b+BATCH_SIZE]; yk,ym,keep=run_batch(idxs)
        if not keep:
            continue
        for j,k in enumerate(keep):
            i0=idxs[k]; ck=int(yk[j]); cm=int(ym[j])
            # added section to store class numbers - by mariana
            kcat_class_full[i0]=ck
            km_class_full[i0]=cm
            # end of added section
            kcat_low_full[i0]=class_ranges_kcat[ck]["low"]; kcat_high_full[i0]=class_ranges_kcat[ck]["high"]
            kcat_low_m1[i0]=class_ranges_kcat[max(ck-1,0)]["low"]; kcat_high_p1[i0]=class_ranges_kcat[min(ck+1,max_kcat_c)]["high"]
            km_low_full[i0]=class_ranges_km[cm]["low"]; km_high_full[i0]=class_ranges_km[cm]["high"]
            km_low_m1[i0]=class_ranges_km[max(cm-1,0)]["low"]; km_high_p1[i0]=class_ranges_km[min(cm+1,max_km_c)]["high"]
        gc.collect(); sys.stdout.flush()
else:
    yk,ym,keep=run_batch(valid)
    if keep:
        for j,k in enumerate(keep):
            i0=valid[k]; ck=int(yk[j]); cm=int(ym[j])
            # added section to store class numbers - by mariana
            kcat_class_full[i0]=ck
            km_class_full[i0]=cm
            # end of added section
            kcat_low_full[i0]=class_ranges_kcat[ck]["low"]; kcat_high_full[i0]=class_ranges_kcat[ck]["high"]
            km_low_full[i0]=class_ranges_km[cm]["low"]; km_high_full[i0]=class_ranges_km[cm]["high"]
        if MODE in ("Demo","Interactive"):
            final_kcat_predictions=["skipped"]*N; final_km_predictions=["skipped"]*N
            for j,k in enumerate(keep):
                i0=valid[k]; final_kcat_predictions[i0]=int(yk[j]); final_km_predictions[i0]=int(ym[j])
            kcat_inference_print=KcatInference(model_path=kcat_model_path,device=device,verbose=False)
            km_inference_print=KcatInference(model_path=km_model_path,device=device,verbose=False)
            print("\n=== kcat Prediction Results ==="); sys.stdout.flush()
            kcat_inference_print.display_prediction_ranges_kcat(final_kcat_predictions,class_ranges_kcat); sys.stdout.flush()
            print("\n=== KM Prediction Results ==="); sys.stdout.flush()
            km_inference_print.display_prediction_ranges_km(final_km_predictions,class_ranges_km); sys.stdout.flush()
# added section to add kcat predicted class numbers - by mariana
df['Predicted_Kcat_class']=kcat_class_full
# end of added section
df['Predicted_Kcat_low']=kcat_low_full
df['Predicted_Kcat_high']=kcat_high_full
df['Predicted_Kcat_low (-1 class error)']=kcat_low_m1
df['Predicted_Kcat_high (+1 class error)']=kcat_high_p1
# added section to add km predicted class numbers - by mariana
df['Predicted_KM_class']=km_class_full
# end of added section
df['Predicted_KM_low']=km_low_full
df['Predicted_KM_high']=km_high_full
df['Predicted_KM_low (-1 class error)']=km_low_m1
df['Predicted_KM_high (+1 class error)']=km_high_p1
df.to_csv("inference_results.csv",index=False)
print("Inference complete. Saved inference_results.csv"); sys.stdout.flush()
# try:
#     if MODE=="Bulk-large" and IS_COLAB:
#         _cf.download("inference_results.csv")
#     if MODE=="Bulk" and IS_COLAB and os.environ.get("DOWNLOAD_CSV","")=="1":
#         _cf.download("inference_results.csv")
# except Exception as e:
#     print(f"[warn] Auto-download failed in subprocess: {e}")
PY
fi

[info] MODE=Bulk-large  MUTATION_MODE=Mechanistic Mutation-Aware (default)  CSV_PATH=/content/infer_input.csv  BATCH_SIZE=256
[RealKcat] Running Mechanistic Mutation-Aware mode (recommended, ESM-C)...
[setup] Installing system deps for Python 3.12…
[setup] Installing Python deps…
[run] Starting mechanistic inference… MODE=Bulk-large  CSV_PATH=/content/infer_input.csv  BATCH_SIZE=256  ENCODER=esmc  ALLOW_ENCODER_MISMATCH=1
Using device: cuda
3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
Starting inference setup…
Batches:   0%|          | 0/20 [00:00<?, ?batch/s][ESMC] Weights missing; downloading from HuggingFace…
Batches:   5%|▌         | 1/20 [03:22<1:04:03, 202.30s/batch]Batches:  10%|█         | 2/20 [04:18<34:59, 116.64s/batch]  Batches:  15%|█▌        | 3/20 [05:14<25:08, 88.74s/batch] Batches:  20%|██        | 4/20 [06:10<20:12, 75.79s/batch]Batches:  25%|██▌       | 5/20 [07:06<17:08, 68.57s/batch]Batches:  30%|███       | 6/20 [08:01<14:57, 64.09s/batch]Batches:  3

debconf: unable to initialize frontend: Dialog
debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 78, <> line 4.)
debconf: falling back to frontend: Readline
debconf: unable to initialize frontend: Readline
debconf: (This frontend requires a controlling tty.)
debconf: falling back to frontend: Teletype
dpkg-preconfigure: unable to re-open stdin: 
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tobler 0.13.0 requires joblib>=1.4, but you have joblib 1.2.0 which is incompatible.
numba 0.60.0 requires numpy<2.1,>=1.22, but you have numpy 2.1.0 which is incompatible.

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s][A
Fetching 4 files:  25%|██▌       | 1/4 [00:00<00:01,  2.63it/s][A
Fetching 4 files: 100%|██████████| 4/4 [00:52<00:00, 14.18s/it][AFetchi

In [None]:
#@title Download inference_results.csv to your computer
import os
from google.colab import files
if os.path.exists("inference_results.csv"):
    files.download("inference_results.csv")
else:
    print("inference_results.csv not found — run the cell above first.")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>