#**RFdiffusion v1.1.1**
RFdiffusion is a method for structure generation, with or without conditional information (a motif, target etc). It can perform a whole range of protein design challenges as we have outlined in the RFdiffusion [manuscript](https://www.biorxiv.org/content/10.1101/2022.12.09.519842v2).


For **instructions**, see end of Notebook.

**<font color="red">NOTE:</font>**  This is tagged v1.1.1 of the notebook, this notebook may break in the future when colab updates. For latest version see [main](https://colab.research.google.com/github/sokrypton/ColabDesign/blob/main/rf/examples/diffusion.ipynb) branch.

Additional Notebooks:

- See [diffusion_foldcond](https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/rf/examples/diffusion_foldcond.ipynb) for fold conditioning functionality.

- See [original version](https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/rf/examples/diffusion_ori.ipynb) of this notebook (from 31Mar2023).


In [ ]:
import numpy as np
import os
import sys
import time
import json
import matplotlib.pyplot as plt
import py3Dmol
from colabdesign.shared.plot import plot_pseudo_3D

RFDIFFUSION_DIR = "/clusterfs/nilah/sergio/RFdifussion/sokrypton/RFdiffusion"
MODEL_DIR = os.path.join(RFDIFFUSION_DIR, "models")
SCHEDULES_DIR = os.path.join(RFDIFFUSION_DIR, "schedules")
OUTPUT_DIR = os.path.join("/clusterfs/nilah/sergio/RFdifussion", "outputs")
BASE_DIR = "/clusterfs/nilah/sergio/RFdifussion"

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Debug sys.path and directories
print("sys.path:", sys.path)
print("BASE_DIR:", BASE_DIR)
print("RFDIFFUSION_DIR:", RFDIFFUSION_DIR)
print("MODEL_DIR:", MODEL_DIR)
print("OUTPUT_DIR:", OUTPUT_DIR)
print("SCHEDULES_DIR:", SCHEDULES_DIR)
print("Exists BASE_DIR:", os.path.exists(BASE_DIR))
print("Exists RFDIFFUSION_DIR:", os.path.exists(RFDIFFUSION_DIR))
print("Exists MODEL_DIR:", os.path.exists(MODEL_DIR))
print("Exists SCHEDULES_DIR:", os.path.exists(SCHEDULES_DIR))

# Verify ananas binary
ANANAS_PATH = os.path.join(BASE_DIR, "ananas")
if not os.path.exists(ANANAS_PATH):
    print("Downloading ananas...")
    os.system(f"wget -qnc https://files.ipd.uw.edu/krypton/ananas -O {ANANAS_PATH}")
    os.system(f"chmod +x {ANANAS_PATH}")

# Verify model files
required_models = ["Base_ckpt.pt", "Complex_base_ckpt.pt"]
for model in required_models:
    model_path = os.path.join(MODEL_DIR, model)
    print(f"Checking model: {model_path}")
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file {model} not found at {model_path}")

# Verify schedules directory
if not os.path.exists(SCHEDULES_DIR):
    raise FileNotFoundError(f"Schedules directory not found at {SCHEDULES_DIR}")

# Add RFdiffusion to sys.path
if RFDIFFUSION_DIR not in sys.path:
    os.environ["DGLBACKEND"] = "pytorch"
    sys.path.append(RFDIFFUSION_DIR)
    print(f"Added {RFDIFFUSION_DIR} to sys.path")

if os.path.join(RFDIFFUSION_DIR, "rfdiffusion") not in sys.path:
    sys.path.append(os.path.join(RFDIFFUSION_DIR, "rfdiffusion"))


# Import RFdiffusion utilities
try:
    from inference.utils import parse_pdb
    from colabdesign.rf.utils import get_ca, fix_contigs, fix_partial_contigs, fix_pdb, sym_it
    from colabdesign.shared.protein import pdb_to_string
except ImportError as e:
    print(f"Error importing utilities: {e}")
    print("Ensure ColabDesign is installed: pip install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1")
    print("Verify RFdiffusion directory at", RFDIFFUSION_DIR)
    print("Contents of RFDIFFUSION_DIR:", os.listdir(RFDIFFUSION_DIR))
    raise

def get_pdb(pdb_code=None):
    """Retrieve or validate PDB file."""
    if pdb_code is None or pdb_code == "":
        raise ValueError("No PDB code provided.")
    elif os.path.isfile(pdb_code):
        return pdb_code
    elif len(pdb_code) == 4:
        pdb_file = os.path.join(BASE_DIR, f"{pdb_code}.pdb1")
        if not os.path.exists(pdb_file):
            os.system(f"wget -qnc https://files.rcsb.org/download/{pdb_code}.pdb1.gz -O {pdb_file}.gz")
            os.system(f"gunzip {pdb_file}.gz")
        return pdb_file
    else:
        pdb_file = os.path.join(BASE_DIR, f"AF-{pdb_code}-F1-model_v3.pdb")
        if not os.path.exists(pdb_file):
            os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb -O {pdb_file}")
        return pdb_file

def run_ananas(pdb_str, path, sym=None):
    """Run AnAnaS for symmetry detection."""
    os.makedirs(os.path.join(OUTPUT_DIR, path), exist_ok=True)
    pdb_filename = os.path.join(OUTPUT_DIR, path, "ananas_input.pdb")
    out_filename = os.path.join(OUTPUT_DIR, path, "ananas.json")
    with open(pdb_filename, "w") as handle:
        handle.write(pdb_str)

    cmd = f"{ANANAS_PATH} {pdb_filename} -u -j {out_filename}"
    if sym:
        cmd += f" {sym}"
    os.system(cmd)

    try:
        out = json.loads(open(out_filename, "r").read())
        results, AU = out[0], out[-1]["AU"]
        group = AU["group"]
        chains = AU["chain names"]
        rmsd = results["Average_RMSD"]
        print(f"AnAnaS detected {group} symmetry at RMSD:{rmsd:.3}")

        C = np.array(results['transforms'][0]['CENTER'])
        A = [np.array(t["AXIS"]) for t in results['transforms']]

        new_lines = []
        for line in pdb_str.split("\n"):
            if line.startswith("ATOM"):
                chain = line[21:22]
                if chain in chains:
                    x = np.array([float(line[i:(i+8)]) for i in [30, 38, 46]])
                    if group[0] == "c":
                        x = sym_it(x, C, A[0])
                    if group[0] == "d":
                        x = sym_it(x, C, A[1], A[0])
                    coord_str = "".join(["{:8.3f}".format(a) for a in x])
                    new_lines.append(line[:30] + coord_str + line[54:])
            else:
                new_lines.append(line)
        return results, "\n".join(new_lines)
    except:
        return None, pdb_str

def run(command, steps, num_designs=1, visual="none"):
    """Execute RFdiffusion command and monitor progress."""
    def run_command_and_get_pid(command):
        pid_file = os.path.join(BASE_DIR, "pid")
        os.system(f'{command} > {BASE_DIR}/run.log 2>&1 & echo $! > {pid_file}')
        with open(pid_file, 'r') as f:
            pid = int(f.read().strip())
        os.remove(pid_file)
        return pid

    def is_process_running(pid):
        try:
            os.kill(pid, 0)
            return True
        except OSError:
            return False

    print(f"Running command: {command}")
    pid = run_command_and_get_pid(command)
    try:
        for _ in range(num_designs):
            for n in range(steps):
                wait = True
                while wait:
                    time.sleep(1)
                    pdb_file = os.path.join(BASE_DIR, f"{n}.pdb")
                    if os.path.exists(pdb_file):
                        with open(pdb_file, "r") as f:
                            pdb_str = f.read()
                        if pdb_str.strip().endswith("TER"):
                            wait = False
                        elif not is_process_running(pid):
                            print("Process failed.")
                            return
                    elif not is_process_running(pid):
                        print("Process terminated unexpectedly.")
                        return

                if visual != "none":
                    if visual == "image":
                        xyz, bfact = get_ca(pdb_file, get_bfact=True)
                        fig = plt.figure(figsize=(6, 6), dpi=100)
                        ax = fig.add_subplot(111)
                        ax.set_xticks([]); ax.set_yticks([])
                        plot_pseudo_3D(xyz, c=bfact, cmin=0.5, cmax=0.9, ax=ax)
                        plt.savefig(os.path.join(OUTPUT_DIR, f"step_{n}.png"))
                        plt.close()
                    elif visual == "interactive":
                        view = py3Dmol.view()
                        view.addModel(pdb_str, 'pdb')
                        view.setStyle({'cartoon': {'colorscheme': {'prop': 'b', 'gradient': 'roygb', 'min': 0.5, 'max': 0.9}}})
                        view.zoomTo()
                        view.save_html(os.path.join(OUTPUT_DIR, f"step_{n}.html"))

                if os.path.exists(pdb_file):
                    os.remove(pdb_file)

        while is_process_running(pid):
            time.sleep(1)

    except KeyboardInterrupt:
        os.system(f"kill -TERM {pid}")
        print("Process stopped.")

def run_diffusion(contigs, path, pdb=None, iterations=50,
                  symmetry="none", order=1, hotspot=None,
                  chains=None, add_potential=False,
                  num_designs=1, visual="none"):
    """Run RFdiffusion inference."""
    full_path = os.path.join(OUTPUT_DIR, path)
    os.makedirs(full_path, exist_ok=True)
    opts = [f"inference.output_prefix={full_path}",
            f"inference.num_designs={num_designs}"]

    if chains == "":
        chains = None

    # Determine symmetry type
    if symmetry in ["auto", "cyclic", "dihedral"]:
        if symmetry == "auto":
            sym, copies = None, 1
        else:
            sym, copies = {"cyclic": (f"c{order}", order),
                          "dihedral": (f"d{order}", order*2)}[symmetry]
    else:
        symmetry = None
        sym, copies = None, 1

    # Determine mode
    contigs = contigs.replace(",", " ").replace(":", " ").split()
    is_fixed, is_free = False, False
    fixed_chains = []
    for contig in contigs:
        for x in contig.split("/"):
            a = x.split("-")[0]
            if a[0].isalpha():
                is_fixed = True
                if a[0] not in fixed_chains:
                    fixed_chains.append(a[0])
            if a.isnumeric():
                is_free = True
    if len(contigs) == 0 or not is_free:
        mode = "partial"
    elif is_fixed:
        mode = "fixed"
    else:
        mode = "free"

    # Fix input contigs
    if mode in ["partial", "fixed"]:
        pdb_str = pdb_to_string(get_pdb(pdb), chains=chains)
        if symmetry == "auto":
            a, pdb_str = run_ananas(pdb_str, path)
            if a is None:
                print("ERROR: no symmetry detected")
                symmetry = None
                sym, copies = None, 1
            else:
                if a["group"][0] == "c":
                    symmetry = "cyclic"
                    sym, copies = a["group"], int(a["group"][1:])
                elif a["group"][0] == "d":
                    symmetry = "dihedral"
                    sym, copies = a["group"], 2 * int(a["group"][1:])
                else:
                    print(f"ERROR: detected symmetry ({a['group']}) not supported")
                    symmetry = None
                    sym, copies = None, 1

        elif mode == "fixed":
            pdb_str = pdb_to_string(pdb_str, chains=fixed_chains)

        pdb_filename = os.path.join(full_path, "input.pdb")
        with open(pdb_filename, "w") as handle:
            handle.write(pdb_str)

        parsed_pdb = parse_pdb(pdb_filename)
        opts.append(f"inference.input_pdb={pdb_filename}")
        if mode == "partial":
            iterations = int(80 * (iterations / 200))
            opts.append(f"diffuser.partial_T={iterations}")
            contigs = fix_partial_contigs(contigs, parsed_pdb)
        else:
            opts.append(f"diffuser.T={iterations}")
            contigs = fix_contigs(contigs, parsed_pdb)
    else:
        opts.append(f"diffuser.T={iterations}")
        parsed_pdb = None
        contigs = fix_contigs(contigs, parsed_pdb)

    if hotspot and hotspot != "":
        opts.append(f"ppi.hotspot_res=[{hotspot}]")

    # Setup symmetry
    if sym is not None:
        sym_opts = ["--config-name symmetry", f"inference.symmetry={sym}"]
        if add_potential:
            sym_opts += ["'potentials.guiding_potentials=[\"type:olig_contacts,weight_intra:1,weight_inter:0.1\"]'",
                        "potentials.olig_intra_all=True", "potentials.olig_inter_all=True",
                        "potentials.guide_scale=2", "potentials.guide_decay=quadratic"]
        opts = sym_opts + opts
        contigs = sum([contigs] * copies, [])

    opts.append(f"'contigmap.contigs=[{' '.join(contigs)}]'")
    opts.append("inference.dump_pdb=True")
    opts.append(f"inference.dump_pdb_path={full_path}")

    print("Mode:", mode)
    print("Output:", full_path)
    print("Contigs:", contigs)

    opts_str = " ".join(opts)
    cmd = f"python {RFDIFFUSION_DIR}/run_inference.py {opts_str}"
    print("Command:", cmd)

    # Run
    run(cmd, iterations, num_designs, visual=visual)

    # Fix PDBs
    for n in range(num_designs):
        pdbs = [f"{full_path}/traj/{path}_{n}_pX0_traj.pdb",
                f"{full_path}/traj/{path}_{n}_Xt-1_traj.pdb",
                f"{full_path}_{n}.pdb"]
        for pdb in pdbs:
            if os.path.exists(pdb):
                with open(pdb, "r") as handle:
                    pdb_str = handle.read()
                with open(pdb, "w") as handle:
                    handle.write(fix_pdb(pdb_str, contigs))

    return contigs, copies

import os
import subprocess

BASE_DIR = "/clusterfs/nilah/sergio/RFdifussion"
PARAM_DIR = os.path.join(BASE_DIR, "params")

# Create param dir if not exists
os.makedirs(PARAM_DIR, exist_ok=True)

# Download AlphaFold params only if needed
done_file = os.path.join(PARAM_DIR, "done.txt")
if not os.path.isfile(done_file):
    print("Downloading AlphaFold parameters...")
    alphafold_tar = os.path.join(PARAM_DIR, "alphafold_params_2022-12-06.tar")

    subprocess.run([
        "wget",
        "-O", alphafold_tar,
        "https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar"
    ], check=True)

    print("Extracting AlphaFold parameters...")
    subprocess.run(["tar", "-xf", alphafold_tar, "-C", PARAM_DIR], check=True)

    with open(done_file, "w") as f:
        f.write("Done.")

def visualize_epitope_mhc(pdb_path, output_png, epitope_chain="D"):
    """
    Visualize the epitope-MHC complex from a PDB file and save as PNG.
    
    Args:
        pdb_path (str): Path to the RFdiffusion output PDB (e.g., RUN2_0.pdb).
        output_png (str): Path to save the PNG image.
        epitope_chain (str): Chain ID of the epitope (assumed to be the designed peptide).
    """
    # Read PDB file
    with open(pdb_path, "r") as f:
        pdb_str = f.read()

    # Initialize py3Dmol view
    view = py3Dmol.view(width=800, height=600)

    # Add the PDB model
    view.addModel(pdb_str, "pdb")

    # Style for MHC (chains A, B, C)
    view.setStyle({"chain": {"A", "B", "C"}}, 
                  {"cartoon": {"color": "lightblue", "opacity": 0.8}})

    # Style for epitope (assumed to be chain D or the designed peptide)
    view.setStyle({"chain": epitope_chain}, 
                  {"cartoon": {"color": "red", "thickness": 1.0}, 
                   "stick": {"color": "red"}})

    # Zoom and center
    view.zoomTo()
    view.setBackgroundColor("white")

    # Save as PNG
    view.render()
    view.png(output_png) #, width=800, height=600)
    print(f"Saved PNG to {output_png}")

    # Alternative: Pseudo-3D plot with matplotlib
    xyz, bfact = get_ca(pdb_path, get_bfact=True)
    fig = plt.figure(figsize=(6, 6), dpi=100)
    ax = fig.add_subplot(111)
    ax.set_xticks([]); ax.set_yticks([])
    plot_pseudo_3D(xyz, c=bfact, cmin=0.5, cmax=0.9, ax=ax)
    plt.savefig(output_png.replace(".png", "_pseudo3D.png"))
    plt.close()
    print(f"Saved pseudo-3D PNG to {output_png.replace('.png', '_pseudo3D.png')}")
# Set parameters from arguments
# name = args.name
# contigs = args.contigs



In [ ]:
# #@title run **RFdiffusion** to generate a backbone
name = "SECONDTRY" 

OUTPUT_DIR = '/clusterfs/nilah/sergio/RFdifussion/outputs'

# determine where to save
path = name
while os.path.exists(f"outputs/{name}_0.pdb"):
  path = name + "_" + ''.join(random.choices(string.ascii_lowercase + string.digits, k=5))


OUTPUT_DIR = os.path.join(OUTPUT_DIR, path)

In [ ]:
# Fixed parameters
pdb = "/clusterfs/nilah/sergio/RFdifussion/RFdiffusion/structures/7RTD.pdb"
hotspot = "A7,A9,A59,A63,A66,A70,A99,A159,A167"
num_designs = 1
contigs = "A1-274 B1-100 0/9"
visual = "none"
symmetry = "none"
order = 1
chains = "A,B,C"
add_potential = False
iterations = 25

flags = {"contigs":contigs,
         "pdb":pdb,
         "order":order,
         "iterations":iterations,
         "symmetry":symmetry,
         "hotspot":hotspot,
         "path":path,
         "chains":chains,
         "add_potential":add_potential,
         "num_designs":num_designs,
         "visual":visual}

for k,v in flags.items():
  if isinstance(v,str):
    flags[k] = v.replace("'","").replace('"','')
# %%time
contigs, copies = run_diffusion(**flags)
#cat /clusterfs/nilah/sergio/RFdifussion/run.log




In [ ]:
import os, subprocess

OUTDIR = "/clusterfs/nilah/sergio/RFdifussion/outputs"
rf_pdb = f"{OUTDIR}/{name}/{name}_0.pdb"     # correct path
loc_dir = f"{OUTDIR}/{name}"
contigs_str = ":".join(contigs)  # fixed A+B, designable C1‑9
copies  = 1

num_seqs = 4
num_recycles = 3
rm_aa   = "C"
temp    = 0.1
num_designs = 1

if not os.path.isfile(rf_pdb) or os.stat(rf_pdb).st_size == 0:
    raise RuntimeError(f"{rf_pdb} missing or empty – diffusion probably failed")

os.environ["COLABDESIGN_AF_DIR"] = "/clusterfs/nilah/sergio/RFdifussion/params"

opts = [
    f"--pdb={rf_pdb}",
    f"--loc={loc_dir}",
    f"--contig={contigs_str}",
    f"--copies={copies}",
    f"--num_seqs={num_seqs}",
    f"--num_recycles={num_recycles}",
    f"--rm_aa={rm_aa}",
    f"--mpnn_sampling_temp={temp}",
    f"--num_designs={num_designs}"
]

subprocess.run(
    ["python",
     "/clusterfs/nilah/sergio/miniconda3/envs/SE3nv/lib/python3.9/site-packages/colabdesign/rf/designability_test.py"]
    + opts,
    check=True
)


In [ ]:
import os, time, subprocess

num_seqs = 2
initial_guess = False
num_recycles = 1
use_multimer = False
rm_aa = "C"
mpnn_sampling_temp = 0.1

# Wait for AlphaFold params
max_wait = 300
elapsed = 0

# Absolute path setup
output_dir = f'{OUTPUT_DIR}/'
contigs_str = ":".join(contigs)

opts = [
    f"--pdb={output_dir}{name}_0.pdb",
    f"--loc={output_dir}",
    f"--contig={contigs_str}",
    f"--copies={copies}",
    f"--num_seqs={num_seqs}",
    f"--num_recycles={num_recycles}",
    f"--rm_aa={rm_aa}",
    f"--mpnn_sampling_temp={mpnn_sampling_temp}",
    f"--num_designs={num_designs}",
]
if initial_guess: opts.append("--initial_guess")
if use_multimer: opts.append("--use_multimer")
# opts.append("--num_models=1")

# Set AlphaFold weights path for ColabDesign
os.environ["COLABDESIGN_AF_DIR"] = "/clusterfs/nilah/sergio/RFdifussion/params"

subprocess.run(["python", "/clusterfs/nilah/sergio/miniconda3/envs/SE3nv/lib/python3.9/site-packages/colabdesign/rf/designability_test.py"] + opts, check=True)


In [ ]:
'A:9'

In [ ]:
contigs

In [ ]:
":".join(contigs)

In [ ]:
from Bio.PDB import PDBParser, Superimposer, PDBIO

alt_pdb_input_path = "/clusterfs/nilah/sergio/RFdifussion/RFdiffusion/structures/7RTD.pdb"

# Check MHC file
if os.path.exists(alt_pdb_input_path):
    pdb_input_path = alt_pdb_input_path
elif not os.path.exists(pdb_input_path):
    # Try downloading using get_pdb logic
    pdb_file = os.path.join(BASE_DIR, f"{pdb_input}.pdb1")
    if not os.path.exists(pdb_file):
        os.system(f"wget -qnc https://files.rcsb.org/download/{pdb_input}.pdb1.gz -O {pdb_file}.gz")
        os.system(f"gunzip {pdb_file}.gz")
    pdb_input_path = pdb_file
if not os.path.exists(pdb_input_path):
    raise FileNotFoundError(f"MHC PDB file not found at {pdb_input_path} or alternatives")

# Ensure OUTPUT_DIR exists
os.makedirs(OUTPUT_DIR, exist_ok=True)

def align_epitope_to_mhc(epitope_pdb, mhc_pdb, epitope_chain="A", ref_chain="C"):
    """
    Align epitope to the reference peptide in MHC (chain C).
    
    Args:
        epitope_pdb (str): Path to epitope PDB (e.g., EPITOPE_TRIAL8/0.pdb).
        mhc_pdb (str): Path to MHC PDB (7rtd.pdb).
        epitope_chain (str): Chain ID of epitope in epitope_pdb.
        ref_chain (str): Chain ID of reference peptide in mhc_pdb.
    
    Returns:
        str: Path to aligned epitope PDB.
    """
    parser = PDBParser(QUIET=True)
    try:
        mhc_struct = parser.get_structure("mhc", mhc_pdb)
        epitope_struct = parser.get_structure("epitope", epitope_pdb)
    except Exception as e:
        print(f"Error parsing PDB files: {e}")
        return epitope_pdb
    
    # Get C-alpha atoms
    try:
        ref_atoms = [a for a in mhc_struct[0][ref_chain].get_atoms() if a.name == "CA"]
        moving_atoms = [a for a in epitope_struct[0][epitope_chain].get_atoms() if a.name == "CA"]
    except KeyError as e:
        print(f"Chain {ref_chain} not found in MHC or {epitope_chain} in epitope: {e}")
        return epitope_pdb
    
    # Ensure 9 residues
    if len(ref_atoms) < 9 or len(moving_atoms) != 9:
        print(f"Warning: Reference chain {ref_chain} has {len(ref_atoms)} CA atoms, epitope has {len(moving_atoms)}")
        return epitope_pdb
    
    # Align
    sup = Superimposer()
    sup.set_atoms(ref_atoms[:9], moving_atoms[:9])
    sup.apply(epitope_struct[0])
    
    # Save aligned epitope
    aligned_pdb = epitope_pdb.replace(".pdb", "_aligned.pdb")
    io = PDBIO()
    io.set_structure(epitope_struct)
    io.save(aligned_pdb)
    return aligned_pdb

def combine_epitope_mhc(epitope_pdb, mhc_pdb, epitope_chain="D"):
    """
    Combine epitope and MHC structures into a single PDB string.
    
    Args:
        epitope_pdb (str): Path to epitope PDB.
        mhc_pdb (str): Path to MHC PDB.
        epitope_chain (str): New chain ID for epitope.
    
    Returns:
        str: Combined PDB string.
    """
    try:
        mhc_str = pdb_to_string(mhc_pdb, chains=["A", "B", "C"])
    except Exception as e:
        raise RuntimeError(f"Failed to load MHC PDB {mhc_pdb}: {e}")
    
    if not os.path.exists(epitope_pdb):
        raise FileNotFoundError(f"Epitope PDB {epitope_pdb} not found")

    with open(epitope_pdb, "r") as f:
        epitope_lines = f.readlines()

    epitope_str = ""
    for line in epitope_lines:
        if line.startswith("ATOM"):
            line = line[:21] + epitope_chain + line[22:]
            epitope_str += line
        elif line.startswith("TER") or line.startswith("END"):
            epitope_str += f"TER   {epitope_chain}\n"

    return mhc_str + "TER\n" + epitope_str + "END\n"

def visualize_epitope_mhc(epitope_pdb, mhc_pdb, output_base, epitope_chain="D"):
    """
    Visualize the epitope-MHC complex and save as PNG using PyMOL.
    
    Args:
        epitope_pdb (str): Path to epitope PDB.
        mhc_pdb (str): Path to MHC PDB.
        output_base (str): Base path for output files.
        epitope_chain (str): Chain ID for the epitope.
    """
    # Align epitope
    aligned_epitope_pdb = align_epitope_to_mhc(epitope_pdb, mhc_pdb, epitope_chain="A", ref_chain="C")
    
    # Combine structures
    combined_pdb_str = combine_epitope_mhc(aligned_epitope_pdb, mhc_pdb, epitope_chain)
    combined_pdb_file = output_base + "_combined.pdb"
    with open(combined_pdb_file, "w") as f:
        f.write(combined_pdb_str)

    # Visualize with PyMOL
    try:
        import pymol
        from pymol import cmd
        cmd.reinitialize()
        cmd.load(combined_pdb_file, "complex")
        
        # Style MHC
        cmd.set("cartoon_color", "lightblue", "complex and chain A+B+C")
        cmd.set("cartoon_transparency", 0.2, "complex and chain A+B+C")
        
        # Style epitope
        cmd.set("cartoon_color", "red", f"complex and chain {epitope_chain}")
        cmd.show("sticks", f"complex and chain {epitope_chain}")
        
        # Style hotspots
        # hotspot_residues = [7, 9, 59, 63, 66, 70, 99, 159, 167]
        # cmd.show("spheres", f"complex and chain A and resi {'+'.join(map(str, hotspot_residues))}")
        # cmd.set("sphere_color", "yellow", f"complex and chain A and resi {'+'.join(map(str, hotspot_residues))}")
        # cmd.set("sphere_transparency", 0.3)
        
                # Orient the camera to look top-down onto the epitope (Chain D)
        cmd.orient(f"complex and chain {epitope_chain}")
        # cmd.turn("x", 90) 
        # cmd.turn("z", 0)
        # cmd.turn("y", 45)  # or "y", "z" — adjust depending on your structure
        # cmd.zoom(f"complex and chain {epitope_chain}", 5)  # Zoom tighter on epitope
        cmd.set_view((0.5800932049751282, 0.019943242892622948, -0.81430584192276, 0.7175763845443726, 0.4605588912963867, 0.5224648714065552, 0.38545551896095276, -0.8874051570892334, 0.2528563439846039, 0.0, 0.0, -109.06678771972656, 13.777585983276367, 12.794111251831055, 66.28702545166016, 85.98908996582031, 132.1444854736328, -20.0))


        # Zoom and save
        # cmd.zoom()
        png_path = output_base + ".png"
        cmd.png(png_path, width=800, height=600)
        # print(f"Saved PNG to {png_path}")
        
        # Clean up
        cmd.delete("all")
        # os.remove(combined_pdb_file)
        if aligned_epitope_pdb != epitope_pdb:
            os.remove(aligned_epitope_pdb)
    except Exception as e:
        print(f"Error visualizing with PyMOL: {e}")
        return

    # Pseudo-3D plot
    try:
        xyz, bfact = get_ca(epitope_pdb, get_bfact=True)
        fig = plt.figure(figsize=(6, 6), dpi=100)
        ax = fig.add_subplot(111)
        ax.set_xticks([]); ax.set_yticks([])
        plot_pseudo_3D(xyz, c=bfact, cmin=0.5, cmax=0.9, ax=ax)
        pseudo_png_path = output_base + "_pseudo3D.png"
        plt.savefig(pseudo_png_path)
        plt.close()
        # print(f"Saved pseudo-3D plot to {pseudo_png_path}")
    except Exception as e:
        print(f"Error creating pseudo-3D plot: {e}")

In [ ]:
# cmd.reinitialize()
# cmd.load("/clusterfs/nilah/sergio/RFdifussion/outputs/EPITOPE_TRIAL8/EPITOPE_TRIAL8/step_0_epitope_mhc_combined.pdb", "complex")

# # Focus and orient on epitope
# cmd.orient("complex and chain D")
# cmd.turn("x", 90)
# cmd.turn("y", 45)
# cmd.zoom("complex and chain D", 15)

# # Now pause here, look at it visually in PyMOL viewer
# # If you like it — store that view
# ref_view = cmd.get_view()

# png_path = 'TRIALVIEW' + ".png"
# cmd.png(png_path, width=800, height=600)
# print(f"Saved PNG to {png_path}")


In [ ]:
for i in range(0, iterations):
    epitope_pdb = os.path.join(OUTPUT_DIR, f"{name}/{i}.pdb")
    output_base = os.path.join(OUTPUT_DIR, f"step_{i}_epitope_mhc")
    if os.path.exists(epitope_pdb):
        try:
            visualize_epitope_mhc(epitope_pdb, pdb_input_path, output_base, epitope_chain="D")
        except Exception as e:
            print(f"Failed to visualize {epitope_pdb}: {e}")
    else:
        print(f"Missing epitope: {epitope_pdb}")

In [ ]:
from PIL import Image
import os

# Directory and output
img_dir = f"{OUTPUT_DIR}"
output_gif = os.path.join(img_dir, "epitope_mhc_rotating.gif")

# Remove old GIF if it exists
if os.path.exists(output_gif):
    os.remove(output_gif)
    print(f"🗑️ Removed old GIF at {output_gif}")

# Load frames
frames = []
for i in range(iterations):  # Adjust range as needed
    img_path = os.path.join(img_dir, f"step_{i}_epitope_mhc.png")
    if os.path.exists(img_path):
        img = Image.open(img_path).convert("RGB").convert("P", palette=Image.ADAPTIVE)
        frames.append(img)

# Save as GIF with disposal=2 to fully replace previous frame
if frames:
    frames[0].save(
        "GIFS/epitope_mhc_rotating.gif",
        format="GIF",
        save_all=True,
        append_images=frames[1:],
        duration=200,  # ms per frame
        loop=0,
        disposal=2  # <-- This ensures the previous frame is cleared
    )
    print(f"✅ Saved clean GIF to {output_gif}")
else:
    print("❌ No PNGs found.")

In [ ]:
# print(pdb_str[:500])  # See if the atoms look correct

#@title Display 3D structure {run: "auto"}
animate = "none" #@param ["none", "movie", "interactive"]
color = "chain" #@param ["rainbow", "chain", "plddt"]
denoise = True
dpi = 100 #@param ["100", "200", "400"] {type:"raw"}
from colabdesign.shared.plot import pymol_color_list
from colabdesign.rf.utils import get_ca, get_Ls, make_animation
from string import ascii_uppercase,ascii_lowercase
alphabet_list = list(ascii_uppercase+ascii_lowercase)

def plot_pdb(num=0):
  if denoise:
    pdb_traj = f"/clusterfs/nilah/sergio/RFdifussion/outputs/{path}/{path}_{num}_pX0_traj.pdb"
  else:
    pdb_traj = f"/clusterfs/nilah/sergio/RFdifussion/outputs/{path}/{path}_{num}_Xt-1_traj.pdb"
  if animate in ["none","interactive"]:
    hbondCutoff = 4.0
    view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
    
    if animate == "interactive":
      pdb_str = open(pdb_traj,'r').read()
      view.addModelsAsFrames(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
    else:
      pdb = f"/clusterfs/nilah/sergio/RFdifussion/outputs/{path}/{path}_{num}.pdb"
      pdb_str = open(pdb,'r').read()
      view.addModel(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
    if color == "rainbow":
      view.setStyle({'cartoon': {'color':'spectrum'}})
    elif color == "chain":
      for n,chain,c in zip(range(len(contigs)),
                              alphabet_list,
                              pymol_color_list):
          view.setStyle({'chain':chain},{'cartoon': {'color':c}})
    else:
      view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':0.5,'max':0.9}}})
    view.zoomTo()
    if animate == "interactive":
      view.animate({'loop': 'backAndForth'})
    view.show()
  else:
    Ls = get_Ls(contigs)
    xyz, bfact = get_ca(pdb_traj, get_bfact=True)
    xyz = xyz.reshape((-1,sum(Ls),3))[::-1]
    bfact = bfact.reshape((-1,sum(Ls)))[::-1]
    if color == "chain":
      display(HTML(make_animation(xyz, Ls=Ls, dpi=dpi, ref=-1)))
    elif color == "rainbow":
      display(HTML(make_animation(xyz, dpi=dpi, ref=-1)))
    else:
      display(HTML(make_animation(xyz, plddt=bfact*100, dpi=dpi, ref=-1)))


if num_designs > 1:
  output = widgets.Output()
  def on_change(change):
    if change['name'] == 'value':
      with output:
        output.clear_output(wait=True)
        plot_pdb(change['new'])
  dropdown = widgets.Dropdown(
      options=[(f'{k}',k) for k in range(num_designs)],
      value=0, description='design:',
  )
  dropdown.observe(on_change)
  display(widgets.VBox([dropdown, output]))
  with output:
    plot_pdb(dropdown.value)
else:
  plot_pdb()

  

In [ ]:
from colabdesign.af.alphafold.model import data

def get_model_haiku_params(model_name, data_dir=".", fuse=True):
    path = os.path.join(data_dir, f"{model_name}.npz")
    print(f"[DEBUG] Attempting to load: {path}")
    if not os.path.isfile(path):
        print(f"[ERROR] File not found: {path}")
        return None


In [ ]:
contigs_str

In [ ]:
import os, time, subprocess

num_seqs = 2
initial_guess = False
num_recycles = 1
use_multimer = False
rm_aa = "C"
mpnn_sampling_temp = 0.1

# Wait for AlphaFold params
max_wait = 300
elapsed = 0

# Absolute path setup
output_dir = f'{OUTPUT_DIR}/'
contigs_str = ":".join(contigs)

opts = [
    f"--pdb={output_dir}{name}_0.pdb",
    f"--loc={output_dir}",
    f"--contig={contigs_str}",
    f"--copies={copies}",
    f"--num_seqs={num_seqs}",
    f"--num_recycles={num_recycles}",
    f"--rm_aa={rm_aa}",
    f"--mpnn_sampling_temp={mpnn_sampling_temp}",
    f"--num_designs={num_designs}",
]
if initial_guess: opts.append("--initial_guess")
if use_multimer: opts.append("--use_multimer")
# opts.append("--num_models=1")


# Set AlphaFold weights path for ColabDesign
os.environ["COLABDESIGN_AF_DIR"] = "/clusterfs/nilah/sergio/RFdifussion/params"

subprocess.run(["python", "/clusterfs/nilah/sergio/miniconda3/envs/SE3nv/lib/python3.9/site-packages/colabdesign/rf/designability_test.py"] + opts, check=True)


In [ ]:
#@title Display best result
import py3Dmol
def plot_pdb(num = "best"):
  if num == "best":
    with open(f"{OUTPUT_DIR}/best.pdb","r") as f:
      # REMARK 001 design {m} N {n} RMSD {rmsd}
      info = f.readline().strip('\n').split()
    num = info[3]
  hbondCutoff = 4.0
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
  pdb_str = open(f"{OUTPUT_DIR}/{name}/{iterations - 1}.pdb",'r').read()
  view.addModel(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
  pdb_str = open(f"{OUTPUT_DIR}/best_design0.pdb",'r').read()
  view.addModel(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})

  view.setStyle({"model":0},{'cartoon':{}}) #: {'colorscheme': {'prop':'b','gradient': 'roygb','min':0,'max':100}}})
  view.setStyle({"model":1},{'cartoon':{'colorscheme': {'prop':'b','gradient': 'roygb','min':0,'max':100}}})
  view.zoomTo()
  view.show()

if num_designs > 1:
  def on_change(change):
    if change['name'] == 'value':
      with output:
        output.clear_output(wait=True)
        plot_pdb(change['new'])
  dropdown = widgets.Dropdown(
    options=["best"] + [str(k) for k in range(num_designs)],
    value="best",
    description='design:',
  )
  dropdown.observe(on_change)
  output = widgets.Output()
  display(widgets.VBox([dropdown, output]))
  with output:
    plot_pdb(dropdown.value)
else:
  plot_pdb()

In [ ]:
import logomaker
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

data = pd.read_csv(f'{OUTPUT_DIR}/mpnn_results.csv')

In [ ]:
data

In [ ]:
data['epitope'] = data['seq'].apply(lambda x: x.split("/")[1])

In [ ]:
# Filter top sequences by structural confidence
# filtered = data[
#     (data['plddt'] > 0.8) &
#     (data['i_ptm'] > 0.8) &
#     (data['i_pae'] < 5.0) &
#     (data['rmsd'] < 2.0)
# ].sort_values(by='plddt', ascending=False)

# print(filtered[['seq', 'plddt', 'ptm', 'pae', 'rmsd', 'mpnn']].head(10))


In [ ]:
data

In [ ]:
sequences = data['epitope'].tolist()

# Create a list of amino acids (standard 20 amino acids)
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")

# Initialize a frequency matrix (9 positions x 20 amino acids)
num_positions = 9
freq_matrix = np.zeros((num_positions, len(amino_acids)))

# Populate the frequency matrix: count occurrences of each amino acid at each position
for seq in sequences:
    for i, aa in enumerate(seq):
        aa_index = amino_acids.index(aa)
        freq_matrix[i, aa_index] += 1

# Normalize the frequency matrix to get probabilities (frequencies sum to 1 at each position)
freq_matrix = freq_matrix / len(sequences)

# Create a DataFrame for logomaker
freq_df = pd.DataFrame(freq_matrix, columns=amino_acids)

# Generate the sequence logo
logo = logomaker.Logo(freq_df, color_scheme='chemistry') #, font_name='Arial Rounded MT Bold')


**Instructions**
---
---

Use `contigs` to define continious chains. Use a `:` to define multiple contigs and a `/` to define mutliple segments within a contig.
For example:

**unconditional**
- `contigs='100'` - diffuse **monomer** of length 100
- `contigs='50:100'` - diffuse **hetero-oligomer** of lengths 50 and 100
- `contigs='50'` `symmetry='cyclic'` `order=2` - make two copies of the defined contig(s) and add a symmetry constraint, for **homo-oligomeric** diffusion.

**binder design**
- `contigs='A:50'` `pdb='4N5T'` - diffuse a **binder** of length 50 to chain A of defined PDB.
- `contigs='E6-155:70-100'` `pdb='5KQV'` `hotspot='E64,E88,E96'` - diffuse a **binder** of length 70 to 100 (sampled randomly) to chain E and defined hotspot(s).

**motif scaffolding**
 - `contigs='40/A163-181/40'` `pdb='5TPN'`
 - `contigs='A3-30/36/A33-68'` `pdb='6MRR'` - diffuse a loop of length 36 between two segments of defined PDB ranges.

**partial diffusion**
- `contigs=''` `pdb='6MRR'` - noise all coordinates
- `contigs='A1-10'` `pdb='6MRR'` - keep first 10 positions fixed, noise the rest
- `contigs='A'` `pdb='1SSC'` - fix chain A, noise the rest

*hints and tips*
- `pdb=''` leave blank to get an upload prompt
- `contigs='50-100'` use dash to specify a range of lengths to sample from

In [ ]:
# #@title run **RFdiffusion** to generate a backbone
# name = "EPITOPE_1_16042025" #@param {type:"string"}
# contigs = "9" #@param {type:"string"}
# pdb = "7rtd.pdb" #@param {type:"string"}
# iterations = 200 #@param ["25", "50", "100", "150", "200"] {type:"raw"}
# hotspot = "A7,A9,A59,A63,A66,A70,A99,A159,A167" #@param {type:"string"}
# num_designs = 1 #@param ["1", "2", "4", "8", "16", "32"] {type:"raw"}
# visual = "image" #@param ["none", "image", "interactive"]

In [ ]:
import itertools
import os
from datetime import datetime


name = '_RUN_'
# Define paths
script_path = "/global/scratch/users/sergiomar10/ESMCBA/ESMCBA/ESMCBA/run_RFDiffusionMHC_epitope.py"
sh_file_dir = "/clusterfs/nilah/sergio/RFdifussion/slurm_jobs/EPITOPE_SERIES/"
log_dir = "/clusterfs/nilah/sergio/RFdifussion/logs/EPITOPE_SERIES/"
os.makedirs(sh_file_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)

# Hyperparameter ranges
# names = [f"EPITOPE_v{i}_20250416" for i in range(1)]  # 28 unique names
contigs = ["0/8", "0/9", "0/10", "0/11"]  # Peptide lengths
iterations = [50]  # RFdiffusion iterations
num_seqs = [32]   # AlphaFold sequences

# Generate all combinations
combinations = list(itertools.product(contigs, iterations, num_seqs))
print(f"Total jobs: {len(combinations)}")

for n in range(1000):
    for contig, iteration, num_seq in combinations:
        # Construct unique file name
        contig_name = contig[-1]
        file_name = f"RFD_{name}v{n}_C{contig_name}_I{iteration}_NS{num_seq}"
        sh_filename = f"{file_name}.sh"
        sh_filepath = os.path.join(sh_file_dir, sh_filename)
        
        # Construct the command
        cmd = (
            f"python {script_path} "
            f"--name {file_name} "
            f"--contigs {contig} "
            f"--iterations {iteration} "
            f"--num_seqs {num_seq}"
        )

        # Write SLURM script
        with open(sh_filepath, "w") as sh_file:
            sh_file.write("#!/bin/bash\n")
            sh_file.write("#SBATCH --account=co_nilah\n")
            sh_file.write("#SBATCH --partition=savio3_gpu\n")
            sh_file.write("#SBATCH --qos=savio_lowprio\n")
            sh_file.write("#SBATCH --cpus-per-task=4\n")
            sh_file.write("#SBATCH --mem=16G\n")
            sh_file.write("#SBATCH --time=00:50:00\n")
            sh_file.write(f"#SBATCH --job-name={file_name}\n")
            sh_file.write(f"#SBATCH --output={log_dir}/{file_name}_%j.out\n")
            sh_file.write(f"#SBATCH --error={log_dir}/{file_name}_%j.err\n")
            sh_file.write("source /clusterfs/nilah/sergio/miniconda3/etc/profile.d/conda.sh\n")
            sh_file.write("conda activate SE3nv\n")
            sh_file.write(cmd + "\n")

        # Make executable
        os.chmod(sh_filepath, 0o755)
        print(f"Created shell script: {sh_filename}")
        

In [ ]:
contig[-1]