<a href="https://colab.research.google.com/github/pablo-arantes/Cloud-Bind/blob/main/OpenFold3%2BGNINA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Hi there!**

This notebook is part of **Cloud-Bind** and provides a workflow for **protein–ligand structure prediction + docking** in Google Colab.

The main goal is to show how cloud compute can be used to run an accessible structure-based pipeline for **co-folding** and **re-docking**.
___

**Hardware & Runtime:**
This notebook requires an A100 High-RAM runtime. It requires ≥24GB RAM and CUDA ≥8.0.
For T4/L4 compatibility, try the [Boltz notebook](https://colab.research.google.com/github/pablo-arantes/Cloud-Bind/blob/main/Boltz%2BGNINA.ipynb).

**Note:** This is an instructional pipeline, not a standard docking protocol.
___

**Bugs**
- If you encounter any bugs, please report the issue to https://github.com/pablo-arantes/Cloud-Bind/issues

**Acknowledgments**
- **Core Pipeline**: We thank the developers of [OpenFold3](https://github.com/aqlaboratory/openfold-3) (structure prediction), [ColabFold](https://github.com/sokrypton/ColabFold) (MSA server), [P2Rank](https://github.com/rdk/p2rank) (pocket detection), and [GNINA](https://github.com/gnina/gnina) (docking/scoring).
- **Cheminformatics & QC**: We thank the teams behind [Open Babel](https://github.com/openbabel/openbabel) (conversion), [RDKit](https://github.com/rdkit/rdkit) (ligand handling), [Dimorphite-DL](https://github.com/durrantlab/dimorphite_dl) (protonation), [PoseBusters](https://github.com/maabuu/posebusters) (QC), [ProLIF](https://github.com/chemosim-lab/ProLIF) (interactions), [MDAnalysis](https://github.com/MDAnalysis/mdanalysis) (parsing), and [Biopython](https://github.com/biopython/biopython).
- **Visualization & Runtime**: Credit to [py3Dmol](https://3dmol.csb.pitt.edu/) for visualization and [PyTorch](https://github.com/pytorch/pytorch) for runtime dependencies.
- **Cloud-Bind Team**: **Pablo R. Arantes** ([@pablitoarantes](https://twitter.com/pablitoarantes)), **Conrado Pedebos** ([@ConradoPedebos](https://twitter.com/ConradoPedebos)), **Rodrigo Ligabue-Braun** ([@ligabue_braun](https://twitter.com/ligabue_braun)), **Davidt da Silva Tarouco**, and **Saul J. Flores** ([@saulfloresjr](https://www.linkedin.com/in/saulfloresjr/)).

- For related notebooks see: [Cloud-Bind](https://github.com/pablo-arantes/Cloud-Bind)


## **A. Runtime & project folder**


In [None]:
#@title **A1) Initialize project folders**

from pathlib import Path
import os, json, datetime

# ----------------------------
# Project naming
# ----------------------------
PROJECT_NAME = ""  #@param {type:"string"}
PROJECT_ROOT = Path(f"/content/{PROJECT_NAME}")

TOOLS_DIR = Path("/content/cloudbind_tools")

RUNS_DIR = PROJECT_ROOT / "runs"

INSTALL_LOGS_DIR = PROJECT_ROOT / "logs"   # installs

# ----------------------------
# Create folders
# ----------------------------
for p in [PROJECT_ROOT, TOOLS_DIR, RUNS_DIR, INSTALL_LOGS_DIR]:
    p.mkdir(parents=True, exist_ok=True)

os.environ["PATH"] = f"{TOOLS_DIR}:{os.environ.get('PATH','')}"

BOOTSTRAP_PATH = Path("/content/.cloudbind_bootstrap.json")
bootstrap = {
    "created_at": datetime.datetime.now().isoformat(),
    "project_name": PROJECT_NAME,
    "project_root": str(PROJECT_ROOT),
    "runs_dir": str(RUNS_DIR),
    "install_logs_dir": str(INSTALL_LOGS_DIR),
    "tools_dir": str(TOOLS_DIR),
}
BOOTSTRAP_PATH.write_text(json.dumps(bootstrap, indent=2))

print(f"Project root : {PROJECT_ROOT}")
print(f"Runs dir     : {RUNS_DIR}")
print(f"Tools dir    : {TOOLS_DIR}  (excluded from export zip)")
print(f"Install logs : {INSTALL_LOGS_DIR}")
print(f"Bootstrap    : {BOOTSTRAP_PATH}")


## **B. Install Dependencies**


In [None]:
#@title **B1) Install OpenFold3, P2Rank, GNINA, and analysis tools**
#@markdown Once this cell triggers a runtime restart, **run it again once** after the restart to finish installation.

import os, subprocess, json, sys
from pathlib import Path

# rehydrate
BOOTSTRAP_PATH = Path("/content/.cloudbind_bootstrap.json")
if "INSTALL_LOGS_DIR" not in globals() and BOOTSTRAP_PATH.exists():
    _boot = json.loads(BOOTSTRAP_PATH.read_text())
    INSTALL_LOGS_DIR = Path(_boot.get("install_logs_dir", "/content/cloudbind_project/logs"))
    INSTALL_LOGS_DIR.mkdir(parents=True, exist_ok=True)

# --------- pinned versions (edit if needed) ----------
OPENFOLD3_VERSION = "0.3.1"
P2RANK_VERSION    = "2.5.1"
GNINA_VERSION     = "1.3.2"

TOOLS_DIR = Path("/content/cloudbind_tools")
TOOLS_DIR.mkdir(parents=True, exist_ok=True)
os.environ["PATH"] = f"{TOOLS_DIR}:{os.environ.get('PATH','')}"

install_log = INSTALL_LOGS_DIR / "01_install.log"
stamp_path  = TOOLS_DIR / f"install_stamp_of3{OPENFOLD3_VERSION}_p2rank{P2RANK_VERSION}_gnina{GNINA_VERSION}.json"
post_restart_flag = TOOLS_DIR / ".post_install_restart_done"

def _tail(path: Path, n=80):
    if not path.exists():
        return ""
    lines = path.read_text(errors="ignore").splitlines()
    return "\n".join(lines[-n:])

def run_cmd(cmd: str, log_path: Path, cwd: Path | None = None):
    log_path.parent.mkdir(parents=True, exist_ok=True)
    with log_path.open("a") as f:
        f.write(f"\n\n$ {cmd}\n")
        p = subprocess.run(cmd, shell=True, cwd=str(cwd) if cwd else None,
                           stdout=f, stderr=subprocess.STDOUT, text=True)
    if p.returncode != 0:
        raise RuntimeError(f"Command failed: {cmd}\n\n--- tail({log_path}) ---\n{_tail(log_path)}")

def have_file(path: Path) -> bool:
    return path.exists() and path.is_file()

# ----------------------------
# Detect runtime
# ----------------------------
def detect_runtime():
    if os.environ.get("COLAB_TPU_ADDR"):
        return "TPU"
    try:
        import torch
        if torch.cuda.is_available():
            return f"GPU ({torch.cuda.get_device_name(0)})"
    except Exception:
        pass
    return "CPU"

runtime = detect_runtime()
print(f"Runtime detected: {runtime}")
if "CPU" in runtime or "TPU" in runtime:
    print("Note: OpenFold3 protein-ligand inference is strongly recommended on GPU for speed.")

# ----------------------------
# Skip if already installed
# ----------------------------
p2rank_dir    = TOOLS_DIR / f"p2rank_{P2RANK_VERSION}"
p2rank_prank  = p2rank_dir / "prank"
gnina_bin     = TOOLS_DIR / "gnina"
micromamba    = TOOLS_DIR / "micromamba"
mamba_root    = TOOLS_DIR / "mamba_root"
of3_env_dir   = mamba_root / "envs" / "of3"
run_openfold_sh = TOOLS_DIR / "run_openfold"
of3_cache_dir = TOOLS_DIR / "openfold3_cache"

os.environ["MAMBA_ROOT_PREFIX"] = str(mamba_root)

did_install = False
if have_file(gnina_bin) and have_file(p2rank_prank) and have_file(micromamba) and have_file(run_openfold_sh) and of3_env_dir.exists():
    print("Existing installation detected. Skipping re-install.")
else:
    did_install = True

    # --------- system deps ----------
    run_cmd("apt-get -qq update", install_log)
    run_cmd("apt-get -qq install -y openjdk-17-jre-headless wget unzip tar rsync curl", install_log)


    run_cmd("apt-get -qq install -y openbabel", install_log)

    # --------- python deps (base runtime) ----------
    run_cmd("python -m pip -q install py3Dmol prolif posebusters MDAnalysis rdkit dimorphite-dl biopython pyyaml", install_log)

    # --------- GNINA binary ----------
    run_cmd(f"wget -q -O {gnina_bin} https://github.com/gnina/gnina/releases/download/v{GNINA_VERSION}/gnina.{GNINA_VERSION}", install_log)
    run_cmd(f"chmod +x {gnina_bin}", install_log)

    # --------- P2Rank release ----------
    p2rank_tar = TOOLS_DIR / f"p2rank_{P2RANK_VERSION}.tar.gz"
    run_cmd(f"wget -q -O {p2rank_tar} https://github.com/rdk/p2rank/releases/download/{P2RANK_VERSION}/p2rank_{P2RANK_VERSION}.tar.gz", install_log)
    run_cmd(f"tar -xzf {p2rank_tar} -C {TOOLS_DIR}", install_log)

    if not p2rank_prank.exists():
        raise RuntimeError(f"P2Rank installed but prank entrypoint not found at: {p2rank_prank}\nCheck install log: {install_log}")

    # --------- Micromamba ----------
    if not micromamba.exists():
        run_cmd(
            f"curl -Ls https://micro.mamba.pm/api/micromamba/linux-64/latest | "
            f"tar -xvj -C {TOOLS_DIR} --strip-components=1 bin/micromamba",
            install_log
        )
        run_cmd(f"chmod +x {micromamba}", install_log)

    os.environ["DS_BUILD_OPS"] = "0"
    os.environ["TRITON_DISABLE"] = "1"

    run_cmd(f"{micromamba} create -y -n of3 -c conda-forge python=3.11 pip", install_log)
    run_cmd(f"{micromamba} run -n of3 micromamba install -y -c bioconda -c conda-forge kalign2 aria2", install_log)
    run_cmd(f"ln -sf {micromamba} {of3_env_dir}/bin/conda", install_log)

    run_cmd(f'DS_BUILD_OPS=0 TRITON_DISABLE=1 {micromamba} run -n of3 python -m pip -q install "openfold3[cuequivariance]=={OPENFOLD3_VERSION}"', install_log)

    of3_cache_dir.mkdir(parents=True, exist_ok=True)
    setup_cmd = (
        f'{micromamba} run -n of3 bash -lc '
        f'"export DS_BUILD_OPS=0; export TRITON_DISABLE=1; '
        f'mkdir -p {of3_cache_dir}; '
        f'{{ printf \\"{of3_cache_dir}\\n{of3_cache_dir}\\n\\"; yes yes; }} | setup_openfold --accept_license"'
    )
    run_cmd(setup_cmd, install_log)

    wrapper = f"""#!/usr/bin/env bash
set -euo pipefail

export MAMBA_ROOT_PREFIX="{mamba_root}"
export DS_BUILD_OPS=0
export TRITON_DISABLE=1
# If cuEquivariance kernels are installed, force fallback to PyTorch for triangle attention.
export CUEQ_TRIATTN_FALLBACK_THRESHOLD=1000000

# Keep cache inside project tools dir (avoid /root surprises)
export XDG_CACHE_HOME="{of3_cache_dir}"

exec "{micromamba}" run -n of3 run_openfold "$@"
"""
    run_openfold_sh.write_text(wrapper)
    run_cmd(f"chmod +x {run_openfold_sh}", install_log)

    stamp = {
        "timestamp": __import__("datetime").datetime.now().isoformat(),
        "openfold3": OPENFOLD3_VERSION,
        "p2rank": P2RANK_VERSION,
        "gnina": GNINA_VERSION,
        "runtime": runtime,
        "tools_dir": str(TOOLS_DIR),
        "gnina_path": str(gnina_bin),
        "p2rank_dir": str(p2rank_dir),
        "p2rank_prank": str(p2rank_prank),
        "micromamba": str(micromamba),
        "mamba_root": str(mamba_root),
        "of3_env_dir": str(of3_env_dir),
        "run_openfold": str(run_openfold_sh),
        "openfold_cache": str(of3_cache_dir),
    }
    stamp_path.write_text(json.dumps(stamp, indent=2))

if BOOTSTRAP_PATH.exists():
    boot = json.loads(BOOTSTRAP_PATH.read_text())
else:
    boot = {}

boot.update({
    "openfold3_version": OPENFOLD3_VERSION,
    "p2rank_version": P2RANK_VERSION,
    "gnina_version": GNINA_VERSION,
    "tools_dir": str(TOOLS_DIR),
    "gnina_path": str(gnina_bin),
    "p2rank_dir": str(p2rank_dir),
    "p2rank_prank": str(p2rank_prank),
    "micromamba": str(micromamba),
    "mamba_root": str(mamba_root),
    "run_openfold": str(run_openfold_sh),
    "openfold_cache": str(of3_cache_dir),
    "install_stamp": str(stamp_path),
})
BOOTSTRAP_PATH.write_text(json.dumps(boot, indent=2))

print("Installed. Key tools:")
print(f"  OpenFold3 : {OPENFOLD3_VERSION}  (isolated micromamba env: of3)")
print(f"  gnina     : {GNINA_VERSION}  -> {gnina_bin}")
print(f"  p2rank    : {P2RANK_VERSION} -> {p2rank_prank}")
print(f"  run_openfold wrapper: {run_openfold_sh}")
print(f"  OpenFold cache dir  : {of3_cache_dir}")
print(f"Install log: {install_log}")
print(f"Bootstrap: {BOOTSTRAP_PATH}")

# ----------------------------
# post-install restart
# ----------------------------
if did_install and not post_restart_flag.exists():
    post_restart_flag.write_text("ok")
    print("\nColab: restarting runtime once to finalize binary wheels (prevents numpy/pandas ABI errors).")
    os.kill(os.getpid(), 9)


### **Shared helpers**

Logging and small display utilities used throughout the notebook.


In [None]:
#@title **Helpers**

import subprocess, warnings, json, os
from pathlib import Path

# ----------------------------
# Rehydrate bootstrap
# ----------------------------
BOOTSTRAP_PATH = Path("/content/.cloudbind_bootstrap.json")
if BOOTSTRAP_PATH.exists():
    _boot = json.loads(BOOTSTRAP_PATH.read_text())
    PROJECT_NAME = _boot.get("project_name", "cloudbind_project")
    PROJECT_ROOT = Path(_boot.get("project_root", f"/content/{PROJECT_NAME}"))
    RUNS_DIR = Path(_boot.get("runs_dir", str(PROJECT_ROOT / "runs")))
    INSTALL_LOGS_DIR = Path(_boot.get("install_logs_dir", str(PROJECT_ROOT / "logs")))
    TOOLS_DIR = Path(_boot.get("tools_dir", "/content/cloudbind_tools"))

    OPENFOLD3_VERSION = _boot.get("openfold3_version")
    P2RANK_VERSION    = _boot.get("p2rank_version")
    GNINA_VERSION     = _boot.get("gnina_version")

    GNINA_PATH   = Path(_boot.get("gnina_path", str(TOOLS_DIR / "gnina")))
    P2RANK_PRANK = Path(_boot.get("p2rank_prank", str(TOOLS_DIR / f"p2rank_{P2RANK_VERSION}" / "prank")))
    RUN_OPENFOLD = Path(_boot.get("run_openfold", str(TOOLS_DIR / "run_openfold")))
    OF3_CACHE    = Path(_boot.get("openfold_cache", str(TOOLS_DIR / "openfold3_cache")))
else:
    raise FileNotFoundError("Bootstrap file not found. Run A1 once to initialize the project.")

os.environ["PATH"] = f"{TOOLS_DIR}:{os.environ.get('PATH','')}"

if not (OPENFOLD3_VERSION and P2RANK_VERSION and GNINA_VERSION):
    stamp_files = sorted(TOOLS_DIR.glob("install_stamp_*.json"))
    if stamp_files:
        stamp = json.loads(stamp_files[-1].read_text())
        OPENFOLD3_VERSION = OPENFOLD3_VERSION or stamp.get("openfold3")
        P2RANK_VERSION    = P2RANK_VERSION    or stamp.get("p2rank")
        GNINA_VERSION     = GNINA_VERSION     or stamp.get("gnina")

try:
    import pandas as pd
except ValueError as e:
    if "numpy.dtype size changed" in str(e):
        print("="*60)
        print("BINARY INCOMPATIBILITY DETECTED (numpy/pandas)")
        print("Fix: Runtime -> Restart session, then run Helpers and continue.")
        print("="*60)
    raise e

# ----------------------------
# logging control
# ----------------------------
warnings.filterwarnings("ignore", message=r"netCDF4 is not available.*")
warnings.filterwarnings("ignore", message=r"Unit cell dimensions not found.*")
warnings.filterwarnings("ignore", message=r".*CRYST1 record.*")
warnings.filterwarnings("ignore", message=r"Found no information for attr: 'formalcharges'.*")
warnings.filterwarnings("ignore", category=DeprecationWarning, message=r"MDAnalysis\.topology\.tables.*")

try:
    from rdkit import RDLogger
    RDLogger.DisableLog("rdApp.warning")
except Exception:
    pass

import logging
logging.getLogger("MDAnalysis").setLevel(logging.ERROR)
logging.getLogger("MDAnalysis.coordinates.AMBER").setLevel(logging.ERROR)

def tail(path: Path, n=60) -> str:
    if not path.exists():
        return ""
    lines = path.read_text(errors="ignore").splitlines()
    return "\n".join(lines[-n:])

def run_cmd(cmd: str, log_path: Path, cwd: Path | None = None, stdin_text: str | None = None):
    """Run a shell command quietly; append stdout/stderr to log_path."""
    log_path.parent.mkdir(parents=True, exist_ok=True)
    with log_path.open("a") as f:
        f.write(f"\n\n$ {cmd}\n")
        p = subprocess.run(
            cmd, shell=True,
            cwd=str(cwd) if cwd else None,
            input=stdin_text,
            stdout=f, stderr=subprocess.STDOUT, text=True
        )
    if p.returncode != 0:
        raise RuntimeError(f"Command failed: {cmd}\n\n--- tail({log_path}) ---\n{tail(log_path)}")

def show_table(df: pd.DataFrame, max_rows=10):
    """Display a compact table without spamming the notebook."""
    from IPython.display import display
    display(df.head(max_rows))


## **C. Inputs**

In [None]:
#@title **C1) Inputs + OpenFold3 preset**
#@markdown Required inputs.

RUN_NAME = ""  #@param {type:"string"}
PROTEIN_SEQUENCE = ""  #@param {type:"string"}
LIGAND_SMILES = ""  #@param {type:"string"}

#@markdown **OpenFold3 preset**
#@markdown - Auto: chooses a stable setting based on detected hardware
#@markdown - Standard: default settings
#@markdown - LowMem: conservative memory settings

OF3_PRESET = "Auto"  #@param ["Auto","Standard","LowMem"]

#@markdown **Sampling**
NUM_MODEL_SEEDS = 1  #@param {type:"integer"}

from pathlib import Path
import json

RUN_DIR = RUNS_DIR / RUN_NAME
for sub in ["inputs","of3","p2rank","gnina","analysis","viz","logs","state","export"]:
    (RUN_DIR / sub).mkdir(parents=True, exist_ok=True)

(RUN_DIR / "inputs" / "protein.fasta").write_text(">protein\n" + PROTEIN_SEQUENCE.strip() + "\n")
(RUN_DIR / "inputs" / "ligand.smi").write_text(LIGAND_SMILES.strip() + "\n")

settings = {
    "run_name": RUN_NAME,
    "openfold3_preset": OF3_PRESET,
    "num_model_seeds": int(NUM_MODEL_SEEDS),
    "chain_ids": {"protein": "A", "ligand": "L"},
    "msa_server": True,
}
(RUN_DIR / "state" / "run_settings.json").write_text(json.dumps(settings, indent=2))

print(f"Run folder: {RUN_DIR}")


## **D. OpenFold3: predict the complex**


In [None]:
#@title **D1) Run OpenFold3 to predict the complex**

from pathlib import Path
import json, os
import yaml

seq = PROTEIN_SEQUENCE.strip().replace(" ", "").replace("\n", "")
smiles = LIGAND_SMILES.strip()

if not RUN_NAME:
    raise ValueError("Please set RUN_NAME in C1.")
if not seq:
    raise ValueError("Please set PROTEIN_SEQUENCE in C1.")
if not smiles:
    raise ValueError("Please set LIGAND_SMILES in C1.")

of3_dir = RUN_DIR / "of3"
of3_dir.mkdir(parents=True, exist_ok=True)

# ----------------------------
# Hardware-aware defaults
# ----------------------------
def detect_gpu():
    """Return basic GPU info from the *base* Colab runtime (not the micromamba env)."""
    try:
        import torch
        if torch.cuda.is_available():
            props = torch.cuda.get_device_properties(0)
            name = props.name
            cap = (props.major, props.minor)
            mem_gb = props.total_memory / (1024**3)
            bf16_ok = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)())
            return {"name": name, "cap": cap, "mem_gb": mem_gb, "bf16": bf16_ok}
    except Exception:
        pass
    return None

gpu = detect_gpu()

if gpu:
    major, minor = gpu["cap"]
    mem_gb = gpu["mem_gb"]
    if (major < 8) or (mem_gb < 24.0):
        raise RuntimeError(
            "OpenFold3-preview is not supported on this Colab runtime.\n"
            f"Detected GPU: {gpu['name']} (CC {major}.{minor}, {mem_gb:.1f} GB VRAM).\n\n"
            "OpenFold3-preview expects: CC >= 8.0 and >= 24 GB VRAM (L4/A10 minimum; A100/H100 recommended).\n"
            "Fix: Runtime -> Change runtime type -> GPU -> pick L4/A100 (Colab Pro/Pro+), or run the Boltz notebook on T4/CPU."
        )

if gpu:
    ACCELERATOR = "gpu"
    major, minor = gpu["cap"]
    mem_gb = gpu["mem_gb"]
    name_l = gpu["name"].lower()

    PRECISION = "bf16-mixed" if gpu["bf16"] else "16-mixed"

    if OF3_PRESET == "Auto":
        OF3_PRESET_EFFECTIVE = "LowMem" if (mem_gb <= 24.5) else "Standard"
    else:
        OF3_PRESET_EFFECTIVE = OF3_PRESET
else:
    ACCELERATOR = "cpu"
    PRECISION = "32-true"
    OF3_PRESET_EFFECTIVE = "Standard" if OF3_PRESET == "Auto" else OF3_PRESET

print("OpenFold3 runtime:")
print(f"  accelerator: {ACCELERATOR}")
print(f"  precision  : {PRECISION}")
print(f"  preset     : {OF3_PRESET_EFFECTIVE}")
if gpu:
    print(f"  gpu        : {gpu['name']} (capability {gpu['cap'][0]}.{gpu['cap'][1]})")

# ----------------------------
# Query JSON (single complex)
# ----------------------------
query = {
    "queries": {
        "lig001": {
            "chains": [
                {"molecule_type": "protein", "chain_ids": ["A"], "sequence": seq},
                {"molecule_type": "ligand",  "chain_ids": ["L"], "smiles": smiles},
            ]
        }
    }
}
query_json = of3_dir / "query.json"
query_json.write_text(json.dumps(query, indent=2))

# ----------------------------
# Find checkpoint
# ----------------------------
def find_ckpt():
    candidates = []
    for root in [
        Path.home() / ".openfold3",
        OF3_CACHE,
        OF3_CACHE / ".openfold3",
        Path("/root/.openfold3"),
    ]:
        if root.exists():
            candidates += list(root.glob("*.pt"))
            candidates += list(root.glob("**/*.pt"))
    seen=set()
    uniq=[]
    for p in candidates:
        p = p.resolve()
        if p not in seen:
            seen.add(p); uniq.append(p)
    uniq.sort(key=lambda p: p.stat().st_mtime if p.exists() else 0, reverse=True)
    return uniq[0] if uniq else None

ckpt_path = find_ckpt()
if ckpt_path:
    print(f"Using checkpoint: {ckpt_path}")
else:
    print("Checkpoint not found explicitly; OpenFold3 will resolve defaults (may download on first run).")

# ----------------------------
# Runner YAML
# ----------------------------
presets = ["predict", "pae_enabled"]
if OF3_PRESET_EFFECTIVE == "LowMem":
    presets.append("low_mem")

runner = {
    "experiment_settings": {
        "mode": "predict",
        "output_dir": str(of3_dir),
        "log_dir": str(RUN_DIR / "logs"),
        "use_msa_server": True,
        "num_seeds": int(NUM_MODEL_SEEDS),
        "skip_existing": True,
        "cache_path": str(OF3_CACHE),
    },
    "pl_trainer_args": {
        "precision": PRECISION,
        "accelerator": ACCELERATOR,
        "devices": 1,
    },
    "model_update": {
        "presets": presets,
        "custom": {
            "settings": {
                "memory": {
                    "eval": {
                        "use_cueq_triangle_kernels": False,
                        "use_deepspeed_evo_attention": False,
                    }
                }
            }
        },
    },
    "output_writer_settings": {
        "structure_format": "pdb",
        "full_confidence_output_format": "json",
    },
    "msa_computation_settings": {
        "msa_file_format": "npz",
        "cleanup_msa_dir": False,
        "msa_output_directory": str(of3_dir / "msas"),
    },
}
if ckpt_path:
    runner["experiment_settings"]["inference_ckpt_path"] = str(ckpt_path)

runner_yaml = of3_dir / "runner.yml"
runner_yaml.write_text(yaml.safe_dump(runner, sort_keys=False))

# ----------------------------
# Run
# ----------------------------
log_path = RUN_DIR / "logs" / "04_openfold3.log"

cmd = f"{RUN_OPENFOLD} predict --query_json {query_json} --runner_yaml {runner_yaml}"

try:
    run_cmd(cmd, log_path, stdin_text="yes\n")
except Exception as e:
    print("="*80)
    print("OpenFold3 inference failed. Key debugging artifacts:")
    print(f"  Runner YAML : {runner_yaml}")
    print(f"  Log file    : {log_path}")
    print("-"*80)
    print("Common fixes:")
    print("  • Verify you are on L4/A10 (24 GB minimum) or A100/H100.")
    print("  • If you see cuEquivariance/triangle-attention errors, ensure use_cueq_triangle_kernels is False.")
    print("  • If you see NaN coordinate errors, try a different seed or (on large GPUs) retry FP32.")
    print("="*80)

    if gpu and (gpu.get("mem_gb", 0) >= 40.0) and (runner.get("precision") != "32-true"):
        print("Retrying once with precision=32-true (FP32) for stability (large GPU detected)...")
        runner["precision"] = "32-true"
        runner_yaml.write_text(yaml.safe_dump(runner, sort_keys=False))
        run_cmd(cmd, log_path, stdin_text="yes\n")
    else:
        raise

print(f"Done. Log: {log_path}")
print(f"Outputs in: {of3_dir}")


In [None]:
#@title **D2) Extract receptor + OpenFold3 pose**

from pathlib import Path
import json
import MDAnalysis as mda
from rdkit import Chem
from rdkit.Chem import AllChem

of3_dir = RUN_DIR / "of3"
analysis_dir = RUN_DIR / "analysis"
analysis_dir.mkdir(parents=True, exist_ok=True)

# ----------------------------
# Locate OpenFold3 output PDB
# ----------------------------
pdb_candidates = sorted(of3_dir.rglob("*_model.pdb"))
if not pdb_candidates:
    pdb_candidates = sorted(of3_dir.rglob("*.pdb"))

if not pdb_candidates:
    raise FileNotFoundError(f"No OpenFold3 PDB outputs found under: {of3_dir}\nRun D1 first.")

pred_pdb = pdb_candidates[0]
preferred = [p for p in pdb_candidates if ("seed_42" in p.as_posix() and "sample_1" in p.as_posix() and p.name.endswith("_model.pdb"))]
if preferred:
    pred_pdb = sorted(preferred)[0]

print(f"Using OpenFold3 structure: {pred_pdb}")

of3_complex_pdb = analysis_dir / "of3_complex.pdb"
of3_complex_pdb.write_text(pred_pdb.read_text())

# ----------------------------
# Split receptor/ligand by chain IDs
# ----------------------------
u = mda.Universe(str(of3_complex_pdb))

protein_sel = u.select_atoms("protein and chainID A")
ligand_sel  = u.select_atoms("not protein and not (resname HOH or resname WAT) and chainID L")

if protein_sel.n_atoms == 0:
    raise RuntimeError("No protein atoms found for chainID A. Check OpenFold3 output and chain IDs.")
if ligand_sel.n_atoms == 0:
    raise RuntimeError("No ligand atoms found for chainID L. Check OpenFold3 output and chain IDs.")

receptor_pdb = analysis_dir / "receptor.pdb"
protein_sel.write(receptor_pdb.as_posix())

ligand_pdb = analysis_dir / "of3_ligand_pose_raw.pdb"
ligand_sel.write(ligand_pdb.as_posix())

print(f"Receptor: {receptor_pdb}")
print(f"Ligand (raw coords): {ligand_pdb}")

# ----------------------------
# Template repair
# ----------------------------
smiles = LIGAND_SMILES.strip()
template = Chem.MolFromSmiles(smiles)
if template is None:
    raise ValueError("RDKit failed to parse LIGAND_SMILES. Please check it in C1.")

pdb_mol = Chem.MolFromPDBFile(str(ligand_pdb), removeHs=False, sanitize=False)
if pdb_mol is None:
    raise RuntimeError("RDKit failed to read ligand PDB coordinates from OpenFold3 output.")

template_noh = Chem.RemoveHs(template)
pdb_noh = Chem.RemoveHs(pdb_mol)

if template_noh.GetNumAtoms() != pdb_noh.GetNumAtoms():
    raise RuntimeError(
        f"Atom count mismatch (template {template_noh.GetNumAtoms()} vs PDB {pdb_noh.GetNumAtoms()}).\n"
        "Cannot safely template-repair. This usually means the SMILES does not match the predicted ligand.\n"
        "Fix: provide the exact ligand SMILES used in the OpenFold3 query."
    )

repaired = None
repair_notes = []

try:
    repaired = AllChem.AssignBondOrdersFromTemplate(template_noh, pdb_noh)
    if repaired is None:
        raise RuntimeError("AssignBondOrdersFromTemplate returned None")

    if repaired.GetNumConformers() == 0 and pdb_noh.GetNumConformers() > 0:
        repaired.AddConformer(pdb_noh.GetConformer(), assignId=True)

    Chem.SanitizeMol(repaired)
    repair_notes.append("AssignBondOrdersFromTemplate (AllChem): OK")

except Exception as e:
    repair_notes.append(f"AssignBondOrdersFromTemplate failed: {type(e).__name__}: {e}")
    repaired = None

if repaired is None:
    repair_notes.append("Fallback: coordinate injection by atom index (best-effort).")
    repaired = Chem.Mol(template_noh)

    if repaired.GetNumAtoms() != pdb_noh.GetNumAtoms():
        raise RuntimeError(
            "Template repair failed and atom counts do not match for fallback.\n"
            f"Template atoms: {repaired.GetNumAtoms()} vs PDB atoms: {pdb_noh.GetNumAtoms()}\n"
            "Try a different SMILES representation or confirm OpenFold3 ligand identity."
        )

    pdb_conf = pdb_noh.GetConformer()
    conf = Chem.Conformer(repaired.GetNumAtoms())
    for i in range(repaired.GetNumAtoms()):
        pos = pdb_conf.GetAtomPosition(i)
        conf.SetAtomPosition(i, pos)
    repaired.RemoveAllConformers()
    repaired.AddConformer(conf, assignId=True)

    try:
        Chem.SanitizeMol(repaired)
    except Exception as se:
        repair_notes.append(f"Fallback sanitize warning: {type(se).__name__}: {se}")

# ----------------------------
# Write repaired SDF
# ----------------------------
of3_pose_sdf = analysis_dir / "of3_ligand_pose.sdf"
with Chem.SDWriter(str(of3_pose_sdf)) as w:
    w.write(repaired)

print("Ligand repair notes:")
for n in repair_notes:
    print(f"  - {n}")
print(f"OpenFold3 pose (SDF): {of3_pose_sdf}")

# ----------------------------
# Update state paths
# ----------------------------
paths = {
    "of3_complex_pdb": str(of3_complex_pdb),
    "receptor_pdb": str(receptor_pdb),
    "of3_ligand_pose_raw_pdb": str(ligand_pdb),
    "of3_ligand_pose_sdf": str(of3_pose_sdf),
}
_ = (RUN_DIR / "state" / "paths.json").write_text(json.dumps(paths, indent=2))


In [None]:
#@title **D3) Visualize OpenFold3 complex**

import py3Dmol
from pathlib import Path

receptor_pdb = RUN_DIR / "analysis" / "receptor.pdb"
of3_pose_sdf = RUN_DIR / "analysis" / "of3_ligand_pose.sdf"

PINE_GREEN = "#01796f"  # pose color

view = py3Dmol.view(width=900, height=520)
view.setViewStyle({'style':'outline','color':'black','width':0.05})

# receptor viz
view.addModel(open(receptor_pdb).read(), "pdb")
view.setStyle({"cartoon": {"color": "white"}})

# surface viz
view.addSurface(py3Dmol.VDW, {'opacity': 0.6, 'color': 'silver'})

# ligand viz
view.addModel(open(of3_pose_sdf).read(), "sdf")
view.setStyle({"model": 1}, {"stick": {"color": PINE_GREEN, "radius": 0.25}})

view.zoomTo()
view.show()

print("OpenFold3 complex preview (protein + predicted ligand pose).")


## **E. P2Rank: pocket hypotheses**


In [None]:
#@title **E1) Run P2Rank + show top pockets**

import pandas as pd
from pathlib import Path
import shutil

receptor_pdb = RUN_DIR / "analysis" / "receptor.pdb"
p2rank_out = RUN_DIR / "p2rank"
p2rank_log = RUN_DIR / "logs" / "04_p2rank.log"

# --- call P2Rank ---
p2rank_version_dir = TOOLS_DIR / f"p2rank_{P2RANK_VERSION}"
if not p2rank_version_dir.exists():
    raise FileNotFoundError(f"P2Rank installation directory not found: {p2rank_version_dir}. Re-run B1.")

prank_executable_path = p2rank_version_dir / "prank"
if not prank_executable_path.exists():
    raise FileNotFoundError(f"P2Rank 'prank' executable not found at: {prank_executable_path}. Re-run B1.")

run_cmd(f"{prank_executable_path} predict -f {receptor_pdb} -o {p2rank_out} -c alphafold", p2rank_log, cwd=p2rank_version_dir)

csv_candidates = list(p2rank_out.rglob("*_predictions.csv"))
if not csv_candidates:
    raise FileNotFoundError(f"No *_predictions.csv found under: {p2rank_out}\nCheck logs at: {p2rank_log}")
pred_csv = sorted(csv_candidates)[0]

df = pd.read_csv(pred_csv)

def pick_col(substrs):
    for s in substrs:
        for c in df.columns:
            if s in c.lower():
                return c
    return None

col_score = pick_col(["score", "probability", "pocket_score"])
col_res   = pick_col(["residue", "residues"])
col_cx    = pick_col(["center_x", "centerx", "center x"])
col_cy    = pick_col(["center_y", "centery", "center y"])
col_cz    = pick_col(["center_z", "centerz", "center z"])

df_disp = df.copy()
df_disp.insert(0, "pocket_rank", range(1, len(df_disp)+1))

keep_cols = ["pocket_rank"]
for c in [col_score, col_cx, col_cy, col_cz, col_res]:
    if c and c not in keep_cols:
        keep_cols.append(c)

p2rank_csv_out = RUN_DIR / "analysis" / "p2rank_pockets.csv"
df_disp[keep_cols].to_csv(p2rank_csv_out, index=False)

print("Top pockets (P2Rank):")
show_table(df_disp[keep_cols], max_rows=5)
print(f"\nSaved: {p2rank_csv_out}")
print(f"P2Rank log: {p2rank_log}")


In [None]:
#@title **E2) Select pocket for GNINA (and visualize)**
#@markdown Select a pocket from the P2Rank table or manual selection.

POCKET_MODE = "choose_rank"  #@param ["auto_top1","choose_rank","manual_residues"]
POCKET_RANK = 1  #@param {type:"integer"}
#@markdown Manual residues examples: `A:123, A:125-130` or `A:45-60`

POCKET_RESIDUES = ""  #@param {type:"string"}

import pandas as pd
import json
import re
import MDAnalysis as mda
import py3Dmol
from pathlib import Path

receptor_pdb = RUN_DIR / "analysis" / "receptor.pdb"
of3_pose_sdf = RUN_DIR / "analysis" / "of3_ligand_pose.sdf"
pockets_csv = RUN_DIR / "analysis" / "p2rank_pockets.csv"

if not pockets_csv.exists():
    raise FileNotFoundError("Missing p2rank_pockets.csv. Run E1 first.")

df = pd.read_csv(pockets_csv)

res_col = None
for c in df.columns:
    if "res" in c.lower():
        res_col = c
        break

def parse_residue_spec(spec: str):
    """Parse strings like 'A:123, A:125-130'. Returns list of (chain, resid)."""
    out=[]
    spec = spec.replace(" ", "")
    if not spec:
        return out
    parts = [p for p in spec.split(",") if p]
    for p in parts:
        if ":" not in p:
            raise ValueError(f"Residue '{p}' must include a chain, e.g. A:123")
        chain, rng = p.split(":", 1)
        if "-" in rng:
            a,b = rng.split("-",1)
            for r in range(int(a), int(b)+1):
                out.append((chain, r))
        else:
            out.append((chain, int(rng)))
    return out

if POCKET_MODE == "auto_top1":
    pocket_rank = 1
elif POCKET_MODE == "choose_rank":
    pocket_rank = int(POCKET_RANK)
else:
    pocket_rank = None  # manual

if pocket_rank is not None:
    if pocket_rank < 1 or pocket_rank > len(df):
        raise ValueError(f"Pocket rank {pocket_rank} is out of range (1..{len(df)}).")
    residues_blob = str(df.loc[pocket_rank-1, res_col]) if res_col else ""
    tokens = re.split(r"[,\s]+", residues_blob.strip())
    residues=[]
    for t in tokens:
        if not t:
            continue
        if ":" in t:
            ch, r = t.split(":", 1)
        elif "_" in t:
            ch, r = t.split("_", 1)
        else:
            continue
        try:
            residues.append((ch, int(re.sub(r"\D","", r))))
        except ValueError:
            pass
else:
    residues = parse_residue_spec(POCKET_RESIDUES)

if not residues:
    raise RuntimeError("No pocket residues found/selected. Try another pocket rank or manual residues.")

u = mda.Universe(str(receptor_pdb))
sel_parts = [f"(chainID {ch} and resid {rid})" for ch,rid in residues]
sel = " or ".join(sel_parts)
pocket_atoms = u.select_atoms(sel)

pocket_pdb = RUN_DIR / "analysis" / "pocket_residues.pdb"
pocket_atoms.write(str(pocket_pdb))

sel_state = {
    "pocket_mode": POCKET_MODE,
    "pocket_rank": pocket_rank,
    "n_residues": len(residues),
    "pocket_pdb": str(pocket_pdb),
}
(RUN_DIR / "state" / "pocket.json").write_text(json.dumps(sel_state, indent=2))

print(f"Selected pocket residues: {len(residues)} residues")
print(f"Pocket residue PDB: {pocket_pdb}")

# --- viz (protein + OpenFold3 pose + selected pocket residues) ---
PINE_GREEN = "#01796f"  # OpenFold3 pose
ROSE_RED = "#ff033e"    # GNINA poses

view = py3Dmol.view(width=900, height=520)
view.setViewStyle({'style':'outline','color':'black','width':0.05})

view.addModel(open(receptor_pdb).read(), "pdb")
view.setStyle({"cartoon": {"color": "#fffafa"}})

view.addModel(open(pocket_pdb).read(), "pdb")
view.setStyle({"model": 1}, {"stick": {"color": "#cc9900", "radius": 0.25}})

view.addModel(open(of3_pose_sdf).read(), "sdf")
view.setStyle({"model": 2}, {"stick": {"color": PINE_GREEN, "radius": 0.28}})

view.zoomTo()
view.show()

print("Pocket preview: yellow = selected pocket residues, green = OpenFold3 ligand pose.")


## **F. GNINA: re-dock + re-score**


In [None]:
#@title **F1) GNINA docking + OpenFold3 pose rescoring**
#@markdown Dock into the **selected pocket**, then **re-score** the original OpenFold3 pose once for comparison.

#@markdown **Feel free to adjust the parameters below.**
GNINA_POSES = 10  #@param {type:"integer"}
GNINA_EXHAUSTIVENESS = 8  #@param {type:"integer"}
AUTOBOX_ADD = 4.0  #@param {type:"number"}
PH = 7.4  #@param {type:"number"}

import json, re
import pandas as pd
from pathlib import Path
from rdkit import Chem
from rdkit.Chem import AllChem

gnina_log = RUN_DIR / "logs" / "05_gnina.log"
gnina_score_log = RUN_DIR / "logs" / "05_gnina_score_only.log"
ligprep_log = RUN_DIR / "logs" / "05_ligprep.log"

receptor_pdb = RUN_DIR / "analysis" / "receptor.pdb"
of3_pose_sdf = RUN_DIR / "analysis" / "of3_ligand_pose.sdf"

if "pocket_pdb" not in globals():
    raise RuntimeError("Missing pocket selection. Run E2 first.")

pocket_pdb = Path(pocket_pdb)

GNINA_EXE = str(GNINA_PATH) if "GNINA_PATH" in globals() else "gnina"

# ------------------------------------------------------------------------
# Prepare receptor
# ------------------------------------------------------------------------
receptor_dock_pdb = RUN_DIR / "gnina" / "receptor_ph.pdb"
receptor_dock_pdb.parent.mkdir(parents=True, exist_ok=True)

print(f"Protonating receptor at pH {PH}...")
run_cmd(f"obabel -ipdb {receptor_pdb} -opdb -O {receptor_dock_pdb} -p {PH}", ligprep_log)

# Patch occasional UNK labels using BioPython
try:
    from Bio import PDB
    parser = PDB.PDBParser(QUIET=True)
    ref_structure = parser.get_structure("ref", str(receptor_pdb))
    tgt_structure = parser.get_structure("tgt", str(receptor_dock_pdb))

    res_name_map = {}
    for model in ref_structure:
        for chain in model:
            for residue in chain:
                key = (chain.id, residue.id[1], residue.id[2])
                res_name_map[key] = residue.resname

    count_fixed = 0
    for model in tgt_structure:
        for chain in model:
            for residue in chain:
                key = (chain.id, residue.id[1], residue.id[2])
                if key in res_name_map:
                    original_name = res_name_map[key]
                    if residue.resname != original_name:
                        residue.resname = original_name
                        count_fixed += 1

    if count_fixed:
        receptor_dock_pdb_fixed = RUN_DIR / "gnina" / "receptor_ph_fixed.pdb"
        io = PDB.PDBIO()
        io.set_structure(tgt_structure)
        io.save(str(receptor_dock_pdb_fixed))
        receptor_dock_pdb = receptor_dock_pdb_fixed

except Exception:
    pass

# ------------------------------------------------------------------------
# Prepare ligand for docking (protonate SMILES + embed + minimize)
# ------------------------------------------------------------------------
lig_in_sdf = RUN_DIR / "gnina" / "ligand_mmff.sdf"
smiles = LIGAND_SMILES.strip()
ph_smiles = smiles

try:
    from dimorphite_dl import protonate_smiles
    prot_list = protonate_smiles(smiles, ph_min=PH-0.5, ph_max=PH+0.5)
    if prot_list:
        ph_smiles = prot_list[0]
except Exception:
    pass

mol = Chem.MolFromSmiles(ph_smiles)
if mol is None:
    raise ValueError("Failed to parse ligand SMILES after protonation attempt.")

mol = Chem.AddHs(mol)
AllChem.EmbedMolecule(mol, AllChem.ETKDGv3())
try:
    AllChem.MMFFOptimizeMolecule(mol, mmffVariant="MMFF94s")
except Exception:
    AllChem.UFFOptimizeMolecule(mol)

with Chem.SDWriter(str(lig_in_sdf)) as w:
    w.write(mol)

# ------------------------------------------------------------------------
# Run GNINA docking
# ------------------------------------------------------------------------
dock_sdf = RUN_DIR / "gnina" / "gnina_docked.sdf"
summary_csv = RUN_DIR / "analysis" / "gnina_pose_summary.csv"

print("Running GNINA docking...")
gnina_cmd = (
    f"{GNINA_EXE} --receptor {receptor_dock_pdb} --ligand {lig_in_sdf} "
    f"--autobox_ligand {pocket_pdb} --autobox_add {AUTOBOX_ADD} "
    f"--num_modes {GNINA_POSES} --exhaustiveness {GNINA_EXHAUSTIVENESS} "
    f"--out {dock_sdf}"
)
run_cmd(gnina_cmd, gnina_log)

# ------------------------------------------------------------------------
# Score OpenFold3 pose (score_only)
# ------------------------------------------------------------------------
of3_pose_ph_sdf = RUN_DIR / "analysis" / "of3_ligand_pose_ph.sdf"
run_cmd(f"obabel -isdf {of3_pose_sdf} -osdf -O {of3_pose_ph_sdf} -p {PH}", ligprep_log)

print("Scoring OpenFold3 pose (gnina --score_only)...")
run_cmd(f"{GNINA_EXE} --score_only --receptor {receptor_dock_pdb} --ligand {of3_pose_ph_sdf}", gnina_score_log)

# ------------------------------------------------------------------------
# Parse results
# ------------------------------------------------------------------------
def _prop(m, key):
    return float(m.GetProp(key)) if (m and m.HasProp(key)) else float("nan")

rows = []
sup = Chem.SDMolSupplier(str(dock_sdf), removeHs=False)
for i, m in enumerate(sup):
    if m:
        rows.append({
            "pose_id": i + 1,
            "CNNscore": _prop(m, "CNNscore"),
            "CNNaffinity": _prop(m, "CNNaffinity"),
            "minimizedAffinity": _prop(m, "minimizedAffinity"),
        })

df_dock = pd.DataFrame(rows)
df_dock.to_csv(summary_csv, index=False)

txt = Path(gnina_score_log).read_text(errors="ignore")
def get_score(name):
    m = re.findall(rf"{name}[:\s]+([-+]?[0-9]*\.?[0-9]+)", txt)
    return float(m[-1]) if m else float("nan")

vina_score = get_score("minimizedAffinity")
if pd.isna(vina_score):
    vina_score = get_score("Affinity")

of3_scores = {
    "CNNscore": get_score("CNNscore"),
    "CNNaffinity": get_score("CNNaffinity"),
    "minimizedAffinity": vina_score,
}
(RUN_DIR / "analysis" / "of3_pose_gnina_scores.json").write_text(json.dumps(of3_scores, indent=2))

# Update
paths_p = RUN_DIR / "state" / "paths.json"
if paths_p.exists():
    d = json.loads(paths_p.read_text())
    d.update({
        "receptor_pdb_ph": str(receptor_dock_pdb),
        "of3_pose_sdf_ph": str(of3_pose_ph_sdf),
        "gnina_docked_sdf": str(dock_sdf),
        "gnina_pose_summary_csv": str(summary_csv),
        "of3_pose_gnina_scores_json": str(RUN_DIR / "analysis" / "of3_pose_gnina_scores.json"),
    })
    paths_p.write_text(json.dumps(d, indent=2))

print(f"Done.\n  GNINA summary: {summary_csv}\n  OpenFold3 score-only: {of3_scores}")


In [None]:
#@title **F2) Plot GNINA scores**
#@markdown Plot a metric from the saved GNINA summary and overlay the OpenFold3 score.

PLOT_METRIC = "minimizedAffinity"  #@param ["CNNscore","CNNaffinity","minimizedAffinity"]
TOP_N = 10  #@param {type:"integer"}

import pandas as pd
import matplotlib.pyplot as plt
import json
import numpy as np
import re
from pathlib import Path

summary_csv = RUN_DIR / "analysis" / "gnina_pose_summary.csv"
of3_scores_json = RUN_DIR / "analysis" / "of3_pose_gnina_scores.json"
plot_png = RUN_DIR / "analysis" / f"gnina_{PLOT_METRIC}_plot.png"

df = pd.read_csv(summary_csv)
of3_scores = json.loads(Path(of3_scores_json).read_text())

gnina_score_log = RUN_DIR / "logs" / "05_gnina_score_only.log"
if any(pd.isna(v) for v in of3_scores.values()) and gnina_score_log.exists():
    txt = gnina_score_log.read_text(errors="ignore")
    def find_val(name):
        # Matches "Name 1.23" OR "Name: 1.23"
        m = re.findall(rf"{name}[:\s]+([-0-9.]+)", txt)
        return float(m[-1]) if m else float("nan")

    s = find_val("CNNscore")
    if s == s: of3_scores["CNNscore"] = s

    s = find_val("CNNaffinity")
    if s == s: of3_scores["CNNaffinity"] = s

    s = find_val("minimizedAffinity")
    if s != s: s = find_val("Affinity")
    if s == s: of3_scores["minimizedAffinity"] = s

    print(f"Recovered OpenFold3 scores: {of3_scores}")

ascending = (PLOT_METRIC == "minimizedAffinity")
dff = df.sort_values(PLOT_METRIC, ascending=ascending).head(TOP_N).copy()

x = np.arange(len(dff))
y = dff[PLOT_METRIC].astype(float).values

# --- viz ---
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.weight'] = 'semibold'
plt.rcParams['axes.labelweight'] = 'semibold'

fig, ax = plt.subplots(figsize=(8.0, 4.0))

ax.plot(x, y, marker='o', linestyle='None',
        color='#367588', markersize=8, label='GNINA poses')

b = of3_scores.get(PLOT_METRIC, float("nan"))
if b == b:  # not NaN
    # grey (#808080)
    ax.axhline(b, linestyle="--", linewidth=1, color="#808080", label=f"OpenFold3")
else:
    print(f"(Note: OpenFold3 baseline for {PLOT_METRIC} is NaN, so dashed line is hidden)")

ax.set_xticks(x)
ax.set_xticklabels(dff["pose_id"].astype(int).values)

ax.set_xlabel("Pose ID")

if PLOT_METRIC == "minimizedAffinity":
    ax.set_ylabel("minimizedAffinity (kcal/mol)")
elif PLOT_METRIC == "CNNaffinity":
    ax.set_ylabel("CNNaffinity")
else:
    ax.set_ylabel("CNNscore")

# spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# legend
ax.legend(bbox_to_anchor=(1.02, 0.9), loc='upper left', frameon=False)

plt.tight_layout()
plt.savefig(plot_png, dpi=400)
plt.show()

print(f"Saved plot: {plot_png}")


## **G. Quick visual comparison**

In [None]:
#@title **G1) Compare OpenFold3 vs GNINA poses (animated)**
#@markdown Animate GNINA poses (red) over the OpenFold3 pose (green).

MAX_MODES = 10  #@param {type:"integer"}
INTERVAL_MS = 600  #@param {type:"integer"}

from pathlib import Path
from rdkit import Chem
import py3Dmol

receptor_pdb   = RUN_DIR / "analysis" / "receptor.pdb"
pocket_pdb     = RUN_DIR / "analysis" / "pocket_residues.pdb"
of3_pose_sdf = RUN_DIR / "analysis" / "of3_ligand_pose.sdf"
docked_sdf     = RUN_DIR / "gnina" / "gnina_docked.sdf"

PINE_GREEN = "#01796f"
ROSE_RED   = "#ff033e"

poses = [m for m in Chem.SDMolSupplier(str(docked_sdf), removeHs=False) if m is not None]
if not poses:
    raise RuntimeError("No GNINA poses found. Run F1 first.")

n = max(1, min(int(MAX_MODES), len(poses)))
tmp_sdf = RUN_DIR / "analysis" / f"gnina_first_{n}_poses.sdf"
w = Chem.SDWriter(str(tmp_sdf))
for m in poses[:n]:
    w.write(m)
w.close()

view = py3Dmol.view(width=900, height=520)
view.setViewStyle({'style':'outline','color':'black','width':0.05})

# --- receptor viz ---
view.addModel(open(receptor_pdb).read(), "pdb")
view.setStyle({'model': 0}, {"cartoon": {"color": "white"}})

# --- surface viz ---
view.addSurface(py3Dmol.VDW, {'opacity': 0.6, 'color': 'silver'}, {'model': 0})

current_idx = 0

# --- pocket ---
if pocket_pdb.exists():
    view.addModel(open(pocket_pdb).read(), "pdb")
    current_idx += 1
    view.setStyle({'model': current_idx}, {"stick": {"color": "grey", "radius": 0.25}})

# --- OpenFold3 pose viz ---
view.addModel(open(of3_pose_sdf).read(), "sdf")
current_idx += 1
view.setStyle({'model': current_idx}, {"stick": {"color": PINE_GREEN, "radius": 0.28}})

# --- gnina poses viz ---
gnina_data = open(tmp_sdf).read()
view.addModelsAsFrames(gnina_data, "sdf")
current_idx += 1

view.setStyle({'model': current_idx}, {"stick": {"color": ROSE_RED, "radius": 0.22}})

view.zoomTo()
view.animate({'loop': 'forward', 'interval': int(INTERVAL_MS)})
view.show()

print(f"Animating {n} GNINA poses (rose red) over OpenFold3 pose (pine green).")
print(f"Frames source: {tmp_sdf}")


## **H. QC + interactions**


In [None]:
#@title **H1) Run PoseBusters and ProLIF for QC and interaction comparison**

POSE_SELECTION = "minimizedAffinity"  #@param ["Manual", "CNNscore","CNNaffinity","minimizedAffinity"]
MANUAL_POSE_ID = 1  #@param {type:"integer"}
CONTACT_CUTOFF_A = 6.0

import sys
import json
import datetime, html, subprocess, shutil
from pathlib import Path
from IPython.display import display, HTML
import MDAnalysis as mda
import pandas as pd
import numpy as np

# --- setup & checks ---
chem_path = RUN_DIR / "state" / "chem.json"
PH = float(json.loads(chem_path.read_text()).get("pH", 7.4)) if chem_path.exists() else 7.4
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

receptor_pdb = RUN_DIR / "analysis" / "receptor.pdb"
of3_pose_sdf = RUN_DIR / "analysis" / "of3_ligand_pose.sdf"
dock_sdf = RUN_DIR / "gnina" / "gnina_docked.sdf"
summary_csv = RUN_DIR / "analysis" / "gnina_pose_summary.csv"

if not summary_csv.exists():
    raise RuntimeError("Missing GNINA summary. Run F1 first.")

df_summary = pd.read_csv(summary_csv)
if len(df_summary) == 0:
    raise RuntimeError("No GNINA poses found.")

# pose selection
if POSE_SELECTION == "Manual":
    pose_id = int(MANUAL_POSE_ID)
    print(f"Selected pose_id: {pose_id} (Manual)")
else:
    ascending = (POSE_SELECTION == "minimizedAffinity")
    pose_id = int(df_summary.sort_values(POSE_SELECTION, ascending=ascending).iloc[0]["pose_id"])
    print(f"Selected pose_id: {pose_id} (Best by {POSE_SELECTION})")

best_pose_sdf = RUN_DIR / "analysis" / f"gnina_best_pose_{pose_id}.sdf"

# pose extraction fxn
from rdkit import Chem
sup = Chem.SDMolSupplier(str(dock_sdf), removeHs=False)
m = sup[pose_id-1]
w = Chem.SDWriter(str(best_pose_sdf))
w.write(m); w.close()

# --- try posebusters CLI ---
print("\nRunning PoseBusters...")
for name, sdf_path in [("OpenFold3", of3_pose_sdf), ("GNINA", best_pose_sdf)]:
    csv_out = RUN_DIR / "analysis" / f"posebusters_{name.lower()}_{ts}.csv"
    cmd = [sys.executable, "-m", "posebusters", str(sdf_path), "-p", str(receptor_pdb), "--outfmt", "csv", "--full-report"]
    try:
        p = subprocess.run(cmd, capture_output=True, text=True)
        csv_out.write_text(p.stdout)
        if p.returncode == 0 and len(p.stdout.strip()) > 0:
            print(f"  {name} Report saved.")
        else:
            print(f"  WARNING: {name} PoseBusters failed (empty output).")
    except Exception as e:
        print(f"  Error running PoseBusters for {name}: {e}")

# --- ensure right protonation ---
print(f"\nChecking protonation at (pH {PH})...")
receptor_h_pdb = RUN_DIR / "analysis" / "receptor_prolif_H.pdb"
of3_h_sdf = RUN_DIR / "analysis" / "of3_pose_prolif_H.sdf"
gnina_h_sdf = RUN_DIR / "analysis" / f"gnina_pose_{pose_id}_prolif_H.sdf"
hlog = RUN_DIR / "logs" / "07_hydrogens.log"

def run_obabel_h(inp, outp, ph):
    run_cmd(f"obabel {inp} -O {outp} -p {ph}", hlog)

run_obabel_h(receptor_pdb, receptor_h_pdb, PH)

# check for UNK resname
try:
    from Bio import PDB
    parser = PDB.PDBParser(QUIET=True)
    ref_st = parser.get_structure("ref", str(receptor_pdb))
    tgt_st = parser.get_structure("tgt", str(receptor_h_pdb))

    mapping = {(c.id, r.id[1], r.id[2]): r.resname for m in ref_st for c in m for r in c}
    fallback = {(r.id[1], r.id[2]): r.resname for m in ref_st for c in m for r in c}

    fixed = 0
    for m in tgt_st:
        for c in m:
            for r in c:
                val = mapping.get((c.id, r.id[1], r.id[2])) or fallback.get((r.id[1], r.id[2]))
                if val and r.resname != val:
                    r.resname = val
                    fixed += 1
    if fixed > 0:
        io = PDB.PDBIO()
        io.set_structure(tgt_st)
        io.save(str(receptor_h_pdb))
        #print(f"  [Patch] Restored {fixed} residue names.")
except Exception as e:
    #print(f"  [Patch] Warning: {e}")
    pass

run_obabel_h(of3_pose_sdf, of3_h_sdf, PH)
run_obabel_h(best_pose_sdf, gnina_h_sdf, PH)

# --- ProLIF fxns ---
print("\nRunning ProLIF (Interactions)...")
try:
    import prolif as plf
    from prolif.interactions import VdWContact

    prot_mol = Chem.MolFromPDBFile(str(receptor_h_pdb), removeHs=False, sanitize=False)
    Chem.SanitizeMol(prot_mol, Chem.SanitizeFlags.SANITIZE_ALL ^ Chem.SanitizeFlags.SANITIZE_PROPERTIES)
    prot = plf.Molecule.from_rdkit(prot_mol)

    # fingerprint params to try
    fp = plf.Fingerprint(interactions=[
        "Hydrophobic", "HBDonor", "HBAcceptor", "PiStacking",
        "Anionic", "Cationic", "CationPi", "PiCation",
        "VdWContact",
        "XBDonor", "XBAcceptor",
        "MetalDonor", "MetalAcceptor"
    ])
    lig_files = {"OpenFold3": of3_h_sdf, "GNINA": gnina_h_sdf}
    html_outputs = []

    for name, fpath in lig_files.items():
        try:
            suppl = Chem.SDMolSupplier(str(fpath), removeHs=False)
            lig_mol = suppl[0]
            # fix Resname for plot
            for atom in lig_mol.GetAtoms():
                mi = Chem.AtomPDBResidueInfo()
                mi.SetResidueName("LIG")
                mi.SetResidueNumber(1)
                mi.SetIsHeteroAtom(True)
                atom.SetMonomerInfo(mi)

            lig = plf.Molecule.from_rdkit(lig_mol)

            # try run
            fp.run_from_iterable([lig], prot)
            df = fp.to_dataframe()

            csv_path = RUN_DIR / "analysis" / f"prolif_{name.lower()}_ifp.csv"
            html_path = RUN_DIR / "analysis" / f"prolif_{name.lower()}.html"

            if df.empty or (df.shape[1] == 0):
                print(f"  {name}: No interactions found (even VdW). Check geometry.")
                html_outputs.append(f"<div style='flex:1; border:1px solid red; padding:10px;'><h4>{name}</h4><p>No interactions detected.</p></div>")
            else:
                df.to_csv(csv_path)
                net = fp.plot_lignetwork(ligand_mol=lig, kind="frame", frame=0)
                net.save(str(html_path))
                src = html.escape(Path(html_path).read_text())
                html_outputs.append(f"""
                <div style="flex: 1; border: 1px solid #ccc; padding: 5px;">
                  <h4 style="text-align:center;">{name} Pose</h4>
                  <iframe srcdoc="{src}" width="100%" height="600px" style="border:none;"></iframe>
                </div>""")
                print(f"  {name}: Success! ({int(df.sum(axis=1).sum())} interactions)")

        except Exception as e:
            print(f"  {name} failed: {e}")
            html_outputs.append(f"<div style='flex:1;'><h4>{name} Error</h4><p>{e}</p></div>")

    display(HTML(f'<div style="display: flex; flex-direction: row; gap: 20px; width: 100%;">{ "".join(html_outputs) }</div>'))

except Exception as e:
    print(f"CRITICAL ProLIF ERROR: {e}")

# --- debugging neighborhood check ---
#def count_contacts(lig_p):
#    try:
        #tmp = lig_p.parent / (lig_p.stem + "_tmp.pdb")
        #run_cmd(f"obabel -isdf {lig_p} -opdb -O {tmp}", hlog)
        #u_p = mda.Universe(str(receptor_pdb))
        #u_l = mda.Universe(str(tmp))
        #d = mda.lib.distances.distance_array(u_l.atoms.positions, u_p.select_atoms("protein").positions)
        #return len(u_p.select_atoms("protein")[np.any(d < CONTACT_CUTOFF_A, axis=0)].residues)
    #except: return -1

#print(f"\nResidues within {CONTACT_CUTOFF_A} Å:")
#print(f"  OpenFold3 : {count_contacts(of3_pose_sdf)}")
#print(f"  GNINA : {count_contacts(best_pose_sdf)}")


## **I. Export results**

Compress the project folder (structures, SDFs, CSVs, logs).


In [None]:
#@title **I1) Export project folder**
#@markdown Choose what to export, then choose compression format.

EXPORT_SCOPE = "current_run" #@param ["current_run", "full_project"]
COMPRESSION_FORMAT = "zip" #@param ["zip", "tar", "gztar"]

import shutil
from pathlib import Path

# --- project-level name ---
try:
    PROJECT_NAME
except NameError:
    # best-effort fallback
    PROJECT_NAME = Path(PROJECT_ROOT).name if "PROJECT_ROOT" in globals() else "cloudbind_project"

ext_map = {"zip": "zip", "tar": "tar", "gztar": "tar.gz"}
ext = ext_map.get(COMPRESSION_FORMAT, "zip")

if EXPORT_SCOPE == "current_run":
    root_dir = RUN_DIR.parent
    base_dir = RUN_DIR.name
    out_base = Path("/content") / f"{RUN_NAME}_run"
    what = f"run folder: {RUN_DIR}"
else:
    root_dir = PROJECT_ROOT
    base_dir = None
    out_base = Path("/content") / f"{PROJECT_NAME}_full_project"
    what = f"full project: {PROJECT_ROOT}"

archive_path = Path(f"{out_base}.{ext}")

if archive_path.exists():
    archive_path.unlink()

print(f"Compressing {what} -> {archive_path} ...")

if base_dir is None:
    shutil.make_archive(str(out_base), COMPRESSION_FORMAT, root_dir=str(root_dir))
else:
    shutil.make_archive(str(out_base), COMPRESSION_FORMAT, root_dir=str(root_dir), base_dir=str(base_dir))

if archive_path.exists():
    sz_mb = archive_path.stat().st_size / (1024 * 1024)
    print(f"Created: {archive_path}")
    print(f"Size   : {sz_mb:.2f} MB")
else:
    print("Error: Archive creation failed.")

# download prompt
try:
    from google.colab import files
    files.download(str(archive_path))
except Exception:
    print("(Download is available in Colab only. Check the Files sidebar if the prompt doesn't appear.)")


## **M. Manifest (Optional)**


In [None]:
#@title **M1) Write a manifest for reproducibility and debugging**

import json, platform, subprocess, datetime
from pathlib import Path

manifest_path = RUN_DIR / "export" / "manifest.json"
manifest_path.parent.mkdir(parents=True, exist_ok=True)

def _sh(cmd):
    try:
        return subprocess.check_output(cmd, shell=True, text=True).strip()
    except Exception:
        return ""

paths_state = RUN_DIR / "state" / "paths.json"
paths = json.loads(paths_state.read_text()) if paths_state.exists() else {}

manifest = {
    "timestamp": datetime.datetime.now().isoformat(),
    "platform": {
        "python": platform.python_version(),
        "os": platform.platform(),
        "uname": platform.uname()._asdict(),
    },
    "tools": {
        "openfold3": OPENFOLD3_VERSION,
        "gnina": GNINA_VERSION,
        "p2rank": P2RANK_VERSION,
        "gnina_path": str(GNINA_PATH),
        "p2rank_prank": str(P2RANK_PRANK),
        "run_openfold": str(RUN_OPENFOLD),
    },
    "run": json.loads((RUN_DIR / "state" / "run_settings.json").read_text()) if (RUN_DIR / "state" / "run_settings.json").exists() else {},
    "paths": paths,
    "logs_tail": {
        "install": _sh(f"tail -n 80 {Path('/content/cloudbind_project/logs/01_install.log')}") if Path("/content/cloudbind_project/logs/01_install.log").exists() else "",
        "openfold3": _sh(f"tail -n 120 {RUN_DIR / 'logs' / '04_openfold3.log'}") if (RUN_DIR / "logs" / "04_openfold3.log").exists() else "",
        "gnina": _sh(f"tail -n 120 {RUN_DIR / 'logs' / '06_gnina.log'}") if (RUN_DIR / "logs" / "06_gnina.log").exists() else "",
    },
    "versions": {
        "torch": _sh("python -c \"import torch; print(torch.__version__)\""),
        "rdkit": _sh("python -c \"from rdkit import Chem; import rdkit; print(rdkit.__version__)\""),
        "MDAnalysis": _sh("python -c \"import MDAnalysis as mda; print(mda.__version__)\""),
    },
}

manifest_path.write_text(json.dumps(manifest, indent=2))
print(f"Wrote: {manifest_path}")
