<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/rf/examples/diffusion_foldcond.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**RFdiffusion** - conditional fold generation
---

**<font color="red">NOTE</font>** This notebook is in development, we are still working on adding all the options from the [manuscript](https://www.biorxiv.org/content/10.1101/2022.12.09.519842v2)

**instructions**:
 - define number of secondary structure `elements` (SSE)
 - define the buffer length (`buff_length`) between SSEs
 - use diagonal to define the SSEs 
  - `H:helix E:sheet C:coil ?:undefined`
 - use off-diagonal to define interactions
  - `0:no_contact 1:contact ?:undefined`
 - use the textbox in the last column to define the length of each SSE


In [None]:
#@title setup **RFdiffusion** {run: "auto"}
#@markdown Note, **RFdiffusion** takes ~1min to setup, next time you run this cell it will take seconds!
name = "test"
elements = 7 #@param ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20"] {type:"raw"}

import os, time, sys
if not os.path.isdir("params"):
  os.system("apt-get install aria2")
  os.system("mkdir params")
  # send param download into background
  os.system("(\
  aria2c -q -x 16 http://files.ipd.uw.edu/pub/RFdiffusion/60f09a193fb5e5ccdc4980417708dbab/Complex_Fold_base_ckpt.pt; \
  touch params/done.txt \
  ) &")

if not os.path.isdir("RFdiffusion"):
  print("installing RFdiffusion...")
  os.system("git clone https://github.com/sokrypton/RFdiffusion.git")
  os.system("pip -q install jedi omegaconf hydra-core icecream py3Dmol")
  os.system("pip -q install dgl -f https://data.dgl.ai/wheels/cu117/repo.html")
  os.system("cd RFdiffusion/env/SE3Transformer; pip -q install --no-cache-dir -r requirements.txt; pip -q install .")

if not os.path.isdir("RFdiffusion/models"):
  print("downloading RFdiffusion params...")
  os.system("mkdir RFdiffusion/models")
  models = ["Complex_Fold_base_ckpt.pt"]
  for m in models:
    while os.path.isfile(f"{m}.aria2"):
      time.sleep(5)
  os.system(f"mv {' '.join(models)} RFdiffusion/models")
  print("----------------------------------")

if 'RFdiffusion' not in sys.path:
  os.environ["DGLBACKEND"] = "pytorch"
  sys.path.append('RFdiffusion')

from IPython.display import display
import ipywidgets as widgets
import numpy as np
import torch
import sys, os, random, string, re, time
import subprocess
import matplotlib.pyplot as plt
import py3Dmol
from google.colab import files, output

class RFdiff_js:
  def reset_callback(self):
    self.adj = [["H" if row == col else "0" for col in range(self.grid_size)] for row in range(self.grid_size)]
    self.txt = [19 for _ in range(self.grid_size)]

  def grid_callback(self, row, col, new_value):
    if row == col:
      self.txt[row] = {"H": 19, "E": 5, "C": 3, "?": 0}[new_value]
      self.adj[row][col] = new_value
      for i in range(self.grid_size):
        if i != row:
          if new_value == "?":
            self.adj[row][i] = "?"
            self.adj[i][row] = "?"
          elif self.adj[i][i] != "?":
            if new_value in ["C","H"]:
              self.adj[row][i] = '0'
              self.adj[i][row] = '0'
    else:
      self.adj[row][col] = new_value
      self.adj[col][row] = new_value

  def text_callback(self, row, new_value):
    self.txt[row] = int(new_value)

  def create_html_code(self):
    def style(row, col):
      state = self.adj[row][col]
      if row == col:
        color = {"H":"red","E":"yellow","C":"lime","?":"lightgray"}[state]
        disabled = ""
      else:
        color = {"0":"white","1":"lightblue","?":"lightgray"}[state]
        if self.adj[row][row] in ["?","C"] or self.adj[col][col] in ["?","C"]:
          disabled = "disabled"
        else:
          disabled = ""
      return {"color":color,
              "text":state,
              "id":f"button-{row}-{col}",
              "disabled":disabled,
              "opacity":1 if disabled == "" else 0.2}
    html_grid = ""
    for row in range(self.grid_size):
      for col in range(self.grid_size):
        button = style(row,col)
        html_grid += f"""
        <button id="{button['id']}", style="opacity:{button["opacity"]};width:30px;height:30px;background-color:{button['color']};border: 2px solid #000;color:#000;padding:0;font-weight:bold;" onclick="buttonClick('{button['id']}', {row}, {col})" {button["disabled"]}>{button['text']}</button>
        """
      text_value = self.txt[row]
      html_grid += f"""
      <input id="text-{row}" type="text" value="{text_value}" style="width:50px;height:24px; background-color:#ffffff; text-align:center; border:2px solid lightgray;" onchange="textFieldChanged({row}, this)">
      """
    
    self.html_code = f"""
    <div style="display: grid; grid-template-columns: repeat({self.grid_size + 1}, 30px); grid-gap: 2px;">{html_grid}</div>
    <button id="reset_button" style="width:62px;height:30px;background-color:#ffffff;border: 2px solid #000;color:#000;padding:0;font-weight:bold;margin-top: 2px;" onclick="reset()">reset</button>
    <script>
    function buttonClick(button_id, row, col) {{
        var button = document.getElementById(button_id);
        if (row === col) {{
            var state_mapping = {{
                "H": {{ "text": "E", "color": "yellow",    "length": "5" }},
                "E": {{ "text": "C", "color": "lime",      "length": "3" }},
                "C": {{ "text": "?", "color": "lightgray", "length": "0" }},
                "?": {{ "text": "H", "color": "red",       "length": "19"}},
            }};
            var current_state = button.textContent;
            update = state_mapping[current_state]

            button.textContent = update.text;
            button.style.backgroundColor = update.color;
            google.colab.kernel.invokeFunction("grid_callback", [row, col, update.text], {{}});

            // Update the corresponding text field value
            var textField = document.getElementById("text-" + row);
            textField.value = update.length;

            // Enable/disable off-diagonal buttons
            for (var i = 0; i < {self.grid_size}; i++) {{
                if (i !== row) {{
                    var row_button = document.getElementById("button".concat("-", row, "-", i));
                    var col_button = document.getElementById("button".concat("-", i, "-", row));
                    var diag_button = document.getElementById("button".concat("-", i, "-", i));

                    if (button.textContent === "C" || button.textContent === "?") {{
                        row_button.disabled = col_button.disabled = true;
                        row_button.style.opacity = col_button.style.opacity = 0.2;

                        if (button.textContent === "?") {{
                            row_button.style.backgroundColor = col_button.style.backgroundColor = 'lightgray';
                            row_button.textContent = col_button.textContent = '?';
                        }} else if (button.textContent === "C" && diag_button.textContent !== "?") {{
                            row_button.style.backgroundColor = col_button.style.backgroundColor = 'white';
                            row_button.textContent = col_button.textContent = '0';
                        }}
                    }} else if (button.textContent === "H") {{
                        if (diag_button.textContent == "C"){{
                            row_button.style.backgroundColor = col_button.style.backgroundColor = 'white';
                            row_button.textContent = col_button.textContent = '0';
                        }} else if (diag_button.textContent !== "?"){{
                            row_button.style.backgroundColor = col_button.style.backgroundColor = 'white';
                            row_button.textContent = col_button.textContent = '0';
                            row_button.disabled = col_button.disabled = false;
                            row_button.style.opacity = col_button.style.opacity = 1;
                        }}
                    }}
                }}
            }}

        }} else {{
            var off_diag_state_mapping = {{
                "0": {{ "text": "1", "color": "lightblue" }},
                "1": {{ "text": "?", "color": "lightgray" }},
                "?": {{ "text": "0", "color": "white" }},
            }};
            var current_state = button.textContent;
            update = off_diag_state_mapping[current_state]
            var sym_button = document.getElementById("button".concat("-", col, "-", row));
            button.textContent = sym_button.textContent = update.text;
            button.style.backgroundColor = sym_button.style.backgroundColor = update.color;
            google.colab.kernel.invokeFunction("grid_callback", [row, col, update.text], {{}});            
        }}
    }}
    function textFieldChanged(row, textField) {{
        var newValue = textField.value;
        google.colab.kernel.invokeFunction("text_callback", [row, newValue], {{}});
    }}
    function reset() {{
        for (var row = 0; row < {self.grid_size}; row++) {{
            for (var col = 0; col < {self.grid_size}; col++) {{
                var button = document.getElementById("button".concat("-", row, "-", col));
                if (row === col) {{
                    button.textContent = "H";
                    button.style.backgroundColor = "red";
                }} else {{
                    button.textContent = "0";
                    button.style.backgroundColor = "white";
                }}
                button.disabled = false;
                button.style.opacity = 1;
            }}
            var textField = document.getElementById("text-" + row);
            textField.value = "19";
        }}
        google.colab.kernel.invokeFunction('reset_callback', [], {{}});
    }}
    </script>
    """
class RFdiff_gui(RFdiff_js):
  def __init__(self, elements, name="test"):

    self.elements = self.grid_size = elements
    self.path = self.name = name    
    os.makedirs(f"outputs/{self.path}", exist_ok=True)
    self.input = widgets.Output()
    self.output = widgets.Output()

    self.reset_callback()

    output.register_callback("reset_callback", self.reset_callback)
    output.register_callback("grid_callback", self.grid_callback)
    output.register_callback("text_callback", self.text_callback)

    button_style = widgets.Layout(width='84px', height='35px', border='2px solid black')
    self.buttons = {
        "buff_length": widgets.BoundedIntText(description='buff_length', value=5, min=0, max=20),
        "iterations":  widgets.Dropdown(description='iterations', options=[25, 50, 100, 200],value=50),
        "reset":       widgets.Button(description='reset',    layout=button_style),
        "diffuse":     widgets.Button(description='diffuse',  layout=button_style),
        "animate":     widgets.Button(description='animate',  layout=button_style),
        "freeze":      widgets.Button(description='freeze',   layout=button_style),
        "download":    widgets.Button(description='download', layout=button_style),
    }
    self.buttons["animate"].on_click(self._plot_pdb)
    self.buttons["freeze"].on_click(self._plot_pdb)
    self.buttons["download"].on_click(self._download)
  
  def display(self):
    self.create_html_code()
    with self.input:
      self.input.clear_output()
      display(
          widgets.VBox([
          widgets.HTML(self.html_code),
          widgets.Label("Options"),
          self.buttons["iterations"],
          self.buttons["buff_length"],
        ])
      )
    display(self.input)
  def _download(self, button):
    os.system(f"zip -r {self.path}.result.zip outputs/{self.path}* outputs/traj/{self.path}*")
    files.download(f"{self.path}.result.zip")

  def _plot_pdb(self, button):
    with self.output:
      self.output.clear_output(wait=True)
      view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
      if button.description == "animate":
        pdb = f"outputs/traj/{self.path}_0_pX0_traj.pdb"
        pdb_str = open(pdb,'r').read()
        view.addModelsAsFrames(pdb_str,'pdb')
      else:
        pdb = f"outputs/{self.path}_0.pdb"
        pdb_str = open(pdb,'r').read()
        view.addModel(pdb_str,'pdb')
      view.setStyle({"ss":"h"},{'cartoon': {'color':'red'}})
      view.setStyle({"ss":"c"},{'cartoon': {'color':'lime'}})
      view.setStyle({"ss":"s"},{'cartoon': {'color':'yellow'}})
      view.zoomTo()
      if button.description == "animate":
        view.animate({'loop': 'backAndForth'})
      out = widgets.Output()
      with out: view.show()
      toggle = self.buttons["freeze"] if button.description == "animate" else self.buttons["animate"]
      display(widgets.VBox([out, widgets.HBox([toggle, self.buttons["download"]])]))
  
  def _get_adj_ss(self, button):
    # get unique path
    while os.path.exists(f"outputs/{self.path}_0.pdb"):
      self.path = self.name + "_" + ''.join(random.choices(string.ascii_lowercase + string.digits, k=5))
      os.makedirs(f"outputs/{self.path}", exist_ok=True)

    # parse results from form
    buff = self.buttons["buff_length"].value
    sse_L = self.txt
    L = (self.elements + 1) * buff + sum(sse_L)

    adj = np.full((L,L),2)
    sse = np.full((L,),3)
    
    n = buff
    for i in range(self.elements):
      ss = {"H":0, "E":1, "C":2, "?":3}[self.adj[i][i]]
      sse[n:n+sse_L[i]] = ss
      m = buff
      for j in range(self.elements):
        k = str(self.adj[i][j])
        if i == j:
          val = {"H":0,"E":0,"C":0,"?":2}[k]
        else:
          val = {"0":0,"1":1,"?":2}[k]
        adj[n:n+sse_L[i],m:m+sse_L[j]] = val
        m += sse_L[j] + buff
      n += sse_L[i] + buff
    
    self._sse = sse
    self._adj = adj

    # save results
    loc = [f"outputs/{self.path}/tmp_ss.pt",
           f"outputs/{self.path}/tmp_adj.pt"]
    torch.save(torch.from_numpy(sse).float(),loc[0])
    torch.save(torch.from_numpy(adj).float(),loc[1])

  def diffuse(self):
    self._get_adj_ss(None)
    # run 
    with self.output:
      self.output.clear_output()
      iterations = self.buttons["iterations"].value
      cmd = ["./RFdiffusion/run_inference.py",
            "inference.num_designs=1",
            f"inference.output_prefix=outputs/{self.path}",
            "scaffoldguided.scaffoldguided=True",
            "scaffoldguided.target_pdb=False",
            f"scaffoldguided.scaffold_dir=outputs/{self.path}",
            f"diffuser.T={iterations}"]
      self.cmd_str = " ".join(cmd)
      steps = iterations - 1
      self._run(self.cmd_str, "Timestep", steps)
    self._plot_pdb(self.buttons["freeze"])

  def _run(self, command, trigger, total_timesteps):
    progress = widgets.FloatProgress(min=0, max=1, description='running', bar_style='info')
    display(progress)
    pattern = re.compile(f'.*{trigger}.*')
    progress_counter = 0
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, text=True)
    while True:
      line = process.stdout.readline()
      if not line: break    
      if pattern.match(line):
        progress_counter += 1
        progress.value = progress_counter / total_timesteps
    return_code = process.wait()
    progress.description = "done"

if "rfdiff" not in dir() or rfdiff.elements != elements:
  rfdiff = RFdiff_gui(elements, name=name)
rfdiff.display()

In [None]:
#@title run **RFdiffusion**
if "rfdiff" in dir():
  display(rfdiff.output)
  rfdiff.diffuse()
else:
  print("Error, looks like you didn't run the cell above")