# DiffSBDD: Structure-based Drug Design with Equivariant Diffusion Models

[**[Paper]**](https://arxiv.org/abs/2210.13695)
[**[Code]**](https://github.com/arneschneuing/DiffSBDD)

Make sure to select `Runtime` -> `Change runtime type` -> `GPU` before you run the script.

<img src="https://raw.githubusercontent.com/arneschneuing/DiffSBDD/main/img/overview.png" height=250>

## Choose target PDB

In [None]:
from google.colab import files
from google.colab import output
output.enable_custom_widget_manager()
import os.path
from pathlib import Path
import urllib
import os

input_dir = Path("/content/input_pdbs/")
output_dir = Path("/content/output_sdfs/")
input_dir.mkdir(exist_ok=True)
output_dir.mkdir(exist_ok=True)

target = "example (3rfm)" #@param ["example (3rfm)", "upload structure"]

if target == "example (3rfm)":
  pdbfile = Path(input_dir, '3rfm.pdb')
  urllib.request.urlretrieve('http://files.rcsb.org/download/3rfm.pdb', pdbfile)

elif target == "upload structure":
  uploaded = files.upload()
  fn = list(uploaded.keys())[0]
  pdbfile = Path(input_dir, fn)
  Path(fn).rename(pdbfile)

## Define binding pocket

You can choose between two options to define the binding pocket:
1. **list of residues:** provide a list where each residue is specified as `<chain_id>:<res_id>`, e.g, `A:1 A:2 A:3 A:4 A:5 A:6 A:7`
2. **reference ligand:** if the uploaded PDB structure contains a reference ligand in the target pocket, you can specify its location as `<chain_id>:<res_id>` and the pocket will be extracted automatically

In [None]:
#@title { run: "auto" }
import ipywidgets as widgets

pocket_definition = "reference ligand" #@param ["list of residues", "reference ligand"]

if pocket_definition == "list of residues":
  print('pocket_residues:')
  w = widgets.Text(value='A:9 A:59 A:60 A:62 A:63 A:64 A:66 A:67 A:80 A:81 A:84 A:85 A:88 A:167 A:168 A:169 A:170 A:172 A:174 A:177 A:181 A:246 A:249 A:250 A:252 A:253 A:256 A:265 A:267 A:270 A:271 A:273 A:274 A:275 A:277 A:278')
  pocket_flag = "--resi_list"
elif pocket_definition == "reference ligand":
  print('reference_ligand:')
  w = widgets.Text(value='A:330')
  pocket_flag = "--ref_ligand"

display(w)

reference_ligand:


Text(value='A:330')

## Settings

Notes: 
- `timesteps < 1000` is an experimental feature
- `resamplings` and `jump_length` only pertain to the inpainting model

In [None]:
#@markdown ## Sampling
n_samples = 10 #@param {type:"slider", min:1, max:100, step:1}
ligand_nodes = 20 #@param {type:"integer"}

model = "conditional_full_atom" #@param ["conditional_full_atom", "inpaint_ca"]
checkpoint = Path('DiffSBDD', 'checkpoints', 'full_atom.ckpt') if model == "conditional_full_atom" else Path('DiffSBDD', 'checkpoints', 'ca_inpaint.ckpt')

timesteps = 100 #@param {type:"slider", min:1, max:1000, step:1}

#@markdown  ## Inpainting parameters
resamplings = 1 #@param {type:"integer"}
jump_length = 1 #@param {type:"integer"}

#@markdown  ## Post-processing
keep_all_fragments = False #@param {type:"boolean"}
keep_all_fragments = "--all_frags" if keep_all_fragments else ""
sanitize = True #@param {type:"boolean"}
sanitize = "--sanitize" if sanitize else ""
relax = True #@param {type:"boolean"}
relax = "--relax" if relax else ""

In [None]:
#@title Install dependencies (this will take about 10 minutes)
%%capture
%%bash

set -e # Exit immediately if a command exits with a non-zero status.

if [ ! -f READY ]; then

  # Unset PYTHONPATH
  env PYTHONPATH=

  # Install Miniconda
  wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh >/dev/null
  bash Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local >/dev/null

  # Install dependencies
  git clone https://github.com/arneschneuing/DiffSBDD.git

  conda install pytorch=1.12.1 cudatoolkit=10.2 -c pytorch >/dev/null
  conda install -c conda-forge pytorch-lightning=1.7.4 >/dev/null
  conda install -c conda-forge wandb=0.13.1 >/dev/null
  conda install -c conda-forge rdkit=2022.03.2 >/dev/null
  conda install -c conda-forge biopython=1.79 >/dev/null
  conda install -c conda-forge imageio=2.21.2 >/dev/null
  conda install -c anaconda scipy=1.7.3 >/dev/null
  conda install -c pyg pytorch-scatter=2.0.9 >/dev/null
  conda install -c conda-forge openbabel=3.1.1 >/dev/null
  conda install -c conda-forge networkx=2.8.6 >/dev/null

  # enforce boost-cpp version
  conda install -y -c conda-forge boost-cpp=1.74.0=h359cf19_5 >/dev/null

  pip install py3Dmol==1.8.1 >/dev/null

  touch READY
fi

In [None]:
#@title Run sampling (this will take a few minutes; runtime depends on the input parameters `n_samples`, `timesteps` etc.)

pocket = w.value
!python -W ignore DiffSBDD/generate_ligands.py $checkpoint --pdbfile $pdbfile --outdir $output_dir $pocket_flag $pocket --n_samples $n_samples --num_nodes_lig $ligand_nodes --resamplings $resamplings --jump_length $jump_length $keep_all_fragments $sanitize $relax >/dev/null

In [None]:
#@title Show generated molecules

import sys
sys.path.append("/usr/local/lib/python3.9/site-packages")
import py3Dmol

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
view.addModel(open(pdbfile, 'r').read(), 'pdb')
view.setStyle({'model': -1}, {'cartoon': {'color': 'lime'}})
# view.addSurface(py3Dmol.VDW, {'opacity': 0.4, 'color': 'lime'})
view.addModelsAsFrames(open(Path(output_dir, f"{pdbfile.stem}_mol.sdf"), 'r').read())
view.setStyle({'model': -1}, {'stick': {}})
view.zoomTo({'model': -1})
view.zoom(0.5)
if target == "example (3rfm)":
  view.rotate(90, 'y')
view.animate({'loop': "forward", 'interval': 1000})
view.show()

In [None]:
#@title Download .sdf file
files.download(Path(output_dir, f"{pdbfile.stem}_mol.sdf"))