<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/rf/examples/diffusion_foldcond.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title setup **RFdiffusion** (~1min)
%%time
import os, time
if not os.path.isdir("params"):
  os.system("apt-get install aria2")
  os.system("mkdir params")
  # send param download into background
  os.system("(\
  aria2c -q -x 16 http://files.ipd.uw.edu/pub/RFdiffusion/60f09a193fb5e5ccdc4980417708dbab/Complex_Fold_base_ckpt.pt; \
  touch params/done.txt \
  ) &")

if not os.path.isdir("RFdiffusion"):
  print("installing RFdiffusion...")
  os.system("git clone https://github.com/sokrypton/RFdiffusion.git")
  os.system("pip -q install jedi omegaconf hydra-core icecream py3Dmol")
  os.system("pip -q install dgl -f https://data.dgl.ai/wheels/cu117/repo.html")
  os.system("cd RFdiffusion/env/SE3Transformer; pip -q install --no-cache-dir -r requirements.txt; pip -q install .")

if not os.path.isdir("RFdiffusion/models"):
  print("downloading RFdiffusion params...")
  os.system("mkdir RFdiffusion/models")
  models = ["Complex_Fold_base_ckpt.pt"]
  for m in models:
    while os.path.isfile(f"{m}.aria2"):
      time.sleep(5)
  os.system(f"mv {' '.join(models)} RFdiffusion/models")

In [None]:
#@title create **RFdiffusion** GUI {run: "auto"}
import ipywidgets as widgets
from IPython.display import display
from ipywidgets import GridBox, Button, BoundedIntText, Output, Label, Layout, HBox, VBox
import numpy as np
import torch
import sys, os, random, string
import matplotlib.pyplot as plt

import py3Dmol
from google.colab import files

if 'RFdiffusion' not in sys.path:
  os.environ["DGLBACKEND"] = "pytorch"
  sys.path.append('RFdiffusion')

name = "test"
elements = 3 #@param ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20"] {type:"raw"}
#@markdown - `elements` number of secondary structure elements (SSE)
#@markdown - diagonal defines the SS (`H`=Alpha-Helix, `E`=Beta-Sheet, `?`=Undefined)
#@markdown - off-diagonal defines which SSE should be in contact
#@markdown - the number to right defines the minimum length of each SSE
#@markdown - `loop_length` defines the max loop length between SSEs

global path
path = name
os.makedirs(f"outputs/{path}", exist_ok=True)

import os
import re
import sys
import time
import subprocess
from IPython.display import display
from ipywidgets import FloatProgress

def run_command_and_monitor_progress(command, trigger, total_timesteps):
  # Create a progress bar
  progress = FloatProgress(min=0, max=1, description='running', bar_style='info')
  display(progress)

  # Define a regular expression pattern to match lines containing the trigger
  pattern = re.compile(f'.*{trigger}.*')

  # Initialize progress_counter and total_timesteps
  progress_counter = 0

  # Run the command and get its output
  process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, text=True)

  while True:
    # Read one line from the output
    line = process.stdout.readline()

    # If the line is empty, the process has terminated
    if not line: break
    
    # If the line contains the trigger, update the progress bar
    if pattern.match(line):
      progress_counter += 1
      progress.value = progress_counter / total_timesteps

  # Wait for the process to terminate and get its return code
  return_code = process.wait()

  # Check for errors
  if return_code != 0:
    print('Error:', process.stderr.read())

  progress.description = "done"

def make_buttons(elements, mask_loops=1):
  global buttons
  buttons = {"rows":[]}
  def on_click(button):
    i, j = button.row, button.col
    if i == j:
      length = buttons["rows"][i][-1]
      if button.description == "H":
        button.description = 'E'
        button.style.button_color = 'yellow'
        length.value = 5
      elif button.description == "E":
        button.description = '?'
        button.style.button_color = 'lightgray'
        length.value = 5
        for k in range(elements):
          if i != k:
            a = buttons["rows"][i][k]
            b = buttons["rows"][k][i]
            for c in [a,b]:
              c.disabled = True
              c.style.button_color = 'lightgray'
              c.description = "0"
      else:
        button.description = 'H'
        button.style.button_color = 'red'
        length.value = 19
        for k in range(elements):
          if i != k and buttons["rows"][k][k].description != "?":
            a = buttons["rows"][i][k]
            b = buttons["rows"][k][i]
            for c in [a,b]:
              c.disabled = False
              c.style.button_color = 'white'
    else:
      button.description = '0' if button.description == '1' else '1'
      button.style.button_color = 'lightblue' if button.style.button_color == 'white' else 'white'
      symmetric_button = buttons["rows"][j][i]
      symmetric_button.style.button_color = button.style.button_color
      symmetric_button.description = button.description

  for i in range(elements):
    row = []
    for j in range(elements):
      button = Button(description='H' if i == j else '0',
                      layout=widgets.Layout(width='35px', height='35px', border='2px solid black'))
      button.row, button.col = i, j
      if i == j:
        button.style.button_color = 'red'
      else:
        button.style.button_color = 'white'
      button.on_click(on_click)
      row.append(button)
    button = widgets.BoundedIntText(value=19, min=0, max=100,
                            layout=widgets.Layout(width='50px'))
    button.row, button.col = i, elements
    row.append(button)
    buttons["rows"].append(row)

  buttons["grid"] = GridBox([btn for row in buttons["rows"] for btn in row],
                             layout=widgets.Layout(grid_template_columns=f"repeat({elements+1}, 37px)",
                                                   grid_template_rows=f"repeat({elements}, 37px)",
                                                   grid_gap="2px"))

  buttons["loop_length"] = widgets.BoundedIntText(description='loop_length', value=5, min=0, max=20)
  buttons["iterations"] = widgets.Dropdown(description='iterations', options=[25, 50, 100, 200],value=50)

  buttons["mask_loops"] = widgets.Checkbox(description='mask_loops', value=mask_loops)

  buttons["iterations"] = widgets.Dropdown(description='iterations', options=[25, 50, 100, 200],value=50)

  def reset(button):
    for i in range(elements):
      for j in range(elements):
        b = buttons["rows"][i][j]
        if i == j:
          b.style.button_color = 'red'
          b.description = 'H'
        else:
          b.style.button_color = 'white'
          b.description = '0'
          b.disabled = False
      b = buttons["rows"][i][-1]
      b.value = 19
    buttons["loop_length"].value = 5
    buttons["iterations"].value = 50
    buttons["mask_loops"].value = 1

  buttons["reset"] = Button(description='reset', layout=widgets.Layout(width='84px', height='35px', border='2px solid black'))
  buttons["reset"].on_click(reset)

  buttons["diffuse"] = Button(description='diffuse', layout=widgets.Layout(width='84px', height='35px', border='2px solid black'))
  buttons["diffuse"].on_click(diffuse)
  return buttons

def download(button=None):
  os.system(f"zip -r {path}.result.zip outputs/{path}* outputs/traj/{path}*")
  files.download(f"{path}.result.zip")

def animate_pdb(button=None):
  with output:
    output.clear_output(wait=True)
    ############################
    view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
    pdb = f"/content/outputs/traj/{path}_0_pX0_traj.pdb"
    pdb_str = open(pdb,'r').read()
    view.addModelsAsFrames(pdb_str,'pdb')
    view.setStyle({"ss":"h"},{'cartoon': {'color':'red'}})
    view.setStyle({"ss":"c"},{'cartoon': {'color':'lime'}})
    view.setStyle({"ss":"s"},{'cartoon': {'color':'yellow'}})
    view.zoomTo()
    view.animate({'loop': 'backAndForth'})
    out = Output()
    with out: view.show()
    ############################
    plot_pdb_button = Button(description='freeze', layout=widgets.Layout(width='84px', height='35px', border='2px solid black'))
    plot_pdb_button.on_click(plot_pdb)
    download_button = Button(description='download', layout=widgets.Layout(width='84px', height='35px', border='2px solid black'))
    download_button.on_click(download)
    display(VBox([out, HBox([plot_pdb_button, download_button])]))

def plot_pdb(button=None):
  with output:
    output.clear_output(wait=True)
    ############################
    view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
    pdb = f"/content/outputs/{path}_0.pdb"
    pdb_str = open(pdb,'r').read()
    view.addModel(pdb_str,'pdb')
    view.setStyle({"ss":"h"},{'cartoon': {'color':'red'}})
    view.setStyle({"ss":"c"},{'cartoon': {'color':'lime'}})
    view.setStyle({"ss":"s"},{'cartoon': {'color':'yellow'}})
    view.zoomTo()
    out = Output()
    with out: view.show()
    ############################
    animate_pdb_button = Button(description='animate', layout=widgets.Layout(width='84px', height='35px', border='2px solid black'))
    animate_pdb_button.on_click(animate_pdb)
    download_button = Button(description='download', layout=widgets.Layout(width='84px', height='35px', border='2px solid black'))
    download_button.on_click(download)
    display(VBox([out, HBox([animate_pdb_button, download_button])]))

def diffuse(button):
  global path
  while os.path.exists(f"outputs/{path}_0.pdb"):
    path = name + "_" + ''.join(random.choices(string.ascii_lowercase + string.digits, k=5))
    os.makedirs(f"outputs/{path}", exist_ok=True)

  ########################
  loop = buttons["loop_length"].value
  sse_L = [row[-1].value for row in buttons["rows"]]
  L = (elements + 1) * loop + sum(sse_L)
  adj = np.zeros((L,L))
  sse = np.full((L,),2)
  n = loop
  for i in range(elements):
    ss = {"H":0,"E":1,"C":2,"?":2}[buttons["rows"][i][i].description]
    sse[n:n+sse_L[i]] = ss
    m = loop
    for j in range(elements):
      if i != j:
        val = int(buttons["rows"][i][j].description)
        adj[n:n+sse_L[i],m:m+sse_L[j]] = val
      m += sse_L[j] + loop
    n += sse_L[i] + loop
  ########################
  torch.save(torch.from_numpy(sse).float(),f"outputs/{path}/tmp_ss.pt")
  torch.save(torch.from_numpy(adj).float(),f"outputs/{path}/tmp_adj.pt")
  with output:
    output.clear_output()
    iterations = buttons["iterations"].value
    mask_loops = bool(buttons["iterations"].value)
    cmd = ["./RFdiffusion/run_inference.py",
           "inference.num_designs=1",
           f"inference.output_prefix=outputs/{path}",
           "scaffoldguided.scaffoldguided=True",
           "scaffoldguided.target_pdb=False",
           f"scaffoldguided.mask_loops={mask_loops}",
           f"scaffoldguided.scaffold_dir=outputs/{path}",
           f"diffuser.T={iterations}"]
    cmd_str = " ".join(cmd)
    steps = iterations - 1
    run_command_and_monitor_progress(cmd_str, "Timestep", steps)
  plot_pdb()

output = Output()
buttons = make_buttons(elements)
display(
    VBox([
        buttons["grid"],
        widgets.Label("Options"),
        buttons["iterations"],
        buttons["loop_length"],
        buttons["mask_loops"],
        HBox([buttons["diffuse"],buttons["reset"]])]))
display(output)