<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/rf/examples/diffusion_foldcond.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**RFdiffusion** - conditional fold generation
---

**<font color="red">NOTE</font>** This notebook is in development, we are still working on adding all the options from the [manuscript](https://www.biorxiv.org/content/10.1101/2022.12.09.519842v2)

**instructions**:
 - define number of secondary structure `elements` (SSE)
 - diagonal defines the SSE 
  - `H:alpha_helix E:beta_sheet ?:undefined`
 - off-diagonal defines interactions
  - `0:no_contact 1:contact ?:undefined`
 - the number to right defines the minimum length of each SSE
 - `loop_length` defines the max loop length between SSEs
 - `mask_loops=True` treats loops as masks and allowing for possible extension of secondary structure.


In [None]:
#@title setup **RFdiffusion** {run: "auto"}
#@markdown Note, **RFdiffusion** takes ~1min to setup, next time you run this cell it will take seconds!
name = "test"
elements = 3 #@param ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20"] {type:"raw"}

import os, time
if not os.path.isdir("params"):
  os.system("apt-get install aria2")
  os.system("mkdir params")
  # send param download into background
  os.system("(\
  aria2c -q -x 16 http://files.ipd.uw.edu/pub/RFdiffusion/60f09a193fb5e5ccdc4980417708dbab/Complex_Fold_base_ckpt.pt; \
  touch params/done.txt \
  ) &")

if not os.path.isdir("RFdiffusion"):
  print("installing RFdiffusion...")
  os.system("git clone https://github.com/sokrypton/RFdiffusion.git")
  os.system("pip -q install jedi omegaconf hydra-core icecream py3Dmol")
  os.system("pip -q install dgl -f https://data.dgl.ai/wheels/cu117/repo.html")
  os.system("cd RFdiffusion/env/SE3Transformer; pip -q install --no-cache-dir -r requirements.txt; pip -q install .")

if not os.path.isdir("RFdiffusion/models"):
  print("downloading RFdiffusion params...")
  os.system("mkdir RFdiffusion/models")
  models = ["Complex_Fold_base_ckpt.pt"]
  for m in models:
    while os.path.isfile(f"{m}.aria2"):
      time.sleep(5)
  os.system(f"mv {' '.join(models)} RFdiffusion/models")
  print("----------------------------------")

from IPython.display import display
import ipywidgets as widgets
import numpy as np
import torch
import sys, os, random, string, re, time
import subprocess
import matplotlib.pyplot as plt
import py3Dmol
from google.colab import files

if 'RFdiffusion' not in sys.path:
  os.environ["DGLBACKEND"] = "pytorch"
  sys.path.append('RFdiffusion')


class RFdiff_gui():
  def __init__(self, elements, name="test"):
    
    self.elements = elements
    self.path = self.name = name    
    os.makedirs(f"outputs/{self.path}", exist_ok=True)

    self.input = widgets.Output()
    self.output = widgets.Output()
    self.adj = []
    self.len = []
    grid = []
    for i in range(self.elements):
      row = []
      for j in range(self.elements):
        button = widgets.Button(description='H' if i == j else '0', layout=widgets.Layout(width='35px', height='35px', border='2px solid black'))
        button.style.button_color = 'red' if i == j else 'white'
        button.row, button.col = i, j
        button.on_click(self._on_click)
        row.append(button)
        grid.append(button)
      self.adj.append(row)  
      button = widgets.BoundedIntText(value=19, min=0, max=100, layout=widgets.Layout(width='50px'))
      button.row, button.col = i, elements
      grid.append(button)
      self.len.append(button)

    grid = widgets.GridBox(grid, layout=widgets.Layout(grid_template_columns=f"repeat({self.elements + 1}, 37px)",
                                                       grid_template_rows=f"repeat({self.elements}, 37px)",
                                                       grid_gap="2px"))
    button_style = widgets.Layout(width='84px', height='35px', border='2px solid black')
    self.buttons = {
        "grid": grid,
        "loop_length": widgets.BoundedIntText(description='loop_length', value=5, min=0, max=20),
        "iterations":  widgets.Dropdown(description='iterations', options=[25, 50, 100, 200],value=50),
        "mask_loops":  widgets.Checkbox(description='mask_loops', value=1),
        "reset":       widgets.Button(description='reset',    layout=button_style),
        "diffuse":     widgets.Button(description='diffuse',  layout=button_style),
        "animate":     widgets.Button(description='animate',  layout=button_style),
        "freeze":      widgets.Button(description='freeze',   layout=button_style),
        "download":    widgets.Button(description='download', layout=button_style)
    }
    self.buttons["animate"].on_click(self._plot_pdb)
    self.buttons["freeze"].on_click(self._plot_pdb)
    self.buttons["download"].on_click(self._download)
    self.buttons["reset"].on_click(self._reset)
    with self.input:
      display(
          widgets.VBox([
          self.buttons["grid"],
          widgets.Label("Options"),
          self.buttons["iterations"],
          self.buttons["loop_length"],
          self.buttons["mask_loops"],
          self.buttons["reset"]
        ])
      )

  def _on_click(self, button):
    i, j = button.row, button.col  
    if i == j:
      if button.description == "H":
        button.description = 'E'
        button.style.button_color = 'yellow'
        self.len[i].value = 5
      elif button.description == "E":
        button.description = '?'
        button.style.button_color = 'lightgray'
        self.len[i].value = 5
        for k in range(elements):
          if i != k:
            a = self.adj[i][k]
            b = self.adj[k][i]
            for c in [a,b]:
              c.disabled = True
              c.style.button_color = 'lightgray'
              c.description = "0"
      else:
        button.description = 'H'
        button.style.button_color = 'red'
        self.len[i].value = 19
        for k in range(elements):
          if i != k and self.adj[k][k].description != "?":
            a = self.adj[i][k]
            b = self.adj[k][i]
            for c in [a,b]:
              c.disabled = False
              c.style.button_color = 'white'
    else:
      if button.description == "0":
        button.description = '1'
        button.style.button_color = 'lightblue'
      elif button.description == '1':
        button.description = '?'
        button.style.button_color = 'lightgray'
      else:
        button.description = '0'
        button.style.button_color = 'white'
      sym_button = self.adj[j][i]
      sym_button.style.button_color = button.style.button_color
      sym_button.description = button.description

  def _reset(self, button):
    for i in range(elements):
      for j in range(elements):
        if i == j:
          self.adj[i][j].style.button_color = 'red'
          self.adj[i][j].description = 'H'
        else:
          self.adj[i][j].style.button_color = 'white'
          self.adj[i][j].description = '0'
          self.adj[i][j].disabled = False
      self.len[i].value = 19
    self.buttons["loop_length"].value = 5
    self.buttons["iterations"].value = 50
    self.buttons["mask_loops"].value = 1

  def _download(self, button):
    os.system(f"zip -r {self.path}.result.zip outputs/{self.path}* outputs/traj/{self.path}*")
    files.download(f"{self.path}.result.zip")

  def _plot_pdb(self, button):
    with self.output:
      self.output.clear_output(wait=True)
      view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
      if button.description == "animate":
        pdb = f"outputs/traj/{self.path}_0_pX0_traj.pdb"
        pdb_str = open(pdb,'r').read()
        view.addModelsAsFrames(pdb_str,'pdb')
      else:
        pdb = f"outputs/{self.path}_0.pdb"
        pdb_str = open(pdb,'r').read()
        view.addModel(pdb_str,'pdb')
      view.setStyle({"ss":"h"},{'cartoon': {'color':'red'}})
      view.setStyle({"ss":"c"},{'cartoon': {'color':'lime'}})
      view.setStyle({"ss":"s"},{'cartoon': {'color':'yellow'}})
      view.zoomTo()
      if button.description == "animate":
        view.animate({'loop': 'backAndForth'})
      out = widgets.Output()
      with out: view.show()
      toggle = self.buttons["freeze"] if button.description == "animate" else self.buttons["animate"]
      display(widgets.VBox([out, widgets.HBox([toggle, self.buttons["download"]])]))

  def diffuse(self):    
    # get unique path
    while os.path.exists(f"outputs/{self.path}_0.pdb"):
      self.path = self.name + "_" + ''.join(random.choices(string.ascii_lowercase + string.digits, k=5))
      os.makedirs(f"outputs/{self.path}", exist_ok=True)

    # parse results from form
    loop = self.buttons["loop_length"].value
    sse_L = [x.value for x in self.len]
    L = (self.elements + 1) * loop + sum(sse_L)
    adj = np.zeros((L,L))
    sse = np.full((L,),2)
    n = loop
    for i in range(self.elements):
      ss = {"H":0,"E":1,"C":2,"?":2}[self.adj[i][i].description]
      sse[n:n+sse_L[i]] = ss
      m = loop
      for j in range(self.elements):
        if i != j:
          val = {"0":0,"1":1,"?":2}[self.adj[i][j].description]
          adj[n:n+sse_L[i],m:m+sse_L[j]] = val
        m += sse_L[j] + loop
      n += sse_L[i] + loop
    
    # save results
    torch.save(torch.from_numpy(sse).float(),f"outputs/{self.path}/tmp_ss.pt")
    torch.save(torch.from_numpy(adj).float(),f"outputs/{self.path}/tmp_adj.pt")

    # run 
    with self.output:
      self.output.clear_output()
      iterations = self.buttons["iterations"].value
      mask_loops = self.buttons["mask_loops"].value
      cmd = ["./RFdiffusion/run_inference.py",
            "inference.num_designs=1",
            f"inference.output_prefix=outputs/{self.path}",
            "scaffoldguided.scaffoldguided=True",
            "scaffoldguided.target_pdb=False",
            f"scaffoldguided.mask_loops={mask_loops}",
            f"scaffoldguided.scaffold_dir=outputs/{self.path}",
            f"diffuser.T={iterations}"]
      self.cmd_str = " ".join(cmd)
      steps = iterations - 1
      self._run(self.cmd_str, "Timestep", steps)
    self._plot_pdb(self.buttons["freeze"])

  def _run(self, command, trigger, total_timesteps):
    progress = widgets.FloatProgress(min=0, max=1, description='running', bar_style='info')
    display(progress)
    pattern = re.compile(f'.*{trigger}.*')
    progress_counter = 0
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, text=True)
    while True:
      line = process.stdout.readline()
      if not line: break    
      if pattern.match(line):
        progress_counter += 1
        progress.value = progress_counter / total_timesteps
    return_code = process.wait()
    progress.description = "done"

if "rfdiff" not in dir() or elements != rfdiff.elements:
  rfdiff = RFdiff_gui(elements, name=name)
display(rfdiff.input)

In [None]:
#@title run **RFdiffusion**
if "rfdiff" in dir():
  display(rfdiff.output)
  rfdiff.diffuse()
else:
  print("Error, looks like you didn't run the cell above")