<a href="https://colab.research.google.com/github/pablo-arantes/Cloud-Bind/blob/main/Boltz_GNINA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Hi there!**

This notebook is part of **Cloud-Bind** and provides an end-to-end workflow for **protein–ligand structure prediction + docking** in Google Colab.

The main goal is to show how cloud compute can be used to run an accessible structure-based pipeline for **co-folding** and **re-docking**.

---

**This notebook is NOT a standard protocol for docking or scoring.** It is an instructional pipeline meant to illustrate common steps and a practical stack in a Colab setting.

**Runtime Note**: A **GPU runtime is strongly recommended**. CPU runs can be extremely slow for Boltz-2.

---

**Bugs**
- If you encounter any bugs, please report the issue to https://github.com/pablo-arantes/Cloud-Bind/issues

**Acknowledgments**
- **Core Pipeline**: We thank the developers of [Boltz](https://github.com/jwohlwend/boltz) (structure prediction), [ColabFold](https://github.com/sokrypton/ColabFold) (MSA server), [P2Rank](https://github.com/rdk/p2rank) (pocket detection), and [GNINA](https://github.com/gnina/gnina) (docking/scoring).
- **Cheminformatics & QC**: We thank the teams behind [Open Babel](https://github.com/openbabel/openbabel) (conversion), [RDKit](https://github.com/rdkit/rdkit) (ligand handling), [Dimorphite-DL](https://github.com/durrantlab/dimorphite_dl) (protonation), [PoseBusters](https://github.com/maabuu/posebusters) (QC), [ProLIF](https://github.com/chemosim-lab/ProLIF) (interactions), [MDAnalysis](https://github.com/MDAnalysis/mdanalysis) (parsing), and [Biopython](https://github.com/biopython/biopython).
- **Visualization & Runtime**: Credit to [py3Dmol](https://3dmol.csb.pitt.edu/) for visualization and [PyTorch](https://github.com/pytorch/pytorch) for runtime dependencies.
- **Cloud-Bind Team**: **Pablo R. Arantes** ([@pablitoarantes](https://twitter.com/pablitoarantes)), **Conrado Pedebos** ([@ConradoPedebos](https://twitter.com/ConradoPedebos)), **Rodrigo Ligabue-Braun** ([@ligabue_braun](https://twitter.com/ligabue_braun)), **Davidt da Silva Tarouco**, and **Saul J. Flores** ([@saulfloresjr](https://www.linkedin.com/in/saulfloresjr/)).

- For related notebooks see: [Cloud-Bind](https://github.com/pablo-arantes/Cloud-Bind)

## **A. Runtime & project folder**


In [None]:
#@title **A1) Initialize project folders**

from pathlib import Path
import os, json, datetime

# ----------------------------
# Project naming
# ----------------------------
PROJECT_NAME = ""  #@param {type:"string"}
PROJECT_ROOT = Path(f"/content/{PROJECT_NAME}")

TOOLS_DIR = Path("/content/cloudbind_tools")

RUNS_DIR = PROJECT_ROOT / "runs"

INSTALL_LOGS_DIR = PROJECT_ROOT / "logs"   # installs / environment
# Per-run logs live in: runs/<run_name>/logs/

# ----------------------------
# Create folders
# ----------------------------
for p in [PROJECT_ROOT, TOOLS_DIR, RUNS_DIR, INSTALL_LOGS_DIR]:
    p.mkdir(parents=True, exist_ok=True)

os.environ["PATH"] = f"{TOOLS_DIR}:{os.environ.get('PATH','')}"

# bootstrap
BOOTSTRAP_PATH = Path("/content/.cloudbind_bootstrap.json")
bootstrap = {
    "created_at": datetime.datetime.now().isoformat(),
    "project_name": PROJECT_NAME,
    "project_root": str(PROJECT_ROOT),
    "runs_dir": str(RUNS_DIR),
    "install_logs_dir": str(INSTALL_LOGS_DIR),
    "tools_dir": str(TOOLS_DIR),
}
BOOTSTRAP_PATH.write_text(json.dumps(bootstrap, indent=2))

print(f"Project root : {PROJECT_ROOT}")
print(f"Runs dir     : {RUNS_DIR}")
print(f"Tools dir    : {TOOLS_DIR}  (excluded from export zip)")
print(f"Install logs : {INSTALL_LOGS_DIR}")
print(f"Bootstrap    : {BOOTSTRAP_PATH}")


## **B. Install Dependencies**


In [None]:
#@title **B1) Install Boltz-2, P2Rank, GNINA, and analysis tools**

#@markdown Once this cell triggers a runtime restart, **run it again once** after the restart to finish installation.

import os, subprocess, json, sys, shlex
from pathlib import Path

# rehydrate paths
BOOTSTRAP_PATH = Path("/content/.cloudbind_bootstrap.json")
if "INSTALL_LOGS_DIR" not in globals() and BOOTSTRAP_PATH.exists():
    _boot = json.loads(BOOTSTRAP_PATH.read_text())
    INSTALL_LOGS_DIR = Path(_boot.get("install_logs_dir", "/content/cloudbind_project/logs"))
    INSTALL_LOGS_DIR.mkdir(parents=True, exist_ok=True)

# --------- pinned versions (edit if needed) ----------
BOLTZ_VERSION  = "2.2.1"
P2RANK_VERSION = "2.5.1"
GNINA_VERSION  = "1.3.2"

TOOLS_DIR = Path("/content/cloudbind_tools")
TOOLS_DIR.mkdir(parents=True, exist_ok=True)
os.environ["PATH"] = f"{TOOLS_DIR}:{os.environ.get('PATH','')}"

install_log = INSTALL_LOGS_DIR / "01_install.log"
stamp_path  = TOOLS_DIR / f"install_stamp_boltz{BOLTZ_VERSION}_p2rank{P2RANK_VERSION}_gnina{GNINA_VERSION}.json"
post_restart_flag = TOOLS_DIR / ".post_install_restart_done"

def _tail(path: Path, n=60):
    if not path.exists():
        return ""
    lines = path.read_text(errors="ignore").splitlines()
    return "\n".join(lines[-n:])

def run_cmd(cmd: str, log_path: Path, cwd: Path | None = None):
    log_path.parent.mkdir(parents=True, exist_ok=True)
    with log_path.open("a") as f:
        f.write(f"\n\n$ {cmd}\n")
        p = subprocess.run(cmd, shell=True, cwd=str(cwd) if cwd else None,
                           stdout=f, stderr=subprocess.STDOUT, text=True)
    if p.returncode != 0:
        raise RuntimeError(f"Command failed: {cmd}\n\n--- tail({log_path}) ---\n{_tail(log_path)}")

def have_file(path: Path) -> bool:
    return path.exists() and path.is_file()

# ----------------------------
# Detect runtime
# ----------------------------
def detect_runtime():
    if os.environ.get("COLAB_TPU_ADDR"):
        return "TPU"
    try:
        import torch
        if torch.cuda.is_available():
            return f"GPU ({torch.cuda.get_device_name(0)})"
    except Exception:
        pass
    return "CPU"

runtime = detect_runtime()
print(f"Runtime detected: {runtime}")
if "CPU" in runtime or "TPU" in runtime:
    print("Note: Boltz-2 protein-ligand inference is strongly recommended on GPU for speed.")

# ----------------------------
# Skip if already installed
# ----------------------------
p2rank_dir = TOOLS_DIR / f"p2rank_{P2RANK_VERSION}"
p2rank_prank = p2rank_dir / "prank"
gnina_bin = TOOLS_DIR / "gnina"

did_install = False
if stamp_path.exists() and have_file(gnina_bin) and have_file(p2rank_prank):
    print("Existing installation detected (stamp found). Skipping re-install.")
else:
    did_install = True

    # --------- system deps ----------
    run_cmd("apt-get -qq update", install_log)
    run_cmd("apt-get -qq install -y openjdk-17-jre-headless wget unzip tar rsync", install_log)

    # Open Babel
    run_cmd("apt-get -qq install -y openbabel", install_log)

    # --------- python deps ----------
    try:
        import torch
        has_cuda = torch.cuda.is_available()
    except Exception:
        has_cuda = False

    boltz_spec = f"boltz[cuda]=={BOLTZ_VERSION}" if has_cuda else f"boltz=={BOLTZ_VERSION}"

    # avoid pip upgrade
    run_cmd(f"python -m pip -q install '{boltz_spec}'", install_log)

    # viz and qc stack
    run_cmd("python -m pip -q install py3Dmol prolif posebusters MDAnalysis rdkit dimorphite-dl biopython", install_log)

    # --------- GNINA binary ----------
    run_cmd(f"wget -q -O {gnina_bin} https://github.com/gnina/gnina/releases/download/v{GNINA_VERSION}/gnina.{GNINA_VERSION}", install_log)
    run_cmd(f"chmod +x {gnina_bin}", install_log)

    # --------- P2Rank release ----------
    p2rank_tar = TOOLS_DIR / f"p2rank_{P2RANK_VERSION}.tar.gz"
    run_cmd(f"wget -q -O {p2rank_tar} https://github.com/rdk/p2rank/releases/download/{P2RANK_VERSION}/p2rank_{P2RANK_VERSION}.tar.gz", install_log)
    run_cmd(f"tar -xzf {p2rank_tar} -C {TOOLS_DIR}", install_log)

    # validate prank exists
    if not p2rank_prank.exists():
        raise RuntimeError(f"P2Rank installed but prank entrypoint not found at: {p2rank_prank}\nCheck install log: {install_log}")

    # write stamp
    stamp = {
        "timestamp": __import__("datetime").datetime.now().isoformat(),
        "boltz": BOLTZ_VERSION,
        "p2rank": P2RANK_VERSION,
        "gnina": GNINA_VERSION,
        "runtime": runtime,
        "tools_dir": str(TOOLS_DIR),
        "gnina_path": str(gnina_bin),
        "p2rank_dir": str(p2rank_dir),
        "p2rank_prank": str(p2rank_prank),
    }
    stamp_path.write_text(json.dumps(stamp, indent=2))

# update bootstrap
BOOTSTRAP_PATH = Path("/content/.cloudbind_bootstrap.json")
if BOOTSTRAP_PATH.exists():
    boot = json.loads(BOOTSTRAP_PATH.read_text())
else:
    boot = {}

boot.update({
    "boltz_version": BOLTZ_VERSION,
    "p2rank_version": P2RANK_VERSION,
    "gnina_version": GNINA_VERSION,
    "tools_dir": str(TOOLS_DIR),
    "gnina_path": str(gnina_bin),
    "p2rank_dir": str(p2rank_dir),
    "p2rank_prank": str(p2rank_prank),
    "install_stamp": str(stamp_path),
})
BOOTSTRAP_PATH.write_text(json.dumps(boot, indent=2))

print("Installed. Key tools:")
print(f"  boltz  : {BOLTZ_VERSION}")
print(f"  gnina  : {GNINA_VERSION}  -> {gnina_bin}")
print(f"  p2rank : {P2RANK_VERSION} -> {p2rank_prank}")
print(f"Install log: {install_log}")
print(f"Bootstrap: {BOOTSTRAP_PATH}")

# ----------------------------
# One-time post-install restart
# ----------------------------
if did_install and not post_restart_flag.exists():
    post_restart_flag.write_text("ok")
    print("\nColab: restarting runtime once to finalize binary wheels (prevents numpy/pandas ABI errors).")
    os.kill(os.getpid(), 9)

### **Shared helpers**

Logging and small display utilities used throughout the notebook.


In [None]:
#@title **Helpers**

import subprocess, warnings, json, os
from pathlib import Path

# ----------------------------
# Rehydrate bootstrap (lets you continue after the one-time install restart)
# ----------------------------
BOOTSTRAP_PATH = Path("/content/.cloudbind_bootstrap.json")
if BOOTSTRAP_PATH.exists():
    _boot = json.loads(BOOTSTRAP_PATH.read_text())
    PROJECT_NAME = _boot.get("project_name", "cloudbind_project")
    PROJECT_ROOT = Path(_boot.get("project_root", f"/content/{PROJECT_NAME}"))
    RUNS_DIR = Path(_boot.get("runs_dir", str(PROJECT_ROOT / "runs")))
    INSTALL_LOGS_DIR = Path(_boot.get("install_logs_dir", str(PROJECT_ROOT / "logs")))
    TOOLS_DIR = Path(_boot.get("tools_dir", "/content/cloudbind_tools"))

    # version check
    BOLTZ_VERSION  = _boot.get("boltz_version")
    P2RANK_VERSION = _boot.get("p2rank_version")
    GNINA_VERSION  = _boot.get("gnina_version")
else:
    raise FileNotFoundError("Bootstrap file not found. Run A1 once to initialize the project.")

os.environ["PATH"] = f"{TOOLS_DIR}:{os.environ.get('PATH','')}"

# load version if missing
if not (BOLTZ_VERSION and P2RANK_VERSION and GNINA_VERSION):
    stamp_files = sorted(TOOLS_DIR.glob("install_stamp_*.json"))
    if stamp_files:
        stamp = json.loads(stamp_files[-1].read_text())
        BOLTZ_VERSION  = BOLTZ_VERSION  or stamp.get("boltz")
        P2RANK_VERSION = P2RANK_VERSION or stamp.get("p2rank")
        GNINA_VERSION  = GNINA_VERSION  or stamp.get("gnina")

try:
    import pandas as pd
except ValueError as e:
    if "numpy.dtype size changed" in str(e):
        print("="*60)
        print("BINARY INCOMPATIBILITY DETECTED (numpy/pandas)")
        print("Fix: Runtime -> Restart session, then run Helpers and continue.")
        print("="*60)
    raise e

# ----------------------------
# logging control
# ----------------------------
warnings.filterwarnings("ignore", message=r"netCDF4 is not available.*")
warnings.filterwarnings("ignore", message=r"Unit cell dimensions not found.*")
warnings.filterwarnings("ignore", message=r".*CRYST1 record.*")
warnings.filterwarnings("ignore", message=r"Found no information for attr: 'formalcharges'.*")
warnings.filterwarnings("ignore", category=DeprecationWarning, message=r"MDAnalysis\.topology\.tables.*")

# rdkit warning logging
try:
    from rdkit import RDLogger
    RDLogger.DisableLog("rdApp.warning")
except Exception:
    pass

# mda logging
import logging
logging.getLogger("MDAnalysis").setLevel(logging.ERROR)
logging.getLogger("MDAnalysis.coordinates.AMBER").setLevel(logging.ERROR)

def tail(path: Path, n=60) -> str:
    if not path.exists():
        return ""
    lines = path.read_text(errors="ignore").splitlines()
    return "\n".join(lines[-n:])

def run_cmd(cmd: str, log_path: Path, cwd: Path | None = None):
    """Run a shell command quietly; append stdout/stderr to log_path."""
    log_path.parent.mkdir(parents=True, exist_ok=True)
    with log_path.open("a") as f:
        f.write(f"\n\n$ {cmd}\n")
        p = subprocess.run(cmd, shell=True, cwd=str(cwd) if cwd else None,
                           stdout=f, stderr=subprocess.STDOUT, text=True)
    if p.returncode != 0:
        raise RuntimeError(f"Command failed: {cmd}\n\n--- tail({log_path}) ---\n{tail(log_path)}")

def show_table(df: pd.DataFrame, max_rows=10):
    """Display a compact table without spamming the notebook."""
    from IPython.display import display
    display(df.head(max_rows))


## **C. Inputs**

In [None]:
#@title **C1) Inputs + Boltz preset**
#@markdown Required inputs.

RUN_NAME = ""  #@param {type:"string"}
# clear testing inputs
PROTEIN_SEQUENCE = ""  #@param {type:"string"}
LIGAND_SMILES = ""  #@param {type:"string"}

#@markdown **Boltz preset**
#@markdown - Fast: Demo / Testing, fewer sampling
#@markdown - Standard: Default, moderate sampling
#@markdown - Quality: Refinement, more sampling


BOLTZ_PRESET = "Fast"  #@param ["Fast","Standard","Quality"]

from pathlib import Path
import json

RUN_DIR = RUNS_DIR / RUN_NAME
for sub in ["inputs","boltz","p2rank","gnina","analysis","viz","logs","state","export"]:
    (RUN_DIR / sub).mkdir(parents=True, exist_ok=True)

# basic run settings
settings = {
    "run_name": RUN_NAME,
    "boltz_preset": BOLTZ_PRESET,
}
(RUN_DIR / "state" / "run_settings.json").write_text(json.dumps(settings, indent=2))

print(f"Run folder: {RUN_DIR}")


## **D. Boltz-2: predict the complex**



In [None]:
#@title **D1) Run Boltz-2**

import json, os, re
from pathlib import Path

# ---- Torch 2.6+ compatibility: allow OmegaConf DictConfig in trusted checkpoints ----
try:
    import torch, omegaconf
    torch.serialization.add_safe_globals([omegaconf.dictconfig.DictConfig])
except Exception as e:
    print(f"(torch safe-globals setup skipped: {e})")

# ---- sanitize inputs ----
seq = re.sub(r"\s+", "", PROTEIN_SEQUENCE.strip().upper())
if not seq:
    raise ValueError("Please paste a real PROTEIN_SEQUENCE (1-letter amino acids).")

smiles = LIGAND_SMILES.strip()
if not smiles:
    raise ValueError("Please provide a LIGAND_SMILES string.")

# ---- write Boltz YAML input (auto) ----
yaml_path = RUN_DIR / "boltz" / f"{RUN_NAME}.yaml"
yaml_text = f"""version: 1
sequences:
  - protein:
      id: A
      sequence: "{seq}"
  - ligand:
      id: L
      smiles: "{smiles}"
properties:
  - affinity:
      binder: L
"""
yaml_path.write_text(yaml_text)

# ---- runtime preset mapping ----
# Defaults in boltz docs are ~recycling_steps=3, sampling_steps=200 (slower, higher quality).
# We expose simplified presets for teaching/training.

if BOLTZ_PRESET == "Fast":
    params = {
        "recycling_steps": 1,
        "sampling_steps": 25,
        "diffusion_samples": 1,
        "sampling_steps_affinity": 25,
        "diffusion_samples_affinity": 1,
        "step_scale": 1.6,
        "max_msa_seqs": 1024,
        "num_subsampled_msa": 128,
        "subsample_msa": True
    }
elif BOLTZ_PRESET == "Standard":
    params = {
        "recycling_steps": 3,
        "sampling_steps": 200,
        "diffusion_samples": 1,
        "sampling_steps_affinity": 200,
        "diffusion_samples_affinity": 5,
        "step_scale": 1.638,
        "max_msa_seqs": 8192,
        "num_subsampled_msa": 1024,
        "subsample_msa": False
    }
else:  # Quality
    params = {
        "recycling_steps": 5,
        "sampling_steps": 300,
        "diffusion_samples": 2,
        "sampling_steps_affinity": 300,
        "diffusion_samples_affinity": 10,
        "step_scale": 1.638,
        "max_msa_seqs": 8192,
        "num_subsampled_msa": 1024,
        "subsample_msa": False
    }

(RUN_DIR / "state" / "boltz_params.json").write_text(json.dumps(params, indent=2))

# Determine accelerator
try:
    import torch
    if torch.cuda.is_available():
        accel = "gpu"
    else:
        accel = "cpu"
except Exception:
    accel = "cpu"

# TPU is not supported for this pipeline
if "COLAB_TPU_ADDR" in os.environ or "TPU_NAME" in os.environ:
    raise RuntimeError("TPU runtime detected. Please switch to a GPU runtime.")

if accel != "gpu":
    print("Note: GPU strongly recommended. CPU runs can be very slow for real systems.")

out_dir = RUN_DIR / "boltz" / "out"
out_dir.mkdir(exist_ok=True)

boltz_log = RUN_DIR / "logs" / "02_boltz.log"

cmd_parts = [
    "boltz predict",
    str(yaml_path),
    f"--out_dir {out_dir}",
    "--use_msa_server",
    f"--accelerator {accel}",
    "--devices 1",
    f"--recycling_steps {int(params['recycling_steps'])}",
    f"--sampling_steps {int(params['sampling_steps'])}",
    f"--diffusion_samples {int(params['diffusion_samples'])}",
    f"--step_scale {float(params['step_scale'])}",
    f"--sampling_steps_affinity {int(params['sampling_steps_affinity'])}",
    f"--diffusion_samples_affinity {int(params['diffusion_samples_affinity'])}",
    "--output_format pdb",
]

# MSA limiting (still uses the MSA server)
if params.get("subsample_msa", False):
    cmd_parts += [
        "--subsample_msa",
        f"--num_subsampled_msa {int(params['num_subsampled_msa'])}",
        f"--max_msa_seqs {int(params['max_msa_seqs'])}",
    ]

cmd = " ".join(cmd_parts)
run_cmd(cmd, boltz_log)

print("Boltz complete.")
print(f"  Preset: {BOLTZ_PRESET}")
print(f"  YAML  : {yaml_path}")
print(f"  Out   : {out_dir}")
print(f"  Log   : {boltz_log}")

In [None]:
#@title **D2) Extract receptor, Boltz pose, and affinity**
#@markdown Temperature is for conversion only.

TEMPERATURE_K = 298.15  #@param {type:"number"}

import json, math, shutil
from pathlib import Path
import MDAnalysis as mda

out_dir = RUN_DIR / "boltz" / "out"

# --- locate model file ---
pdb_candidates = list(out_dir.rglob("*_model_0.pdb")) + list(out_dir.rglob("*model_0.pdb"))
if not pdb_candidates:
    raise FileNotFoundError(f"No Boltz PDB outputs found under: {out_dir}")
model_pdb = sorted(pdb_candidates)[0]

# copy
complex_pdb = RUN_DIR / "analysis" / "boltz_complex.pdb"
shutil.copyfile(model_pdb, complex_pdb)

# --- locate affinity JSON ---
aff_json = None
json_candidates = list(out_dir.rglob("*.json"))
# prep aff files
for p in sorted(json_candidates):
    if "affinity" in p.name.lower():
        aff_json = p
        break
if aff_json is None:
    # look for keys containing affinity_pred_value fallback
    for p in sorted(json_candidates):
        try:
            data = json.loads(p.read_text())
            if any("affinity_pred_value" in k for k in data.keys()):
                aff_json = p
                break
        except Exception:
            continue

if aff_json is None:
    boltz_log = RUN_DIR / "logs" / "02_boltz.log"
    raise FileNotFoundError(
        "Affinity JSON not found in Boltz output. This notebook expects affinity outputs.\n"
        f"Check: {out_dir}\n"
        f"Tail of boltz log:\n{tail(boltz_log, n=60)}"
    )

# create a copy
affinity_json_out = RUN_DIR / "analysis" / "boltz_affinity.json"
shutil.copyfile(aff_json, affinity_json_out)

# --- parse affinity_pred_value (Boltz reports log10(IC50 in µM)) ---
data = json.loads(affinity_json_out.read_text())
affinity_value = None
affinity_prob = None

for k, v in data.items():
    if "affinity_pred_value" in k and isinstance(v, (int, float)):
        affinity_value = float(v)
        break
for k, v in data.items():
    if "affinity_probability_binary" in k and isinstance(v, (int, float)):
        affinity_prob = float(v)
        break

if affinity_value is None:
    raise RuntimeError(f"Could not find affinity_pred_value in: {affinity_json_out.name}")

ic50_uM = 10**affinity_value
R_kcal = 0.0019872041
dG = R_kcal * float(TEMPERATURE_K) * math.log(ic50_uM * 1e-6)  # ΔG° = RT ln(Kd), Kd in M

summary = {
    "run": RUN_NAME,
    "complex_pdb": str(complex_pdb),
    "affinity_log10_ic50_uM": affinity_value,
    "ic50_uM": ic50_uM,
    "ic50_nM": ic50_uM * 1000.0,
    "deltaG_kcal_per_mol": dG,
    "temperature_K": float(TEMPERATURE_K),
}
if affinity_prob is not None:
    summary["affinity_probability_binary"] = affinity_prob

# --- split protein + ligand from complex ---
u = mda.Universe(str(complex_pdb))

protein = u.select_atoms("protein")
if len(protein) == 0:
    raise RuntimeError("Could not identify protein atoms in the Boltz PDB output.")

# Ligand: chainID L (from YAML) if present
lig = u.select_atoms("chainID L and not protein and not nucleic and not resname HOH WAT")
if len(lig) == 0:
    lig = u.select_atoms("not protein and not nucleic and not resname HOH WAT and not name NA CL K CA MG ZN")

receptor_pdb = RUN_DIR / "analysis" / "receptor.pdb"
boltz_pose_pdb = RUN_DIR / "analysis" / "boltz_ligand_pose.pdb"

protein.write(str(receptor_pdb))
lig.write(str(boltz_pose_pdb))

# convert lig to sdf
extract_log = RUN_DIR / "logs" / "03_extract.log"
boltz_pose_sdf = RUN_DIR / "analysis" / "boltz_ligand_pose.sdf"
run_cmd(f"obabel -ipdb {boltz_pose_pdb} -osdf -O {boltz_pose_sdf}", extract_log)


summary_path = RUN_DIR / "analysis" / "boltz_summary.json"
summary_path.write_text(json.dumps(summary, indent=2))

# key paths-
state = {
    "complex_pdb": str(complex_pdb),
    "receptor_pdb": str(receptor_pdb),
    "boltz_pose_pdb": str(boltz_pose_pdb),
    "boltz_pose_sdf": str(boltz_pose_sdf),
    "affinity_json": str(affinity_json_out),
}
(RUN_DIR / "state" / "paths.json").write_text(json.dumps(state, indent=2))

print("Extracted:")
print(f"  Complex PDB : {complex_pdb}")
print(f"  Receptor PDB: {receptor_pdb}")
print(f"  Boltz pose  : {boltz_pose_pdb}")
print(f"  Boltz pose  : {boltz_pose_sdf}")
print(f"  Affinity JSON: {affinity_json_out}")

print(f"\nAffinity (Boltz): log10(IC50 µM) = {affinity_value:.3f}")
print(f"  IC50 ~ {ic50_uM:.3g} µM  ({ic50_uM*1000:.3g} nM)")
print(f"  ΔG°({TEMPERATURE_K} K) ~ {dG:.2f} kcal/mol  (approx.)")
if affinity_prob is not None:
    print(f"  P(binder) ~ {affinity_prob:.3f}")

print(f"\nLogs: {extract_log}")


In [None]:
#@title **D3) Visualize Boltz complex (py3Dmol)**

import py3Dmol
from pathlib import Path

receptor_pdb = RUN_DIR / "analysis" / "receptor.pdb"
boltz_pose_sdf = RUN_DIR / "analysis" / "boltz_ligand_pose.sdf"


PINE_GREEN = "#01796f"  # Boltz pose color

view = py3Dmol.view(width=900, height=520)
view.setViewStyle({'style':'outline','color':'black','width':0.05})

# receptor viz
view.addModel(open(receptor_pdb).read(), "pdb")
view.setStyle({"cartoon": {"color": "white"}})

# surface viz
view.addSurface(py3Dmol.VDW, {'opacity': 0.6, 'color': 'silver'})

# ligand viz
view.addModel(open(boltz_pose_sdf).read(), "sdf")
view.setStyle({"model": 1}, {"stick": {"color": PINE_GREEN, "radius": 0.25}})

view.zoomTo()
view.show()

print("Boltz complex preview (protein + predicted ligand pose).")

## **E. P2Rank: pocket hypotheses**


In [None]:
#@title **E1) Run P2Rank + show top pockets**

import pandas as pd
from pathlib import Path
import shutil

receptor_pdb = RUN_DIR / "analysis" / "receptor.pdb"
p2rank_out = RUN_DIR / "p2rank"
p2rank_log = RUN_DIR / "logs" / "04_p2rank.log"

# --- call P2Rank ---
p2rank_version_dir = TOOLS_DIR / f"p2rank_{P2RANK_VERSION}"
if not p2rank_version_dir.exists():
    raise FileNotFoundError(f"P2Rank installation directory not found: {p2rank_version_dir}. Re-run B1.")

prank_executable_path = p2rank_version_dir / "prank"
if not prank_executable_path.exists():
    raise FileNotFoundError(f"P2Rank 'prank' executable not found at: {prank_executable_path}. Re-run B1.")

run_cmd(f"{prank_executable_path} predict -f {receptor_pdb} -o {p2rank_out} -c alphafold", p2rank_log, cwd=p2rank_version_dir)

csv_candidates = list(p2rank_out.rglob("*_predictions.csv"))
if not csv_candidates:
    raise FileNotFoundError(f"No *_predictions.csv found under: {p2rank_out}\nCheck logs at: {p2rank_log}")
pred_csv = sorted(csv_candidates)[0]

df = pd.read_csv(pred_csv)

def pick_col(substrs):
    for s in substrs:
        for c in df.columns:
            if s in c.lower():
                return c
    return None

col_score = pick_col(["score", "probability", "pocket_score"])
col_res   = pick_col(["residue", "residues"])
col_cx    = pick_col(["center_x", "centerx", "center x"])
col_cy    = pick_col(["center_y", "centery", "center y"])
col_cz    = pick_col(["center_z", "centerz", "center z"])

df_disp = df.copy()
df_disp.insert(0, "pocket_rank", range(1, len(df_disp)+1))

keep_cols = ["pocket_rank"]
for c in [col_score, col_cx, col_cy, col_cz, col_res]:
    if c and c not in keep_cols:
        keep_cols.append(c)

p2rank_csv_out = RUN_DIR / "analysis" / "p2rank_pockets.csv"
df_disp[keep_cols].to_csv(p2rank_csv_out, index=False)

print("Top pockets (P2Rank):")
show_table(df_disp[keep_cols], max_rows=5)
print(f"\nSaved: {p2rank_csv_out}")
print(f"P2Rank log: {p2rank_log}")


In [None]:
#@title **E2) Select pocket for GNINA (and visualize)**
#@markdown Select a pocket from the P2Rank table or manual selection.

POCKET_MODE = "choose_rank"  #@param ["auto_top1","choose_rank","manual_residues"]
POCKET_RANK = 1  #@param {type:"integer"}
#@markdown Manual residues examples: `A:123, A:125-130` or `A:45-60`

POCKET_RESIDUES = ""  #@param {type:"string"}

import pandas as pd
import json
import re
import MDAnalysis as mda
import py3Dmol
from pathlib import Path

receptor_pdb = RUN_DIR / "analysis" / "receptor.pdb"
boltz_pose_sdf = RUN_DIR / "analysis" / "boltz_ligand_pose.sdf"
pockets_csv = RUN_DIR / "analysis" / "p2rank_pockets.csv"

if not pockets_csv.exists():
    raise FileNotFoundError("Missing p2rank_pockets.csv. Run E1 first.")

df = pd.read_csv(pockets_csv)

# residue-list column
res_col = None
for c in df.columns:
    if "res" in c.lower():
        res_col = c
        break

def parse_residue_spec(spec: str):
    """Parse strings like 'A:123, A:125-130'. Returns list of (chain, resid)."""
    out=[]
    spec = spec.replace(" ", "")
    if not spec:
        return out
    parts = [p for p in spec.split(",") if p]
    for p in parts:
        if ":" not in p:
            raise ValueError(f"Residue '{p}' must include a chain, e.g. A:123")
        chain, rng = p.split(":", 1)
        if "-" in rng:
            a,b = rng.split("-",1)
            for r in range(int(a), int(b)+1):
                out.append((chain, r))
        else:
            out.append((chain, int(rng)))
    return out

# residue select
if POCKET_MODE == "auto_top1":
    pocket_rank = 1
elif POCKET_MODE == "choose_rank":
    pocket_rank = int(POCKET_RANK)
else:
    pocket_rank = None  # manual

if pocket_rank is not None:
    if pocket_rank < 1 or pocket_rank > len(df):
        raise ValueError(f"Pocket rank {pocket_rank} is out of range (1..{len(df)}).")
    residues_blob = str(df.loc[pocket_rank-1, res_col]) if res_col else ""
    tokens = re.split(r"[,\s]+", residues_blob.strip())
    residues=[]
    for t in tokens:
        if not t:
            continue
        if ":" in t:
            ch, r = t.split(":", 1)
        elif "_" in t:
            ch, r = t.split("_", 1)
        else:
            continue
        try:
            residues.append((ch, int(re.sub(r"\D","", r))))
        except ValueError:
            pass
else:
    residues = parse_residue_spec(POCKET_RESIDUES)

if not residues:
    raise RuntimeError("No pocket residues found/selected. Try another pocket rank or manual residues.")

# pocket build
u = mda.Universe(str(receptor_pdb))
sel_parts = [f"(chainID {ch} and resid {rid})" for ch,rid in residues]
sel = " or ".join(sel_parts)
pocket_atoms = u.select_atoms(sel)

pocket_pdb = RUN_DIR / "analysis" / "pocket_residues.pdb"
pocket_atoms.write(str(pocket_pdb))

sel_state = {
    "pocket_mode": POCKET_MODE,
    "pocket_rank": pocket_rank,
    "n_residues": len(residues),
    "pocket_pdb": str(pocket_pdb),
}
(RUN_DIR / "state" / "pocket.json").write_text(json.dumps(sel_state, indent=2))

print(f"Selected pocket residues: {len(residues)} residues")
print(f"Pocket residue PDB: {pocket_pdb}")

# --- viz (protein + Boltz pose + selected pocket residues) ---
PINE_GREEN = "#01796f"  # Boltz pose
ROSE_RED = "#ff033e"    # GNINA poses (used later)

view = py3Dmol.view(width=900, height=520)
view.setViewStyle({'style':'outline','color':'black','width':0.05})

view.addModel(open(receptor_pdb).read(), "pdb")
view.setStyle({"cartoon": {"color": "#fffafa"}})

view.addModel(open(pocket_pdb).read(), "pdb")
view.setStyle({"model": 1}, {"stick": {"color": "#cc9900", "radius": 0.25}})

view.addModel(open(boltz_pose_sdf).read(), "sdf")
view.setStyle({"model": 2}, {"stick": {"color": PINE_GREEN, "radius": 0.28}})

view.zoomTo()
view.show()

print("Pocket preview: yellow = selected pocket residues, green = Boltz ligand pose.")


## **F. GNINA: re-dock + re-score**


In [None]:
#@title **F1) Run GNINA docking + re-score the Boltz pose**
#@markdown Dock into the **selected pocket** (E2), then **re-score** the original Boltz pose once for comparison.

#@markdown **Feel free to adjust the parameters below.**
GNINA_POSES = 10  #@param {type:"integer"}
GNINA_EXHAUSTIVENESS = 8  #@param {type:"integer"}
AUTOBOX_ADD = 4.0  #@param {type:"number"}
PH = 7.4  #@param {type:"number"}

import json
import pandas as pd
from pathlib import Path
import shutil
import re
from rdkit import Chem
from rdkit.Chem import AllChem
try:
    from IPython.display import display
except ImportError:
    pass

# logging
gnina_log = RUN_DIR / "logs" / "05_gnina.log"
gnina_score_log = RUN_DIR / "logs" / "05_gnina_score_only.log"
ligprep_log = RUN_DIR / "logs" / "05_ligprep.log"

receptor_pdb = RUN_DIR / "analysis" / "receptor.pdb"
boltz_pose_sdf = RUN_DIR / "analysis" / "boltz_ligand_pose.sdf"

if "pocket_pdb" not in globals():
    raise RuntimeError("Missing pocket selection. Run E2 first.")
pocket_pdb = Path(pocket_pdb)

# ------------------------------------------------------------------------
# troubleshooting -PREPARE RECEPTOR (Protonate + Fix UNK Labels)
# ------------------------------------------------------------------------
receptor_dock_pdb = RUN_DIR / "gnina" / "receptor_ph.pdb"
receptor_dock_pdb.parent.mkdir(parents=True, exist_ok=True)

print(f"Protonating receptor at pH {PH}...")
run_cmd(f"obabel -ipdb {receptor_pdb} -opdb -O {receptor_dock_pdb} -p {PH}", ligprep_log)

# PATCH with bipoython
try:
    from Bio import PDB
    #print("Running residue name patch (fixing 'UNK' labels)...")
    parser = PDB.PDBParser(QUIET=True)
    ref_structure = parser.get_structure("ref", str(receptor_pdb))
    tgt_structure = parser.get_structure("tgt", str(receptor_dock_pdb))

    res_name_map = {}
    for model in ref_structure:
        for chain in model:
            for residue in chain:
                key = (chain.id, residue.id[1], residue.id[2])
                res_name_map[key] = residue.resname

    count_fixed = 0
    for model in tgt_structure:
        for chain in model:
            for residue in chain:
                key = (chain.id, residue.id[1], residue.id[2])
                if key in res_name_map:
                    original_name = res_name_map[key]
                    if residue.resname == "UNK" or residue.resname != original_name:
                        residue.resname = original_name
                        count_fixed += 1

    receptor_dock_pdb_fixed = RUN_DIR / "gnina" / "receptor_ph_fixed.pdb"
    io = PDB.PDBIO()
    io.set_structure(tgt_structure)
    io.save(str(receptor_dock_pdb_fixed))
    receptor_dock_pdb = receptor_dock_pdb_fixed
    #print(f"   Patch applied. Restored {count_fixed} residue names.")

except ImportError:
    print("   BioPython not installed. Skipping 'UNK' patch.")
except Exception as e:
    print(f"   Patch failed: {e}. Proceeding with unpatched file.")

# ------------------------------------------------------------------------
# PREPARE LIGAND
# ------------------------------------------------------------------------
lig_in_sdf = RUN_DIR / "gnina" / "ligand_mmff.sdf"
smiles = LIGAND_SMILES.strip()
ph_smiles = smiles

try:
    from dimorphite_dl import protonate_smiles
    prot_list = protonate_smiles(smiles, ph_min=PH-0.5, ph_max=PH+0.5)
    if prot_list: ph_smiles = prot_list[0]
except:
    pass

mol = Chem.MolFromSmiles(ph_smiles)
if mol:
    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol, AllChem.ETKDGv3())
    try: AllChem.MMFFOptimizeMolecule(mol, mmffVariant="MMFF94s")
    except: AllChem.UFFOptimizeMolecule(mol)
    with Chem.SDWriter(str(lig_in_sdf)) as w:
        w.write(mol)
else:
    raise ValueError("Failed to process ligand SMILES.")

# ------------------------------------------------------------------------
# RUN GNINA (Docking)
# ------------------------------------------------------------------------
dock_sdf = RUN_DIR / "gnina" / "gnina_docked.sdf"
summary_csv = RUN_DIR / "analysis" / "gnina_pose_summary.csv"

print("Running GNINA docking...")
gnina_cmd = (
    f"gnina --receptor {receptor_dock_pdb} --ligand {lig_in_sdf} "
    f"--autobox_ligand {pocket_pdb} --autobox_add {AUTOBOX_ADD} "
    f"--num_modes {GNINA_POSES} --exhaustiveness {GNINA_EXHAUSTIVENESS} "
    f"--out {dock_sdf}"
)
run_cmd(gnina_cmd, gnina_log)

# ------------------------------------------------------------------------
# SCORE BOLTZ POSE (Score Only)
# ------------------------------------------------------------------------
boltz_pose_ph_sdf = RUN_DIR / "analysis" / "boltz_ligand_pose_ph.sdf"
run_cmd(f"obabel -isdf {boltz_pose_sdf} -osdf -O {boltz_pose_ph_sdf} -p {PH}", ligprep_log)

print("Scoring Boltz pose...")
run_cmd(f"gnina --score_only --receptor {receptor_dock_pdb} --ligand {boltz_pose_ph_sdf}", gnina_score_log)

# ------------------------------------------------------------------------
# PARSE RESULTS (troubleshooting)
# ------------------------------------------------------------------------
def _prop(m, key):
    return float(m.GetProp(key)) if m and m.HasProp(key) else float('nan')

rows = []
sup = Chem.SDMolSupplier(str(dock_sdf), removeHs=False)
for i, m in enumerate(sup):
    if m:
        rows.append({
            "pose_id": i+1,
            "CNNscore": _prop(m, "CNNscore"),
            "CNNaffinity": _prop(m, "CNNaffinity"),
            "minimizedAffinity": _prop(m, "minimizedAffinity")
        })
df_dock = pd.DataFrame(rows)
df_dock.to_csv(summary_csv, index=False)

txt = Path(gnina_score_log).read_text(errors="ignore")
def get_score(name):
    # Matches 'Name: 123' or 'Name 123'
    m = re.findall(rf"{name}[:\s]+([-+]?[0-9]*\.?[0-9]+)", txt)
    return float(m[-1]) if m else float('nan')

# fallback for vina affinity
vina_score = get_score("minimizedAffinity")
if pd.isna(vina_score):
    vina_score = get_score("Affinity")

boltz_scores = {
    "CNNscore": get_score("CNNscore"),
    "CNNaffinity": get_score("CNNaffinity"),
    "minimizedAffinity": vina_score
}
(RUN_DIR / "analysis" / "boltz_pose_gnina_scores.json").write_text(json.dumps(boltz_scores, indent=2))

paths = RUN_DIR / "state" / "paths.json"
if paths.exists():
    d = json.loads(paths.read_text())
    d.update({
        "receptor_pdb_ph": str(receptor_dock_pdb),
        "boltz_pose_sdf_ph": str(boltz_pose_ph_sdf)
    })
    paths.write_text(json.dumps(d, indent=2))

print(f"\n   Done. Summary: {summary_csv}")
print(f"   Boltz Scores (Raw): {boltz_scores}")

print("\n--- Top 3 GNINA Poses ---")
if not df_dock.empty:
    display(df_dock.sort_values("CNNaffinity", ascending=False).head(3))
else:
    print("No docked poses found.")

In [None]:
#@title **F2) Plot GNINA scores**
#@markdown Plot a metric from the saved GNINA summary and overlay the Boltz score.

PLOT_METRIC = "minimizedAffinity"  #@param ["CNNscore","CNNaffinity","minimizedAffinity"]
TOP_N = 10  #@param {type:"integer"}

import pandas as pd
import matplotlib.pyplot as plt
import json
import numpy as np
import re
from pathlib import Path

summary_csv = RUN_DIR / "analysis" / "gnina_pose_summary.csv"
boltz_scores_json = RUN_DIR / "analysis" / "boltz_pose_gnina_scores.json"
plot_png = RUN_DIR / "analysis" / f"gnina_{PLOT_METRIC}_plot.png"

df = pd.read_csv(summary_csv)
boltz_scores = json.loads(Path(boltz_scores_json).read_text())

# --- NaN fix ---
gnina_score_log = RUN_DIR / "logs" / "05_gnina_score_only.log"
if any(pd.isna(v) for v in boltz_scores.values()) and gnina_score_log.exists():
    txt = gnina_score_log.read_text(errors="ignore")
    def find_val(name):
        # Matches "Name 1.23" OR "Name: 1.23"
        m = re.findall(rf"{name}[:\s]+([-0-9.]+)", txt)
        return float(m[-1]) if m else float("nan")

    # update scores
    s = find_val("CNNscore")
    if s == s: boltz_scores["CNNscore"] = s

    s = find_val("CNNaffinity")
    if s == s: boltz_scores["CNNaffinity"] = s

    s = find_val("minimizedAffinity")
    if s != s: s = find_val("Affinity")
    if s == s: boltz_scores["minimizedAffinity"] = s

    print(f"Recovered Boltz scores: {boltz_scores}")

ascending = (PLOT_METRIC == "minimizedAffinity")
dff = df.sort_values(PLOT_METRIC, ascending=ascending).head(TOP_N).copy()

x = np.arange(len(dff))
y = dff[PLOT_METRIC].astype(float).values

# --- viz ---
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.weight'] = 'semibold'
plt.rcParams['axes.labelweight'] = 'semibold'

fig, ax = plt.subplots(figsize=(8.0, 4.0))

ax.plot(x, y, marker='o', linestyle='None',
        color='#367588', markersize=8, label='GNINA poses')

b = boltz_scores.get(PLOT_METRIC, float("nan"))
if b == b:  # not NaN
    # grey (#808080)
    ax.axhline(b, linestyle="--", linewidth=1, color="#808080", label=f"Boltz")
else:
    print(f"(Note: Boltz baseline for {PLOT_METRIC} is NaN, so dashed line is hidden)")

ax.set_xticks(x)
ax.set_xticklabels(dff["pose_id"].astype(int).values)

ax.set_xlabel("Pose ID")

if PLOT_METRIC == "minimizedAffinity":
    ax.set_ylabel("minimizedAffinity (kcal/mol)")
elif PLOT_METRIC == "CNNaffinity":
    ax.set_ylabel("CNNaffinity")
else:
    ax.set_ylabel("CNNscore")

# spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# legend
ax.legend(bbox_to_anchor=(1.02, 0.9), loc='upper left', frameon=False)

plt.tight_layout()
plt.savefig(plot_png, dpi=400)
plt.show()

print(f"Saved plot: {plot_png}")

## **G. Quick visual comparison**

In [None]:
#@title **G1) Compare Boltz vs GNINA poses (animated)**
#@markdown Animate GNINA poses (red) over the Boltz pose (green).

MAX_MODES = 10  #@param {type:"integer"}
INTERVAL_MS = 600  #@param {type:"integer"}

from pathlib import Path
from rdkit import Chem
import py3Dmol

receptor_pdb   = RUN_DIR / "analysis" / "receptor.pdb"
pocket_pdb     = RUN_DIR / "analysis" / "pocket_residues.pdb"
boltz_pose_sdf = RUN_DIR / "analysis" / "boltz_ligand_pose.sdf"
docked_sdf     = RUN_DIR / "gnina" / "gnina_docked.sdf"

PINE_GREEN = "#01796f"
ROSE_RED   = "#ff033e"

# frame limit
poses = [m for m in Chem.SDMolSupplier(str(docked_sdf), removeHs=False) if m is not None]
if not poses:
    raise RuntimeError("No GNINA poses found. Run F1 first.")

n = max(1, min(int(MAX_MODES), len(poses)))
tmp_sdf = RUN_DIR / "analysis" / f"gnina_first_{n}_poses.sdf"
w = Chem.SDWriter(str(tmp_sdf))
for m in poses[:n]:
    w.write(m)
w.close()

view = py3Dmol.view(width=900, height=520)
view.setViewStyle({'style':'outline','color':'black','width':0.05})

# --- receptor viz ---
view.addModel(open(receptor_pdb).read(), "pdb")
view.setStyle({'model': 0}, {"cartoon": {"color": "white"}})

# --- surface viz ---
view.addSurface(py3Dmol.VDW, {'opacity': 0.6, 'color': 'silver'}, {'model': 0})

current_idx = 0

# --- pocket optional ---
if pocket_pdb.exists():
    view.addModel(open(pocket_pdb).read(), "pdb")
    current_idx += 1
    view.setStyle({'model': current_idx}, {"stick": {"color": "grey", "radius": 0.25}})

# --- boltz pose viz ---
view.addModel(open(boltz_pose_sdf).read(), "sdf")
current_idx += 1
view.setStyle({'model': current_idx}, {"stick": {"color": PINE_GREEN, "radius": 0.28}})

# --- gnina poses viz ---
gnina_data = open(tmp_sdf).read()
view.addModelsAsFrames(gnina_data, "sdf")
current_idx += 1

view.setStyle({'model': current_idx}, {"stick": {"color": ROSE_RED, "radius": 0.22}})

view.zoomTo()
view.animate({'loop': 'forward', 'interval': int(INTERVAL_MS)})
view.show()

print(f"Animating {n} GNINA poses (rose red) over Boltz pose (pine green).")
print(f"Frames source: {tmp_sdf}")

## **H. QC + interactions**


In [None]:
#@title **H1) Run PoseBusters and ProLIF for QC and interaction comparison**

POSE_SELECTION = "minimizedAffinity"  #@param ["Manual", "CNNscore","CNNaffinity","minimizedAffinity"]
MANUAL_POSE_ID = 1  #@param {type:"integer"}
CONTACT_CUTOFF_A = 6.0

import sys
import json
import datetime, html, subprocess, shutil
from pathlib import Path
from IPython.display import display, HTML
import MDAnalysis as mda
import pandas as pd
import numpy as np

# --- setup & checks ---
chem_path = RUN_DIR / "state" / "chem.json"
PH = float(json.loads(chem_path.read_text()).get("pH", 7.4)) if chem_path.exists() else 7.4
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

receptor_pdb = RUN_DIR / "analysis" / "receptor.pdb"
boltz_pose_sdf = RUN_DIR / "analysis" / "boltz_ligand_pose.sdf"
dock_sdf = RUN_DIR / "gnina" / "gnina_docked.sdf"
summary_csv = RUN_DIR / "analysis" / "gnina_pose_summary.csv"

if not summary_csv.exists():
    raise RuntimeError("Missing GNINA summary. Run F1 first.")

df_summary = pd.read_csv(summary_csv)
if len(df_summary) == 0:
    raise RuntimeError("No GNINA poses found.")

# pose selection
if POSE_SELECTION == "Manual":
    pose_id = int(MANUAL_POSE_ID)
    print(f"Selected pose_id: {pose_id} (Manual)")
else:
    ascending = (POSE_SELECTION == "minimizedAffinity")
    pose_id = int(df_summary.sort_values(POSE_SELECTION, ascending=ascending).iloc[0]["pose_id"])
    print(f"Selected pose_id: {pose_id} (Best by {POSE_SELECTION})")

best_pose_sdf = RUN_DIR / "analysis" / f"gnina_best_pose_{pose_id}.sdf"

# pose extraction fxn
from rdkit import Chem
sup = Chem.SDMolSupplier(str(dock_sdf), removeHs=False)
m = sup[pose_id-1]
w = Chem.SDWriter(str(best_pose_sdf))
w.write(m); w.close()

# --- try posebusters CLI ---
print("\nRunning PoseBusters...")
for name, sdf_path in [("Boltz", boltz_pose_sdf), ("GNINA", best_pose_sdf)]:
    csv_out = RUN_DIR / "analysis" / f"posebusters_{name.lower()}_{ts}.csv"
    cmd = [sys.executable, "-m", "posebusters", str(sdf_path), "-p", str(receptor_pdb), "--outfmt", "csv", "--full-report"]
    try:
        p = subprocess.run(cmd, capture_output=True, text=True)
        csv_out.write_text(p.stdout)
        if p.returncode == 0 and len(p.stdout.strip()) > 0:
            print(f"  {name} Report saved.")
        else:
            print(f"  WARNING: {name} PoseBusters failed (empty output).")
    except Exception as e:
        print(f"  Error running PoseBusters for {name}: {e}")

# --- ensure right protonation ---
print(f"\nChecking protonation at (pH {PH})...")
receptor_h_pdb = RUN_DIR / "analysis" / "receptor_prolif_H.pdb"
boltz_h_sdf = RUN_DIR / "analysis" / "boltz_pose_prolif_H.sdf"
gnina_h_sdf = RUN_DIR / "analysis" / f"gnina_pose_{pose_id}_prolif_H.sdf"
hlog = RUN_DIR / "logs" / "07_hydrogens.log"

def run_obabel_h(inp, outp, ph):
    run_cmd(f"obabel {inp} -O {outp} -p {ph}", hlog)

run_obabel_h(receptor_pdb, receptor_h_pdb, PH)

# check for UNK resname
try:
    from Bio import PDB
    parser = PDB.PDBParser(QUIET=True)
    ref_st = parser.get_structure("ref", str(receptor_pdb))
    tgt_st = parser.get_structure("tgt", str(receptor_h_pdb))

    mapping = {(c.id, r.id[1], r.id[2]): r.resname for m in ref_st for c in m for r in c}
    fallback = {(r.id[1], r.id[2]): r.resname for m in ref_st for c in m for r in c}

    fixed = 0
    for m in tgt_st:
        for c in m:
            for r in c:
                val = mapping.get((c.id, r.id[1], r.id[2])) or fallback.get((r.id[1], r.id[2]))
                if val and r.resname != val:
                    r.resname = val
                    fixed += 1
    if fixed > 0:
        io = PDB.PDBIO()
        io.set_structure(tgt_st)
        io.save(str(receptor_h_pdb))
        #print(f"  [Patch] Restored {fixed} residue names.")
except Exception as e:
    #print(f"  [Patch] Warning: {e}")
    pass

run_obabel_h(boltz_pose_sdf, boltz_h_sdf, PH)
run_obabel_h(best_pose_sdf, gnina_h_sdf, PH)

# --- ProLIF fxns ---
print("\nRunning ProLIF (Interactions)...")
try:
    import prolif as plf
    from prolif.interactions import VdWContact

    # try loading
    prot_mol = Chem.MolFromPDBFile(str(receptor_h_pdb), removeHs=False, sanitize=False)
    Chem.SanitizeMol(prot_mol, Chem.SanitizeFlags.SANITIZE_ALL ^ Chem.SanitizeFlags.SANITIZE_PROPERTIES)
    prot = plf.Molecule.from_rdkit(prot_mol)

    # fingerprint params to try
    fp = plf.Fingerprint(interactions=[
        "Hydrophobic", "HBDonor", "HBAcceptor", "PiStacking",
        "Anionic", "Cationic", "CationPi", "PiCation",
        "VdWContact",
        "XBDonor", "XBAcceptor",
        "MetalDonor", "MetalAcceptor"
    ])
    lig_files = {"Boltz": boltz_h_sdf, "GNINA": gnina_h_sdf}
    html_outputs = []

    for name, fpath in lig_files.items():
        try:
            suppl = Chem.SDMolSupplier(str(fpath), removeHs=False)
            lig_mol = suppl[0]
            # fix Resname for plot
            for atom in lig_mol.GetAtoms():
                mi = Chem.AtomPDBResidueInfo()
                mi.SetResidueName("LIG")
                mi.SetResidueNumber(1)
                mi.SetIsHeteroAtom(True)
                atom.SetMonomerInfo(mi)

            lig = plf.Molecule.from_rdkit(lig_mol)

            # try run
            fp.run_from_iterable([lig], prot)
            df = fp.to_dataframe()

            csv_path = RUN_DIR / "analysis" / f"prolif_{name.lower()}_ifp.csv"
            html_path = RUN_DIR / "analysis" / f"prolif_{name.lower()}.html"

            if df.empty or (df.shape[1] == 0):
                print(f"  {name}: No interactions found (even VdW). Check geometry.")
                html_outputs.append(f"<div style='flex:1; border:1px solid red; padding:10px;'><h4>{name}</h4><p>No interactions detected.</p></div>")
            else:
                df.to_csv(csv_path)
                net = fp.plot_lignetwork(ligand_mol=lig, kind="frame", frame=0)
                net.save(str(html_path))
                src = html.escape(Path(html_path).read_text())
                html_outputs.append(f"""
                <div style="flex: 1; border: 1px solid #ccc; padding: 5px;">
                  <h4 style="text-align:center;">{name} Pose</h4>
                  <iframe srcdoc="{src}" width="100%" height="600px" style="border:none;"></iframe>
                </div>""")
                print(f"  {name}: Success! ({int(df.sum(axis=1).sum())} interactions)")

        except Exception as e:
            print(f"  {name} failed: {e}")
            html_outputs.append(f"<div style='flex:1;'><h4>{name} Error</h4><p>{e}</p></div>")

    display(HTML(f'<div style="display: flex; flex-direction: row; gap: 20px; width: 100%;">{ "".join(html_outputs) }</div>'))

except Exception as e:
    print(f"CRITICAL ProLIF ERROR: {e}")

# --- debugging neighborhood check ---
#def count_contacts(lig_p):
#    try:
        #tmp = lig_p.parent / (lig_p.stem + "_tmp.pdb")
        #run_cmd(f"obabel -isdf {lig_p} -opdb -O {tmp}", hlog)
        #u_p = mda.Universe(str(receptor_pdb))
        #u_l = mda.Universe(str(tmp))
        #d = mda.lib.distances.distance_array(u_l.atoms.positions, u_p.select_atoms("protein").positions)
        #return len(u_p.select_atoms("protein")[np.any(d < CONTACT_CUTOFF_A, axis=0)].residues)
    #except: return -1

#print(f"\nResidues within {CONTACT_CUTOFF_A} Å:")
#print(f"  Boltz : {count_contacts(boltz_pose_sdf)}")
#print(f"  GNINA : {count_contacts(best_pose_sdf)}")

## **I. Export results**

Compress the project folder (structures, SDFs, CSVs, logs).


In [None]:
#@title **I1) Export project folder**
#@markdown Choose what to export, then choose compression format.

EXPORT_SCOPE = "full_project" #@param ["current_run", "full_project"]
COMPRESSION_FORMAT = "zip" #@param ["zip", "tar", "gztar"]

import shutil
from pathlib import Path

# --- project-level name ---
try:
    PROJECT_NAME
except NameError:
    # best-effort fallback
    PROJECT_NAME = Path(PROJECT_ROOT).name if "PROJECT_ROOT" in globals() else "cloudbind_project"

ext_map = {"zip": "zip", "tar": "tar", "gztar": "tar.gz"}
ext = ext_map.get(COMPRESSION_FORMAT, "zip")

if EXPORT_SCOPE == "current_run":
    root_dir = RUN_DIR.parent
    base_dir = RUN_DIR.name
    out_base = Path("/content") / f"{RUN_NAME}_run"
    what = f"run folder: {RUN_DIR}"
else:
    root_dir = PROJECT_ROOT
    base_dir = None
    out_base = Path("/content") / f"{PROJECT_NAME}_full_project"
    what = f"full project: {PROJECT_ROOT}"

archive_path = Path(f"{out_base}.{ext}")

if archive_path.exists():
    archive_path.unlink()

print(f"Compressing {what} -> {archive_path} ...")

if base_dir is None:
    shutil.make_archive(str(out_base), COMPRESSION_FORMAT, root_dir=str(root_dir))
else:
    shutil.make_archive(str(out_base), COMPRESSION_FORMAT, root_dir=str(root_dir), base_dir=str(base_dir))

if archive_path.exists():
    sz_mb = archive_path.stat().st_size / (1024 * 1024)
    print(f"Created: {archive_path}")
    print(f"Size   : {sz_mb:.2f} MB")
else:
    print("Error: Archive creation failed.")

# download prompt
try:
    from google.colab import files
    files.download(str(archive_path))
except Exception:
    print("(Download is available in Colab only. Check the Files sidebar if the prompt doesn't appear.)")


## **M. Manifest (Optional)**


In [None]:
#@title **M1) Write a manifest for reproducibility and debugging**

import json, platform, subprocess, datetime
from pathlib import Path

manifest = {
    "timestamp": datetime.datetime.now().isoformat(),
    "project_root": str(PROJECT_ROOT),
    "run_name": RUN_NAME if "RUN_NAME" in globals() else None,
    "run_dir": str(RUN_DIR) if "RUN_DIR" in globals() else None,
    "tools_dir": str(TOOLS_DIR),
    "versions": {
        "boltz": BOLTZ_VERSION,
        "p2rank": P2RANK_VERSION,
        "gnina": GNINA_VERSION,
        "python": platform.python_version(),
    },
}

# hardware info
try:
    import torch
    if torch.cuda.is_available():
        manifest["gpu"] = torch.cuda.get_device_name(0)
except Exception:
    pass

def tail_text(p: Path, n=40):
    if not p.exists():
        return None
    lines = p.read_text(errors="ignore").splitlines()
    return "\n".join(lines[-n:])

if "RUN_DIR" in globals():
    log_dir = RUN_DIR / "logs"
    manifest["log_tails"] = {
        p.name: tail_text(p, n=40) for p in sorted(log_dir.glob("*.log"))
    }
    # expected files
    expected = [
        RUN_DIR / "analysis" / "boltz_complex.pdb",
        RUN_DIR / "analysis" / "receptor.pdb",
        RUN_DIR / "analysis" / "boltz_ligand_pose.pdb",
        RUN_DIR / "analysis" / "boltz_ligand_pose.sdf",
        RUN_DIR / "analysis" / "boltz_affinity.json",
        RUN_DIR / "analysis" / "gnina_modes.csv",
    ]
    manifest["artifact_exists"] = {str(p): p.exists() for p in expected}

out_path = (RUN_DIR / "analysis" / f"manifest_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json") if "RUN_DIR" in globals() else (PROJECT_ROOT / "manifest.json")
Path(out_path).write_text(json.dumps(manifest, indent=2))
print(f"Wrote: {out_path}")
