diff --git a/.gitignore b/.gitignore
index 6a3e68da..f7e9c036 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,2 @@
-**/.DS_Store
\ No newline at end of file
+**/.DS_Store
+*.pyc
diff --git a/af/README.md b/af/README.md
index 0d95345d..da2eaf5e 100644
--- a/af/README.md
+++ b/af/README.md
@@ -1,4 +1,4 @@
-# AfDesign (v1.0.5)
+# AfDesign (v1.0.6)
### Google Colab
@@ -16,6 +16,7 @@ Minor changes changes include renaming intra_pae/inter_con to pae/con and inter_
- **11July2022** - v1.0.3 - Improved homo-oligomeric support. RMSD and dgram losses have been refactored to automatically save aligned coordinates. Multimeric coordinates now saved with chain identifiers.
- **23July2022** - v1.0.4 - Adding support for openfold weights. To enable set `mk_afdesign_model(..., use_openfold=True)`.
- **31July2022** - v1.0.5 - Refactoring to add support for swapping batch features without recompile. Allowing for implementation of [AF2Rank](https://github.com/sokrypton/ColabDesign/blob/main/af/examples/AF2Rank.ipynb)!
+- **19Aug2022** - v1.0.6 - Adding support for alphafold-multimer. To enable set `mk_afdesign_model(..., use_multimer=True)`. For multimer mode, multiple recycles maybe needed!
### setup
```bash
@@ -23,7 +24,7 @@ pip install git+https://github.com/sokrypton/ColabDesign.git
# download alphafold weights
mkdir params
-curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
+curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params
# download openfold weights (optional)
for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1
@@ -97,14 +98,15 @@ model.opt["weights"]["pae"] = 0.0
#### How do I control number of recycles used during design?
```python
model = mk_afdesign_model(num_recycles=1, recycle_mode="average")
-# if recycle_mode in ["average","last","sample"] the number of recycles can change during optimization
+# if recycle_mode in ["average",last","sample","first"] the number of recycles can change during optimization
model.set_opt(num_recycles=1)
```
- `num_recycles` - number of recycles to use during design (for denovo proteins we find 0 is often enough)
- `recycle_mode` - optimizing across all recycles can be tricky, we experiment with a couple of ways:
- - *last* - use loss from last recycle. (Not recommended, unless you increase number optimization)
- - *sample* - Same as *last* but each iteration a different number of recycles are used. (Previous default).
- - *average* - compute loss at each recycle and average gradients. (Default; Recommended).
+ - *last* - use loss from last recycle. (Default)
+ - *average* - compute loss at each recycle and average gradients. (Previous default from v.1.0.5)
+ - *sample* - Same as *last* but each iteration a different number of recycles are used.
+ - *first* - use loss from first recycle.
- *add_prev* - average the outputs (dgram, plddt, pae) across all recycles before computing loss.
- *backprop* - use loss from last recycle, but backprop through all recycles.
@@ -122,17 +124,6 @@ model.set_opt(num_models=1)
#### Can I use OpenFold model params for design instead of AlphaFold?
```python
model = mk_afdesign_model(use_openfold=True, use_alphafold=False)
-# OR
-model.set_opt(use_openfold=True, use_alphafold=False)
-```
-#### How is contact defined? How do I change it?
-By default, 2 [con]tacts per positions are optimized to be within cβ-cβ < 14.0Å and sequence seperation ≥ 9. This can be changed with:
-```python
-model.set_opt(con=dict(cutoff=8, seqsep=5, num=1))
-```
-For interface:
-```python
-model.set_opt(i_con=dict(...))
```
#### For binder hallucination, can I specify the site I want to bind?
```python
@@ -142,12 +133,6 @@ model.prep_inputs(..., hotspot="1-10,15,3")
```python
model.prep_inputs(..., chain="A,B")
```
-#### Can I design homo-oligomers?
-```python
-model.prep_inputs(..., copies=2)
-# specify interface specific contact and/or pae loss
-model.set_weights(i_con=1, i_pae=0)
-```
#### For fixed backbone design, how do I force the sequence to be the same for homo-dimer optimization?
```python
model.prep_inputs(pdb_filename="6Q40.pdb", chain="A,B", copies=2, homooligomer=True)
@@ -168,14 +153,13 @@ model.restart(seed=0)
- `design_hard()` - optimize *one_hot(logits)* inputs (discrete)
- For complex topologies, we find directly optimizing one_hot encoded sequence `design_hard()` to be very challenging.
-To get around this problem, we propose optimizing in 2 or 3 stages.
- - `design_2stage()` - *soft* → *hard*
+To get around this problem, we propose optimizing in 3 stages.
- `design_3stage()` - *logits* → *soft* → *hard*
+
#### What are all the different losses being optimized?
- general losses
- *pae* - minimizes the predicted alignment error
- *plddt* - maximizes the predicted LDDT
- - *msa_ent* - minimize entropy for MSA design (see example at the end of notebook)
- *pae* and *plddt* values are between 0 and 1 (where lower is better for both)
- fixbb specific losses
@@ -184,18 +168,26 @@ To get around this problem, we propose optimizing in 2 or 3 stages.
- we find *dgram_cce* loss to be more stable for design (compared to *fape*)
- hallucination specific losses
- - *con* - maximize number of contacts. (We find just minimizing *plddt* results in single long helix,
-and maximizing *pae* results in a two helix bundle. To encourage compact structures we add a `con` term)
+ - *con* - maximize `1` contacts per position. `model.set_opt("con",num=1)`
- binder specific losses
- - *i_pae* - minimize PAE interface of the proteins
- - *pae* - minimize PAE within binder
- - *i_con* - maximize number of contacts at the interface of the proteins
- - *con* - maximize number of contacts within binder
+ - *pae* - minimize PAE at interface and within binder
+ - *con* - - maximize `2` contacts per binder position, within binder. `model.set_opt("con",num=2)`
+ - *i_con* - maximize `1` contacts per binder position `model.set_opt("i_con",num=1)`
- partial hallucination specific losses
- *sc_fape* - sidechain-specific fape
+#### How is contact defined? How do I change it?
+By default, 2 [con]tacts per positions are optimized to be within cβ-cβ < 14.0Å and sequence seperation ≥ 9. This can be changed with:
+```python
+model.set_opt(con=dict(cutoff=8, seqsep=5, num=1))
+```
+For interface:
+```python
+model.set_opt(i_con=dict(...))
+```
+
# Advanced FAQ
#### loss during Gradient descent is too jumpy, can I do some kind of greedy search towards the end?
Gradient descent updates multiple positions each iteration, which can be a little too aggressive during hard (discrete) mode.
diff --git a/af/design.ipynb b/af/design.ipynb
index 90e3a32b..61e380c1 100644
--- a/af/design.ipynb
+++ b/af/design.ipynb
@@ -16,7 +16,7 @@
"id": "OA2k3sAYuiXe"
},
"source": [
- "#AfDesign (v1.0.5)\n",
+ "#AfDesign (v1.0.6)\n",
"Backprop through AlphaFold for protein design.\n",
"\n",
"**WARNING**\n",
@@ -42,7 +42,7 @@
" ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\n",
" # download params\n",
" mkdir params\n",
- " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n",
+ " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n",
" for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1\n",
" do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done\n",
"fi"
@@ -246,7 +246,9 @@
},
"source": [
"# binder hallucination\n",
- "For a given protein target and protein binder length, generate/hallucinate a protein binder sequence AlphaFold thinks will bind to the target structure. To do this, we minimize PAE and maximize number of contacts at the interface and within the binder, and we maximize pLDDT of the binder."
+ "For a given protein target and protein binder length, generate/hallucinate a protein binder sequence AlphaFold thinks will bind to the target structure.\n",
+ "To do this, we minimize PAE and maximize number of contacts at the interface and within the binder, and we maximize pLDDT of the binder.\n",
+ "By default, AlphaFold-ptm with residue index offset hack is used. To enable AlphaFold-multimer set: mk_afdesign_model(use_multimer=True).\n"
]
},
{
@@ -275,12 +277,6 @@
"outputs": [],
"source": [
"af_model.restart()\n",
- "\n",
- "# settings we find work best for helical peptide binder hallucination\n",
- "af_model.set_weights(plddt=0.1, pae=0.1, i_pae=1.0, con=0.1, i_con=0.5)\n",
- "af_model.set_opt(con=dict(binary=True, cutoff=21.6875, num=af_model._binder_len, seqsep=0))\n",
- "af_model.set_opt(i_con=dict(binary=True, cutoff=21.6875, num=af_model._binder_len))\n",
- "\n",
"af_model.design_3stage(100,100,10)"
]
},
@@ -373,8 +369,7 @@
"af_model.prep_inputs(pdb_filename=get_pdb(\"6MRR\"),\n",
" chain=\"A\",\n",
" pos=\"3-30,33-68\", # define positions to contrain\n",
- " length=100, # total length if different from input pdb\n",
- " fix_seq=False) # set True to constrain sequence in the specified positions\n",
+ " length=100) # total length if different from input pdb\n",
"\n",
"af_model.rewire(loops=[36]) # set loop length between segments "
],
@@ -452,4 +447,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
-}
\ No newline at end of file
+}
diff --git a/af/examples/2stage_binder_hallucination.ipynb b/af/examples/2stage_binder_hallucination.ipynb
deleted file mode 100644
index 0202d96d..00000000
--- a/af/examples/2stage_binder_hallucination.ipynb
+++ /dev/null
@@ -1,281 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "view-in-github",
- "colab_type": "text"
- },
- "source": [
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "OA2k3sAYuiXe"
- },
- "source": [
- "# AfDesign - two-stage binder hallucination\n",
- "For a given protein target and protein binder length, generate/hallucinate a protein binder sequence AlphaFold thinks will bind to the target structure. To do this, we minimize PAE and maximize number of contacts at the interface and within the binder, and we maximize pLDDT of the binder.\n",
- "\n",
- "**WARNING**\n",
- "1. This notebook is in active development and was designed for demonstration purposes only.\n",
- "2. Using AfDesign as the only \"loss\" function for design might be a bad idea, you may find adversarial sequences (aka. sequences that trick AlphaFold)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "-AXy0s_4cKaK"
- },
- "outputs": [],
- "source": [
- "#@title install\n",
- "%%bash\n",
- "if [ ! -d params ]; then\n",
- " pip -q install git+https://github.com/sokrypton/ColabDesign.git\n",
- " mkdir params\n",
- " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n",
- " for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1\n",
- " do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done\n",
- "fi"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "Vt7G_nbNeSQ3"
- },
- "outputs": [],
- "source": [
- "#@title #import libraries\n",
- "import warnings\n",
- "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
- "\n",
- "import os\n",
- "from colabdesign import mk_afdesign_model, clear_mem\n",
- "from IPython.display import HTML\n",
- "from google.colab import files\n",
- "import numpy as np\n",
- "\n",
- "#########################\n",
- "def get_pdb(pdb_code=\"\"):\n",
- " if pdb_code is None or pdb_code == \"\":\n",
- " upload_dict = files.upload()\n",
- " pdb_string = upload_dict[list(upload_dict.keys())[0]]\n",
- " with open(\"tmp.pdb\",\"wb\") as out: out.write(pdb_string)\n",
- " return \"tmp.pdb\"\n",
- " else:\n",
- " os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n",
- " return f\"{pdb_code}.pdb\""
- ]
- },
- {
- "cell_type": "code",
- "source": [
- "#@title # Prep Inputs\n",
- "pdb = \"4N5T\" #@param {type:\"string\"}\n",
- "chain = \"A\" #@param {type:\"string\"}\n",
- "binder_len = 50#@param {type:\"integer\"}\n",
- "hotspot = \"\" #@param {type:\"string\"}\n",
- "if hotspot == \"\": hotspot = None\n",
- "\n",
- "x = {\"pdb_filename\":pdb, \"chain\":chain, \"binder_len\":binder_len, \"hotspot\":hotspot}\n",
- "if \"x_prev\" not in dir() or x != x_prev:\n",
- " x[\"pdb_filename\"] = get_pdb(x[\"pdb_filename\"])\n",
- " \n",
- " clear_mem()\n",
- " model = mk_afdesign_model(protocol=\"binder\")\n",
- " model.prep_inputs(**x)\n",
- "\n",
- " pre_model = mk_afdesign_model(protocol=\"hallucination\")\n",
- " pre_model.prep_inputs(length=binder_len)\n",
- "\n",
- " x_prev = x\n",
- " print(\"target length:\", model._target_len)\n",
- " print(\"binder length:\", model._binder_len)"
- ],
- "metadata": {
- "cellView": "form",
- "id": "HSgE99WALOE-"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title #stage 1 - Pre-hallucinate binder scaffold\n",
- "#@markdown ---\n",
- "#@markdown ####Weights\n",
- "#@markdown - Minimizing `pae` or maximizing `plddt` often results in a single helix.\n",
- "#@markdown To avoid this, we start with a random sequence and instead try to optimize \n",
- "#@markdown defined `num`ber of `con`tacts per position. \n",
- "pae = 0.1 #@param [\"0.01\", \"0.1\", \"0.5\", \"1.0\"] {type:\"raw\"}\n",
- "plddt = 0.1 #@param [\"0.01\", \"0.1\", \"0.5\", \"1.0\"] {type:\"raw\"}\n",
- "helix = 0.0 \n",
- "con = 1.0 #@param [\"0.01\", \"0.1\", \"0.5\", \"1.0\"] {type:\"raw\"}\n",
- "#@markdown ####Contact Definition\n",
- "#@markdown - The contact definition is based on Cb-Cb diststance `cutoff`. To avoid \n",
- "#@markdown biasing towards helical contact, only contacts with sequence seperation > \n",
- "#@markdown `seqsep` are considered.\n",
- "\n",
- "seqsep = 9 #@param [\"0\",\"5\",\"9\"] {type:\"raw\"}\n",
- "cutoff = \"14\" #@param [\"8\", \"14\", \"max\"]\n",
- "num = \"2\" #@param [\"1\", \"2\", \"3\", \"4\", \"8\", \"max\"]\n",
- "binary = True #@param {type:\"boolean\"}\n",
- "if cutoff == \"max\": cutoff = 21.6875\n",
- "if num == \"max\": num = binder_len\n",
- "\n",
- "pre_opt = {\"con\":{\"seqsep\":int(seqsep),\"cutoff\":float(cutoff),\"num\":int(num),\n",
- " \"binary\":binary}}\n",
- "pre_weights = {\"con\":float(con),\"helix\":float(helix),\n",
- " \"pae\":float(pae),\"plddt\":float(plddt)}\n",
- "\n",
- "# pre-design with gumbel initialization and softmax activation\n",
- "pre_model.restart(mode=\"gumbel\", opt=pre_opt, weights=pre_weights)\n",
- "pre_model.design_soft(50)\n",
- "save_seq = np.asarray(pre_model.aux[\"seq\"][\"pseudo\"])\n",
- "\n",
- "# refine\n",
- "pre_model.restart(seq=save_seq, opt=pre_opt, weights=pre_weights, keep_history=True)\n",
- "pre_model.design(50, soft=0.0, e_soft=1.0)\n",
- "save_seq = np.asarray(pre_model.aux[\"seq\"][\"pseudo\"])"
- ],
- "metadata": {
- "cellView": "form",
- "id": "TX1aPX0fXa6D"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@markdown ## display pre-hallucinated binder scaffold {run: \"auto\"}\n",
- "color = \"pLDDT\" #@param [\"chain\", \"pLDDT\", \"rainbow\"]\n",
- "show_sidechains = False #@param {type:\"boolean\"}\n",
- "show_mainchains = False #@param {type:\"boolean\"}\n",
- "pre_model.plot_pdb(show_sidechains=show_sidechains,\n",
- " show_mainchains=show_mainchains,\n",
- " color=color)"
- ],
- "metadata": {
- "cellView": "form",
- "id": "gz7wRJaGXs9e"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "HTML(pre_model.animate())"
- ],
- "metadata": {
- "id": "5OJdtq8trTB4"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title #state 2 - binder design\n",
- "#@markdown ---\n",
- "#@markdown ####interface Weights\n",
- "i_pae = 1.0 #@param [\"0.01\", \"0.1\", \"0.5\", \"1.0\"] {type:\"raw\"}\n",
- "i_con = 0.5 #@param [\"0.01\", \"0.1\", \"0.5\", \"1.0\"] {type:\"raw\"}\n",
- "weights = {\"i_pae\":float(i_pae),\n",
- " \"i_con\":float(i_con),\n",
- " **pre_weights}\n",
- "\n",
- "#@markdown ####interface Contact Definition\n",
- "cutoff = \"max\" #@param [\"8\", \"14\", \"max\"]\n",
- "num = \"max\" #@param [\"1\", \"2\", \"4\", \"8\", \"max\"]\n",
- "binary = True #@param {type:\"boolean\"}\n",
- "if cutoff == \"max\": cutoff = 21.6875\n",
- "if num == \"max\": num = binder_len\n",
- "\n",
- "opt = {\"i_con\":{\"cutoff\":float(cutoff),\"num\":int(num),\n",
- " \"binary\":binary},\n",
- " **pre_opt}\n",
- "\n",
- "model.restart(seq=save_seq, opt=opt, weights=weights)\n",
- "model.design_3stage(100,100,10)"
- ],
- "metadata": {
- "cellView": "form",
- "id": "eCGc3J663NGz"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@markdown ## display hallucinated binder {run: \"auto\"}\n",
- "color = \"chain\" #@param [\"chain\", \"pLDDT\", \"rainbow\"]\n",
- "show_sidechains = False #@param {type:\"boolean\"}\n",
- "show_mainchains = False #@param {type:\"boolean\"}\n",
- "model.save_pdb(f\"{model.protocol}.pdb\")\n",
- "model.plot_pdb(show_sidechains=show_sidechains,\n",
- " show_mainchains=show_mainchains,\n",
- " color=color)"
- ],
- "metadata": {
- "cellView": "form",
- "id": "Ec0BnP1ehH5w"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "L2E9Tn2Acchj"
- },
- "outputs": [],
- "source": [
- "HTML(model.animate())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "YSKWYu0_GlUH"
- },
- "outputs": [],
- "source": [
- "model.get_seqs()"
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "collapsed_sections": [
- "q4qiU9I0QHSz"
- ],
- "name": "2stage_binder_hallucination.ipynb",
- "provenance": [],
- "include_colab_link": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- },
- "language_info": {
- "name": "python"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
\ No newline at end of file
diff --git a/af/examples/AF2Rank.ipynb b/af/examples/AF2Rank.ipynb
index c2e0d402..2682cf09 100644
--- a/af/examples/AF2Rank.ipynb
+++ b/af/examples/AF2Rank.ipynb
@@ -26,7 +26,7 @@
"colab_type": "text"
},
"source": [
- ""
+ ""
]
},
{
@@ -44,7 +44,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {
"cellView": "form",
"id": "zk6_tVpg9Bdi"
@@ -61,7 +61,7 @@
"\n",
" # alphafold params\n",
" mkdir params\n",
- " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n",
+ " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n",
"\n",
" # download openfold weights (optional)\n",
" for W in openfold_model_ptm_1 openfold_model_ptm_2\n",
@@ -108,7 +108,7 @@
" return o\n",
" \n",
"def plot_me(scores, x=\"tm_i\", y=\"composite\", \n",
- " title=None, diag=False, scale_axis=True, dpi=100):\n",
+ " title=None, diag=False, scale_axis=True, dpi=100, **kwargs):\n",
" def rescale(a,amin=None,amax=None): \n",
" a = np.copy(a)\n",
" if amin is None: amin = a.min()\n",
@@ -122,14 +122,16 @@
" x_vals = np.array([k[x] for k in scores])\n",
" y_vals = np.array([k[y] for k in scores])\n",
" c = rescale(np.array([k[\"plddt\"] for k in scores]),0.5,0.9)\n",
- " plt.scatter(x_vals, y_vals, c=c*0.75, s=5, vmin=0, vmax=1, cmap=\"gist_rainbow\")\n",
+ " plt.scatter(x_vals, y_vals, c=c*0.75, s=5, vmin=0, vmax=1, cmap=\"gist_rainbow\",\n",
+ " **kwargs)\n",
" if diag:\n",
" plt.plot([0,1],[0,1],color=\"black\")\n",
" \n",
" labels = {\"tm_i\":\"TMscore of Input\",\n",
- " \"tm_o\":\"TMscore of Ouput\",\n",
+ " \"tm_o\":\"TMscore of Output\",\n",
" \"tm_io\":\"TMscore between Input and Output\",\n",
" \"ptm\":\"Predicted TMscore (pTM)\",\n",
+ " \"i_ptm\":\"Predicted interface TMscore (ipTM)\",\n",
" \"plddt\":\"Predicted LDDT (pLDDT)\",\n",
" \"composite\":\"Composite\"}\n",
"\n",
@@ -138,19 +140,31 @@
" if x in labels: plt.xlim(-0.1, 1.1)\n",
" if y in labels: plt.ylim(-0.1, 1.1)\n",
" \n",
- " plt.show()\n",
" print(spearmanr(x_vals,y_vals).correlation)\n",
"\n",
- "\n",
"class af2rank:\n",
- " def __init__(self, pdb, chain=None):\n",
- " self.model = mk_af_model(protocol=\"fixbb\", use_templates=True,\n",
- " use_alphafold=True, use_openfold=True)\n",
- " self.model.prep_inputs(pdb, chain=chain)\n",
+ " def __init__(self, pdb, chain=None, model_name=\"model_1_ptm\", model_names=None):\n",
+ " self.args = {\"pdb\":pdb, \"chain\":chain,\n",
+ " \"use_multimer\":(\"multimer\" in model_name),\n",
+ " \"model_name\":model_name,\n",
+ " \"model_names\":model_names}\n",
+ " self.reset()\n",
+ "\n",
+ " def reset(self):\n",
+ " self.model = mk_af_model(protocol=\"fixbb\",\n",
+ " use_templates=True,\n",
+ " use_multimer=self.args[\"use_multimer\"],\n",
+ " use_alphafold=True, use_openfold=True,\n",
+ " debug=False,\n",
+ " model_names=self.args[\"model_names\"])\n",
+ " \n",
+ " self.model.prep_inputs(self.args[\"pdb\"], chain=self.args[\"chain\"])\n",
" self.model.set_seq(mode=\"wildtype\")\n",
" self.wt_batch = copy_dict(self.model._inputs[\"batch\"])\n",
+ " self.wt = self.model._wt_aatype\n",
"\n",
" def set_pdb(self, pdb, chain=None):\n",
+ " if chain is None: chain = self.args[\"chain\"]\n",
" self.model.prep_inputs(pdb, chain=chain)\n",
" self.model.set_seq(mode=\"wildtype\")\n",
" self.wt = self.model._wt_aatype\n",
@@ -183,51 +197,54 @@
" return score\n",
" \n",
" def predict(self, pdb=None, seq=None, chain=None, \n",
- " input_template=True, \n",
- " rm_tm_seq=True, rm_tm_sc=True, recycles=1,\n",
- " iterations=1, model_name=\"model_1_ptm\",\n",
- " tm_aatype=21, save_pdb=False,\n",
- " output_dir=\"tmp\",output_pdb=None,\n",
- " extras=None, verbose=True):\n",
+ " input_template=True, model_name=None,\n",
+ " rm_seq=True, rm_sc=True, rm_ic=False,\n",
+ " recycles=1, iterations=1,\n",
+ " output_pdb=None, extras=None, verbose=True):\n",
+ " \n",
+ " if model_name is not None:\n",
+ " self.args[\"model_name\"] = model_name\n",
+ " if \"multimer\" in model_name: \n",
+ " if not self.args[\"use_multimer\"]:\n",
+ " self.args[\"use_multimer\"] = True\n",
+ " self.reset()\n",
+ " else:\n",
+ " if self.args[\"use_multimer\"]:\n",
+ " self.args[\"use_multimer\"] = False\n",
+ " self.reset()\n",
" \n",
" if pdb is not None: self.set_pdb(pdb, chain)\n",
" if seq is not None: self.set_seq(seq)\n",
"\n",
" # set template sequence\n",
- " tm_aatype = np.full_like(self.wt, tm_aatype) if rm_tm_seq else self.wt\n",
- " self.model.opt[\"template\"][\"aatype\"] = tm_aatype\n",
+ " self.model._inputs[\"batch\"][\"aatype\"] = self.wt\n",
"\n",
" # set other options\n",
- " self.model.set_opt(template=dict(dropout=not input_template),\n",
- " rm_template_seq=True, rm_template_sc=rm_tm_sc,\n",
- " num_recycles=recycles)\n",
+ " self.model.set_opt(\n",
+ " template=dict(dropout=not input_template,\n",
+ " rm_ic=rm_ic,\n",
+ " rm_sc=rm_sc,\n",
+ " rm_seq=rm_seq),\n",
+ " num_recycles=recycles)\n",
" \n",
" # \"manual\" recycles using templates\n",
" ini_atoms = self.model._inputs[\"batch\"][\"all_atom_positions\"].copy()\n",
" for i in range(iterations):\n",
- " self.model.predict(models=[model_name], verbose=False)\n",
+ " self.model.predict(models=self.args[\"model_name\"], verbose=False)\n",
" if i < iterations - 1:\n",
" self.model._inputs[\"batch\"][\"all_atom_positions\"] = self.model.aux[\"atom_positions\"]\n",
" else:\n",
" self.model._inputs[\"batch\"][\"all_atom_positions\"] = ini_atoms\n",
" \n",
- " if save_pdb:\n",
- " os.makedirs(output_dir, exist_ok=True)\n",
- " if output_pdb is None:\n",
- " if pdb is None:\n",
- " aatype = np.array(self.model._params[\"seq\"])[0].argmax(-1)\n",
- " diff = np.where(aatype != self.wt)[0]\n",
- " seq_diff = np.array(list(seq))[diff]\n",
- " output_pdb = \"_\".join([f\"{i}{a}\" for i,a in zip(diff, seq_diff)]) + \".pdb\"\n",
- " else:\n",
- " output_pdb = os.path.basename(pdb)\n",
- " self.model.save_pdb(os.path.join(output_dir, output_pdb))\n",
- "\n",
" score = self._get_score()\n",
- " if extras is not None: score.update(extras)\n",
+ " if extras is not None:\n",
+ " score.update(extras)\n",
+ "\n",
+ " if output_pdb is not None:\n",
+ " self.model.save_pdb(output_pdb)\n",
" \n",
" if verbose:\n",
- " print_list = [\"tm_i\",\"tm_o\",\"tm_io\",\"composite\",\"ptm\",\"plddt\",\"fitness\",\"id\"]\n",
+ " print_list = [\"tm_i\",\"tm_o\",\"tm_io\",\"composite\",\"ptm\",\"i_ptm\",\"plddt\",\"fitness\",\"id\"]\n",
" print_score = lambda k: f\"{k} {score[k]:.4f}\" if isinstance(score[k],float) else f\"{k} {score[k]}\"\n",
" print(*[print_score(k) for k in print_list if k in score])\n",
" \n",
@@ -237,38 +254,46 @@
"cellView": "form",
"id": "1o-_Rl4hFfkR"
},
- "execution_count": 2,
+ "execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
- "#@markdown ### settings\n",
- "seq_relacement = \"gap\" #@param [\"gap\", \"X\", \"A\", \"none\"]\n",
- "mask_sidechains = True #@param {type:\"boolean\"}\n",
+ "#@markdown ### **settings**\n",
"recycles = 1 #@param [\"0\", \"1\", \"2\", \"3\", \"4\"] {type:\"raw\"}\n",
"iterations = 1 \n",
- "model_name = \"model_1_ptm\" #@param [\"model_1_ptm\", \"model_2_ptm\", \"openfold_model_ptm_1\", \"openfold_model_ptm_2\"]\n",
+ "\n",
+ "# decide what model to use\n",
+ "model_mode = \"alphafold\" #@param [\"alphafold\", \"alphafold-multimer\", \"openfold\"]\n",
+ "model_num = 1 #@param [\"1\", \"2\", \"3\", \"4\", \"5\"] {type:\"raw\"}\n",
+ "\n",
+ "if model_mode == \"alphafold\":\n",
+ " model_name = f\"model_{model_num}_ptm\"\n",
+ "if model_mode == \"alphafold-multimer\":\n",
+ " model_name = f\"model_{model_num}_multimer_v2\"\n",
+ "if model_mode == \"openfold\":\n",
+ " model_name = f\"openfold_model_ptm_{model_num}\"\n",
+ "\n",
"save_output_pdbs = False #@param {type:\"boolean\"}\n",
"\n",
- "tm_aa = -1\n",
- "if seq_relacement == \"A\": tm_aa = 0\n",
- "if seq_relacement == \"gap\": tm_aa = 21\n",
- "if seq_relacement == \"X\": tm_aa = 20\n",
+ "#@markdown ### **advanced**\n",
+ "mask_sequence = True #@param {type:\"boolean\"}\n",
+ "mask_sidechains = True #@param {type:\"boolean\"}\n",
+ "mask_interchain = False #@param {type:\"boolean\"}\n",
"\n",
- "SETTINGS = {\"rm_tm_seq\":seq_relacement != \"none\",\n",
- " \"rm_tm_sc\":not mask_sidechains,\n",
+ "SETTINGS = {\"rm_seq\":mask_sequence,\n",
+ " \"rm_sc\":mask_sidechains,\n",
+ " \"rm_ic\":mask_interchain,\n",
" \"recycles\":int(recycles),\n",
" \"iterations\":int(iterations),\n",
- " \"model_name\":model_name, \n",
- " \"tm_aatype\":tm_aa, \n",
- " \"save_pdb\":save_output_pdbs}"
+ " \"model_name\":model_name}"
],
"metadata": {
"cellView": "form",
"id": "6G7XWsStB1sB"
},
- "execution_count": 3,
+ "execution_count": null,
"outputs": []
},
{
@@ -295,12 +320,12 @@
"\n",
"# setup model\n",
"clear_mem()\n",
- "af = af2rank(NATIVE_PATH, CHAIN)"
+ "af = af2rank(NATIVE_PATH, CHAIN, model_name=SETTINGS[\"model_name\"])"
],
"metadata": {
"id": "iDCRJjdSIG0g"
},
- "execution_count": 4,
+ "execution_count": 42,
"outputs": []
},
{
@@ -313,16 +338,16 @@
"colab": {
"base_uri": "https://localhost:8080/"
},
- "id": "nKGf21nk6Ofx",
- "outputId": "3c800748-71ac-4480-b1b7-25456109c929"
+ "id": "UCUZxJdbBjZt",
+ "outputId": "5d290158-bcc1-4f02-f3a9-af13d1946248"
},
- "execution_count": 5,
+ "execution_count": 43,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
- "tm_i 1.0000 tm_o 0.6650 tm_io 0.6650 composite 0.2399 ptm 0.5467 plddt 0.6599\n"
+ "tm_i 1.0000 tm_o 0.6650 tm_io 0.6650 composite 0.2399 ptm 0.5467 i_ptm 0.0000 plddt 0.6599\n"
]
}
]
@@ -337,11 +362,16 @@
"\n",
"# score the decoy sctructures\n",
"for decoy_pdb in os.listdir(DECOY_DIR):\n",
- " decoy_path = os.path.join(DECOY_DIR, decoy_pdb)\n",
- " SCORES.append(af.predict(pdb=decoy_path, **SETTINGS, extras={\"id\":decoy_path}))"
+ " input_pdb = os.path.join(DECOY_DIR, decoy_pdb)\n",
+ " if save_output_pdbs:\n",
+ " output_pdb = os.path.join(\"tmp\",decoy_pdb)\n",
+ " else:\n",
+ " output_pdb = None\n",
+ " SCORES.append(af.predict(pdb=input_pdb, output_pdb=output_pdb,\n",
+ " **SETTINGS, extras={\"id\":decoy_pdb}))"
],
"metadata": {
- "id": "ye1ScsVBajSo"
+ "id": "ChgI637YCArk"
},
"execution_count": null,
"outputs": []
@@ -357,29 +387,29 @@
"base_uri": "https://localhost:8080/",
"height": 497
},
- "id": "Ouy6R-DoC-zO",
- "outputId": "01b3fb93-ebfb-4416-8641-43b4da802c37"
+ "id": "ZUEaAlP5CR8h",
+ "outputId": "7322efa2-96d6-4866-d1ad-8c4d46a771f8"
},
- "execution_count": 7,
+ "execution_count": null,
"outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "0.9286667300616703\n"
+ ]
+ },
{
"output_type": "display_data",
"data": {
"text/plain": [
"