diff --git a/.gitignore b/.gitignore index 6a3e68da..f7e9c036 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -**/.DS_Store \ No newline at end of file +**/.DS_Store +*.pyc diff --git a/af/README.md b/af/README.md index 0d95345d..da2eaf5e 100644 --- a/af/README.md +++ b/af/README.md @@ -1,4 +1,4 @@ -# AfDesign (v1.0.5) +# AfDesign (v1.0.6) ### Google Colab Open In Colab @@ -16,6 +16,7 @@ Minor changes changes include renaming intra_pae/inter_con to pae/con and inter_ - **11July2022** - v1.0.3 - Improved homo-oligomeric support. RMSD and dgram losses have been refactored to automatically save aligned coordinates. Multimeric coordinates now saved with chain identifiers. - **23July2022** - v1.0.4 - Adding support for openfold weights. To enable set `mk_afdesign_model(..., use_openfold=True)`. - **31July2022** - v1.0.5 - Refactoring to add support for swapping batch features without recompile. Allowing for implementation of [AF2Rank](https://github.com/sokrypton/ColabDesign/blob/main/af/examples/AF2Rank.ipynb)! +- **19Aug2022** - v1.0.6 - Adding support for alphafold-multimer. To enable set `mk_afdesign_model(..., use_multimer=True)`. For multimer mode, multiple recycles maybe needed! ### setup ```bash @@ -23,7 +24,7 @@ pip install git+https://github.com/sokrypton/ColabDesign.git # download alphafold weights mkdir params -curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params +curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params # download openfold weights (optional) for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1 @@ -97,14 +98,15 @@ model.opt["weights"]["pae"] = 0.0 #### How do I control number of recycles used during design? ```python model = mk_afdesign_model(num_recycles=1, recycle_mode="average") -# if recycle_mode in ["average","last","sample"] the number of recycles can change during optimization +# if recycle_mode in ["average",last","sample","first"] the number of recycles can change during optimization model.set_opt(num_recycles=1) ``` - `num_recycles` - number of recycles to use during design (for denovo proteins we find 0 is often enough) - `recycle_mode` - optimizing across all recycles can be tricky, we experiment with a couple of ways: - - *last* - use loss from last recycle. (Not recommended, unless you increase number optimization) - - *sample* - Same as *last* but each iteration a different number of recycles are used. (Previous default). - - *average* - compute loss at each recycle and average gradients. (Default; Recommended). + - *last* - use loss from last recycle. (Default) + - *average* - compute loss at each recycle and average gradients. (Previous default from v.1.0.5) + - *sample* - Same as *last* but each iteration a different number of recycles are used. + - *first* - use loss from first recycle. - *add_prev* - average the outputs (dgram, plddt, pae) across all recycles before computing loss. - *backprop* - use loss from last recycle, but backprop through all recycles. @@ -122,17 +124,6 @@ model.set_opt(num_models=1) #### Can I use OpenFold model params for design instead of AlphaFold? ```python model = mk_afdesign_model(use_openfold=True, use_alphafold=False) -# OR -model.set_opt(use_openfold=True, use_alphafold=False) -``` -#### How is contact defined? How do I change it? -By default, 2 [con]tacts per positions are optimized to be within cβ-cβ < 14.0Å and sequence seperation ≥ 9. This can be changed with: -```python -model.set_opt(con=dict(cutoff=8, seqsep=5, num=1)) -``` -For interface: -```python -model.set_opt(i_con=dict(...)) ``` #### For binder hallucination, can I specify the site I want to bind? ```python @@ -142,12 +133,6 @@ model.prep_inputs(..., hotspot="1-10,15,3") ```python model.prep_inputs(..., chain="A,B") ``` -#### Can I design homo-oligomers? -```python -model.prep_inputs(..., copies=2) -# specify interface specific contact and/or pae loss -model.set_weights(i_con=1, i_pae=0) -``` #### For fixed backbone design, how do I force the sequence to be the same for homo-dimer optimization? ```python model.prep_inputs(pdb_filename="6Q40.pdb", chain="A,B", copies=2, homooligomer=True) @@ -168,14 +153,13 @@ model.restart(seed=0) - `design_hard()` - optimize *one_hot(logits)* inputs (discrete) - For complex topologies, we find directly optimizing one_hot encoded sequence `design_hard()` to be very challenging. -To get around this problem, we propose optimizing in 2 or 3 stages. - - `design_2stage()` - *soft* → *hard* +To get around this problem, we propose optimizing in 3 stages. - `design_3stage()` - *logits* → *soft* → *hard* + #### What are all the different losses being optimized? - general losses - *pae* - minimizes the predicted alignment error - *plddt* - maximizes the predicted LDDT - - *msa_ent* - minimize entropy for MSA design (see example at the end of notebook) - *pae* and *plddt* values are between 0 and 1 (where lower is better for both) - fixbb specific losses @@ -184,18 +168,26 @@ To get around this problem, we propose optimizing in 2 or 3 stages. - we find *dgram_cce* loss to be more stable for design (compared to *fape*) - hallucination specific losses - - *con* - maximize number of contacts. (We find just minimizing *plddt* results in single long helix, -and maximizing *pae* results in a two helix bundle. To encourage compact structures we add a `con` term) + - *con* - maximize `1` contacts per position. `model.set_opt("con",num=1)` - binder specific losses - - *i_pae* - minimize PAE interface of the proteins - - *pae* - minimize PAE within binder - - *i_con* - maximize number of contacts at the interface of the proteins - - *con* - maximize number of contacts within binder + - *pae* - minimize PAE at interface and within binder + - *con* - - maximize `2` contacts per binder position, within binder. `model.set_opt("con",num=2)` + - *i_con* - maximize `1` contacts per binder position `model.set_opt("i_con",num=1)` - partial hallucination specific losses - *sc_fape* - sidechain-specific fape +#### How is contact defined? How do I change it? +By default, 2 [con]tacts per positions are optimized to be within cβ-cβ < 14.0Å and sequence seperation ≥ 9. This can be changed with: +```python +model.set_opt(con=dict(cutoff=8, seqsep=5, num=1)) +``` +For interface: +```python +model.set_opt(i_con=dict(...)) +``` + # Advanced FAQ #### loss during Gradient descent is too jumpy, can I do some kind of greedy search towards the end? Gradient descent updates multiple positions each iteration, which can be a little too aggressive during hard (discrete) mode. diff --git a/af/design.ipynb b/af/design.ipynb index 90e3a32b..61e380c1 100644 --- a/af/design.ipynb +++ b/af/design.ipynb @@ -16,7 +16,7 @@ "id": "OA2k3sAYuiXe" }, "source": [ - "#AfDesign (v1.0.5)\n", + "#AfDesign (v1.0.6)\n", "Backprop through AlphaFold for protein design.\n", "\n", "**WARNING**\n", @@ -42,7 +42,7 @@ " ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\n", " # download params\n", " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n", + " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", " for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1\n", " do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done\n", "fi" @@ -246,7 +246,9 @@ }, "source": [ "# binder hallucination\n", - "For a given protein target and protein binder length, generate/hallucinate a protein binder sequence AlphaFold thinks will bind to the target structure. To do this, we minimize PAE and maximize number of contacts at the interface and within the binder, and we maximize pLDDT of the binder." + "For a given protein target and protein binder length, generate/hallucinate a protein binder sequence AlphaFold thinks will bind to the target structure.\n", + "To do this, we minimize PAE and maximize number of contacts at the interface and within the binder, and we maximize pLDDT of the binder.\n", + "By default, AlphaFold-ptm with residue index offset hack is used. To enable AlphaFold-multimer set: mk_afdesign_model(use_multimer=True).\n" ] }, { @@ -275,12 +277,6 @@ "outputs": [], "source": [ "af_model.restart()\n", - "\n", - "# settings we find work best for helical peptide binder hallucination\n", - "af_model.set_weights(plddt=0.1, pae=0.1, i_pae=1.0, con=0.1, i_con=0.5)\n", - "af_model.set_opt(con=dict(binary=True, cutoff=21.6875, num=af_model._binder_len, seqsep=0))\n", - "af_model.set_opt(i_con=dict(binary=True, cutoff=21.6875, num=af_model._binder_len))\n", - "\n", "af_model.design_3stage(100,100,10)" ] }, @@ -373,8 +369,7 @@ "af_model.prep_inputs(pdb_filename=get_pdb(\"6MRR\"),\n", " chain=\"A\",\n", " pos=\"3-30,33-68\", # define positions to contrain\n", - " length=100, # total length if different from input pdb\n", - " fix_seq=False) # set True to constrain sequence in the specified positions\n", + " length=100) # total length if different from input pdb\n", "\n", "af_model.rewire(loops=[36]) # set loop length between segments " ], @@ -452,4 +447,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/af/examples/2stage_binder_hallucination.ipynb b/af/examples/2stage_binder_hallucination.ipynb deleted file mode 100644 index 0202d96d..00000000 --- a/af/examples/2stage_binder_hallucination.ipynb +++ /dev/null @@ -1,281 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OA2k3sAYuiXe" - }, - "source": [ - "# AfDesign - two-stage binder hallucination\n", - "For a given protein target and protein binder length, generate/hallucinate a protein binder sequence AlphaFold thinks will bind to the target structure. To do this, we minimize PAE and maximize number of contacts at the interface and within the binder, and we maximize pLDDT of the binder.\n", - "\n", - "**WARNING**\n", - "1. This notebook is in active development and was designed for demonstration purposes only.\n", - "2. Using AfDesign as the only \"loss\" function for design might be a bad idea, you may find adversarial sequences (aka. sequences that trick AlphaFold)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "-AXy0s_4cKaK" - }, - "outputs": [], - "source": [ - "#@title install\n", - "%%bash\n", - "if [ ! -d params ]; then\n", - " pip -q install git+https://github.com/sokrypton/ColabDesign.git\n", - " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n", - " for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1\n", - " do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done\n", - "fi" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "Vt7G_nbNeSQ3" - }, - "outputs": [], - "source": [ - "#@title #import libraries\n", - "import warnings\n", - "warnings.simplefilter(action='ignore', category=FutureWarning)\n", - "\n", - "import os\n", - "from colabdesign import mk_afdesign_model, clear_mem\n", - "from IPython.display import HTML\n", - "from google.colab import files\n", - "import numpy as np\n", - "\n", - "#########################\n", - "def get_pdb(pdb_code=\"\"):\n", - " if pdb_code is None or pdb_code == \"\":\n", - " upload_dict = files.upload()\n", - " pdb_string = upload_dict[list(upload_dict.keys())[0]]\n", - " with open(\"tmp.pdb\",\"wb\") as out: out.write(pdb_string)\n", - " return \"tmp.pdb\"\n", - " else:\n", - " os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n", - " return f\"{pdb_code}.pdb\"" - ] - }, - { - "cell_type": "code", - "source": [ - "#@title # Prep Inputs\n", - "pdb = \"4N5T\" #@param {type:\"string\"}\n", - "chain = \"A\" #@param {type:\"string\"}\n", - "binder_len = 50#@param {type:\"integer\"}\n", - "hotspot = \"\" #@param {type:\"string\"}\n", - "if hotspot == \"\": hotspot = None\n", - "\n", - "x = {\"pdb_filename\":pdb, \"chain\":chain, \"binder_len\":binder_len, \"hotspot\":hotspot}\n", - "if \"x_prev\" not in dir() or x != x_prev:\n", - " x[\"pdb_filename\"] = get_pdb(x[\"pdb_filename\"])\n", - " \n", - " clear_mem()\n", - " model = mk_afdesign_model(protocol=\"binder\")\n", - " model.prep_inputs(**x)\n", - "\n", - " pre_model = mk_afdesign_model(protocol=\"hallucination\")\n", - " pre_model.prep_inputs(length=binder_len)\n", - "\n", - " x_prev = x\n", - " print(\"target length:\", model._target_len)\n", - " print(\"binder length:\", model._binder_len)" - ], - "metadata": { - "cellView": "form", - "id": "HSgE99WALOE-" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "#@title #stage 1 - Pre-hallucinate binder scaffold\n", - "#@markdown ---\n", - "#@markdown ####Weights\n", - "#@markdown - Minimizing `pae` or maximizing `plddt` often results in a single helix.\n", - "#@markdown To avoid this, we start with a random sequence and instead try to optimize \n", - "#@markdown defined `num`ber of `con`tacts per position. \n", - "pae = 0.1 #@param [\"0.01\", \"0.1\", \"0.5\", \"1.0\"] {type:\"raw\"}\n", - "plddt = 0.1 #@param [\"0.01\", \"0.1\", \"0.5\", \"1.0\"] {type:\"raw\"}\n", - "helix = 0.0 \n", - "con = 1.0 #@param [\"0.01\", \"0.1\", \"0.5\", \"1.0\"] {type:\"raw\"}\n", - "#@markdown ####Contact Definition\n", - "#@markdown - The contact definition is based on Cb-Cb diststance `cutoff`. To avoid \n", - "#@markdown biasing towards helical contact, only contacts with sequence seperation > \n", - "#@markdown `seqsep` are considered.\n", - "\n", - "seqsep = 9 #@param [\"0\",\"5\",\"9\"] {type:\"raw\"}\n", - "cutoff = \"14\" #@param [\"8\", \"14\", \"max\"]\n", - "num = \"2\" #@param [\"1\", \"2\", \"3\", \"4\", \"8\", \"max\"]\n", - "binary = True #@param {type:\"boolean\"}\n", - "if cutoff == \"max\": cutoff = 21.6875\n", - "if num == \"max\": num = binder_len\n", - "\n", - "pre_opt = {\"con\":{\"seqsep\":int(seqsep),\"cutoff\":float(cutoff),\"num\":int(num),\n", - " \"binary\":binary}}\n", - "pre_weights = {\"con\":float(con),\"helix\":float(helix),\n", - " \"pae\":float(pae),\"plddt\":float(plddt)}\n", - "\n", - "# pre-design with gumbel initialization and softmax activation\n", - "pre_model.restart(mode=\"gumbel\", opt=pre_opt, weights=pre_weights)\n", - "pre_model.design_soft(50)\n", - "save_seq = np.asarray(pre_model.aux[\"seq\"][\"pseudo\"])\n", - "\n", - "# refine\n", - "pre_model.restart(seq=save_seq, opt=pre_opt, weights=pre_weights, keep_history=True)\n", - "pre_model.design(50, soft=0.0, e_soft=1.0)\n", - "save_seq = np.asarray(pre_model.aux[\"seq\"][\"pseudo\"])" - ], - "metadata": { - "cellView": "form", - "id": "TX1aPX0fXa6D" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "#@markdown ## display pre-hallucinated binder scaffold {run: \"auto\"}\n", - "color = \"pLDDT\" #@param [\"chain\", \"pLDDT\", \"rainbow\"]\n", - "show_sidechains = False #@param {type:\"boolean\"}\n", - "show_mainchains = False #@param {type:\"boolean\"}\n", - "pre_model.plot_pdb(show_sidechains=show_sidechains,\n", - " show_mainchains=show_mainchains,\n", - " color=color)" - ], - "metadata": { - "cellView": "form", - "id": "gz7wRJaGXs9e" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "HTML(pre_model.animate())" - ], - "metadata": { - "id": "5OJdtq8trTB4" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "#@title #state 2 - binder design\n", - "#@markdown ---\n", - "#@markdown ####interface Weights\n", - "i_pae = 1.0 #@param [\"0.01\", \"0.1\", \"0.5\", \"1.0\"] {type:\"raw\"}\n", - "i_con = 0.5 #@param [\"0.01\", \"0.1\", \"0.5\", \"1.0\"] {type:\"raw\"}\n", - "weights = {\"i_pae\":float(i_pae),\n", - " \"i_con\":float(i_con),\n", - " **pre_weights}\n", - "\n", - "#@markdown ####interface Contact Definition\n", - "cutoff = \"max\" #@param [\"8\", \"14\", \"max\"]\n", - "num = \"max\" #@param [\"1\", \"2\", \"4\", \"8\", \"max\"]\n", - "binary = True #@param {type:\"boolean\"}\n", - "if cutoff == \"max\": cutoff = 21.6875\n", - "if num == \"max\": num = binder_len\n", - "\n", - "opt = {\"i_con\":{\"cutoff\":float(cutoff),\"num\":int(num),\n", - " \"binary\":binary},\n", - " **pre_opt}\n", - "\n", - "model.restart(seq=save_seq, opt=opt, weights=weights)\n", - "model.design_3stage(100,100,10)" - ], - "metadata": { - "cellView": "form", - "id": "eCGc3J663NGz" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "#@markdown ## display hallucinated binder {run: \"auto\"}\n", - "color = \"chain\" #@param [\"chain\", \"pLDDT\", \"rainbow\"]\n", - "show_sidechains = False #@param {type:\"boolean\"}\n", - "show_mainchains = False #@param {type:\"boolean\"}\n", - "model.save_pdb(f\"{model.protocol}.pdb\")\n", - "model.plot_pdb(show_sidechains=show_sidechains,\n", - " show_mainchains=show_mainchains,\n", - " color=color)" - ], - "metadata": { - "cellView": "form", - "id": "Ec0BnP1ehH5w" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "L2E9Tn2Acchj" - }, - "outputs": [], - "source": [ - "HTML(model.animate())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "YSKWYu0_GlUH" - }, - "outputs": [], - "source": [ - "model.get_seqs()" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [ - "q4qiU9I0QHSz" - ], - "name": "2stage_binder_hallucination.ipynb", - "provenance": [], - "include_colab_link": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/af/examples/AF2Rank.ipynb b/af/examples/AF2Rank.ipynb index c2e0d402..2682cf09 100644 --- a/af/examples/AF2Rank.ipynb +++ b/af/examples/AF2Rank.ipynb @@ -26,7 +26,7 @@ "colab_type": "text" }, "source": [ - "\"Open" + "\"Open" ] }, { @@ -44,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "cellView": "form", "id": "zk6_tVpg9Bdi" @@ -61,7 +61,7 @@ "\n", " # alphafold params\n", " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n", + " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", "\n", " # download openfold weights (optional)\n", " for W in openfold_model_ptm_1 openfold_model_ptm_2\n", @@ -108,7 +108,7 @@ " return o\n", " \n", "def plot_me(scores, x=\"tm_i\", y=\"composite\", \n", - " title=None, diag=False, scale_axis=True, dpi=100):\n", + " title=None, diag=False, scale_axis=True, dpi=100, **kwargs):\n", " def rescale(a,amin=None,amax=None): \n", " a = np.copy(a)\n", " if amin is None: amin = a.min()\n", @@ -122,14 +122,16 @@ " x_vals = np.array([k[x] for k in scores])\n", " y_vals = np.array([k[y] for k in scores])\n", " c = rescale(np.array([k[\"plddt\"] for k in scores]),0.5,0.9)\n", - " plt.scatter(x_vals, y_vals, c=c*0.75, s=5, vmin=0, vmax=1, cmap=\"gist_rainbow\")\n", + " plt.scatter(x_vals, y_vals, c=c*0.75, s=5, vmin=0, vmax=1, cmap=\"gist_rainbow\",\n", + " **kwargs)\n", " if diag:\n", " plt.plot([0,1],[0,1],color=\"black\")\n", " \n", " labels = {\"tm_i\":\"TMscore of Input\",\n", - " \"tm_o\":\"TMscore of Ouput\",\n", + " \"tm_o\":\"TMscore of Output\",\n", " \"tm_io\":\"TMscore between Input and Output\",\n", " \"ptm\":\"Predicted TMscore (pTM)\",\n", + " \"i_ptm\":\"Predicted interface TMscore (ipTM)\",\n", " \"plddt\":\"Predicted LDDT (pLDDT)\",\n", " \"composite\":\"Composite\"}\n", "\n", @@ -138,19 +140,31 @@ " if x in labels: plt.xlim(-0.1, 1.1)\n", " if y in labels: plt.ylim(-0.1, 1.1)\n", " \n", - " plt.show()\n", " print(spearmanr(x_vals,y_vals).correlation)\n", "\n", - "\n", "class af2rank:\n", - " def __init__(self, pdb, chain=None):\n", - " self.model = mk_af_model(protocol=\"fixbb\", use_templates=True,\n", - " use_alphafold=True, use_openfold=True)\n", - " self.model.prep_inputs(pdb, chain=chain)\n", + " def __init__(self, pdb, chain=None, model_name=\"model_1_ptm\", model_names=None):\n", + " self.args = {\"pdb\":pdb, \"chain\":chain,\n", + " \"use_multimer\":(\"multimer\" in model_name),\n", + " \"model_name\":model_name,\n", + " \"model_names\":model_names}\n", + " self.reset()\n", + "\n", + " def reset(self):\n", + " self.model = mk_af_model(protocol=\"fixbb\",\n", + " use_templates=True,\n", + " use_multimer=self.args[\"use_multimer\"],\n", + " use_alphafold=True, use_openfold=True,\n", + " debug=False,\n", + " model_names=self.args[\"model_names\"])\n", + " \n", + " self.model.prep_inputs(self.args[\"pdb\"], chain=self.args[\"chain\"])\n", " self.model.set_seq(mode=\"wildtype\")\n", " self.wt_batch = copy_dict(self.model._inputs[\"batch\"])\n", + " self.wt = self.model._wt_aatype\n", "\n", " def set_pdb(self, pdb, chain=None):\n", + " if chain is None: chain = self.args[\"chain\"]\n", " self.model.prep_inputs(pdb, chain=chain)\n", " self.model.set_seq(mode=\"wildtype\")\n", " self.wt = self.model._wt_aatype\n", @@ -183,51 +197,54 @@ " return score\n", " \n", " def predict(self, pdb=None, seq=None, chain=None, \n", - " input_template=True, \n", - " rm_tm_seq=True, rm_tm_sc=True, recycles=1,\n", - " iterations=1, model_name=\"model_1_ptm\",\n", - " tm_aatype=21, save_pdb=False,\n", - " output_dir=\"tmp\",output_pdb=None,\n", - " extras=None, verbose=True):\n", + " input_template=True, model_name=None,\n", + " rm_seq=True, rm_sc=True, rm_ic=False,\n", + " recycles=1, iterations=1,\n", + " output_pdb=None, extras=None, verbose=True):\n", + " \n", + " if model_name is not None:\n", + " self.args[\"model_name\"] = model_name\n", + " if \"multimer\" in model_name: \n", + " if not self.args[\"use_multimer\"]:\n", + " self.args[\"use_multimer\"] = True\n", + " self.reset()\n", + " else:\n", + " if self.args[\"use_multimer\"]:\n", + " self.args[\"use_multimer\"] = False\n", + " self.reset()\n", " \n", " if pdb is not None: self.set_pdb(pdb, chain)\n", " if seq is not None: self.set_seq(seq)\n", "\n", " # set template sequence\n", - " tm_aatype = np.full_like(self.wt, tm_aatype) if rm_tm_seq else self.wt\n", - " self.model.opt[\"template\"][\"aatype\"] = tm_aatype\n", + " self.model._inputs[\"batch\"][\"aatype\"] = self.wt\n", "\n", " # set other options\n", - " self.model.set_opt(template=dict(dropout=not input_template),\n", - " rm_template_seq=True, rm_template_sc=rm_tm_sc,\n", - " num_recycles=recycles)\n", + " self.model.set_opt(\n", + " template=dict(dropout=not input_template,\n", + " rm_ic=rm_ic,\n", + " rm_sc=rm_sc,\n", + " rm_seq=rm_seq),\n", + " num_recycles=recycles)\n", " \n", " # \"manual\" recycles using templates\n", " ini_atoms = self.model._inputs[\"batch\"][\"all_atom_positions\"].copy()\n", " for i in range(iterations):\n", - " self.model.predict(models=[model_name], verbose=False)\n", + " self.model.predict(models=self.args[\"model_name\"], verbose=False)\n", " if i < iterations - 1:\n", " self.model._inputs[\"batch\"][\"all_atom_positions\"] = self.model.aux[\"atom_positions\"]\n", " else:\n", " self.model._inputs[\"batch\"][\"all_atom_positions\"] = ini_atoms\n", " \n", - " if save_pdb:\n", - " os.makedirs(output_dir, exist_ok=True)\n", - " if output_pdb is None:\n", - " if pdb is None:\n", - " aatype = np.array(self.model._params[\"seq\"])[0].argmax(-1)\n", - " diff = np.where(aatype != self.wt)[0]\n", - " seq_diff = np.array(list(seq))[diff]\n", - " output_pdb = \"_\".join([f\"{i}{a}\" for i,a in zip(diff, seq_diff)]) + \".pdb\"\n", - " else:\n", - " output_pdb = os.path.basename(pdb)\n", - " self.model.save_pdb(os.path.join(output_dir, output_pdb))\n", - "\n", " score = self._get_score()\n", - " if extras is not None: score.update(extras)\n", + " if extras is not None:\n", + " score.update(extras)\n", + "\n", + " if output_pdb is not None:\n", + " self.model.save_pdb(output_pdb)\n", " \n", " if verbose:\n", - " print_list = [\"tm_i\",\"tm_o\",\"tm_io\",\"composite\",\"ptm\",\"plddt\",\"fitness\",\"id\"]\n", + " print_list = [\"tm_i\",\"tm_o\",\"tm_io\",\"composite\",\"ptm\",\"i_ptm\",\"plddt\",\"fitness\",\"id\"]\n", " print_score = lambda k: f\"{k} {score[k]:.4f}\" if isinstance(score[k],float) else f\"{k} {score[k]}\"\n", " print(*[print_score(k) for k in print_list if k in score])\n", " \n", @@ -237,38 +254,46 @@ "cellView": "form", "id": "1o-_Rl4hFfkR" }, - "execution_count": 2, + "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ - "#@markdown ### settings\n", - "seq_relacement = \"gap\" #@param [\"gap\", \"X\", \"A\", \"none\"]\n", - "mask_sidechains = True #@param {type:\"boolean\"}\n", + "#@markdown ### **settings**\n", "recycles = 1 #@param [\"0\", \"1\", \"2\", \"3\", \"4\"] {type:\"raw\"}\n", "iterations = 1 \n", - "model_name = \"model_1_ptm\" #@param [\"model_1_ptm\", \"model_2_ptm\", \"openfold_model_ptm_1\", \"openfold_model_ptm_2\"]\n", + "\n", + "# decide what model to use\n", + "model_mode = \"alphafold\" #@param [\"alphafold\", \"alphafold-multimer\", \"openfold\"]\n", + "model_num = 1 #@param [\"1\", \"2\", \"3\", \"4\", \"5\"] {type:\"raw\"}\n", + "\n", + "if model_mode == \"alphafold\":\n", + " model_name = f\"model_{model_num}_ptm\"\n", + "if model_mode == \"alphafold-multimer\":\n", + " model_name = f\"model_{model_num}_multimer_v2\"\n", + "if model_mode == \"openfold\":\n", + " model_name = f\"openfold_model_ptm_{model_num}\"\n", + "\n", "save_output_pdbs = False #@param {type:\"boolean\"}\n", "\n", - "tm_aa = -1\n", - "if seq_relacement == \"A\": tm_aa = 0\n", - "if seq_relacement == \"gap\": tm_aa = 21\n", - "if seq_relacement == \"X\": tm_aa = 20\n", + "#@markdown ### **advanced**\n", + "mask_sequence = True #@param {type:\"boolean\"}\n", + "mask_sidechains = True #@param {type:\"boolean\"}\n", + "mask_interchain = False #@param {type:\"boolean\"}\n", "\n", - "SETTINGS = {\"rm_tm_seq\":seq_relacement != \"none\",\n", - " \"rm_tm_sc\":not mask_sidechains,\n", + "SETTINGS = {\"rm_seq\":mask_sequence,\n", + " \"rm_sc\":mask_sidechains,\n", + " \"rm_ic\":mask_interchain,\n", " \"recycles\":int(recycles),\n", " \"iterations\":int(iterations),\n", - " \"model_name\":model_name, \n", - " \"tm_aatype\":tm_aa, \n", - " \"save_pdb\":save_output_pdbs}" + " \"model_name\":model_name}" ], "metadata": { "cellView": "form", "id": "6G7XWsStB1sB" }, - "execution_count": 3, + "execution_count": null, "outputs": [] }, { @@ -295,12 +320,12 @@ "\n", "# setup model\n", "clear_mem()\n", - "af = af2rank(NATIVE_PATH, CHAIN)" + "af = af2rank(NATIVE_PATH, CHAIN, model_name=SETTINGS[\"model_name\"])" ], "metadata": { "id": "iDCRJjdSIG0g" }, - "execution_count": 4, + "execution_count": 42, "outputs": [] }, { @@ -313,16 +338,16 @@ "colab": { "base_uri": "https://localhost:8080/" }, - "id": "nKGf21nk6Ofx", - "outputId": "3c800748-71ac-4480-b1b7-25456109c929" + "id": "UCUZxJdbBjZt", + "outputId": "5d290158-bcc1-4f02-f3a9-af13d1946248" }, - "execution_count": 5, + "execution_count": 43, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "tm_i 1.0000 tm_o 0.6650 tm_io 0.6650 composite 0.2399 ptm 0.5467 plddt 0.6599\n" + "tm_i 1.0000 tm_o 0.6650 tm_io 0.6650 composite 0.2399 ptm 0.5467 i_ptm 0.0000 plddt 0.6599\n" ] } ] @@ -337,11 +362,16 @@ "\n", "# score the decoy sctructures\n", "for decoy_pdb in os.listdir(DECOY_DIR):\n", - " decoy_path = os.path.join(DECOY_DIR, decoy_pdb)\n", - " SCORES.append(af.predict(pdb=decoy_path, **SETTINGS, extras={\"id\":decoy_path}))" + " input_pdb = os.path.join(DECOY_DIR, decoy_pdb)\n", + " if save_output_pdbs:\n", + " output_pdb = os.path.join(\"tmp\",decoy_pdb)\n", + " else:\n", + " output_pdb = None\n", + " SCORES.append(af.predict(pdb=input_pdb, output_pdb=output_pdb,\n", + " **SETTINGS, extras={\"id\":decoy_pdb}))" ], "metadata": { - "id": "ye1ScsVBajSo" + "id": "ChgI637YCArk" }, "execution_count": null, "outputs": [] @@ -357,29 +387,29 @@ "base_uri": "https://localhost:8080/", "height": 497 }, - "id": "Ouy6R-DoC-zO", - "outputId": "01b3fb93-ebfb-4416-8641-43b4da802c37" + "id": "ZUEaAlP5CR8h", + "outputId": "7322efa2-96d6-4866-d1ad-8c4d46a771f8" }, - "execution_count": 7, + "execution_count": null, "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.9286667300616703\n" + ] + }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd0AAAHPCAYAAAAWIwD+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd5xU1d3H8c9vO22X3kVQFFSs2HuvMdaoiUk0msSaxxJ9fDSJisZoorHHGCsaS2LUGLHGhgUVFBULCIiA0ouwy7LL1vP8cc7s3B1m2+zuzC5836/XvHbnzC3n3pm5vzn1mnMOERERaX9Zmc6AiIjIxkJBV0REJE0UdEVERNJEQVdERCRNFHRFRETSREFXREQkTRR0RURE0kRBV0REJE0UdEVERNJEQbcBZna1mW0w03WZ2elm5sxs5yaW26COu6PT+W45M9s/fJb3z3ReNkb6zLZOpwq6ZtbdzMaZ2Utm9l344p2e6XxJ2zCz8WZWmpA2MbzPE5IsPzy8dkkkLXZBjj2qzOxrM3vYzDZLstyJDeTlztiFJXaRacZjYpudDJFOxMyuMLNjM52PziAn0xloob7AlcA3wDRg/3bc1++BG9px+x1VRz3u75nZWOfc1GYufzvwAZAL7AT8EjjKzLZ1zi1q4b6fBr6KPO8O/BX4d3gtZmkLtyupeQvoAlRmOiMbqWTXiCuAJ4Fn0p+dzqWzBd3FwCDn3JJQTfpBe+3IOVcNVLfX9pMxMwMKnHPl6dxvVCaOuxm+AXoAVwHfb+Y6bzvnngz/P2hms/CB+DTg+pbs3Dn3KfBp7LmZ9cUH3U+dc4+0ZFvSes65WmBdpvOxseqg14h2YWZZQJ5zrs0+b52qetk5V+GcW9KcZc1snpk9F6oRPzSzcjP7LNYOZGbHh+frzGyqme2YsH7Sdgsz+7GZTTGzMjNbZWZvmdmhkdeLzGy0mRW1II+HmdmHQDlwVnjtZ2b2upktM7MKM5tuZuc0so29Q77WherUnzZj/73COgvMbFRDxx2qTu80s2PN7POQny/M7PAk24yd73VmNsfMzmqDNqA1wC3A0Wa2U4rbeD38HdGKfLRKeI8+iJ6bRpb9cfhcloemlH+Y2SZJltvNzF4In8W1ZvapmV2QsMyBZvZ2eH21mf3HzLaKvH5AeI+PS7L9H4XX9gjPB5rZg+EzU2Fmi8P2hjdx7BOTVb+bb1KYl5B2Sjj2NWZWEr6nF0ReX69NN2z/czPb2szeCN/PhWb2v0n2uamZPRvOxzIzuyV8B5vVTmxmQ8zsfjNbFM7BXDP7q5nlRZbZzMz+Fd67MjN738yOSthO7DhOMrOrQn7XmNmT4TqSb2a3hjyWhvOen7CN2HfzVDObafHr2b5J8r2jmb0Yzmmpmb1mZrsnLJMb8jI7bGulmb1jZodElqn3fQ7/dwNOs3hTy/iE8/WAmS21+LXjjKbOc1j3kLD/1SHPM83sDwnLFIQ8zQp5XmxmT5vZ5pFlupnZn83s25CHmWZ2iZlZI+fzC6ACOLy1xxHV2Uq6LTUSeAz4G/AIcAkwwczOBv4A3BWWuxx4wsxGhV/RSZnZVcDVwLv4au5KYDfgQOC/YbHjgAeBnwHjm5HHUcDjIY/3AjND+jnAF8Cz+F+VRwN3mVmWc+4vSY7zSeB+4CHgDGC8mU11zn3RwLH0BV4BegP7OefmNJHPvYHj8edsDfA/wFNmNsw5tzJsc0fgJXyNxFVANv48LW/GeWjKbcBF+PPf3NJuVOwLuLIN8tJiZrYt/jOyHH8MOcA4klRJm9lvgGuBJ4D7gH7Ar4C3zGxH59zqsNwhwHP4830bsATYCvheeI6ZHQy8CHwd9tslbGuSme3knJsHTAS+BU7FV5lHnQrMcc69F54/BWwD3AHMA/oDhwDDwvNWCcf0OPAacFlI3grYK3ZMjeiF//w9jT93JwJ/NLPPnHMvhu13w/8AG0T8nP0IOKCZ+RsMTAF6AvcAXwJDwr66ApVmNgB/jeiKr11Zia9hedbMTnTOJZ7jy/E/uG/Af5d/BVQBteGYrgZ2B04H5gLXJKy/H3By2FcFcC7wkpnt6pz7POR7G+BtoAT4U9j+WcBEM9vPOTc5bOvqkJ/7wnEWAjvjm2heaeC0/CSy/D0hbU7Y7wDgfcABd+I//0cA95tZoXPu1ga2Gcvzc/hapivDsY3EfxZiy2SHZQ4C/oF/T3vgP5NjgDkhsD6Lf4/vBz4BDgNuxL93FyXs+kDgpJDfFcC81hzHepxznfKB/yA44PQGXp8XXt8jknZoSCsDhkXSfxnS94+kXe1PT93zkUAN/gudlbAvi/x/emP5aiCPhyV5rUuStJfwF8Bk29gnktYPX/12U5J87QwMBD7HfzE2TdheveMOaQ7/gd88krZdSD8/kvYssBYYnHDeqhK32cD5GA+UJqRNBD4P/18Z9rlTeD48PL8ksvz+Ie1n+D4Ag4Aj8RerWmDnhOVObCAvdzaU57BdB1zdgs/rv/EX1ujnbiv8D6ro52zTkHZFwvpjwnm8IjzPxgfSeUDPRj6PH+MDe++E964GeCiS9ofwmSlK+BxVxY4TH2jqne8WHP9EYGID7/m8yPNbgWIgu5Ftxd67/RO274CfRNLy8D9InoykXRyWOyaSVgDMSNxmA/t+KJy7nZO8ZuHvLWFbe0de6x7er7mE60fkOD4DciPLPhY+qy8kbP/d6LmKfDcdMDaSNix81p5O+PxVAJtF0gbhg/CbkbRPgOeaOAdXs/41ohQYn2TZ+4BFQJ+E9MeB1SS5zkWWuTAcW99GlvlZWOaiRt6PY8Iyv0l4/V/hPEevay68v1u31XEkPjpV9XIKprv4L3SA2K+5151z3yRJ34yGHYuvjr/GJZSGXTj74f/xzjlzzo1vZh7nOudeTkx0kXbdUNXUF3gT2MzWr7qe7px7O7LucnyJOdnxDA3byQX2dc7Nb2Y+X3WR0rDz7ZwlsX2EX5wHA8+4SEcl59xX+JJWW7gNWIUvRTflAfyv0UXA84TqL+fch22Ul2YL5+Yw/Lmp+9w552YAie/98fjP2RNm1jf2wJfIZhMvke2Iryq/1YWSb2S7sV7Xg4Ad8BfD7yKvf4ovtRwZWe1hIB9fYos5GV8ij7Vbl+Nrd/Y3s14tOgnNtxr/Xh3S1IJJlBLPK865SnzpK/o9OBxYiP+BGFtuHb6WqVHm2/eOBSYk+xxFrgNHAlOcc+9EXivFlwKHA1snrPqwc64q8nwyYPjPMAnpm5hZYg3ley7SwTB8xv4DHGZm2eHzdyj+8/d1ZLnF+AC/t5kVhuTVwDZmtkUDp6HZQgnzBGBCeBr9PL8MFOFL0A2Jfa6PCec+mRPwpdE7El9IeD9q8DUBUX/Gn+cjEtLfdM5Nb8PjqGdDD7rRwIpzrjj8+23CcrH0xi4km+N/FU1vZJlUzE2WaGZ7mdmrZrYW/+Fbji+NgH+To75hfatIfjx/x1cJ7uecW9iCfDa1j/74qsuvkiyXLK3Fwvt3K/B9S2iDT+Ia/IX7QHzJbrBz7u9tkY8U9MOfm9lJXpuZ8HwL/IVgNv49jz62wp9niFeXf97IfjdtYB/gS3Z9Q3Urzrkv8R0TT40scyrwfvjhhHOuAl/lewSw1Hx/hv81s4GN5KGl7gJmAS+abzd+wJL0HWjAgugP4CDxe7ApvrYocbnmfEb74atbGzvnsX00dM5jr0clfrcau05lsf73P9nnaha+ertfeHRtJE9ZQKy/wJX4Go1Z5tvSbzSz7ZKs1xz9wrZ+yfqf5QfDMv2TrwrAP4FJ+FLmUvP9Gk5KCMCbAzOd79zVkE2BRc65NQnpDb0fidfk1h5HPRt6m25NC9OtgfT2tF5P5dAB4DV8e9HF+C9fJf4X20Ws/2OpJcfzNPBT4AJ8201zdZRzFmvbvQpf/dSQz5xzrzbyeqw3YpcGXu9KZnrIZuGruI4g+TkvTZLWVh4GbjOzofhS7+7A+dEFnHO3mh8zfSy+9H4tcLmZHeic+7iRbTuSf1ayE7a/zMx2CNs+Ijx+ZmYPO+dOayL/HeUz2lId5jrlnHsrXH+OwZeOfw5cZGZnO+fua+HmYtepR/DV8sl82kA6zrly8x3CDgCOwtdSnAy8bmaHOucaOj+tlXhNbtVxJNrQg25bmoM/+Vvj2z3a09H4i973o9WRZtaszh5NuAP/q/4aMyt2zrXVmNxl+CA1MslrydJS4pwrNrNb8e1KDX0BmiNWrT6qgddHRZZpreX4L3KyKrvE/c/BX1TnOudmNbLNWFX/GKChHxeNHeNoYIVzbm0k7R/AzcAP8T9GqvCljXpCM8OfgT+HashPgF8DP24kv6tI3tyRWMqIVQtPwHd6zMKXfs8ys2tjpe5WmA9sbWaWUNptzmd0Ob5JZUwz9tHQOY+93paSfa62xPddiXViLGskT7VEStWhKeJB/FC77vhx0VfjS5wNSaw5IOx7Db59vrEfwA1v1DflvRYeF5vZFcB1+ED8Kv57sJuZ5SZU0UfNBw42sx4Jpd3mvh+tPo6oDb16uS09g/9wXpnYvhDtdm4tGDLUiNgvuHrbxXcaaDXn3LXATcD1lmQYUorbrMF/CY4NPTwBMLORrN9m0lq34qvcr0x1A6E96xPgx2bWM/qamY3Fl/LapC06nJuX8edmWGQ/W+FLdFFP49//q5IMZzAz6xOefoSvBrswSf4t7Dd2jKdFlzGzMfhSzAsJ+VyBP+Yf46uWXwppsfW6mllBQn7n4C9I+TRuDjDazPpFtrc9kZ6oIa1P9Hm46MZKEU3tozlexvdYresBH47pF02tGPLyDH7o2nrTqUberxeAXS0MswqvdcNXT86j7Zuo9rDIUDrzQ8uOAf7rnKsJn7//4ttGh0eWG4Dvuf2Oc64kpCWe/1L8j/Smzv1afBVsdN0afG/3E8Jnrp7oZyEZM+udJDlW4Inl5yl8x8bzExdMeD+ykyxzEf7HQqPf89YeR6JOV9I1s/Pxb27swn50qA4DuCPSbtumnHNfmdl1wO+At83saXxvwF3wnXViVbUtHTKUzH/x1ckTzOxv+J6Pv8CXJgelegxRzrlLQyD/i5mtcW0zycPV+Iv5JDP7K/EP+uf4Dj1tIpR2b6N5HaoaczH+IvyJ+XGFi/Dtpr/E93pt0SQaTbgKXz32tpndhf/u/Qo/LKyuzcw5N8fMfhv2PdzMnsEHtRH4z9Y9+F7pteEH04SQ/wdDnkfjh/TEgvml+IvKe2Z2P/EhQ8X49yvRw/jhZ+A/61FbAq+Z2RP4wFEd8jQAX0puzAOE8x3y0R84Oxx/YWS5+8LF9nVgAb4k/Cv8xXYGrfc3/Gfy8fAZWoz/gRFrSkhWYou6Av8Zf9PM7gl5GgT8AD+sbjV+6M8P8e3StwPf4YcMjQBOSOyI2QY+x5/X6JAhqP/9+C2+j8M74fNXjR8ylA9ExzJPNz+eemrI9874znV3NpGHqfjS5MX479Fc54ch/R++VDrZzO7Ff2564zseHRz+b8iVoXr5eXxptH84tgVArJPaw/jmspvNbFf8sKhuYdt34TuUTQDeAK4LPzqm4d/DY/AdEZsaLkkrj6O+5nZz7igP4kNkkj2GJyy3Xtf3sNydCWnDWX/oydUkGTKCD6Yf4b+k3+GHKhwcef10WjZkKGn3fHwV8zR8teRc/Bcj1j2+Occ5kcgQjUi+do6kZeF7L1YRhlAkO+5k5yyy7/EJaQeG81OB/4V8Jr5UXd6M8zGeRoYMJaT3xF/gEt+3/WlkKFCS7eyG/1J+F87DAnxP1iGNrNPiIUNhvX2BD8O5mYO/6DX0OTsefwEpDY8Z+AvflgnL7YX/kVYSlptGZBhXWOYg/EWqDB9snwW2aiCPeeFcrMbPjhZ9rU/Iw4ywr9X4sYs/aObxnxqOuwI/lOlQ1h8ydAL+h9DSsNx84G5gYJL3eP9mfE7qbT+kjcCP7SzD/5C9KZxvB+zWjOMYhm/aiDWpzAnnJS+yzGb4ISmr8N/hycBRCdtJ+lklyXc1+t0kMoSG+LjRU/Gdp9bhv3/7J8l3bBz9GnzJ9HUiQyrDMr8JeV0Vzs8M/A+N3MR8JKw3Cj8qoizkaXzktf4hj9/gCxOL8bViv2jiPB+Ir1lYGD4LC/HXqy0SluuCn5ry68j2/0X94VHd8U0nC8Mys/DzNljCtpJe61pzHImP2DgmSWBm1wKXO+c6XW1ARxNKa9s451o9DEHal/nhKIvww2LOzHR+0sXMLsSPrx3qWtarP6PMzwb1F+fcetWr0jGpTbdhg/Djv6QFzKxLwvMt8L2uJ2YkQ9JSx+KHSDyc6Yy0lySf0QJ8rcPszhRwpXNSKS6B+du/HYdvo3kuw9npjL4O7aNf49vjzsFXxfwpk5mSxpnZbvi25d8BHzvn3sxwltrT02b2Db6duAjfcWw09ccoi7QLBd317YvvgDAR3/FDWuYlfCeSgfh2mPfwUxcmG8AvHcc5+ODzCb5NcUP2Mn786an4zn7TgVOcc+sNjxJpa2rTFRERSRO16YqIiKSJgq6IiEiabHRtumGWksH4sWoiIrJx64G/IUJa2lo3uqCLD7gLMp0JERHpMIbiJ85odxtj0F0D8O2331JYWNjUsiIisoEqKSlhk002gTTWfG6MQReAwsJCBV0REUkrdaQSERFJEwVdERGRNFHQFRERSRMFXRERkTRR0BUREUkTBV0REZE0UdAVERFJEwVdERGRNFHQFRERSRMFXRERkTRR0BUREUkTBV0REZE0UdAVERFJEwVdERGRNFHQFRERSRMFXRERkTRR0BUREUkTBV0REZE0UdAVERFJEwVdERGRNFHQFRERSRMFXRERkTRR0BUREUkTBV0REZE0yWjQNbN9zWyCmS0yM2dmxzZjnf3N7CMzqzCzr8zs9DRkVUREpNUyXdLtBkwDzmvOwmY2AngeeAPYAbgVuM/MDmu3HIqIiLSRnEzu3Dn3IvAigJk1Z5WzgbnOuV+H5zPMbG/gIuDldsmkiIhIG8l0Sbel9gBeTUh7OaSLiIh0aBkt6aZgILA0IW0pUGhmXZxz5YkrmFk+kB9J6tGO+RMREWlQZyvppuJyoDjyWJDZ7IiIyMaqswXdJcCAhLQBQEmyUm5wPVAUeQxtv+yJiIg0rLNVL78HHJmQdkhIT8o5VwFUxJ43s8OWiIhIm8v0ON3uZraDme0QkkaE58PC69eb2cORVe4GNjOzP5nZaDM7FzgJuCXNWRcREWmxTFcv7wx8HB4AN4f/rwnPBwHDYgs75+YCR+FLt9OAXwM/d85puJCIiHR45pzLdB7SyswKgeLi4mIKCwsznR0REcmQkpISioqKAIqccyXp2GemS7oiIiIbDQVdERGRNFHQFRERSRMFXRERkTRR0BUREUkTBV0REZE0UdAVERFJEwVdERGRNFHQFRERSRMFXRERkTRR0BUREUkTBV0REZE0UdAVERFJEwVdERGRNFHQFRERSRMFXRERkTRR0BUREUkTBV0REZE0UdAVERFJEwVdERGRNFHQFRERSRMFXRERkTRR0BUREUkTBV0REZE0UdAVERFJEwVdERGRNFHQFRERSRMFXRERkTRR0BUREUkTBV0REZE0UdAVERFJEwVdERGRNFHQFRERSRMFXRERkTRR0BUREUkTBV0REZE0UdAVERFJEwVdERGRNFHQFRERSRMFXRERkTRR0BUREUkTBV0REZE0UdAVERFJEwVdERGRNFHQFRERSRMFXRERkTRR0BUREUkTBV0REZE0UdAVERFJEwVdERGRNFHQFRERSRMFXRERkTTJeNA1s/PMbJ6ZrTOzyWa2axPLX2hmM82s3My+NbNbzKwgXfkVERFJVUaDrpmdDNwMjAN2AqYBL5tZ/waW/xFwQ1h+K+BM4GTgD2nJsIiISCtkuqR7MXCvc+5B59x04GygDDijgeX3BCY55x5zzs1zzv0XeBxotHQsIiLSEWQs6JpZHjAWeDWW5pyrDc/3aGC1d4GxsSpoM9sMOBJ4oX1zKyIi0no5Gdx3XyAbWJqQvhQYnWwF59xjZtYXeMfMDJ//u51zDVYvm1k+kB9J6tGqXIuIiKQo09XLLWJm+wNXAOfi24CPB44ys981strlQHHksaCdsykiIpJUJku6K4AaYEBC+gBgSQPrXAv83Tl3X3j+mZl1A+4xs+tC9XSi6/GdtWJ6oMArIiIZkLGSrnOuEpgKHBRLM7Os8Py9BlbrCiQG1prY6g3sp8I5VxJ7AGtalXEREZEUZbKkC74E+pCZfQhMAS4EugEPApjZw8BC59zlYfkJwMVm9jEwGRiJL/1OcM7VJG5cRESkI8lo0HXO/dPM+gHXAAOBT4DDnXOxzlXDqF+y/T3gwt8hwHJ8IP5N2jItIiKSInPOZToPaWVmhUBxcXExhYWFmc6OiIhkSElJCUVFRQBFofmx3XWq3ssiIiKdmYKuiIhImijoioiIpImCroiISJoo6IqIiKSJgq6IiEiaKOiKiIikiYKuiIhImijoioiIpImCroiISJoo6IqIiKSJgq6IiEiaKOiKiIikiYKuiIhImijoioiIpImCroiISJoo6IqIiKSJgq6IiEiaKOiKiIikiYKuiIhImijoioiIpImCroiISJoo6IqIiKSJgq6IiEiaKOiKiIikiYKuiIhImijoioiIpImCroiISJoo6IqIiKSJgq6IiEiaKOiKiIikiYKuiIhImijoioiIpImCroiIpN2yFTDxXVi+ItM5SS8FXRERSav3p8KIXeGAE2DzPWDqtEznKH0UdEVEJK2uuxXWVfj/15bBDXdmNj/ppKArIiJplZ0d/9+A7I0oEm1EhyoiIh3B1ZdAYXf/f1ER/PaizOYnnXIynQEREdm47DAG5n8IX82FLTaDHt0znaP0UdAVEZG0K+wBO22X6Vykn6qXRUQEgNpauPluOO5ncMMdUF2d6RxteFTSFRERAG65By4ZB2bwn5ehohKu+nWmc7VhUUlXREQAmDjJ/3XOP157O7P52RAp6IqICAC77OhLuQBZBnvsnNn8bIhUvSwiIgBc/iuoqIA33oW9doFrLs10jjY85pzLdB7SyswKgeLi4mIKCwsznR0R2QCsKfXVsYU9Mp0TaYmSkhKKiooAipxzJenYp6qXRURa4fo7oec20HMMXPnnzOVj1Wr4YiZUVmYuD9I0BV0RkRTN/Qau+BPUho5H194GM2anNw/OwSNPwaDtYcz+sO0BG9+dezoTBV0RkRStWbt+Wklp+vZfWQlH/Ah+cr4f3gMwZx7c8UD68iAto6ArIpKiMaPg4L3jz/fZFcZum779P/EsvDyxfpoDqqrSlwdpGfVeFhFJUVYWvPAQvPiGr2I+Yn/ISeNVdW1Z+Ccb6AIYdMuFs3/a8DrfLoK7Hvb5PP80GNAvDRmVOuq9LCLSSX23CsYeBvOK8ffIA6iCHIOu3eDhm+GYg+LLl6yB0QfAspWAg+GbwBevQn5+BjLfAWyUvZfN7Dwzm2dm68xsspnt2sTyPc3sL2a22MwqzGyWmR2ZrvyKiHQUvXvBy//EX8kNqAaqoLoSSlbBCefUb3f+6HNYvAxqaqCmFubMhy/nZCbvG6uUg66Z5ZjZwWZ2lpn1CGmDzazZN2kys5OBm4FxwE7ANOBlM+vfwPJ5wCvAcOBEYBTwC2BhqschItKZuSwYMASsAKit/1pNJSyJ9GQesUn8BvJmkJ8HQwelLatCikHXzDYFPgP+A/wFiLUKXAbc1IJNXQzc65x70Dk3HTgbKAPOaGD5M4DewLHOuUnOuXnOuTedc9NSOQ4Rkc7s26WwyxmwohxcAfQZUf/1HkUwYkj8+aZD4R93wuabwqjN4N/3Qp9e6c3zxi7Vku5twIdAL6A8kv5v4KCkayQIpdaxwKuxNOdcbXi+RwOrfR94D/iLmS01s8/N7Aozy275IYiIdC5ry+Drb/3fJ16E718Ia8p8VTHAyjLI7QnkAgVQZvDNkvrbOPEoeOgOOO5EWLbW385P0ifVfnb7AHs65yotNju2Nw8YknSN9fXF97lbmpC+FBjdwDqbAQcCjwJHAiOBu/AfsXHJVjCzfCDaTUATtYlIp1G6Fq57AF6aDJ/Nhpp1UJAF6yrxV74u8WVzs6EKoBtgUAO8Ow1WFcPy72DfneGD6XDg2WBZvm135nz4w3mZOLKNU6pBNwsfMBMNBdaknp1m7XcZ8EvnXA0w1cyGAJfSQNAFLgeuasc8iYi0iweehTOvwV/5Yp2lusK6svB/LRALvi4yPjebuiv0OTdAaWjXHb0Z7LeXH+pUXePTHn1RQTedUq1e/i9wYeS5Cx2oxgEvNHMbK/A/xAYkpA8Alqy/OACLgVkh4MbMAAaG6upkrgeKIo+hzcyfiEjGrFgdAq7DB9jYw+GLS7FiTxW+J8zayGuRK3tpOXXFqy+/huKSeHV0djaM1BUxrVINur8G9jKz6UAB8BjxquXLmrMB51wlMJVIG7CZZYXn7zWw2iRgZFguZktgcdhesv1UOOdKYg/atyQuItIqzsGNj8G+5xLvjezCg8jzKMNfzRu62UFk+SP3hDO/D70LYbdt4AHVA6ZVypNjmFkOcDKwPdAd+Ah41DlX3uiK9bdxMvAQcBYwBV96PgkY7ZxbamYPAwudc5eH5TcBvgjr3AFsATwA3O6cu66Z+9TkGCKScavWwt1vQkU1/HxvGNrbp9/+JFxwB74EW4Efe9uVeCm3lvrdVxNlEW84NBg1HGbNAFcDh+4FE+6CvIbqBTcymZgcI6U2XTPbF3jXOfcovlNTLD3HzPZ1zr3VnO045/5pZv2Aa4CBwCfA4c65WOeqYURGnjnnvjWzw4BbgE/x43NvA/6YynGIiGRCdQ3sdyN8sciPl/3bW/D+pTDpA7jnSXzDm8O31cYesaBbnWSDsSro2P/gr+75MHOZH0r072tg7x38/iRzUirpmlkNMMg5tywhvQ+wzDnXYYfwqKQrIpk2cwmMvgJYjQ+SXaDLl1AeK8F2w89+0B1f4o0tBz4gFyds0BFvLIyVdLvE07Kz4IofwzVntsvhdFqdpqRL/DdXoj745nwREQmcg8XLoVchdCmAvt2ABfgAWgPMg/KKsHA3oGd4GPGr9Hfh+Tp8aTcWZGvDOhk0ms8AACAASURBVLH/oyXZUAJ2DrpspPMrdzQtCrpm9nT41wHjzawi8nI2sB3wbhvlTUSk0ytdC4f/EiZ9BN26wDN3QmVsSI/h22cL8LMJlOGrkqNtrhaWrcJfsWMNbrG/WcTbfNcSH1pUQ90VftQmcO6x7XWE0hItLenGKjUM3ws42pxfCbwP3NsG+RIR6dTWrIU/3guvvgdTPvNpZevg9N9Al374K2hOeMTG13bBl2LX4ku6sfrE2HCgWnyVcxY+qBL+j12Jc+DAsTB/KXy9CJzB8KEw6S4oavas+NKeUm3TvQq4yTnX6aqS1aYrIulwyBnw+mRftVt3mQ2TW9Qx/Bx5Dh84HfGqYsMH4Qp8gI4NC8ohfkeh7LC9srCddbDsX3D2A/D0e/Ftbd4fplwFvRV46+k0t/Zzzo3rjAFXRKStfbsKLngaznkCvgzjLmpq4LX3/bzGdQG3C76UGhUbfxuras7HB8/u+ICZhS/lxgJudlimEH/rl374GfD7UNeW+/zH8PRU6gX3Ocvgn5Pb8KAlZc2uXjazj4CDnHOrzOxjknekAsA5t1NbZE5EpCOrqIa9b4eFoeHtnx/D7N9Cn25+fOzs+WH2p8H4+fAqWX++PcO34WYTv6pGO0PF2nNjgTcy1zJV+PbgPGAtWA7kRK/qkaFEBbmtOlRpIy1p0/0PvqID4Jl2yIuISKcyezl8syr+fFU5fLQADhkFE/4K54yDL1bC4ljAy8eXSlf6AOny8CXbbsRLvbF5lSE+tzLUHxYUVQuUUteB6n//BvuNgjdnUTdV5P6j4Ye7t9VRS2ukPCNVZ6U2XRFpKyXrYMhVsLYCXDVkV8MXV8Oogf71dZVw0o0w4UPivYodsAoG9Iel84AR1J/YYi2+k1ReWKcWWE68F3PvkB4L0gX4cbwR//w/GNwfKmtgWF/fpqtJMdbXadp0zWwTMxsaeb6rmd1qZr9su6yJiHRsPfLh8NHgSoAvoeYrOOIS+GwO/HE8bHc+TJiCD5zVxGeU6gJLYz2OY+kQD8rRmxxkEe9cFW7ZF1OQD9/fZv18De0Le4+GA7eBkQMUcDuSVG948BhwAICZDcTfeH5X4Dozu7KN8iYi0qF9sQSenEa9dtr5y2Cvc+Dyu2F24t3Cq4lPYOHwpdYyfMNdLfGJLwhp+fgOUd3937xog6DB4N5w6A7Uv9FqFuw6qo0OUNpcqkF3DP4GBeBvUPCZc25P4FTg9DbIl4hIh7euav20WgdrysAVsP7dgSBeeo3NNpWDD7Rl+KAbrWquDct2hx23g0/ugMLQKznL4LofwhFjoWsXyM4Hy4Xj94KcDjsRr6Q6DWQu8U5VBwPPhv+/BAa1NlMiIp3BTkPhqK3h+WLg25CYT3ze4+gkFjGW5P+Gqn8jV+hj94THJ0NJtU8f1MeXcnt3hw//DI+/BQN6wi8ObeVBSbtKNeh+AZxtZs8DhwC/C+mDgZVtkTERkY4uKwv+83N4+wCYuwzOeAH4PLzo8AE41o6bpFRct1ws6BYSH5fbg7oOUznZMKgIzo7N95cFC1fBvybDWQfBVpvANae2wwFKm0s16F4G/Bu4FHjIOTctpH+feLWziMgGr6YG3pkG0+ZDzqLInfcs8oiNw60mfts+iAfc2ExUsXmX1+GDbujJfNyO8Mu7IjsNt/rrppsYdDopBV3n3EQz6wsUOucio9S4B98yISKyUbjsEbjthRBHo71kHD7AZuOvtNn4YBp9PfY3drOCAnxgXoOfe7kLnLcL3P1Kwk5r4cid4CSNve10Ui3p4pyrCTet3zskzXTOzWubbImIdA7PfgQuNuVirBRbRXyoT6zjVGJxJLEdtyqyTE5Ypxq2GApd82FN5PYyd/4Mzj1MQ4E6o1TH6XYzsweAxcBb4bHIzO43s66Nry0ismEoKYcFVfjq3lx8STV2A/lu+DbdWE/ihqZhzAmvdSU+7/IQfAAvh7tehfHnxadxPG5XOOsQBdzOKtWS7s3AfsDRwKSQtjdwO/Bn4JzWZ01EpON66Vt4drqf9aleqbUr8aE/sakba/CBtRA/7iN2B6G1IT12n9wafLAOJd7sLN+B6vjdYdVOfuarPj3ScHDSblINuicAJzrnJkbSXjCzcuAJFHRFZAN293Q4ZxJYuL1eVq0fnwv4gFqAv1VfFfFOVLX4IDwIH1wrwmuxq3Cs3rEAsit8B63h/eDuM0Nynn9I55Zq0O0KJM61ArCM+neLFBHZ4Nw/0/91OWBbQe9lsGJZZAHDzzaVB6wIf9cAq4CB+CBcgb8CR6d8zIZ3z4fthvvZp3JT7nUjHVWqM1K9B4wzs4JYgpl1Aa4Kr4mIbLCGdYfsWJVyV6jqnWQhw5d0S4F5wHf4mQwK8ffBHYIvEceqoYHcLNhjNHQrUMDdUKX6tl4AvAwsMLPYGN3t8S0Zh7VFxkREOqLqGij+CGqygEJ/s4Pib/BX09jsU7Gb0lfjb+VXE2Jr9H64OcBIYA7+9n01cNHhaTwQyYhUx+l+bmZb4OdaHh2SHwcedc6VN7ymiEjn9vxUeO2DhMSBxOdShviN5/Fp2QNgswKYHbsFX6z9Nza3cj707wE3nNSeOZeOoDXjdMuAe5tcUERkAzBlDtz/FixbleTFWMAN01FZTf37HNRkQ8/e+Krm5UARPuCujK+/vBSc01CgDV3KQdfMRgG/ArYKSTOAO51zX7ZFxkREUvHFalhUDnv2g25t1C46czHsdT1UDyc+rjZWku2KD7qxquU8cJX11x/3PXhsCn4cbiW+G2r0PrrAmMF+LmfZsKU6OcYJ+Gm9xwLTwmMn4LPwmohI2t0+E8a8AIe+Adu9ACsrml6nOV6fAdW7AaPwbbI7AwPwQTQfH4AdPhgPot5EGD27wM/3hNnFkQ2GOxAV5kFePgzpB39W1fJGIdXfVX8CrnfO7eGcuzg89gT+EF4TEUm7K6bF/59bCo/Mbfk21lRAba3/f+kaeONrePlboD/wFb6KeB2+J3K44cCZu8NbF+FLvVn4wNsTKIReBdC3O+Ql3OP20oPhhD2hKheWVsAR98Pkb1qeX+lcUq18GQQ8nCT9Efydh0RE0i4vy0/yVPe8BTdzX1MBRz8Gb86H/t3gDwfAryZAeawaeUt8O2xULr7X8cGwxQDo0wVWxmah6glUwdx58PlCuPx4uPor3267exX8/nvQ7+pw579aPwTpuemw27AUD146hVRLuhOBfZKk7w28nXJuRGSDUVwOpz0E2/4eLn8GahIDVju4a2fICR2Rdu0DPx3R/HX/NMkHXIBlpXDes5GAC36sbbSYUuvH1W45BhZXwhOfwsrYBBhr8ZNhLPCLZuXBLZWQNRiyhsDMkVDmYMt+8fG+NQ5G9Wv5MUvnYs65ppdKXMnsbOAa/JSP74fk3YEf4CfIWBRb1jn3bOuz2XbMrBAoLi4uprCwMNPZEdlg/XQ8PPaBDyYG3HQCXHxQ++93xTpYUQFb9PBzFzfHvxbASf/AB8pY7+FKfEk2dhP6vPjzogIojnWEqsJXN8fmTiY8X+df32Vz2PEguKek/j7f2wX6VcHPnoCvV8JPxsJ1h6szVTqVlJRQVFQEUOScK2lq+baQavVy7HbK54ZHstfAfyRbUMEjIhuKqd/4gAs+kExbkJ799i3wj5Y4/QP87FBrQoLDzxwVu3rFpmyshQP6wrFj4YKXw2tZxOdajimALXtBdRf4wMEHCaX8HtmwZVfonQtvJV5BZYOW6uQY+i0mIo06cgzMWAJZ5quWD92q6XXa29MfwbgJkJ8DZ+8Ps76DTXpC2Vf4mePBl2gLqF9ciD03mJENs6KTY4Q5kykAKuDS3eGHu8CLZfCbV4jfRag7dTexf2kHH3Bl45NS9XJnpuplkfSoroFbX4dPF/qA++PdMpufOctg1O9Cz2QLE1F0BxcJhnW64ku9ZfhgOpT12nPrlWwN2BSohhf2huuqYFJpeG0N/nZ93/p1eveHFbofbofQmaqXMbNdgAPwHenrlXydcxe3Ml8i0snlZMMlh2Q6F76U/ZuX4NHJkc5cDigA1wt/9UoMusXEu0GvBZbgA29MFvEJMmJjdBcAg+DeFTApWkruAUylLkh/twL+sxkcO7KNDlA6lZSCrpldAfwemImfWyVaXN64is4i0qHdMQn+9Aa4UMI1B+SD60P8frYlJMzbWH8b2dUJSQ4sD6gFF5tPuQyohJwt6904CID+Fq+9zjF4fYGC7saqNXcZOsM5N74N8yIinYRzsLoCivJ9m21H9tli35GrBqA7FGXBqOEwZQW4GvwEF4PxPZTX4Ntnc/Cl3xA9a/pCfg+oXA2uArr0goM3hdzV8PTKyM7KoGi475W8vByyS+EPg+HdxTDha9+xrNrBThoatNFKtUNULTCpLTMiIp3D0jLY4Qno/QCM+DvMTHYDgA7k8FG+WjkrG+gCo0bC/tuCK8C3tcbk4O9zOxDfPnswZG+OnxRjKFTmwMgRULAZlBfBhNWwySbx1bMMttoP7quEZQ6yCuCIoXDpQLj/EDh1NOw8AK7fCw4dDVWqE9wopTpO93+Bwc65C9s+S+1LHalEWudXb8NfP/eltmyDw4fBc0dlOleN++ObcPkb8Wpf1xXfm7ia+rNMhV7IuVnwyY/gwknwysL4yzl5/m91uGyO6gH/NxD+OQ026w3r9oaH19XdbIjh2TB3SHz9JbVwYAnMqIEBBq8Uwra6WX3GdKaOVDcBz5vZHGA68fttAOCcO761GRORjunjFfHxtzUOvkvshNQBZYcrXW0WvmdyHv7GBQ7fnhsThvE4B6tzoTSPeg201X3BVvj/s4CeeXD6Lv4BMKEMHljnY3ct8L0u1HN9OcwKjcMrHFy0Fl4tasMDlQ4v1aB7O77n8hv4O0KqokQkBVOp4lr82JIr6c5OdKzBmwvK4akl0DcPThkMU1fApITq5Au2y0zeWmL7geCygCHUTXJBLKAaUFp/+VqDvcrAtsdXOdfgJ8sogpxJULUa+uTDX3apv97RXeHffWFCOWyVCxf0qP96SeRKWQsU68q50Um1enkNcIpz7vm2z1L7UvWydBQzqWRHllGOAdUYFWRRy64U8G+GMiD1EX1tYtE62O5t+K7K/6o+dTDs0wXOfh9fSiwHaqDsp9ClE1SR/vS/8PfFkYTYsJ5lxLsmG35cbhf8pLZAbi1UhQCdjZ/v9hnzpdycFvaK+bAa9i32p86AJ7rDifmpHY+0Xiaql1PtSPUdMKctMyKyMfmYdezIXMopxY81qcBRSw0whXX8X90Ak8x5bhmsrIpXYz26CLbtA+wFHAQcBZuO6RwBF+ClWC/jLHxQzcE3vkbHAjmgHGxsPGlgNpyW45uAtzcYn++nmWxpwAXYOQem94THu8OnRQq4G6NUg+7VwDgz69qGeRHZaPyelZSTj7/6ZxFtoakBvq3ripM5gyIBwYBeOTC9G9Arnrhsc6jt4FWk7y6HoY/A8lJ8nW534lM9JlzBsg2eOQk229w/zwVuKYDxebCmC0wtgJGtnAR3eDackg9jOsmPFWlbqb7t/wNsDiw1s3ms35Fqp1bmS2SDdTcr+Q+V+MpKh48E9adTOIPM9675Xn+4YDjcNR965cLlm8Lb8/FBK3QQSsPd+lrlg5Ww1/P4K1Rs/G00aOb753nAVn3g2r3h6E3hMAef18JQg4GaaV7aUKpB95k2zYXIRuIDyjiHhfjI5YDVxAaY9KGQPenCxfRi/3oDSDPDDG7dGm7ZCm6bCReFSf4tG9zuQHe4oWtqk2PU4BjPYuZTznH0Z0d6NL1SCv5nSvinB3U/FNabP6/KV5F/cno8ucBg5xbcH+09qhgXWufH0ZVdM9weLx1XqncZGtfWGRHZGMymEl8+rCB+M1Zfyu1HBc+yRSazl5QZ3DQj/jyrFk5aAVcPhS1TvHHnr5jJX1lINnAD85nCLuzQxoF3ZQVMWYUvxtYSr8V3PqiuK4OspVBbA6e0ogf2cmo5mDV10zdPooT59KRXyq13siFr1afCzMaa2Y/DY8e2ypTIhmo/ulFAFr7zVFnkFcfX682633H0yY9fLJyD7bukHnABHmUJ4Nuva3H8m+Wtyt83pfDZd/Xbl/80PaH6uxT/W6cK1lXCCX38MKLc3rDXlqnveyY1lOH3VYufSfKrDl/xLpmSUtA1s/5m9jrwAX7M7u3AVDN7zcw0q6hsNOaUw+uroLSm6WUBhpDLFLZgBPnEu0JUACVUUsyvmYXrgMPe79vNB16A/QbA+a0IUgDD6VI3YqcGGEEL7zofcdOnsOnjsN1TcNgLUBXi3eoq3zGqTsJpfWoZuCKo6gpnvgsllantfxuy6YmRjW827o0xilb8IpENWqol3TvwrSTbOOd6O+d6A2Pww8dvb6vMiXRkDy2BLabAQZ/C1h/AkmZctB9gET9kGsOo4A6GkUM50Ruz3sy3vMPqhjeQIbv0gSXHQ8kP4PWDoFsrmywfZxvG0J1Cl8PYqqG8smYQL6UQ9Mqq4bIp8eevLoLn5vv/zxoJedErXGIcjBRGqxyUVJGSXmTxJj04mTx+SB5v0YNCOvhdICRjUp0coxg42Dn3QUL6rsB/nXM92yh/bU6TY0hbGfIeLAqBIgsYNxx+u2nDy79PMXvwYd3yg8jnXkZxJO/jS71ZQDf2oTdvMbbhDW1AjiqBl0Kwc8CkQtijBZNylVZB0fj61chPHgwnjPCdtd4rrWTm8lx2LMpiVjlc/El4r7aC6z+BOWv8OkcOgecO0o3lNzadae7lLBKGCQWxK4fIBi/H6g/0yWnigj2R7+r+rwUWUsFDfA3Eing1QBmfbEQ9X1+tigfMLOCNqpYF3e65cOVOcPVH/vke/eGoTWAV1RzAdKZ1L6Nb9ywmMIpTKOKUyM0Hjh8CT873pfYfDFfAlfRI9dv9OnCbmf3QObcIwMyGALcAr7VV5kQ6sttGwknTfdXk6K5w1qDGl19OOfUbFmv5J4sSlqphRzaeGpjts+GjmliHKtghhSvSVWPhB5vB6kro1a+cU7Pm8hHlzA/lgjJqOZe5zGCHeuv1yodftLJtWqSlUi2Vno9vv51nZnPC3YbmhrRftXRjZnaemc0zs3VmNjlUUzdnvVPMzJmZxg1L2h3bFxbtAZ+OhU/G+gkkGjOabvg7o1fjK4Vid0mPG0kR/2BM+2S4A3q6BxydCztmw53d4Mi81LazdS/YfYDjyKzP+A8rmce6ug5pDh94RTqCVMfpfmtmOwEHA6ND8gzn3Kst3ZaZnQzcDJwNTAYuBF42s1HOuQYnoDWz4fhbDL7d0n2KtJW+uf7RHKexCS+zjKeIzrofu81NNdlU8xO6Mp+VDGJwO+S24xmaDf9uoGBfXQsvz/V/Dx8B+U1crVZTzTwqwrPYwFx/dq9iaFtlWaRVWtSRyswOBO4Edk9sdDazIuBd4GznXLMDoZlNBj5wzp0fnmcB3wJ3OOduaGCdbOAt4AFgH6Cnc+7YZu5PHakko5ZRwWG8yycU44cL1YS/vs3XgAkcx1FslrlMZphzcPTT8PzX/vmeg2HiKZDbyEgch2MbpjKLchyQhXEbm7MvRYxJnGRZhM5xl6ELgXuTZc45Vwz8Dbi4uRszszxgLFBXQnbO1YbnezSy6pXAMufc/c3YR76ZFcYe0E7zzUmn5XBcSzWDWcfOVPB5O1dF9iefl9iVoSwEvgbmAyvqLfMwX7RrHjq6L7+LB1xy4d1h0G8unLIE1tbCXNZxKV/zG+axNHREM4xX2JYzGMjx9OUVxnAugxRwpUNpafXy9sBljbz+X+CSFmyvL3703NKE9KXEq63rMbO9gTMhoVdEwy4HrmpBnmQj8xy1XBnmP16G41iq+Ir2vefabXzAYtZEUqrxs+87DGMw3dt1/x1VaSXc8QnMj/6sHwsMhGLgX6XQP6eG8X2nUhp+HP2DZcxgZ/LIYgj53NMBp9IUiWlp0B1A8qFCMdVAu81IZWY9gL8Dv3DOrWhq+eB6fJtxTA9gQVvnTTqv2bi6oT81wFwcLgS/9lJJDbXrzTxVBeSwH0O5stGKns7J4XgImIrjQIzjkpzfo5+Ftxb44Tu5+VDVH39Fsdg24JWqctZEaiO+poLZlLNNB7hJhEhTWlq9vBAa7Vq5HdTrJdKUFfjr3ICE9AEQJmetb3NgODDBzKrNrBr4KfD98HzzxBWccxXOuZLYA+oVL0Q4gixy8VUuBhxLVrsGXIDzGEtevSmS/FfxI37C65xEryamRVxKMX/hNR7lPapp5hyUGXYj8DMcdwPH43g04UfHmkqYuMB3gapxfjrHopHEhzE7H3T7d19B3Z0LwmNIO9dMiLSVlpZ0XwCuNbOXnHP1Zmc3sy7AOOC55m7MOVdpZlOBgwi3CwwdqQ7Cd9hK9CWwbULa7/Gl1wvwHbBEWmQrsnifPB6jhkEY56Vh3tyllLIT/XiP5fjb4BiGkd2M38ErWMMOXM1SinHAM3zMvzi3vbPcak/hgHKqw+/ey8jjVHrXvd4tF/p3gRXr4jcu6Jfj50R2DsiG03vCPj2yeIsSCCXb/Sig50Y0oYh0bi39pP4eOB6YZWZ3AjND+mjgPHxh4boWbvNm4CEz+xCYgu+s1Q14EMDMHgYWOucuD4H+8+jKZrYawDlXL12kJXYkix3bYDK1b2vgnjLIBc7tBn2TbHIhJezPg1RSCfQmWuH0GDPZjr6N7uMFPmUJxXXPn+RDVrOWt/iCb1jB99iZ4fRv9bGscv4WtAVtVOjfklqmRCqaFlLJX1nLOSF4Zhk8fyz84lVYWQ6XjIW9toDvzYYlVXBoF7h8UBX7sA5/3ko4gF48w05tk0GRNGhR0HXOLTWzPYG/4ttKY19HB7wMnOecS+wU1dQ2/xnuTHQNMBD4BDg8sp1hoJHt0vEV18JuK2BZrf9CPF4On/aD3ISgNYEvqaCUeCtyFmA4HMWsbXI/M+u14DjAuI4nuYn/AHA5j/Axf2YkTUyR1YC3neNwHGVAtoPHMX4QmSNxMuuYSDlbkcNclrCOWn7KMAbV3SU+uUtxPJKQNjWhi8jOA+DjU+svs2h7WOegSxZczWpWUosPur35kmwKacG8kSIZ1uI6GefcfOBIM+sFjMQH3tnOuVWpZsI5dyfJq5Nxzu3fxLqnp7pfkbb0QRUsjvw8/LIGZlXDNgkx4ct6nfVXAP3xX8UyxlLU5H4GUUg8WEMWNdzNy3Wvl1HBP3iH3/KDFh+Dw/G9EHAxqDH4SQ2ckOVLoi+ylqNCdwtfA7yALL7jdr5iOodSlCQAVuKYRiUDMIqwUCnunUlXvqaaH7OS2VRzCl25lZ5kh9/zDketQZcQ9AvIqlvbgC5JaieKqeAXvMm7LOUABnM3+9JNgVk6iJTr05xzq5xzHzjnprQm4IpsKEZkx79Qhh8ANChJ8/Doug7+hu+xvABjPn0o42i2bnI/R7E93agGVgLfcRzbMoAiskKgqsUxgNRu9FWBo8TKwcqI9VyqNMKAKvg7pZGbPDigD7XAItbxDusPKFhNLTuyhF1ZxgiWcik59KWYXFZzJAvZmWx+wkqmUMkKarmTUu4Lpf33WMNgPiKPyfyYr6jGcQ492S50mirA+Mt6fTDhMibzNHNZyFoe4yvGhTs7NWUd1TzFLJ5hNlWdpHOadD7qfSDSRjbPgYd7wm/W+Dbd2wuhd5KftWeyC28xlyeYxlD6cho7UkgBJ7MDA5oxd8vljGctqyAEhqd5hRs5i5t4jqWs5mT24nQOaHH+HY4zWExdB3/XFRjC6c7q7ks7uK4MGhO/Ce6wJJNQPMhaZoSQXQuMo4RaJlKD40XgdnKYTd+6EJcDzAnLn8pXLKOKWuBRVnAghZxBfz5kOPOooj/Z9EjS6W0Gq6gJPwtqccxsxv2Jq6jhQJ7gvVB1fzDDeJkT637IiLQVBV2RNnRqF/9oTC7ZPMYP+V92ZxHfsTdbUdiCWZOeZTJESmIO+AevsYj7qaGWnBR7X8+iksejI+qsjN+5CsZlxfP2W3oxjUomUs6mZFHKSirJ5Rq2ZttmVI3X4OrGJ2dhzGQ1P2QYt1NKDj4wHxvahpeGgOuXhSWh/TcbY3MavjPC8YzgLRaTg1GN41hGNJmvKSypC7gAr/IN01nJmCY6tYm0lIKuSAbcwgQu9h30GUZfPuRG+jUjaAGMZigfM5t4/0LHh0znHT5lH7ZPOU95SUp1x1i8rPcW5bxEGafRg/8yKIxlbnx+6NPpxj2U8iXVZAFHUMrzUBcQj2E4h9OTMeTyFdUcQxf2DNXH5zGAG1mMAd3I4mT6NOs4/odt6UsBk1nGvgziRNYbvr+enknG+RY1EthFUtWiGx5sCHTDA+kIevAjSvFD3Q3jJk7jYr7frHW/Zgkn83s+ZFZI8TdMuJfL+DlHtypf/8cy/hhuvPBLiribgRjGq5RxKIvJCnu7il5cHRlj25gKHB9TyWCyGUo29zKDL/iOIxjGEQxrcD2H41lW8Q2VHE1PhjcxYUhrXcUkruV9DOOP7MMl7NKu+5PMy8QND1TSFcmAfHLrgi44ClrQu/Zu/s6HvAKA0Y0sepBHPge2wXjVG+jP+fSiBtg0kqcnWUs28Q5Vj7KmWUH3O0o4hzv4iK84it24kZ9zVjM6i4H/MXJMZB/V1PI/TOFfzGcUhTzKPmzahnNUj2MvLmUXDFNvZ2k3rZ8NQERa7G+cTW74zbsro/mMTdmcSZzEZ6xqZHrzz5nFjdxX99yxlp9yCO9xN5sxpE3yNpTcegEXYDNy6lqRs4GRzQxK53MXT/EOX7GI23mGW3g6pTw5HL/gdf7KDFZQwSSWMZp/8Pl690pp2FxK+DXvchnvs6iB8dDdyVPAlXalkq5IBpzAHrxKN07lD0whm8ksAoz5rCMP45EGpjhfy/+zd99xVtTX/8efc+/dxu7SexUURRQBQQUVeywRSzTYE9M0RqOJJUaTX6xJ1BhLYozRRL+xJbZorLFgQyzYOF0uGwAAIABJREFUUJBeBKRJZ3u7d35/zN3LLkWxAJZ58biP3Tvzmc987ixzz5zP55z3qV5n2zmOseMmrqzzM61NVudRVXaQ75as4tX7VnvfakN10jK7Lnp7DdfV0D5gbnGVdCpae04ITDL3E523WoO/meRKT1uicfavM0rVSPihh4xz2sf2s1qtYR6yPDu7cL9ZpjhOwWaQ/IyJaUpsdGNithCnudZCy4SGaVSWSmPiR6hSDbWj/QzznNfASPvYYTOUsisQ+OdaObH3muoET8gIdVFsnBMsqG/pe9nhJ1Fcfj6tj5YMEtIyvvkJ1klDocP9z2gz0HS57UMUYpX5G5lP+7ZlljR5YHlfuWlW2Wkjg7NiYj4vYqMbE7OFWGi5jAxmYoCEUEbgyI+ojpmU9KR/eNarEhL2N3yTV0TaEL8yNpf+s0SVv5lgq/Seuf1plGVK/NGZppnpQEN824iN7n+5GqMtRGusrb+zAFTIN89qPT8m8ntrLaUEufzdQkk9vqY1i2O2LLHRjYnZQpxqpKvdK2GKlIQjfMe+uvjxx6zN5slzsL02yxgzQi+rk8Qw+c3EIlLZd2HufWCvvKhmUqP/OSTJucHI3DErVbnXOwqkHG9wLoBshRqPmquNAiP1khBoKV/SdtIKUZN9JbM/a0Glen82zh8d+JGfo6dS9zvQr70uT8I1hmsTlwOM2QLERjcmZgtxlVPtqp85Fjvc7rbVY0sPqRlpGX0t8342H3h3KWO1y3nW19vXtzysVlofrfzUYB2SvNiSv9fSLuCCJkIhlWrt6nozs3KRt3vD80430XIjPKwsG0D2TT1srURnpUJtRFPLjVPDDWiDFRrVsDbW0z9S740SytgQodAN7nS3R21rK9e6UIeNTJvaUH9bapYiZssR5+nGxGxhQqEb/dfT3jTINn7tRAVfAGGG61U4W0WzbW9qa4h80yy0RJmtdbVMrX7ayt9AUNJU1W6zxIc+cMda0cuvO8d+HlexVkH7hHoZ7UTlsyeQLbIQkZIUSFulu5Ze8X09NlJY5LPwgCeN8jNE0/wHGO5Jt37ifmb4wBEuMM08B9nN/S5X/DEVmmI2DXGebkzM15C/e9yZbgCPeU2ZKtc7Y4uNZ2Emeq1Isj5H7FpPONe/wAA9vOyiDRrc+WrtaqIqGZmmEpOi6empVqtQR7M0nVBGnjWKW2un8ASusY/h2tlRRy02U4rPm96TktQgLS3tDRM/VT+nudp0H8gIPWWcq/3LJX74OY825otKnKcbE7OFGWOCZPZWDIWeM36LjeXeWnquZJfV/GtVsaLMGqu7vaSBki50X27bRB+4z7gN9veicuUy0mSnigfIl1SiwD8dbycdye6NCLA19sFAeZaLpCZbNNkfWC3PRea51AzVm6Ai0DwLLVgrB3hfu2mQlhDkgtg+DQssk84+UAQCiyz/zOON+fIQe7oxMVuY3fRzt9EgIWEPO2yxsZxduSYI6v1M4MLaDloUVWkrcKoWCNepvJP8iGf3vk2kGxNobUcLfUe+pEDgblNEZ6zKthpGbp00pURbK40WreWmETjaEBebAZ6xVJl6N30GzemmhEInOc+/PAaOdpAH/BkcZIT7XO9+T+qjh984/VOd4zRHOtufBdl/33HQ5zL2mC8HsdGNidnCnOFI5ao96XVDbOt3W3Cqce0Ij2IJFzZLrQn8yXf8xP/JCA2ztWMN22B/uypxk96utEBrKX/TW0GTr53/eV+gjVBbkeGt1XROe6VAF4NsZb7+2rvMXi40XVJZtmIRL36OnuLbJuUMLvzHU570koOzqU6jHGKUQz7TOX7uGP30NMn7DjDUwM2QZx3zxSE2ujExW5iEhF850a+cuMnO8bi0v6nXBb+Rr8cGvNMziyv8uqIYge6JtOPyk8KQoIlze6r9HGqQ5SrsoPtHerpwmk5OW0+xeSjSXqijyNwnUN4sDYlyi4TucZi9tAHDtXWH+WSPGPE5CVykpd26HpnKCaY62Ai1ap3hPI952iA7usPfdPyInOqP4mDDHPwRDysxX11ioxsT8xXnEhmXIsqgbTBalemKpdaaJl6uxlUFj0qk8oXpYh+kKvSpPcLQIM+TBVEKUCPdtNXtM6TLTPCeG/3DG4YLbC3MGu62WtnFKk9Ji6aUG7Lb1wRL/VgvlRo8aamBWrpcv089jqZc6iY3uafZtkDgSAeAa/zFbe4WCo32otOd5wG3fy7njvn6EBvdmJivMNVCl2clJiNS3tdgoVDPtYzuRKuiXNlkPcmslmNQZnzYzuX1oWvzrbOe+2lYaJHhjlHl+yI5x3KUCgRWqPCUBdntJRICl+hjxyZT3IHAubZxrm0+81ia8rDnrIktzeiorXtca9tsbu90syQksrHLaVNM+1zPH/P1II5ejonZTIRCd7jdyU5ynWs05ArlbToyCNcylC0FOq/HePbTSlHjZHEYEOYRlsoI3WCxYi+4OTut+1l4zZuqHIP+KECNFqoFashFDNdgmecN8JvPIGjxSRisX/bTJwTynO0H9m0yBfwtI6WlpbLpUcc6arOMK+arRezpxsRsJu50ux/5voSEe/zLcsv91u832fmeVmWWeqcrdGN2eraljGfly1+P0e2syFMOcJkJloSBCXWDkC+UESYXqJFxumkO1V73z1BQvr/tsIhcbm/GkUqMttiSJg8iAXor/tTn+aRc7wJ1GrxlkkOMcK6Tm+0/wjf9z/2e9ryd7OBkx2+2scV8dYgVqWJiNhPfcYL73ZstcsBAg7yxiXJyf2el/2cFaCFwt27ayDcc9ZnAlSuZ28BxJXxzA3ZtXIb/ZKpcnXiXRFVu+zt2NVDpZxrfSR5zt5RARijhIYMVChzpFbVq5Em42VDf1+czneeT0qBBKvZFvjZsCUWqeHo5JmYzMdAgYTYuNylpsJ0/9phQaJLFplqSO3Zj+JNVud9rhd5UYW+BfIGTP+T3K/lXOSMXMWbdEr1gtwSXpwrtlFjzNTFMy2brq5+GOvVO18f5OigVosLvvWsXraxymMWOVGvUJjG45Soc5Uda6qun7RzneK961Tsm62UP+bZ1pFPVZAsqxMR83sSPdDExm4mfO8dyyz3tSUMMdY3rP7J9KPRd97or6w3/xG7O0Vdag20N+kix/I6SucKBGXRoItP4VNUagcUERlex1wakfwskjDXE/ZZI4BidJD9DMFWlGns63ztmg8AQbO9tK11hsj8arHATFpa/3PUe9pSMD5Wrd68Z7nefVvpZrUYo9IjRbnKXs2NpxphNQOzpxsRsJlJSrnCVt7zrFrcqzU7RTjfNWU5zltPMMjPX/m0L3GW89ubY2z+86zgn2sl37exCo3LT1Ovj/3TUMWu8DtXCj61ZSunSpJZCBnM/prbCGGUmq1EgpegTfGVMNM+hrrSvSz3t3ey4XvJOtsoQhMZrfARYnq0atCl53wfZ61afG0FGxkplueuZlLDY0k0+lpivJ7GnGxOzhVhosUMdb5oXJURf9o940ESzlGYnXpNqDfUgGjSdBX7ef0zyugEbEFjYRaGFeqlDwVqead+OzFguSoEtYXoTLzct9Jxqr5vtNe8ol/CidtmqP4H5BviFnh/72WrU2c9lVqqQwcumucqvnacax2Eh/qdRBiOBU2y9Udfts3CswzzgMVEQ1xrN5kB5Vhs6ejg6yZGbfCwxX09ioxsTs4X4pcu9Z6KCrOFJS1tmqakm28VudtbVkbqr1bBeSf/G6eV69c7xC4953GCD3OIm7bUXCNZbpn1wHk92ivzLJDml51DoKIs8YgEetUYlqqtM9qvid6qcp8fH1oFdaKVlTaoK1Uv7ldlNfPOu2NrPDLKj3eypg342fWDjt430tH97wMOe87iZpkkipcJWejrDOQ6yl203cwBXzNeH2OjGxGwhFlqsQZCrnBugUJGtsl/4CQn/dpbT/MMS0xUJc95uO63tYFdwnT+50U1CoQ/MlyfPvdnSe02pD6NJ1V8HLAwZLSov8MfsjPEl3vCIF0StGoO2CjT9mlitzAcq9fyYYKoe2umpvQVWCJEvtdZkeKi1GSZ43G+9qeQzRkN/Er5hL9+wF67xb/9yq1t01c2VrtZV1802jpivJ/GabkzMFuJU30FCrdYy8vS1vQc9oUMTPd88+a7zsv6Ga4UO6CRhF0Ny3uYUUyWyt3Ja2nsmrXOu+0NaZSjO8PMMtyaYl+S+JK0D5lrtcs+ylk8drGUqE0irWaf/Kg3O9JZhnnGxiRISXnSxk+3tWMO96GKXNZFrTFmo1PPeN90YT32q6/d5cLwTjPaCO9wdG9yYzULs6cbEbCGOdaTuunrd23a3i90MWafNAhUu9I6FztPSPZLu01FPp7ox12akb/qnO6SkNGjwLUc066Mu5LsZOVN5C46kWa2cJSqbJCQl0EYkIpm0k5YmWC1a8X3dIcaY4B/ys4Ibs0xxkTHuVSQt3+uWK5XnPP3c6rRcr7tib6WOcpg8EyWyaTkttf7U1zAm5stGbHRjYjYxFWq8Ya6e2tp6rao0e9jVHtlp4vXxTU+YZKW0EPsa4yp76tVsTfVoR/mvBzzlGQPs6MdOadZHLev4piuayjFjkE4G62R8ToZxqcj4djfBclHZvdFCNabhdq86xV7+6Rp/dB7YSnfvuxQtvGG5Mqs87X75Ch3sGPkKDNPZDc50oVPU40Sn2cP+G3klY2K+/MSKVDExm5DFVtvNleZZISFwm5OdbPhGHVsnrcDfm20botCbvrfe9hkZS8xRqq3rPeZ37tJCoVv9wn/Te/hntl0vjE/wfrDAzV7TUqHz7a1Ivn4escA4kZnuizyRdc5gCiZl349wv2P9ziD1TVJ9FvihVfbxJ4M95Sjvm265PB31cpcHDTQA1KpRp07pZgieionZEFtCkSr2dGNiNiH/MNZ8K0FG6EIPrWN0y31oick62VFJE084X9JWis1Rkdv2llmWqdR+LU3iWtUuc6ApxkpIedV2qnVUrc5xLrck8ZBvKbIy5PCAsmClPd2kLqt1/LQZxjvLT+3sQmPX+RyBhEBKRgG6Y7wLvK/lOmEh9ajwineMV2yxs4TqLfGsYfZxiYvd7E4lit3oSiPimrIxXzPiQKqYmE1Iqom6UoDUWrfcXK/6gz7+bj9X29p8bzXb/wc7owLVWCJPg6ImtWUbedGdpmSNZUaDwU3KztWoUx5UOjzg5ARtAl42R7V6aaG00ASLLFPll3o7wZ7ZIxfl+uijpaPsJWE4xmGa2V63UvfcVHeVPlZn16XvtdwihwqVitaHj1Aj7QK/8755JplqpJNU24AGZUzMV5TY042J2YScaoQ7vGqKxVKS/uTYZvuf81sN2RXXOpVedJUT3Zfb/239/cR0N3lNvqS/+7Zi60pI1asRCJpoO4e59wcYoot2zdrvoJOArABHoJ1ibRR61UQ/spWjbe1xE+1mG0P0tYM2Vqs13WQTsgFQodAcCaM9Z74G3zPdmuf4psIToVABTbKGM0Jlyi21XE/dP93FjYn5EhIb3ZiYTUhbxd51kSkW6aKVDmvloyaa3YJB7v1zXvEn12qlzk+d5SqXyZdU0KR9hUnqrdbKLkY4wSOutdRcMMpF9rODYoW+4xvriFkM1NVdjnOVF7VW6E8Od4rfu93jYB87e9oN8pqcr1DK//xIb/9Wn52WLlRgkIH20crtKj2f847rlVqhj7sVWK1cN0c73Y3uVqZCKLST/rrHaToxXzNioxsTs4nJk7TTBry5b7jUHC+ptlKR1gY3/D8/XD3HbcEbtConWeENx7nYMRqstpOj7ORoM11slstAa7sb6lnXm2iKsdroqreBKtRbpGotw84ila7ylmoN7nSSnbQ316KcwYUXvG2sd+xraLNju+rkETf5teskJFzlPO2y8ok7CjxvuWhdd45unlAgik1paZGeQuM86e/uki9PLx38xwOOcKT89XjvMTFfReLo5ZiYLUyNMsvNUpzpa+jsErPqs4IUBbPp/QMjgwm2VZadLs74vgcsMIommbWD/EcnR+Xej7XYIZ5SocG2WhpjpE6KNMjo7y6zrQZFUqb7Lqp1dahIFLInZnjZL22ruzv9R748JxulZANF5W/2jNPcmz12JlYY4AVFKhGVMjzSj53vRnXqjDDceG+DfezrSaNzAh8xMZuLuJ5uTMwXmDoVHnacP+vkP45U06Rm7WehUEvdDPZudYlZ9US3ZYLabajtrYcK0QpqRkLKBM8Yra+JOjfppfmt/DOvqcpOAc9S7loTwQfKzbAqF0BVod7rFuuivZHOx6k4EGearY1dHOpclznTb+zvWOn1qkDze/9FOZaQjbb+0FZNRpd0qJPB68blDC684HlTTPl0Fy8m5ktGPL0cE7ORjHWJqe4XypjpMc873yFu+cT9lFlqtrd1t71SpRLyBRLkPSNwmDAX4pTWKrVS0iAZ70hIy2hwrTnesg8YKjDb+dpr4zb19shGNtdI62qybb2uUlvVeoMuirVRYLU6GRmBJW7zmBUGWW47kcFMCHCLl8zxQXbUbbxupbP9TY3+7jZLTyXu9Q07aaeVIgmBjKXICLBUb9VKnWGY4x3nFveb73r72qPZ9QgEWseqVDFfE2KjGxOzkawwQ5jVIg6lrWiSlrOxzPOeH5X/xrQV39Yt734ndLpccbJSaz3V5c92Sucf+veSaxUGpf7YKXR06r86e9MIl2trtukGGK9Vrr83FaC1VUKHqbBEaymBczDedbkY5t4y2FehlKcd6WxjzDLZInM97gOPeMOuTpOUkhalN/XRxUugtUhSI/QX/xOajAFmWO14o01yrL85xUhXWanSLra3n+3MtthIu/iu/X3DYZ43RkbGQx51ih+53W0SEq52ra66ysjEU8wxX3lioxsTs5Fs5ygzPSIhJaNBP8es0yYTMrYm+jmiiORaFfD+WvWQZ2b/RyBjPlZWDfKLbQ6w0mwtsF+bW+3X5lYHeF07u1gtrUI7/3N9toe38CxkzWlhdF6sFCoXWqTC254VCCSyDwkfZI+BoTp5ySi7usgipGWy3uabUlpL6yXjPcu09FOXutk/1euKLkINmI8p0lqYp4+MjEF6muMG1ep11GqdaOkXjc1NTSckbGU7q1QKBJ43TgfDlSn3cye7yi8+tnRgTMyXlfixMiZmIxngZEd5yBBnOsK9dnZ6s/1hyKjF7L2AfRdy+KLI+DZlVkX/qK2UjJRplftojGVs2rRAW9BK0tk5zzYtqTvZNJsiBYqNkhR5pvtIqVVrdy95SptchaCEpC52XOfzDNFbsolxa/CBWn/DvWjhcXX+gg6+IVKhSiIfvUXBUm/LGCtpuGJ7aWM3Z7rIe+uZARhisGRWKCQjYxdDFCqUlDTKz6ywSr0GV7vVM17e4N8gJubLTmx0Y2LWw1sZzq7jD/VUN7GG2zrS/q61vWPW8cYm1fFg5Zr3T1TxVm3zfo8v2lN024USGnQvnCAIaG+wpJRAykB/VGLr3DHXaGecbg6UzprRA3G8nX3feINdoND5ihQqsLe01dqbZVcv+54Vuutpfz/0wDqf8WrH+5697KCbnzkoW9igP7bLtojGuTAnahFmXwF2AVW5vNxARsr9HjXYSC94rdm5HvQvxzraHoa5zU0OsC+oUatCVU7UAxZZuu4fJCbmK0I8vRwTsxaTM+xeG2kqZfByhocLPu4oChttcCKkJE06kBckNC3nc1RpJ3/rXu+m5fV65CX8sttCvTyhu4NkstHGr1nkD16ym+721lsgMESBCSYJpURFCFobi0ekXa5IHxXmCWQkMQBV3vUts4xyiwMVN7nVl6h2pQlKJZxlpD4mechkyyRFRrUGpU3GXZH9vdEwBmib/Zlusi3yZEOhv7vHPk10lbvo7G63rXPNShQb5WD3e1Ig0EFbh9jr4y92TMyXlNjoxsSsxVOZSN6h0cQ8lommiRMfs8y4TT5ntgnd0LJWozzyVWHCv7PCD4960MteNLjdLsa3OzHrKR+aOz4p30MmO9q/RaYrkPS2jsr1cKTFeoqMXBrLMcd5AkO9bY7jm42lo0JbS7rWDusY3J4eV5s1qJd7W2i66KsgX1QIcI5IsrFUVOJvGrYhpyMdoiw7jqomZ410lAN0yE6Pbwz/co3D7WeF1UY5WMe1JCtjYr5KxEY3JmYttgvWGNwktgo+3uA2Mrxdxg1N3t8TZNwk9KS7/MR3c4Xml1niDOesc/yf3J+dao08ywIdnWmMl/SX0D3ryUKdaPr3LdP92db6mZ0tmxcI/c4LfuSCdfq/zjRNZ7xDbUVPCBlRlHKdyNNdibmiwK2WIgPb6LWH2X1FZBWnjrWPMV60SLUhdvIbZ27cBUNKykmO2Oj2MTFfZuI13ZiYtTgkwRUpemBowMPrUSisMcFs25mmwCLfE6oH7Zu1imKLi/CE/woEGrJTyPd4wnPeV67WPB+aaYFQKGGiIGfyQ9/xuqEqDPe2jGQ2OCpEg97G6mOczqa6zhF2c7MB/uP7DvNNB4AyNe423n9NkpbxrqXWTBtHucBRNaF5uBUPaq9BoFxkcOtFXvUUax5FVokimKOC99vr626/t8Crqk01zkM5aciYmJjmxJ5uTMxaBAEX5EWvDbHISerNQlqZ2xUZprXTdFQmYb6MbZB2sEUK9NXH1rmqPw0GGWuE/d2pVEK5sahzogN803xzlZitG1bpYKEAe3jDpa7ylP29YqD9Xa6LyWA09tTHNfqoNEngeAlp5Wrs6i+mZQOT9re9Z5WJHg0KREZ0vMjwvidhuYxyy0xFIc3UpxZjusgIL7Ojwbo73ABdXe7EXGRyoY1Y/I6J+RoTG92YmE9BnYVu9FOjHWRb01xtida4zxyB9/AuQi/JV6213h40UMYsLHWwRm+zXFqUAjTH3Ua7206S+qGzlLTFVgvcLRTY0xumOcokS3MGF1bgTAdaYq4i96oyH/X6KjXDKJGvnfas8dnf6zAj+7Nx9TohIyMSwWh07Vug0hoPdyEykpL2NNBNfr1pLm5MzFeY2OjGxHwKHnGtG7JpL9NsLyXtAXTXQjqbWpMU6KbIe25UY45jRGbuF2pFnmbj6k7TDN120tk83AYJdzvCFU5Rboy07f3EEBdJ+22zaGJqlCo0WpUPyU51z1SujTFW+oZIFzlqGa3Drm1Qs0UWmqVBJVHiR76vQqWZ7vGeCfYwzO9c/NkuYEzM15QvxJpuEARnBEEwJwiCmiAIxgVBsOtHtD0lCIKXgiBYmX2N/qj2MTEbQ7Xpphlpkt0tb1JEfm2qpN1skYfsK5k1WBlJryhwuHeNUedIPeVJ6KPEnUZkVZwixiHP4+SEKyqwILu3lGZFDKiWVmEn3xbq6wkjXGm4fzjYrlKSAknjHWu1Htk+10wJh+ihSGRom5LWUcqfnWyYre1vew/7P7e5zRkOa1ZD9ycO9Xcn+rdTveE51ZYZ7TFtP0F0ckxMzBq2eGm/IAiOxR04TfSd9HOMwnZhGC5ZT/u78TJeEX2b/BLfwg5hGC5Yu/16jo9L+8U0I5Txjj7qzCerPLyjNxXbuVm7jNBe3vWKcqEW6JiVj4BVAksE6KbAJD1UGm+VbZxiiSnm2M3TCt3neYEGBUKl2gvc43nTfOgnnlSpm2jNNfJk20o4SrV/GJMbR6DA98xRorW/OkAoPxvVPF/CzTLZAKcEWkqpkJF2XLaQwhpe8gt72mad61GuymNe019PA9ezPybmq8LXtbTfOfh7GIb/F4bhZJHxrcIP1tc4DMMTwzD8axiG74RhOBU/En2O/TfbiGO+UqSVqTPXGi8xVOXdddq9r8bLyrNGtgpLjZA0SlJoSc7PbGG8V/Tzjm+ZabAqYy3T1uP2tYP7tNZGQrWkJc53ocny/NgCNXYwUDfR9PB7+L2VLnGPBzSdSg6lLdTRDQ6RlicjlNSAhTKG62SAPnbURyfVygXKJbywzufppHS916NUC8fbz0DbKFPuZOfawYHO83v12anrmJiYT8cWXdMNgiAfQ3BF47YwDDNBEIzG8I3spoUo0XDFBs5RQLOQyvV/08R8bUlqpchOqk0iK0lRYvd12rWTJ0+gPmcAK/1JS8USHhJoEErgBP8VqgMJad/2gLc8jZYus582LvMzrRxuWz0M0Nn9Mtk+31VhB0mTPEh2dbjCDFFwU0eRWV9lrOah1WlM90+vq/ETGdX6KrIyGxxF0kwHG+kF9TJCVzpKX50+9tr83GXu9l9pGVPM1EFbv3TaJ77GMTExEVs6kKq9KFrjw7W2f4h+G9nHVaKwytEb2H8hcdRHzIYJBLb3jAUu12CFjk5TlNMfXkNrKXfbzmlmqpFxmV4GKgH/M8h15imVtL+uynNTuYFa1SKzeBJYKeNq1c43yALVOYPbyCQzNU/XIYo6biHyeEs0mChhhYyVmGYPlba2Wm+RntRlpihp8qyZEDpSe4+6MPeZN4a3s/m9jcdMMHWjjouJiVk/W9rofiaCILgAx2GfMAzXjhZp5Apc2+R9qSizPyYmR56OtmqmJbV+RulglA7rbN9fW/tng4uqXOYNL6rxgZSO3nES3mzSOqFKgxVqba3Uvjp5PvfcWS+amOlBroB8C3SypthAFApV4jfKsjEZLwdcJ+VnGvTHT8NQvRr3B429JRRIf+KSeSPtZ4KpEgJpGQca8YmOj4mJac6WNrrLRI/0a89zdRJl42+QIAjOwwU4IAzDCRtqF4ZhLWuU74IgrtMZ8/kRCv3XRLd61jiTddbKbU4zwixl5jrCFGMsQTeRNnFUiG+I1l51tVc8p70BLnCCK72lRd0qF73xrLu2K/Re+86KTJWyQIMG1fYWTTOHUhapU0eQ0miIbw3z9A8KLQ8TLlSuVGhEhgsSpCX0N8xqS7Vaz0PDhrjU2Yrku91dEqqsNDOrCh3fRzExn4YtanTDMKwLguAtURDUfyEIgsagqL9s6LggCM7Hr3FQGIZvbqhdTMym5jf+53ceFz1TrJJsAAAgAElEQVQ/skK5kf5gsZu9oiBrcIkMY5GdLPWeOfp7UtI8uwuEXvZ7c7Gn257/jz6rP3DBHnso8pJ8Ndn6Pe9ppUAno6w2S4WXVEnknN9AaEHQ0TeVEHCVGpeb7YogVBRSEDT4nQPBcS51jIs2+Jnq1UpISWb/vespK7wtLe1SF+iqu2OcuAmvakzMV5cvQvTytTglCIKTgyDYHjehGP8HQRDcEQRBLtAqCIJf4nJRdPOcIAg6Z18lW2DsMV9hwuy/psyx0nmedL6nLFDmb16hSR5uRmiJ1arVSa3lDUZroq/JmG/f7EROMpvIM8DLmGa/+bO0rqvHbIlcLaCIPXX0lO8ZpEp52ErLdFLfdFp+GArDfGXW3AIzFTpXidqQ/EAzJeR7XGypeet81hpVbnKGoxQ5TmuvRs/B3vGmdHaNOSllovGf5nLGxMT4AhjdMAzvxXm4DO9gEA4Ow7BxkasnujQ55CeiObYHRErtja/zNteYY776PGqqDq5Q6FIXeRYsVWk3f3edsa7xvF39VSelAoUap3gDHGSAIqH9dXWYHrk+91OqMQ94qSLpkOJKCmtDSyXxotc65+tdVqn/8iUawuj2bDT7B5jgFFt71ONqg4zFiQJnVoeuqqzWJmhgrQeEhLQwsN6J4LpsGT6YZ7If2sooxZ7wV4RqVLjGiRrU289BEtl/aQ1G2O9zuMIxMV9Ptrg4xuYmFseI+Thq1GvnCtVNkoOuNNIlJqtRTTa1iEL9dLW7l4z0gsf00ksvp3lKkbTQ/1PiYud6yfW5ij2voU7v9GovvPGqnsujpJ5Tdmrnth6l2lU3uOmlFV7tXO+6nUL5QkkZnYV+hd8qNTdbXCAVhnZM11mdjAzo6qC1Fboi0NIKnSzKlZbvKQrHgjxJ57nPLo4CF9jbFC/LrBMxzb3KBFKud6WZpjvc0Y7w7U1x2WNiNjtbQhwjNroxMWuxTKUOrmy2LbBYqB8qsq8asrm4hZJecbNCmSZ+bcRIB3vZkGxJv1Zo0M8Y35k/2a/eXZhrV5Og1cFd7FBb61sVK61MhJ5pE2UNlwQUZssEdlDgboUSIekgsEO6Qs+wwYxkJAJ5UJg0KM21qYwwCHXSzW/82ase9ro7lIqMcIEWbrFaUsqPbWOhWetch32d5Bx3fk5XNSbmi8eWMLpbOno5JuYLR3vFRtrWY6aLPNp6oYWicndbiVZl1tyfNdJe08t+3l+nryrva9Az+64SLXRV6oRwUbN2qZD+4SLbFDCpIDKgvQNqhCoxQWPMVK3jpI0O8pSq1zZocFUZhSGjWnNAJu2ockYmzrWi9fH6GSglpUCZGe7Ina9OlXq1klJ6CSxsMpYR9ren0w0yyEqPaGGQgtxniImJ+Sxs8TXdmJgvIv9xvOP0FBV3nyQKluqa3Ztcp/0TtrWf73uliVLUGxLGKxLdZvmiZ9x6z9nKXzsPtbCkONf2r30pTSAQrcMGkYmvEinFrClrzzgNuqvRSoOqgOtK2DrNnyv4VhmWB7pm2uhnRw842UVKjfcn7ZoUU9jbDxWKzt9FnV2wPfaUtJv+dtTWe7Y33RHetV225m9MTMxnJZ5ejvlaMsM8P3OtD61wlmOd7NB12kw1T38/1ttivS3zkh+o0/h/Zhlho/JoB4JKVMs3zaGWSqj2uFo1viVSKV0jbEFGYIbS9LNGray1Ip/pLaOQ/a2yLUJRcvksrBJl+GZE5rsLiiSlpQUh/Rq4e9WacVfUddWi89teSt7hf35JNq+2s8528CNtDTXISInsM/drLvaGywSSCB3lRWWusNr/NNbabeNw23ros1/4mJgvEPH0ckzMZiAUOsTPzbFIWsb3XKavHna3U7N2/fQ0WLF9vCMj0M4dHvFjNfJsu5xuq+ptu2K5qe1rVPZeqI8JlmjhcT3USYrWcFMi81knCmVKIKGVD6xOJjzTPqk0p7HMdF11CZfpurrOW62jceyfbmFKssoMUerPBRL+K6mVtFocu5YW2/MdF8okvyltiISkjAahUJlFOrvSNsbnDG4oY6Bva6OjVZbq5Zs621VFzkOPoq0TuTCsmJiYz0JsdGO+dtSoNUvzKpATzGxmdEOhlcqNVK9MICHU3Rznu8V5M/+s9M41GWonnbSPbb0gRH/0UObOXF/LMD37e4E8rbUwXwsT9ZUwS1KFXRVbqEY7ywy2NBirS/k0354TaN/lSD3bBRLVD+qYR36KWY7QPut15gu8U5TvsJpaMqxaRX2CVe3e0dMuQplcufutQYMKoxXaQUaduQ5W6XkEtnad9qLS1D38VoWx6n0oTyfdXfI5/xViYr6exEY35mtHkULD7eh1kxGJVowwKLd/qjl2c5oyNYb50PDGYvVhoCJop7IgMKdjvu2X1ZnYsZPZffNsZ83kcV8rBNJCk0Xrv0WiKeYahd5SYqG26CWjv9BrGkxzEGgjra05JvRImdCjwcj0a95LLiLFyJUMqxthWpuGaIkYYRBalKw1vlUPfe5boPXCjKPw9p4hu3d2ule94hAlVuopRKjA9qDcw1mDC6HFztPWjyUUKtLPIHPVmS9fD4nGE8bExHwmYqMb87XkMdf6vX9aapUfOVwfXXzfKGOMNltvjaFL47TWIt3NXmULdfyQv3SodW3333LGjoYsWOby5xM+bKJlnMFKBUJzrJH8DkRBWCmtROu2PcIo4rgqEdrPG44wV6ki++hnnit8YLaeWngl+Ydc30+0Zv/3XxK2WmPgQ7QLSMz5QKsmIciDx4Yqhx6nJL+fI71sgVM1WKitM5Rm5SDDdfJyM5oKbCQUKMz6xzExMZ8PsdGN+VrSVit/9DMwwxw7GWK1KWq10rxgfI3nEns46clxHtuh0OS2pVFosUJvdW1j9rFnaafOaDMN845QkQf1pZnMYohqefL1yCwxdDVdlrHtNC47VNaCLnFsAxOWz/Vc0Ri9EgPsVnsA7db0EmRfLfI4bD59XkaSqXuvLw0hUKI7KNBPH2PWadHS4QoNVZOtgNTRJRKKPvU1jYmJ+XhioxvzpaZOvTr1Sj5loE86Xev/nj7QWYve9/QwxvQPRLHCjSqkSfTyo2OH2WbZdHnBMoXmIVAddCOvwDka/NbrHkdGvQ9z68VrIpYTFhqsXJ+G0LXXR1v/31Fy661nLeOEpTQkaFFU7fLi1y1sWGi395mVdbyPXk4yTYeVDPonyYbo2K1m895ZVLbNV7wiEuwI+p1C/kfLkSe00MfLqr0uqY1CO3yqaxgTE7PxxEY35kvLvz3l+36rVp0f+5ab/PITl5yr+cfpfvfEDA0JTn+EQ64s90w/EhoUqBLIaAieUxccJlW7XIsmHmyx2Rb7j394NLctEGqrwmKtRapVGYWqbadSX9GUcqMpblcRHXPOTE6KbKW8NCct4LHteLv0Q3f9vxYWd6rSooDSXRLCHhk9Z5GqbzwfhZV0n99OiyfLo+K5DQFv3sW+V1Da9iM/f0K+Ynt+omsWExPz6YnFMWK+lNSo9T2Xq81KMd7sIc8Y94n7aTH2EQHyMmQC9htXCgpUCmS0RB/ldqy9z8n3zMtN8UavjPsyDykLG/x8PGPv58HH2WZlvUAlMlqoMkKlrbLHtKpMuGY/Rm/LDyd3d9qHnZz0wZrxBCJ1qnYNDEiEKi79tW59Ttdm6wukdllCxR9N7UM6FYlohAHpIgoriwX1dVmFypDaKhbPVmameR5V9dHlqWNiYjYTsacb86WkRp069c22rfDJctvnmC6/W6FOFYFkJpSXYWaXAqQFQgXWLKmGeQ3+9VOKKikvRhg5lcekOWIRQ6NAaKX1/HNMje2PqBGiBO9m+xiA59tmBHtFxvJXmfl2mkNYTJCkMa5peSGdl3NBQ4PFPX6jxYlz5GVVndOPX6/617z5Xfq8RCZJOJBOefMYUcw72epBifbm9pzpOScJpeVpaaSXtbHjJ7pGMTExny+x0Y35UnGdl9xinJ5aO8oBHjQa9NHNIXYnneaB3zPxebYbxnGXqMvjGreYbLqRDnCswz3nKWc6RNezQxf/lccPDry5I3NKqgQ6SmiQ1ETmKUF9EUULWFSMgLYhA/OYVs7OURPJkB6V0SGBaIK5MNvFlOy2MDsD/q8ErduzR4qwdSRqFWZr3146i5oWzO2R0WBhzuimqht0fovFQ3i7J3m17LGQoAHfqGPvTNR5ybe8U/BHoeh9g0qT3WAPN2+iv0xMTMzGEBvdmC8Nj5niHI+BGZbZQSePusZqFUbaUyslPHQl/7oYIe+9QCbj3JNTbnS7hMBdHpTR4EzfVSY0txOnX0pptph8O5XmKFShlzrF2ocLpbJSqclVzO62ZjzzAhoepWt1JEiRDKPX/b3WtGm6wtw4LR0iCOm2kkwLZremT6OAVUhiRVYGspg8fRUYuKaTkb8z4K4f6jCJ+jb5Om1Vp6BQ5CXXNfH8K25WWjPE8sLGM5LMmf+YmJgtRWx0Y740TLRYQiAjlBaaYqmRawcBTX9NLuUnDJk81hMywuwxKUnXu06ZyB2twwfk4nbLlGjIFi2o08HMoJXOwVTFmYzaNuuOaecP2HoVy1px2x4sbcEt20T7BhukgxJvGCsh4QjMkDEjZPA8bryLlvVMPbazsh6LtcygjHAOdQWl0n3O0cuZEk2N5bAfSJSs1nX8edHoEwlk1q5fDwZlzrHQGeqsUqKXAc7/FFc9Jibm8yQ2ujFfSG73snPcA651nJPt4Rv6+o2nJCWEQgfbdt0Dt9+TcQ9HvwcBO+5tiOXmWiAtrUHaAhOaHZINHJZBoKHJnlCDPLUyShKRrlQbrMzuPec1ts2+6bSKm/tQ0DIwxhj9DVWoUEbGdJNMsq9iyxu7NeJ+irPT0APuWiJ92UMaFj4hOW2xRMlOCgZdoCCVTflJV1P+HkU9KejE4mtom51GnpOhvWjxOGt/QesjtC063nGOVGWhEr0kmlRAiomJ2TLERjfmC8cHVviB22Sy7tsP3GZf/QzV3bNOdZe3ddfKefZe9+AjziGTjqaWtxvGt3/lbyrkyTPRFLvo4z1/t9Qa53BvpXZWbjHGqdLRh5ZoheW6WamnNfYsJbJx/VZz5Quks/PH81txTMn+fuYfuuVqBZGQsJV25jYa3Gijsu4UZ4sUBXktpILD6XGk7NLtGmoW8sruVM8lyGPIg1lxjiad5Q0lOY1WvWn9M/J7UboXQSClhZa2+eR/hJiYmE1CbHRjvnAssipncCEjtMgqPbWztz721ie37053usQlWmjhL/5i78TeHP3L6JWlrTbudgN4z1jn+ruhWI5+GCBKmO2MvmiwyApLFQr1kM6ty0Y6VJG4Y/tWXHkCB71OYQHP70XbxHNK11NrN19HhXqoMV8oStRNleyCN0gVMeqO7DTxepjzF6rnR7+HDUw5j+3/yDsnEqZp0ZtZe7K8Pwf+lPZDPvkFj4mJ2WzERjfmC8dAPWyvi6nZ3NLtdDZIz3XaTTbZyU4WCiUkjDTSYosVK16nbSM72MPRzvGg67SWZ6hStVkvNBSlAdVIaBAozAZXNSUjqm/7Frr05IWeUXn7xkJ51ZZruZa7mpCyi+fN8GsNyvQKztZ+5DfYfxV5RaQKPuJqrL1YG9L1WNqOoOoDrjiR+ddGu168ncteoe+wj+gvJiZmSxKLY8R84SiQ52W/co1jXeNYr/iVgvWsR84yK/IckZFRocKHOfnG9RMInOoaD6v0sAp7rxVc9AHaKpEvT7nEWpnAsj4xvXQx2C+0NzB3E3Wxmw7ry4Nd9awW035j4OzOhtTfqb1vRNuLWn+MwUWvMyjskh18in7ZAgiFXakoYv6sJo1Dnrrho/uLiYnZosSebswXkjaKnZ2thrMhhhuurbZWWy0U6q+/Xnp95DGNFGSF/Uc4XycDLDNNN3s4QUJP2xjnTRe5VKWJ+oSr9M5wbJKead6aTufqRV4ddLXhiT8JlQgkbO8YibVvqYp3mHSQnPhjxevs9MrGX4ii7uw9lbJ3abFVZGxzH6J4TQ5SI+3WXhSOiYn5IhGE4XpyDb7CBEHQEqtXr16tZcuWW3o4MZ+RmWa6xS2KFDnLWdo1LcvzOfC+q7RZcYFUa9G8UEhiJS1G8/wuzOndzvct23AHC2/g/bOabxteR+JziiT+5+k8dVP0e7d+XPkuqbj2bUzMxlBWVqZVq1bQKgzDTyZp9ymJPd2YLzXb2MYf/OHjG24koYyMvwq9JDBMkecUFkYBw6Fo/TbM6h63rCD5ccXdS3Zu8iZJ0Xafn8GF7/2Voy+LcpJbdvj49jExMVuU2OjGxDQh40/SzhG5tfcpCNsIitZk6SSRNwUhc7uljHDjR3fYcg/63sniv5LXid7Xff6DLm3/+fcZExOzSYiNbsyXmzCDKoKPrh27sWSyWs5kBJWUWCnMluoNkKwhOYm6pQf7Ztu7Ffro0nmg40nRKyYm5mtPHL0c8+UlfIN0F9KlNOxPWPnZ+quvF8zvFs0dN5BaRqJSs2ClZD1Bp0DBgRdvnMGNiYmJaUJsdGO+vKRPIRfE9ALhX5rvz9RGghLrHFfOyn+z+pFIYAIqqxj2bck+z0r8saXEks6RZ1tNagmJcpIrSFRgp54MjnNhY2JiPjmx0Y35ErPcGrHhgDCrqxiGjD6NXbdju3787YJoeybDikVM3425J/D+Ecw9Ptr3wJO8PUlQH0id315q25Yk+yApUR1IrYwMcABtztmsnzImJuarQ7ymG/PlJXEumbOzbwpJnBz9WjaG/9/enUdLUZ55HP/+LgjigjuKIqNxd1xQJyqoCQ7ucc3IEZdRGY+7njFHo7hF0EncTVxjjAsuiIwzE0aNGUY8IQmbURSFcQniUUHAXRHCzjN/VHVo2ttwu293VXPv73NOnXvrrfftet739u2nq7q63gEvwpddkpx8wX/DxvvBpXfBdu/CL2aueIyvnoYls0vuZ0wyV99mY2HBQ+nDHwRLJ8Nafw+dD6l3z8ysjXLStTVX0yWgvSGmgQ4BpTfG+HwmfF70tZwAbrgbZn8Km5XeG7kjNK0L/Y+Eex+HP6czEN1+JXTuDp2vKarbzAQLZmYV8OllW7Ppe9B01oqEC9DzGNg4WOlWTR/MhWXL4bUu8GR6AZQ6wdYPQIeusHQpdEzfg3ZdD3rtmlkXzKz9cNK1tqdjV/j+wVA8XcGmhRnoBXdtDYvGw+5zYZOBSfG9T8DEycnv8+bDOVdnGbGZtRM+vWxtx8JFcMhZMGkqbLxh8jmtgOUBg86GvXaBqdPgwL1hh21Wbjt33orPdZcHfD2v9NHNzFrNSdfajtMuh3GvJr/P+gR6bA4Xnwa9dobDDkjKv7t7820Hngj3DYOvv0nWrzy3/vGaWbvjCQ+sMUTAa3fCzOehx6Gw948rf4ztD4fpM1asd1oLFr3e8vazP4ExL8F2PWHfPSvfv5mtUTzhgbVf434M025PTgd//gLMnQ5976/sMfrtv3LS3bIbfDMf1i8/qf1KuneDk4+pbJ9mZhXwhVTWGN4btvL6h0+1rN3CRXDDL2HgVXBcPzj9OFg/vVny+x/BHsfDZ1/WNlYzsyo56Vpj6FIyLV2nDVrW7qxrYPA98PgzcPT5MPAEWFb8VaFZMOy52sVpZtYKTrrWGL4/HJZ0Tn5f1hEOGt6yds//MbnaeNlyaGqCF8bD2kVz3EasvG5mliMnXWsMtz4P528HF+8A5+0AD0xqWbs9doIO6V2mli2D3XeE+6+DtdLLFQ7cOznlbGbWAHwhlTWGV6YmR6zz0qfkpDdb1m74bXDB9TDtAzjtGDjpyOT7tof2gc+/gm17JEfAZmYNwEnXGsOhfWDMy9CkJPn2a+HUeVt2g5H3fLt8w67JYmbWQJx0rTEMOhvW6QITJsNB+8CFpySfx/5mNEz5S3Jzi9698o7SzKxVfHMMa1w3PwiD7khOD0fA8FvhpKPyjsrM2og8bo7hD7uscT3yX8nP5cuTpHvq5TD+tXxjMjNrBSdda1zb9lhpoiCWB9z6cG7hmJm1lpOuNa5fDV75YigBXdbOKxozs1Zz0rXG1XNLGPsEdNskWe+2CQy+MN+YzMxawVcvW2PbdXv4YDR8OBt6doe1O+cdkZlZ1Zx0rfGt3Rl23CbvKMzMWs2nl83MzDLSEElX0oWS3pe0UNJLkvZdTf3+kt5O60+R5C9vmplZw8s96Uo6CbgDGALsDbwOjJLUrUz9PsBw4CFgL2AkMFLSbtlEbGZmVp3c70gl6SXg5Yi4KF1vAmYAd0fETc3UHwGsGxFHF5VNBCZHxHkt2J/vSGVmZu3vjlSSOgH7AKMLZRGxPF3vXaZZ7+L6qVHl6kvqLKlrYQHWb3XgZmZmVcj79PKmQAfg45Lyj4EtyrTZosL6VwJfFy0zq4rUzMyslfJOulm4EdigaOmRbzhmZtZe5f093c+AZcDmJeWbA3PKtJlTSf2IWAQsKqxLaq6amZlZ3eV6pBsRi4FJQL9CWXohVT9gQplmE4rrpw5dRX0zM7OGkPeRLiRfF3pU0ivAn4FLgHWBRwAkPQZ8FBFXpvXvBP4g6VLgt8AA4B+Ac7IO3MzMrBK5J92IGCFpM+B6kouhJgNHREThYqmewPKi+uMlnQL8G/AzYBpwfERMzTZyMzOzyuT+Pd2s+Xu6ZmYG7fB7umZmZu2Jk66ZmVlGnHTNzMwy4qRrZmaWESddMzOzjDjpmpmZZcRJ18zMLCNOumZmZhlx0jUzM8uIk66ZmVlGnHTNzMwy4qRrZmaWESddMzOzjDjpmpmZZcRJ18zMLCNOumZmZhlx0jUzM8uIk66ZmVlGnHTNzMwy4qRrZmaWESddMzOzjDjpmpmZZcRJ18zMLCNOumZmZhnpmHcAeZk7d27eIZiZWY7yyAOKiMx3midJWwEz847DzMwaRo+I+CiLHbXHpCtgS+CbHMNYnyTx98g5jry09/6Dx6C99x88Bo3S//WBWZFRMmx3p5fTgc3kHU05Sd4H4JuIaHfnudt7/8Fj0N77Dx6DBup/pvv2hVRmZmYZcdI1MzPLiJNuPhYBQ9Kf7VF77z94DNp7/8Fj0C773+4upDIzM8uLj3TNzMwy4qRrZmaWESddMzOzjDjpmpmZZcRJt04kXSjpfUkLJb0kad/V1O8v6e20/hRJR2UVaz1U0n9JZ0v6k6Qv02X06sZrTVDpc6Co3QBJIWlkvWOspyr+BzaUdK+k2ZIWSfpLe/o/SOtfIukdSQskzZD0c0lrZxVvLUn6nqRnJc1Kn8/Ht6BNX0mvpn//dyWdmUGomXLSrQNJJwF3kFwOvzfwOjBKUrcy9fsAw4GHgL2AkcBISbtlE3FtVdp/oC9J/w8GegMzgP9N75O9RqpiDArttgFuA/5U5xDrqor/gU7AC8A2wInATsDZ5Hz3uNaoYgxOAW5K6+8CnAWcBPwsk4Brb12SPl/YksqStgV+C/we6AX8AnhQ0uF1izAPEeGlxgvwEnBP0XoTyYvHoDL1RwDPlZRNBO7Puy9Z9L+Z9h1Ibs12et59yXIM0n6PI3mxHQqMzLsfWfUfOA+YDqyVd+w5jsE9wIslZbcDY/PuSw3GIoDjV1PnZmBqSdlTwP/kHX8tFx/p1lj6jn0fYHShLCKWp+u9yzTrXVw/NWoV9RtWlf0vtQ6wFvBFzQPMQCvG4CfAJxHxUH0jrK8q+38sMAG4V9LHkqZKukpSh7oHXAdVjsF4YJ/CKWhJ3wGOAp6vb7QNo828Dq5Ku5vwIAObkhyxfFxS/jGwc5k2W5Spv0VtQ8tENf0vdTMwi2//A64pKh4DSQeSHOH2qm9omajmOfAd4B+BYSSJZnvgPpI3X0PqE2ZdVTwGEfGkpE2BselsaB1JznatqaeXK1XudbCrpC4RsSCHmGrOR7rWUCQNAgYAJ0TEwrzjyYKk9YHHgbMj4rO848lJE/AJcE5ETIqIEcBPSU47twuS+gJXAReQfAb8Q+AHkq7NMy6rLR/p1t5nwDJg85LyzYE5ZdrMqbB+I6um/wBIugwYBBwSEW/UJ7xMVDoG25FcQPRs0XRnTQCSlgI7RcT0ukRaH9U8B2YDSyJiWVHZW8AWkjpFxOLah1lX1YzBDcDjEfFguj5F0rrAA5J+mp6ebsvKvQ7ObStHueAj3ZpLXxwmAf0KZZKa0vUJZZpNKK6fOnQV9RtWlf1H0uXAtcAREfFKveOspyrG4G1gd5JTy4XlGVZcxTmjziHXVJXPgXHA9mm9gh2B2Wtgwq12DNYBShNr4U2IaPvazOvgKuV9JVdbXEgu818InEFy6f+vgC+BzdPtjwE3FtXvAywBLiX5vGcwsBjYLe++ZNT/K0hmGvknks91Cst6efclqzFopv1Q1uyrlyt9DmxNcsX63STJ9gckn+ddnXdfMhyDwekYDAC2JUk47wIj8u5Llf1fjxVvIgP4Ufp7z3T7jcBjRfW3BeYDt6SvgxcAS4HD8+5LTccl7wDa6gJcBHyQJpOXgP2Kto0BhpbU7w+8k9afChyVdx+y6j/wfvpPWboMzrsfWT4HStqu0Um3mv6TXKU6MU1U00k+3+yQdz+yGgOSj/uuSxPtAuBD4F5gw7z7UWXf+5b5vx6abh8KjGmmzWvpeE0Hzsy7H7VePLWfmZlZRvyZrpmZWUacdM3MzDLipGtmZpYRJ10zM7OMOOmamZllxEnXzMwsI066ZmZmGXHSNbOKSNpZ0kRJCyVNzjseszWJk65ZCUmxmmWwpG3S35dJ2qqkfXdJS9Pt2+TTi7oaQnK7vp349r1yAZA0VNLITKNK9numpK+y3q9ZSznpmn1b96LlEpL74RaX3VZU9yPg9JL2Z6TluZG0Vh0ffjtgbER8EBGf13E/Zm2Ok65ZiYiYU1iAr5OiFWURMa+o+qPAwJKHGJiW/42kjSQNk/SppAWSpkkaWLS9h6Thkr6QNF/SK5L2K9p+vqTpkhZLekfSP5c8fqR1npE0H7g6LT9O0qvpqeD3JF0nqTePf7IAAAP4SURBVOyUnpKaJP1E0kxJiyRNlnRE8X6AfYCfFI76WzKmksZIukvSLWkf55S2LerD79Ixek/SiUXb+6Z1Niwq61U4o5DOR/sIsEHxWYmWxGeWFSdds9Z5BthI0oEA6c+NgGdL6t0A7AocSTLjzPkkc64iaT3gD8BWwLHAniQzrRTm1D0BuBO4HdiNZLaaRyQdXLKPwcBvSKYJfFjSQSQz2dyZ7vtc4EzShFzGv5LMdnUZsAcwCnhG0g7p9u7A/6WxlB71r84ZJKel9wMuJ0nch5bUuQH4T5IxGAY8JWmXFj7+eL59ZqKS+MzqzpPYm7XOEuAJ4F+AsenPJ9LyYj2B12LFXMHvF207BdgM+G5EfJGWvVu0/TKSmVnuS9fvkLR/Wv77onpPRsQjhRVJDwM3RUThqPs9SdeSJPQhZfpzGXBzRDyVrl+RJvdLgAsjYo6kpcC89ExAJd6IiMJ+p0m6iOQz4ReK6jwdKyZxvzZNyheTTPO2ShGxWNLfzkxUGJtZJnyka9Z6DwP9JW1BMkXjw83U+SUwID1de4ukPkXbepEk5C+aaQfJkfG4krJxaXmxV0rW9yQ5mpxXWIBfA90lrVO6E0ldgS1buK9qvFGyPhvoVlJWOmH5hBrt26whOOmatVJETAHeBoYDb0XE1Gbq/A74O+DnJIntRUmFU58LahTK/JL19UjmZ+1VtOwO7EAyZ23WSo/+g8peg5anP1VUVs8LxsxqzknXrDYeJpmAu7mjXAAi4tOIeDQiTiM5XXtOuukNoJekjcs0fQs4oKTsAODN1cT0KrBTRLzbzLK8tHJEzAVmVbmvWtm/mfW30t8/TX92L9req6T+YqBDHeIyqwl/pmtWG78Gngaa/Y6opOuBSSQXIXUGjmZFMhkOXAWMlHQlyWnXvYBZETEBuBX4d0mvAaOBY4AfAoesJqbrgeckfQj8B8mR4p7AbhFxTZk2twJDJE0HJpNcid0LOHU1+6qV/pJeIfl8/FRgX+CsdNu7wAxgsKSrgR1JLvoq9j6wnqR+wOvAXyPir1kEbtYSPtI1q4GIWBoRn0XE0jJVFgM3khzV/hFYBgxI2y4GDgM+AZ4HpgCD0jpExEiSq4ovI0na5wIDI2LMamIaRZLcDwNeBiYCPwI+WEWzu4A7SK5OngIcARwbEdNWta8auo5kXN4g+f7zyRHxJkBELAFOBnZOt18BrPTmISLGA/cDI0iOjC/PKG6zFlFE5B2DmVnhO8AnpG8yzNokH+mamZllxEnXzMwsIz69bGZmlhEf6ZqZmWXESdfMzCwjTrpmZmYZcdI1MzPLiJOumZlZRpx0zczMMuKka2ZmlhEnXTMzs4w46ZqZmWXk/wFcEo/V8evrGAAAAABJRU5ErkJggg==\n" + "image/png": "\n" }, "metadata": { "needs_background": "light" } - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "0.9286664636709675\n" - ] } ] }, @@ -395,28 +425,28 @@ "height": 497 }, "id": "yi_ztCc3kRXy", - "outputId": "531e7145-e65f-4042-b367-940a30ff2f2f" + "outputId": "6d6cd607-5818-4ed8-be63-f192816030aa" }, - "execution_count": 8, + "execution_count": null, "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.9143006449691364\n" + ] + }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], - "image/png": "\n" + "image/png": "\n" }, "metadata": { "needs_background": "light" } - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "0.9142998039353573\n" - ] } ] }, @@ -432,28 +462,28 @@ "height": 497 }, "id": "_Ix9KPR3kuLu", - "outputId": "f2582a38-46d0-4db4-bda7-f3fdc985144e" + "outputId": "2ed55002-b417-4526-88fd-ea844544d7a5" }, - "execution_count": 9, + "execution_count": null, "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.8287990333595944\n" + ] + }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], - "image/png": "\n" + "image/png": "\n" }, "metadata": { "needs_background": "light" } - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "0.8287968094269524\n" - ] } ] }, @@ -492,7 +522,7 @@ "metadata": { "id": "Ed3SBjR0Defv" }, - "execution_count": 10, + "execution_count": null, "outputs": [] }, { @@ -500,23 +530,33 @@ "source": [ "# setup model\n", "clear_mem()\n", - "af = af2rank(\"5CEG_AD_trim.pdb\", chain=\"A,B\")" + "af = af2rank(\"5CEG_AD_trim.pdb\", chain=\"A,B\", model_name=SETTINGS[\"model_name\"])\n", + "SCORES,LABELS = [],[]" ], "metadata": { "id": "lySWA526TtUQ" }, - "execution_count": 11, + "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ - "SCORES = []\n", - "for mut,x in seqs.items():\n", - " SCORES.append(af.predict(seq=x[\"seq\"], **SETTINGS, extras={\"fitness\":x[\"sco\"], \"mut\":mut}))" + "for label,x in seqs.items():\n", + " if label not in LABELS:\n", + "\n", + " if save_output_pdbs:\n", + " output_pdb = os.path.join(\"tmp\",f\"{label}.pdb\")\n", + " else:\n", + " output_pdb = None\n", + "\n", + " score = af.predict(seq=x[\"seq\"], **SETTINGS, output_pdb=output_pdb,\n", + " extras={\"fitness\":x[\"sco\"], \"id\":label})\n", + " SCORES.append(score)\n", + " LABELS.append(label)" ], "metadata": { - "id": "cMfdjfIbr_PV" + "id": "fYUmIaJBL4j7" }, "execution_count": null, "outputs": [] @@ -524,82 +564,21 @@ { "cell_type": "code", "source": [ - "plot_me(SCORES, x=\"fitness\", y=\"composite\", scale_axis=False)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 474 - }, - "id": "HaYwiXeop_is", - "outputId": "202c8345-3c82-4938-849d-e8425823b2c5" - }, - "execution_count": 18, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "
" - ], - "image/png": "\n" - }, - "metadata": { - "needs_background": "light" - } - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "0.37408836098630754\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "plot_me(SCORES, x=\"fitness\", y=\"dgram_cce\", scale_axis=False)" + "SCORES[1]" ], "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 475 - }, - "id": "Mybxni5mqcLD", - "outputId": "90ae7780-cd3e-4f91-ac9e-81f3a44b899b" + "id": "3qKYfNDWvpLd" }, - "execution_count": 21, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "
" - ], - "image/png": "\n" - }, - "metadata": { - "needs_background": "light" - } - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "-0.37727701197990876\n" - ] - } - ] + "execution_count": null, + "outputs": [] }, { "cell_type": "code", "source": [ - "" + "plot_me(SCORES, x=\"fitness\", y=\"composite\", scale_axis=False)" ], "metadata": { - "id": "n65JnuQaMYlA" + "id": "mpBcceKdSlOG" }, "execution_count": null, "outputs": [] diff --git a/af/examples/afdesign_hotspot_test.ipynb b/af/examples/afdesign_hotspot_test.ipynb index ef4d6b49..9ec33d4a 100644 --- a/af/examples/afdesign_hotspot_test.ipynb +++ b/af/examples/afdesign_hotspot_test.ipynb @@ -33,7 +33,7 @@ "if [ ! -d params ]; then\n", " pip -q install git+https://github.com/sokrypton/ColabDesign.git\n", " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n", + " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", " for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1\n", " do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done\n", "fi" diff --git a/af/examples/binder_hallucination.ipynb b/af/examples/binder_hallucination.ipynb index 9889a7a0..6faa2c05 100644 --- a/af/examples/binder_hallucination.ipynb +++ b/af/examples/binder_hallucination.ipynb @@ -38,7 +38,7 @@ "if [ ! -d params ]; then\n", " pip -q install git+https://github.com/sokrypton/ColabDesign.git\n", " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n", + " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", " for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1\n", " do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done\n", "fi" @@ -123,27 +123,26 @@ "#@markdown - The contact definition is based on Cb-Cb diststance `cutoff`. To avoid \n", "#@markdown biasing towards helical contact, only contacts with sequence seperation > \n", "#@markdown `seqsep` are considered.\n", - "seqsep = 0 #@param [\"0\",\"5\",\"9\"] {type:\"raw\"}\n", - "cutoff = \"max\" #@param [\"8\", \"14\", \"max\"]\n", - "num = \"max\" #@param [\"1\", \"2\", \"4\", \"8\", \"max\"]\n", - "binary = True #@param {type:\"boolean\"}\n", + "seqsep = 9 #@param [\"0\",\"5\",\"9\"] {type:\"raw\"}\n", + "cutoff = \"14\" #@param [\"8\", \"14\", \"max\"]\n", + "num = \"2\" #@param [\"1\", \"2\", \"4\", \"8\"]\n", + "binary = False #@param {type:\"boolean\"}\n", "if cutoff == \"max\": cutoff = 21.6875\n", "if num == \"max\": num = binder_len\n", "\n", "opt = {\"con\":{\"seqsep\":int(seqsep),\"cutoff\":float(cutoff),\"num\":int(num),\"binary\":binary}}\n", "#@markdown ---\n", "#@markdown ####interface Weights\n", - "i_pae = 1.0 #@param [\"0.01\", \"0.1\", \"0.5\", \"1.0\"] {type:\"raw\"}\n", - "i_con = 0.5 #@param [\"0.01\", \"0.1\", \"0.5\", \"1.0\"] {type:\"raw\"}\n", + "tb_con = 0.0 #@param [\"0.0\", \"0.1\", \"0.5\", \"1.0\"] {type:\"raw\"}\n", + "bt_con = 1.0 #@param [\"0.0\", \"0.1\", \"0.5\", \"1.0\"] {type:\"raw\"}\n", "\n", - "weights.update({\"i_pae\":float(i_pae),\"i_con\":float(i_con)})\n", + "weights.update({\"bt_pae\":float(bt_pae),\"tb_con\":float(tb_con)})\n", "\n", "#@markdown ####interface Contact Definition\n", "cutoff = \"max\" #@param [\"8\", \"14\", \"max\"]\n", - "num = \"max\" #@param [\"1\", \"2\", \"4\", \"8\", \"max\"]\n", - "binary = True #@param {type:\"boolean\"}\n", + "num = \"1\" #@param [\"1\", \"2\", \"4\", \"8\"]\n", + "binary = False #@param {type:\"boolean\"}\n", "if cutoff == \"max\": cutoff = 21.6875\n", - "if num == \"max\": num = binder_len\n", "\n", "opt.update({\"i_con\":{\"cutoff\":float(cutoff),\"num\":int(num),\"binary\":binary}})\n", "\n", diff --git a/af/examples/disulfide_design.ipynb b/af/examples/disulfide_design.ipynb index 8ec2a835..5a0231cc 100644 --- a/af/examples/disulfide_design.ipynb +++ b/af/examples/disulfide_design.ipynb @@ -38,7 +38,7 @@ " ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\n", " # download params\n", " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n", + " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", " for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1\n", " do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done\n", "fi" @@ -86,7 +86,7 @@ "import random\n", "from jax.lax import dynamic_slice\n", "import jax.numpy as jnp\n", - "from colabdesign.af.loss import get_con_loss\n", + "from colabdesign.af.loss import _get_con_loss\n", "\n", "def generate_disulfide_pattern(L, disulfide_num, min_sep=5):\n", " disulfide_pattern = []\n", @@ -121,7 +121,7 @@ " for pair in disulfide_pattern:\n", " i,j = pair\n", " pair_dgram = dynamic_slice(dgram, (i,j,0), (1,1,len(dgram_bins))) + dynamic_slice(dgram, (j,i,0), (1,1,len(dgram_bins)))\n", - " disulfide_loss += get_con_loss(pair_dgram, dgram_bins, cutoff=7.0, binary=False, num=1)\n", + " disulfide_loss += _get_con_loss(pair_dgram, dgram_bins, cutoff=7.0, binary=False, num=1)\n", " return disulfide_loss.mean()\n", "\n", " # add disulfide loss here:\n", diff --git a/af/examples/hallucination.ipynb b/af/examples/hallucination.ipynb index c199e114..48b65f97 100644 --- a/af/examples/hallucination.ipynb +++ b/af/examples/hallucination.ipynb @@ -38,7 +38,7 @@ "if [ ! -d params ]; then\n", " pip -q install git+https://github.com/sokrypton/ColabDesign.git\n", " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n", + " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", " for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1\n", " do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done\n", "fi" diff --git a/af/examples/hallucination_custom_loss.ipynb b/af/examples/hallucination_custom_loss.ipynb index 208ece8a..42f057bd 100644 --- a/af/examples/hallucination_custom_loss.ipynb +++ b/af/examples/hallucination_custom_loss.ipynb @@ -38,7 +38,7 @@ " ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\n", " # download params\n", " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n", + " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", " for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1\n", " do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done\n", "fi" diff --git a/af/examples/partial_hallucination_rewire.ipynb b/af/examples/partial_hallucination_rewire.ipynb index 284d7796..ad72619b 100644 --- a/af/examples/partial_hallucination_rewire.ipynb +++ b/af/examples/partial_hallucination_rewire.ipynb @@ -33,7 +33,7 @@ "if [ ! -d params ]; then\n", " pip -q install git+https://github.com/sokrypton/ColabDesign.git\n", " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n", + " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", " for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1\n", " do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done\n", "fi" diff --git a/af/examples/use_esm_1b_bias.ipynb b/af/examples/use_esm_1b_bias.ipynb index 8d32b896..4da84648 100644 --- a/af/examples/use_esm_1b_bias.ipynb +++ b/af/examples/use_esm_1b_bias.ipynb @@ -147,7 +147,7 @@ "if [ ! -d params ]; then\n", " pip -q install git+https://github.com/sokrypton/ColabDesign.git\n", " mkdir params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n", + " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params\n", " for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1\n", " do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done\n", "fi" diff --git a/colabdesign/af/README.md b/colabdesign/af/README.md index d3068458..3c905288 100644 --- a/colabdesign/af/README.md +++ b/colabdesign/af/README.md @@ -3,4 +3,5 @@ - loss.py - configure the loss - prep.py - prep features - design.py - gradient update loop -- utils.py - various tools for saving/plotting \ No newline at end of file +- utils.py - various tools for saving/plotting +- crop.py - functions specific to cropping \ No newline at end of file diff --git a/colabdesign/af/alphafold/common/confidence_jax.py b/colabdesign/af/alphafold/common/confidence_jax.py index 209c6f96..6e041ac1 100644 --- a/colabdesign/af/alphafold/common/confidence_jax.py +++ b/colabdesign/af/alphafold/common/confidence_jax.py @@ -20,8 +20,6 @@ import jax from jax import jit - -@jit def compute_plddt_jax(logits): """Port of confidence.compute_plddt to jax @@ -40,7 +38,6 @@ def compute_plddt_jax(logits): predicted_lddt_ca = jax.numpy.sum(probs * bin_centers[None, :], axis=-1) return predicted_lddt_ca * 100 -@jit def _calculate_bin_centers(breaks): """Gets the bin centers from the bin edges. @@ -59,13 +56,12 @@ def _calculate_bin_centers(breaks): axis=0) return bin_centers -@jit def predicted_tm_score_jax( logits, breaks, residue_weights=None, asym_id=None, - interface: bool = False): + interface=False): """Computes predicted TM alignment or predicted interface TM alignment score. Args: @@ -114,12 +110,10 @@ def predicted_tm_score_jax( pair_residue_weights = pair_mask * ( residue_weights[None, :] * residue_weights[:, None]) - normed_residue_mask = pair_residue_weights / (1e-8 + jax.numpy.sum( - pair_residue_weights, axis=-1, keepdims=True)) + normed_residue_mask = pair_residue_weights / (1e-8 + pair_residue_weights.sum(-1,keepdims=True)) per_alignment = jax.numpy.sum(predicted_tm_term * normed_residue_mask, axis=-1) return jax.numpy.asarray(per_alignment[(per_alignment * residue_weights).argmax()]) - def get_confidence_metrics( prediction_result, multimer_mode: bool): @@ -152,8 +146,6 @@ def get_confidence_metrics( return confidence_metrics - -@jit def _calculate_expected_aligned_error( alignment_confidence_breaks, aligned_distance_error_probs): @@ -175,8 +167,6 @@ def _calculate_expected_aligned_error( return (jax.numpy.sum(aligned_distance_error_probs * bin_centers, axis=-1), jax.numpy.asarray(bin_centers[-1])) - -@jit def compute_predicted_aligned_error( logits, breaks): diff --git a/colabdesign/af/alphafold/data/pipeline_multimer.py b/colabdesign/af/alphafold/data/pipeline_multimer.py new file mode 100644 index 00000000..012408cd --- /dev/null +++ b/colabdesign/af/alphafold/data/pipeline_multimer.py @@ -0,0 +1,284 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for building the features for the AlphaFold multimer model.""" + +import collections +import contextlib +import copy +import dataclasses +import json +import os +import tempfile +from typing import Mapping, MutableMapping, Sequence + +from absl import logging +from alphafold.common import protein +from alphafold.common import residue_constants +from alphafold.data import feature_processing +from alphafold.data import msa_pairing +from alphafold.data import parsers +from alphafold.data import pipeline +from alphafold.data.tools import jackhmmer +import numpy as np + +# Internal import (7716). + + +@dataclasses.dataclass(frozen=True) +class _FastaChain: + sequence: str + description: str + + +def _make_chain_id_map(*, + sequences: Sequence[str], + descriptions: Sequence[str], + ) -> Mapping[str, _FastaChain]: + """Makes a mapping from PDB-format chain ID to sequence and description.""" + if len(sequences) != len(descriptions): + raise ValueError('sequences and descriptions must have equal length. ' + f'Got {len(sequences)} != {len(descriptions)}.') + if len(sequences) > protein.PDB_MAX_CHAINS: + raise ValueError('Cannot process more chains than the PDB format supports. ' + f'Got {len(sequences)} chains.') + chain_id_map = {} + for chain_id, sequence, description in zip( + protein.PDB_CHAIN_IDS, sequences, descriptions): + chain_id_map[chain_id] = _FastaChain( + sequence=sequence, description=description) + return chain_id_map + + +@contextlib.contextmanager +def temp_fasta_file(fasta_str: str): + with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file: + fasta_file.write(fasta_str) + fasta_file.seek(0) + yield fasta_file.name + + +def convert_monomer_features( + monomer_features: pipeline.FeatureDict, + chain_id: str) -> pipeline.FeatureDict: + """Reshapes and modifies monomer features for multimer models.""" + converted = {} + converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_) + unnecessary_leading_dim_feats = { + 'sequence', 'domain_name', 'num_alignments', 'seq_length'} + for feature_name, feature in monomer_features.items(): + if feature_name in unnecessary_leading_dim_feats: + # asarray ensures it's a np.ndarray. + feature = np.asarray(feature[0], dtype=feature.dtype) + elif feature_name == 'aatype': + # The multimer model performs the one-hot operation itself. + feature = np.argmax(feature, axis=-1).astype(np.int32) + elif feature_name == 'template_aatype': + feature = np.argmax(feature, axis=-1).astype(np.int32) + new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + feature = np.take(new_order_list, feature.astype(np.int32), axis=0) + elif feature_name == 'template_all_atom_mask': + feature_name = 'template_all_atom_mask' + converted[feature_name] = feature + return converted + + +def int_id_to_str_id(num: int) -> str: + """Encodes a number as a string, using reverse spreadsheet style naming. + + Args: + num: A positive integer. + + Returns: + A string that encodes the positive integer using reverse spreadsheet style, + naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the + usual way to encode chain IDs in mmCIF files. + """ + if num <= 0: + raise ValueError(f'Only positive integers allowed, got {num}.') + + num = num - 1 # 1-based indexing. + output = [] + while num >= 0: + output.append(chr(num % 26 + ord('A'))) + num = num // 26 - 1 + return ''.join(output) + + +def add_assembly_features( + all_chain_features: MutableMapping[str, pipeline.FeatureDict], + ) -> MutableMapping[str, pipeline.FeatureDict]: + """Add features to distinguish between chains. + + Args: + all_chain_features: A dictionary which maps chain_id to a dictionary of + features for each chain. + + Returns: + all_chain_features: A dictionary which maps strings of the form + `_` to the corresponding chain features. E.g. two + chains from a homodimer would have keys A_1 and A_2. Two chains from a + heterodimer would have keys A_1 and B_1. + """ + # Group the chains by sequence + seq_to_entity_id = {} + grouped_chains = collections.defaultdict(list) + for chain_id, chain_features in all_chain_features.items(): + seq = str(chain_features['sequence']) + if seq not in seq_to_entity_id: + seq_to_entity_id[seq] = len(seq_to_entity_id) + 1 + grouped_chains[seq_to_entity_id[seq]].append(chain_features) + + new_all_chain_features = {} + chain_id = 1 + for entity_id, group_chain_features in grouped_chains.items(): + for sym_id, chain_features in enumerate(group_chain_features, start=1): + new_all_chain_features[ + f'{int_id_to_str_id(entity_id)}_{sym_id}'] = chain_features + seq_length = chain_features['seq_length'] + chain_features['asym_id'] = chain_id * np.ones(seq_length) + chain_features['sym_id'] = sym_id * np.ones(seq_length) + chain_features['entity_id'] = entity_id * np.ones(seq_length) + chain_id += 1 + + return new_all_chain_features + + +def pad_msa(np_example, min_num_seq): + np_example = dict(np_example) + num_seq = np_example['msa'].shape[0] + if num_seq < min_num_seq: + for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask'): + np_example[feat] = np.pad( + np_example[feat], ((0, min_num_seq - num_seq), (0, 0))) + np_example['cluster_bias_mask'] = np.pad( + np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq),)) + return np_example + + +class DataPipeline: + """Runs the alignment tools and assembles the input features.""" + + def __init__(self, + monomer_data_pipeline: pipeline.DataPipeline, + jackhmmer_binary_path: str, + uniprot_database_path: str, + max_uniprot_hits: int = 50000, + use_precomputed_msas: bool = False): + """Initializes the data pipeline. + + Args: + monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs + the data pipeline for the monomer AlphaFold system. + jackhmmer_binary_path: Location of the jackhmmer binary. + uniprot_database_path: Location of the unclustered uniprot sequences, that + will be searched with jackhmmer and used for MSA pairing. + max_uniprot_hits: The maximum number of hits to return from uniprot. + use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold. + """ + self._monomer_data_pipeline = monomer_data_pipeline + self._uniprot_msa_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=uniprot_database_path) + self._max_uniprot_hits = max_uniprot_hits + self.use_precomputed_msas = use_precomputed_msas + + def _process_single_chain( + self, + chain_id: str, + sequence: str, + description: str, + msa_output_dir: str, + is_homomer_or_monomer: bool) -> pipeline.FeatureDict: + """Runs the monomer pipeline on a single chain.""" + chain_fasta_str = f'>chain_{chain_id}\n{sequence}\n' + chain_msa_output_dir = os.path.join(msa_output_dir, chain_id) + if not os.path.exists(chain_msa_output_dir): + os.makedirs(chain_msa_output_dir) + with temp_fasta_file(chain_fasta_str) as chain_fasta_path: + logging.info('Running monomer pipeline on chain %s: %s', + chain_id, description) + chain_features = self._monomer_data_pipeline.process( + input_fasta_path=chain_fasta_path, + msa_output_dir=chain_msa_output_dir) + + # We only construct the pairing features if there are 2 or more unique + # sequences. + if not is_homomer_or_monomer: + all_seq_msa_features = self._all_seq_msa_features(chain_fasta_path, + chain_msa_output_dir) + chain_features.update(all_seq_msa_features) + return chain_features + + def _all_seq_msa_features(self, input_fasta_path, msa_output_dir): + """Get MSA features for unclustered uniprot, for pairing.""" + out_path = os.path.join(msa_output_dir, 'uniprot_hits.sto') + result = pipeline.run_msa_tool( + self._uniprot_msa_runner, input_fasta_path, out_path, 'sto', + self.use_precomputed_msas) + msa = parsers.parse_stockholm(result['sto']) + msa = msa.truncate(max_seqs=self._max_uniprot_hits) + all_seq_features = pipeline.make_msa_features([msa]) + valid_feats = msa_pairing.MSA_FEATURES + ( + 'msa_species_identifiers', + ) + feats = {f'{k}_all_seq': v for k, v in all_seq_features.items() + if k in valid_feats} + return feats + + def process(self, + input_fasta_path: str, + msa_output_dir: str) -> pipeline.FeatureDict: + """Runs alignment tools on the input sequences and creates features.""" + with open(input_fasta_path) as f: + input_fasta_str = f.read() + input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) + + chain_id_map = _make_chain_id_map(sequences=input_seqs, + descriptions=input_descs) + chain_id_map_path = os.path.join(msa_output_dir, 'chain_id_map.json') + with open(chain_id_map_path, 'w') as f: + chain_id_map_dict = {chain_id: dataclasses.asdict(fasta_chain) + for chain_id, fasta_chain in chain_id_map.items()} + json.dump(chain_id_map_dict, f, indent=4, sort_keys=True) + + all_chain_features = {} + sequence_features = {} + is_homomer_or_monomer = len(set(input_seqs)) == 1 + for chain_id, fasta_chain in chain_id_map.items(): + if fasta_chain.sequence in sequence_features: + all_chain_features[chain_id] = copy.deepcopy( + sequence_features[fasta_chain.sequence]) + continue + chain_features = self._process_single_chain( + chain_id=chain_id, + sequence=fasta_chain.sequence, + description=fasta_chain.description, + msa_output_dir=msa_output_dir, + is_homomer_or_monomer=is_homomer_or_monomer) + + chain_features = convert_monomer_features(chain_features, + chain_id=chain_id) + all_chain_features[chain_id] = chain_features + sequence_features[fasta_chain.sequence] = chain_features + + all_chain_features = add_assembly_features(all_chain_features) + + np_example = feature_processing.pair_and_merge( + all_chain_features=all_chain_features) + + # Pad MSA to avoid zero-sized extra_msa. + np_example = pad_msa(np_example, 512) + + return np_example diff --git a/colabdesign/af/alphafold/model/all_atom.py b/colabdesign/af/alphafold/model/all_atom.py index 1e6dbeef..c3bc5780 100644 --- a/colabdesign/af/alphafold/model/all_atom.py +++ b/colabdesign/af/alphafold/model/all_atom.py @@ -1086,9 +1086,9 @@ def frame_aligned_point_error( normed_error *= jnp.expand_dims(frames_mask, axis=-1) normed_error *= jnp.expand_dims(positions_mask, axis=-2) - normalization_factor = ( - jnp.sum(frames_mask, axis=-1) * - jnp.sum(positions_mask, axis=-1)) + mask = (jnp.expand_dims(frames_mask, axis=-1) * + jnp.expand_dims(positions_mask, axis=-2)) + normalization_factor = jnp.sum(mask, axis=(-1, -2)) return (jnp.sum(normed_error, axis=(-2, -1)) / (epsilon + normalization_factor)) diff --git a/colabdesign/af/alphafold/model/all_atom_multimer.py b/colabdesign/af/alphafold/model/all_atom_multimer.py new file mode 100644 index 00000000..382f5948 --- /dev/null +++ b/colabdesign/af/alphafold/model/all_atom_multimer.py @@ -0,0 +1,983 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ops for all atom representations.""" + +from typing import Dict, Text + +from colabdesign.af.alphafold.common import residue_constants +from colabdesign.af.alphafold.model import geometry +from colabdesign.af.alphafold.model import utils + +import jax +import jax.numpy as jnp +import numpy as np + +def squared_difference(x, y): + return jnp.square(x - y) + +def _make_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in residue_constants.restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in residue_constants.restypes: + residue_name = residue_constants.restype_1to3[residue_name] + residue_chi_angles = residue_constants.chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append( + [residue_constants.atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return np.array(chi_atom_indices) + + +def _make_renaming_matrices(): + """Matrices to map atoms to symmetry partners in ambiguous case.""" + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative groundtruth coordinates where the naming is swapped + restype_3 = [ + residue_constants.restype_1to3[res] for res in residue_constants.restypes + ] + restype_3 += ['UNK'] + # Matrices for renaming ambiguous atoms. + all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3} + for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): + correspondences = np.arange(14) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = residue_constants.restype_name_to_atom14_names[ + resname].index(source_atom_swap) + target_index = residue_constants.restype_name_to_atom14_names[ + resname].index(target_atom_swap) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = np.zeros((14, 14), dtype=np.float32) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1. + all_matrices[resname] = renaming_matrix.astype(np.float32) + renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3]) + return renaming_matrices + + +def _make_restype_atom37_mask(): + """Mask of which atoms are present for which residue type in atom37.""" + # create the corresponding mask + restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) + for restype, restype_letter in enumerate(residue_constants.restypes): + restype_name = residue_constants.restype_1to3[restype_letter] + atom_names = residue_constants.residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = residue_constants.atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + return restype_atom37_mask + + +def _make_restype_atom14_mask(): + """Mask of which atoms are present for which residue type in atom14.""" + restype_atom14_mask = [] + + for rt in residue_constants.restypes: + atom_names = residue_constants.restype_name_to_atom14_names[ + residue_constants.restype_1to3[rt]] + restype_atom14_mask.append([(1. if name else 0.) for name in atom_names]) + + restype_atom14_mask.append([0.] * 14) + restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) + return restype_atom14_mask + + +def _make_restype_atom37_to_atom14(): + """Map from atom37 to atom14 per residue type.""" + restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14 + for rt in residue_constants.restypes: + atom_names = residue_constants.restype_name_to_atom14_names[ + residue_constants.restype_1to3[rt]] + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append([ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in residue_constants.atom_types + ]) + + restype_atom37_to_atom14.append([0] * 37) + restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32) + return restype_atom37_to_atom14 + + +def _make_restype_atom14_to_atom37(): + """Map from atom14 to atom37 per residue type.""" + restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37 + for rt in residue_constants.restypes: + atom_names = residue_constants.restype_name_to_atom14_names[ + residue_constants.restype_1to3[rt]] + restype_atom14_to_atom37.append([ + (residue_constants.atom_order[name] if name else 0) + for name in atom_names + ]) + # Add dummy mapping for restype 'UNK' + restype_atom14_to_atom37.append([0] * 14) + restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32) + return restype_atom14_to_atom37 + + +def _make_restype_atom14_is_ambiguous(): + """Mask which atoms are ambiguous in atom14.""" + # create an ambiguous atoms mask. shape: (21, 14) + restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32) + for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): + for atom_name1, atom_name2 in swap.items(): + restype = residue_constants.restype_order[ + residue_constants.restype_3to1[resname]] + atom_idx1 = residue_constants.restype_name_to_atom14_names[resname].index( + atom_name1) + atom_idx2 = residue_constants.restype_name_to_atom14_names[resname].index( + atom_name2) + restype_atom14_is_ambiguous[restype, atom_idx1] = 1 + restype_atom14_is_ambiguous[restype, atom_idx2] = 1 + + return restype_atom14_is_ambiguous + + +def _make_restype_rigidgroup_base_atom37_idx(): + """Create Map from rigidgroups to atom37 indices.""" + # Create an array with the atom names. + # shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3) + base_atom_names = np.full([21, 8, 3], '', dtype=object) + + # 0: backbone frame + base_atom_names[:, 0, :] = ['C', 'CA', 'N'] + + # 3: 'psi-group' + base_atom_names[:, 3, :] = ['CA', 'C', 'O'] + + # 4,5,6,7: 'chi1,2,3,4-group' + for restype, restype_letter in enumerate(residue_constants.restypes): + resname = residue_constants.restype_1to3[restype_letter] + for chi_idx in range(4): + if residue_constants.chi_angles_mask[restype][chi_idx]: + atom_names = residue_constants.chi_angles_atoms[resname][chi_idx] + base_atom_names[restype, chi_idx + 4, :] = atom_names[1:] + + # Translate atom names into atom37 indices. + lookuptable = residue_constants.atom_order.copy() + lookuptable[''] = 0 + restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])( + base_atom_names) + return restype_rigidgroup_base_atom37_idx + + +CHI_ATOM_INDICES = _make_chi_atom_indices() +RENAMING_MATRICES = _make_renaming_matrices() +RESTYPE_ATOM14_TO_ATOM37 = _make_restype_atom14_to_atom37() +RESTYPE_ATOM37_TO_ATOM14 = _make_restype_atom37_to_atom14() +RESTYPE_ATOM37_MASK = _make_restype_atom37_mask() +RESTYPE_ATOM14_MASK = _make_restype_atom14_mask() +RESTYPE_ATOM14_IS_AMBIGUOUS = _make_restype_atom14_is_ambiguous() +RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX = _make_restype_rigidgroup_base_atom37_idx() + +# Create mask for existing rigid groups. +RESTYPE_RIGIDGROUP_MASK = np.zeros([21, 8], dtype=np.float32) +RESTYPE_RIGIDGROUP_MASK[:, 0] = 1 +RESTYPE_RIGIDGROUP_MASK[:, 3] = 1 +RESTYPE_RIGIDGROUP_MASK[:20, 4:] = residue_constants.chi_angles_mask + + +def get_atom37_mask(aatype): + if not jnp.issubdtype(aatype.dtype, jnp.integer): aatype = aatype.argmax(-1) + return utils.batched_gather(jnp.asarray(RESTYPE_ATOM37_MASK), aatype) + +def get_atom14_mask(aatype): + if not jnp.issubdtype(aatype.dtype, jnp.integer): aatype = aatype.argmax(-1) + return utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_MASK), aatype) + +def get_atom14_is_ambiguous(aatype): + if not jnp.issubdtype(aatype.dtype, jnp.integer): aatype = aatype.argmax(-1) + return utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_IS_AMBIGUOUS), aatype) + +def get_atom14_to_atom37_map(aatype): + if not jnp.issubdtype(aatype.dtype, jnp.integer): aatype = aatype.argmax(-1) + return utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_TO_ATOM37), aatype) + +def get_atom37_to_atom14_map(aatype): + if not jnp.issubdtype(aatype.dtype, jnp.integer): aatype = aatype.argmax(-1) + return utils.batched_gather(jnp.asarray(RESTYPE_ATOM37_TO_ATOM14), aatype) + +def atom14_to_atom37(atom14_data: jnp.ndarray, # (N, 14, ...) + aatype: jnp.ndarray + ) -> jnp.ndarray: # (N, 37, ...) + """Convert atom14 to atom37 representation.""" + if not jnp.issubdtype(aatype.dtype, jnp.integer): + aatype = aatype.argmax(-1) + + assert len(atom14_data.shape) in [2, 3] + idx_atom37_to_atom14 = get_atom37_to_atom14_map(aatype) + atom37_data = utils.batched_gather( + atom14_data, idx_atom37_to_atom14, batch_dims=1) + atom37_mask = get_atom37_mask(aatype) + if len(atom14_data.shape) == 2: + atom37_data *= atom37_mask + elif len(atom14_data.shape) == 3: + atom37_data *= atom37_mask[:, :, None].astype(atom37_data.dtype) + return atom37_data + + +def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask): + """Convert Atom37 positions to Atom14 positions.""" + residx_atom14_to_atom37 = utils.batched_gather( + jnp.asarray(RESTYPE_ATOM14_TO_ATOM37), aatype) + atom14_mask = utils.batched_gather( + all_atom_mask, residx_atom14_to_atom37, batch_dims=1).astype(jnp.float32) + # create a mask for known groundtruth positions + atom14_mask *= utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_MASK), aatype) + # gather the groundtruth positions + atom14_positions = jax.tree_map( + lambda x: utils.batched_gather(x, residx_atom14_to_atom37, batch_dims=1), + all_atom_pos) + atom14_positions = atom14_mask * atom14_positions + return atom14_positions, atom14_mask + + +def get_alt_atom14(aatype, positions: geometry.Vec3Array, mask): + """Get alternative atom14 positions.""" + # pick the transformation matrices for the given residue sequence + # shape (num_res, 14, 14) + renaming_transform = utils.batched_gather( + jnp.asarray(RENAMING_MATRICES), aatype) + + alternative_positions = jax.tree_map( + lambda x: jnp.sum(x, axis=1), positions[:, :, None] * renaming_transform) + + # Create the mask for the alternative ground truth (differs from the + # ground truth mask, if only one of the atoms in an ambiguous pair has a + # ground truth position) + alternative_mask = jnp.sum(mask[..., None] * renaming_transform, axis=1) + + return alternative_positions, alternative_mask + + +def atom37_to_frames( + aatype: jnp.ndarray, # (...) + all_atom_positions: geometry.Vec3Array, # (..., 37) + all_atom_mask: jnp.ndarray, # (..., 37) +) -> Dict[Text, jnp.ndarray]: + if not jnp.issubdtype(aatype.dtype, jnp.integer): aatype = aatype.argmax(-1) + + """Computes the frames for the up to 8 rigid groups for each residue.""" + # 0: 'backbone group', + # 1: 'pre-omega-group', (empty) + # 2: 'phi-group', (currently empty, because it defines only hydrogens) + # 3: 'psi-group', + # 4,5,6,7: 'chi1,2,3,4-group' + aatype_in_shape = aatype.shape + + # If there is a batch axis, just flatten it away, and reshape everything + # back at the end of the function. + aatype = jnp.reshape(aatype, [-1]) + all_atom_positions = jax.tree_map(lambda x: jnp.reshape(x, [-1, 37]), + all_atom_positions) + all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37]) + + # Compute the gather indices for all residues in the chain. + # shape (N, 8, 3) + residx_rigidgroup_base_atom37_idx = utils.batched_gather( + RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX, aatype) + + # Gather the base atom positions for each rigid group. + base_atom_pos = jax.tree_map( + lambda x: utils.batched_gather( # pylint: disable=g-long-lambda + x, residx_rigidgroup_base_atom37_idx, batch_dims=1), + all_atom_positions) + + # Compute the Rigids. + point_on_neg_x_axis = base_atom_pos[:, :, 0] + origin = base_atom_pos[:, :, 1] + point_on_xy_plane = base_atom_pos[:, :, 2] + gt_rotation = geometry.Rot3Array.from_two_vectors( + origin - point_on_neg_x_axis, point_on_xy_plane - origin) + + gt_frames = geometry.Rigid3Array(gt_rotation, origin) + + # Compute a mask whether the group exists. + # (N, 8) + group_exists = utils.batched_gather(RESTYPE_RIGIDGROUP_MASK, aatype) + + # Compute a mask whether ground truth exists for the group + gt_atoms_exist = utils.batched_gather( # shape (N, 8, 3) + all_atom_mask.astype(jnp.float32), + residx_rigidgroup_base_atom37_idx, + batch_dims=1) + gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists # (N, 8) + + # Adapt backbone frame to old convention (mirror x-axis and z-axis). + rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1]) + rots[0, 0, 0] = -1 + rots[0, 2, 2] = -1 + gt_frames = gt_frames.compose_rotation( + geometry.Rot3Array.from_array(rots)) + + # The frames for ambiguous rigid groups are just rotated by 180 degree around + # the x-axis. The ambiguous group is always the last chi-group. + restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32) + restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1]) + + for resname, _ in residue_constants.residue_atom_renaming_swaps.items(): + restype = residue_constants.restype_order[ + residue_constants.restype_3to1[resname]] + chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1) + restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1 + restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1 + restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1 + + # Gather the ambiguity information for each residue. + residx_rigidgroup_is_ambiguous = utils.batched_gather( + restype_rigidgroup_is_ambiguous, aatype) + ambiguity_rot = utils.batched_gather(restype_rigidgroup_rots, aatype) + ambiguity_rot = geometry.Rot3Array.from_array(ambiguity_rot) + + # Create the alternative ground truth frames. + alt_gt_frames = gt_frames.compose_rotation(ambiguity_rot) + + fix_shape = lambda x: jnp.reshape(x, aatype_in_shape + (8,)) + + # reshape back to original residue layout + gt_frames = jax.tree_map(fix_shape, gt_frames) + gt_exists = fix_shape(gt_exists) + group_exists = fix_shape(group_exists) + residx_rigidgroup_is_ambiguous = fix_shape(residx_rigidgroup_is_ambiguous) + alt_gt_frames = jax.tree_map(fix_shape, alt_gt_frames) + + return { + 'rigidgroups_gt_frames': gt_frames, # Rigid (..., 8) + 'rigidgroups_gt_exists': gt_exists, # (..., 8) + 'rigidgroups_group_exists': group_exists, # (..., 8) + 'rigidgroups_group_is_ambiguous': + residx_rigidgroup_is_ambiguous, # (..., 8) + 'rigidgroups_alt_gt_frames': alt_gt_frames, # Rigid (..., 8) + } + + +def torsion_angles_to_frames( + aatype: jnp.ndarray, # (N) + backb_to_global: geometry.Rigid3Array, # (N) + torsion_angles_sin_cos: jnp.ndarray # (N, 7, 2) +) -> geometry.Rigid3Array: # (N, 8) + """Compute rigid group frames from torsion angles.""" + if not jnp.issubdtype(aatype.dtype, jnp.integer): + aatype = aatype.argmax(-1) + + assert len(aatype.shape) == 1, ( + f'Expected array of rank 1, got array with shape: {aatype.shape}.') + assert len(backb_to_global.rotation.shape) == 1, ( + f'Expected array of rank 1, got array with shape: ' + f'{backb_to_global.rotation.shape}') + assert len(torsion_angles_sin_cos.shape) == 3, ( + f'Expected array of rank 3, got array with shape: ' + f'{torsion_angles_sin_cos.shape}') + assert torsion_angles_sin_cos.shape[1] == 7, ( + f'wrong shape {torsion_angles_sin_cos.shape}') + assert torsion_angles_sin_cos.shape[2] == 2, ( + f'wrong shape {torsion_angles_sin_cos.shape}') + + # Gather the default frames for all rigid groups. + # geometry.Rigid3Array with shape (N, 8) + m = utils.batched_gather(residue_constants.restype_rigid_group_default_frame, + aatype) + default_frames = geometry.Rigid3Array.from_array4x4(m) + + # Create the rotation matrices according to the given angles (each frame is + # defined such that its rotation is around the x-axis). + sin_angles = torsion_angles_sin_cos[..., 0] + cos_angles = torsion_angles_sin_cos[..., 1] + + # insert zero rotation for backbone group. + num_residues, = aatype.shape + sin_angles = jnp.concatenate([jnp.zeros([num_residues, 1]), sin_angles], + axis=-1) + cos_angles = jnp.concatenate([jnp.ones([num_residues, 1]), cos_angles], + axis=-1) + zeros = jnp.zeros_like(sin_angles) + ones = jnp.ones_like(sin_angles) + + # all_rots are geometry.Rot3Array with shape (N, 8) + all_rots = geometry.Rot3Array(ones, zeros, zeros, + zeros, cos_angles, -sin_angles, + zeros, sin_angles, cos_angles) + + # Apply rotations to the frames. + all_frames = default_frames.compose_rotation(all_rots) + + # chi2, chi3, and chi4 frames do not transform to the backbone frame but to + # the previous frame. So chain them up accordingly. + + chi1_frame_to_backb = all_frames[:, 4] + chi2_frame_to_backb = chi1_frame_to_backb @ all_frames[:, 5] + chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[:, 6] + chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[:, 7] + + all_frames_to_backb = jax.tree_map( + lambda *x: jnp.concatenate(x, axis=-1), all_frames[:, 0:5], + chi2_frame_to_backb[:, None], chi3_frame_to_backb[:, None], + chi4_frame_to_backb[:, None]) + + # Create the global frames. + # shape (N, 8) + all_frames_to_global = backb_to_global[:, None] @ all_frames_to_backb + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + aatype: jnp.ndarray, # (N) + all_frames_to_global: geometry.Rigid3Array # (N, 8) +) -> geometry.Vec3Array: # (N, 14) + """Put atom literature positions (atom14 encoding) in each rigid group.""" + + if not jnp.issubdtype(aatype.dtype, jnp.integer): + aatype = aatype.argmax(-1) + + # Pick the appropriate transform for every atom. + residx_to_group_idx = utils.batched_gather( + residue_constants.restype_atom14_to_rigid_group, aatype) + group_mask = jax.nn.one_hot( + residx_to_group_idx, num_classes=8) # shape (N, 14, 8) + + # geometry.Rigid3Array with shape (N, 14) + map_atoms_to_global = jax.tree_map( + lambda x: jnp.sum(x[:, None, :] * group_mask, axis=-1), + all_frames_to_global) + + # Gather the literature atom positions for each residue. + # geometry.Vec3Array with shape (N, 14) + lit_positions = geometry.Vec3Array.from_array( + utils.batched_gather( + residue_constants.restype_atom14_rigid_group_positions, aatype)) + + # Transform each atom from its local frame to the global frame. + # geometry.Vec3Array with shape (N, 14) + pred_positions = map_atoms_to_global.apply_to_point(lit_positions) + + # Mask out non-existing atoms. + mask = utils.batched_gather(residue_constants.restype_atom14_mask, aatype) + pred_positions = pred_positions * mask + + return pred_positions + + +def extreme_ca_ca_distance_violations( + positions: geometry.Vec3Array, # (N, 37(14)) + mask: jnp.ndarray, # (N, 37(14)) + residue_index: jnp.ndarray, # (N) + max_angstrom_tolerance=1.5 + ) -> jnp.ndarray: + """Counts residues whose Ca is a large distance from its neighbor.""" + this_ca_pos = positions[:-1, 1] # (N - 1,) + this_ca_mask = mask[:-1, 1] # (N - 1) + next_ca_pos = positions[1:, 1] # (N - 1,) + next_ca_mask = mask[1:, 1] # (N - 1) + has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype( + jnp.float32) + ca_ca_distance = geometry.euclidean_distance(this_ca_pos, next_ca_pos, 1e-6) + violations = (ca_ca_distance - + residue_constants.ca_ca) > max_angstrom_tolerance + mask = this_ca_mask * next_ca_mask * has_no_gap_mask + return utils.mask_mean(mask=mask, value=violations) + + +def between_residue_bond_loss( + pred_atom_positions: geometry.Vec3Array, # (N, 37(14)) + pred_atom_mask: jnp.ndarray, # (N, 37(14)) + residue_index: jnp.ndarray, # (N) + aatype: jnp.ndarray, # (N) + tolerance_factor_soft=12.0, + tolerance_factor_hard=12.0) -> Dict[Text, jnp.ndarray]: + """Flat-bottom loss to penalize structural violations between residues.""" + if not jnp.issubdtype(aatype.dtype, jnp.integer): + aatype = aatype.argmax(-1) + + assert len(pred_atom_positions.shape) == 2 + assert len(pred_atom_mask.shape) == 2 + assert len(residue_index.shape) == 1 + assert len(aatype.shape) == 1 + + # Get the positions of the relevant backbone atoms. + this_ca_pos = pred_atom_positions[:-1, 1] # (N - 1) + this_ca_mask = pred_atom_mask[:-1, 1] # (N - 1) + this_c_pos = pred_atom_positions[:-1, 2] # (N - 1) + this_c_mask = pred_atom_mask[:-1, 2] # (N - 1) + next_n_pos = pred_atom_positions[1:, 0] # (N - 1) + next_n_mask = pred_atom_mask[1:, 0] # (N - 1) + next_ca_pos = pred_atom_positions[1:, 1] # (N - 1) + next_ca_mask = pred_atom_mask[1:, 1] # (N - 1) + has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype( + jnp.float32) + + # Compute loss for the C--N bond. + c_n_bond_length = geometry.euclidean_distance(this_c_pos, next_n_pos, 1e-6) + + # The C-N bond to proline has slightly different length because of the ring. + next_is_proline = ( + aatype[1:] == residue_constants.restype_order['P']).astype(jnp.float32) + gt_length = ( + (1. - next_is_proline) * residue_constants.between_res_bond_length_c_n[0] + + next_is_proline * residue_constants.between_res_bond_length_c_n[1]) + gt_stddev = ( + (1. - next_is_proline) * + residue_constants.between_res_bond_length_stddev_c_n[0] + + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1]) + c_n_bond_length_error = jnp.sqrt(1e-6 + + jnp.square(c_n_bond_length - gt_length)) + c_n_loss_per_residue = jax.nn.relu( + c_n_bond_length_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * has_no_gap_mask + c_n_loss = jnp.sum(mask * c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6) + c_n_violation_mask = mask * ( + c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)) + + # Compute loss for the angles. + c_ca_unit_vec = (this_ca_pos - this_c_pos).normalized(1e-6) + c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length + n_ca_unit_vec = (next_ca_pos - next_n_pos).normalized(1e-6) + + ca_c_n_cos_angle = c_ca_unit_vec.dot(c_n_unit_vec) + gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0] + gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0] + ca_c_n_cos_angle_error = jnp.sqrt( + 1e-6 + jnp.square(ca_c_n_cos_angle - gt_angle)) + ca_c_n_loss_per_residue = jax.nn.relu( + ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask + ca_c_n_loss = jnp.sum(mask * ca_c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6) + ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error > + (tolerance_factor_hard * gt_stddev)) + + c_n_ca_cos_angle = (-c_n_unit_vec).dot(n_ca_unit_vec) + gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0] + gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1] + c_n_ca_cos_angle_error = jnp.sqrt( + 1e-6 + jnp.square(c_n_ca_cos_angle - gt_angle)) + c_n_ca_loss_per_residue = jax.nn.relu( + c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask + c_n_ca_loss = jnp.sum(mask * c_n_ca_loss_per_residue) / (jnp.sum(mask) + 1e-6) + c_n_ca_violation_mask = mask * ( + c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)) + + # Compute a per residue loss (equally distribute the loss to both + # neighbouring residues). + per_residue_loss_sum = (c_n_loss_per_residue + + ca_c_n_loss_per_residue + + c_n_ca_loss_per_residue) + per_residue_loss_sum = 0.5 * (jnp.pad(per_residue_loss_sum, [[0, 1]]) + + jnp.pad(per_residue_loss_sum, [[1, 0]])) + + # Compute hard violations. + violation_mask = jnp.max( + jnp.stack([c_n_violation_mask, + ca_c_n_violation_mask, + c_n_ca_violation_mask]), axis=0) + violation_mask = jnp.maximum( + jnp.pad(violation_mask, [[0, 1]]), + jnp.pad(violation_mask, [[1, 0]])) + + return {'c_n_loss_mean': c_n_loss, # shape () + 'ca_c_n_loss_mean': ca_c_n_loss, # shape () + 'c_n_ca_loss_mean': c_n_ca_loss, # shape () + 'per_residue_loss_sum': per_residue_loss_sum, # shape (N) + 'per_residue_violation_mask': violation_mask # shape (N) + } + + +def between_residue_clash_loss( + pred_positions: geometry.Vec3Array, # (N, 14) + atom_exists: jnp.ndarray, # (N, 14) + atom_radius: jnp.ndarray, # (N, 14) + residue_index: jnp.ndarray, # (N) + asym_id: jnp.ndarray, # (N) + overlap_tolerance_soft=1.5, + overlap_tolerance_hard=1.5) -> Dict[Text, jnp.ndarray]: + """Loss to penalize steric clashes between residues.""" + assert len(pred_positions.shape) == 2 + assert len(atom_exists.shape) == 2 + assert len(atom_radius.shape) == 2 + assert len(residue_index.shape) == 1 + + # Create the distance matrix. + # (N, N, 14, 14) + dists = geometry.euclidean_distance(pred_positions[:, None, :, None], + pred_positions[None, :, None, :], 1e-10) + + # Create the mask for valid distances. + # shape (N, N, 14, 14) + dists_mask = (atom_exists[:, None, :, None] * atom_exists[None, :, None, :]) + + # Mask out all the duplicate entries in the lower triangular matrix. + # Also mask out the diagonal (atom-pairs from the same residue) -- these atoms + # are handled separately. + dists_mask *= ( + residue_index[:, None, None, None] < residue_index[None, :, None, None]) + + # Backbone C--N bond between subsequent residues is no clash. + c_one_hot = jax.nn.one_hot(2, num_classes=14) + n_one_hot = jax.nn.one_hot(0, num_classes=14) + neighbour_mask = ((residue_index[:, None] + 1) == residue_index[None, :]) + neighbour_mask &= (asym_id[:, None] == asym_id[None, :]) + neighbour_mask = neighbour_mask[..., None, None] + c_n_bonds = neighbour_mask * c_one_hot[None, None, :, + None] * n_one_hot[None, None, None, :] + dists_mask *= (1. - c_n_bonds) + + # Disulfide bridge between two cysteines is no clash. + cys_sg_idx = residue_constants.restype_name_to_atom14_names['CYS'].index('SG') + cys_sg_one_hot = jax.nn.one_hot(cys_sg_idx, num_classes=14) + disulfide_bonds = (cys_sg_one_hot[None, None, :, None] * + cys_sg_one_hot[None, None, None, :]) + dists_mask *= (1. - disulfide_bonds) + + # Compute the lower bound for the allowed distances. + # shape (N, N, 14, 14) + dists_lower_bound = dists_mask * ( + atom_radius[:, None, :, None] + atom_radius[None, :, None, :]) + + # Compute the error. + # shape (N, N, 14, 14) + dists_to_low_error = dists_mask * jax.nn.relu( + dists_lower_bound - overlap_tolerance_soft - dists) + + # Compute the mean loss. + # shape () + mean_loss = (jnp.sum(dists_to_low_error) + / (1e-6 + jnp.sum(dists_mask))) + + # Compute the per atom loss sum. + # shape (N, 14) + per_atom_loss_sum = (jnp.sum(dists_to_low_error, axis=[0, 2]) + + jnp.sum(dists_to_low_error, axis=[1, 3])) + + # Compute the hard clash mask. + # shape (N, N, 14, 14) + clash_mask = dists_mask * ( + dists < (dists_lower_bound - overlap_tolerance_hard)) + + # Compute the per atom clash. + # shape (N, 14) + per_atom_clash_mask = jnp.maximum( + jnp.max(clash_mask, axis=[0, 2]), + jnp.max(clash_mask, axis=[1, 3])) + + return {'mean_loss': mean_loss, # shape () + 'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14) + 'per_atom_clash_mask': per_atom_clash_mask # shape (N, 14) + } + + +def within_residue_violations( + pred_positions: geometry.Vec3Array, # (N, 14) + atom_exists: jnp.ndarray, # (N, 14) + dists_lower_bound: jnp.ndarray, # (N, 14, 14) + dists_upper_bound: jnp.ndarray, # (N, 14, 14) + tighten_bounds_for_loss=0.0, +) -> Dict[Text, jnp.ndarray]: + """Find within-residue violations.""" + assert len(pred_positions.shape) == 2 + assert len(atom_exists.shape) == 2 + assert len(dists_lower_bound.shape) == 3 + assert len(dists_upper_bound.shape) == 3 + + # Compute the mask for each residue. + # shape (N, 14, 14) + dists_masks = (1. - jnp.eye(14, 14)[None]) + dists_masks *= (atom_exists[:, :, None] * atom_exists[:, None, :]) + + # Distance matrix + # shape (N, 14, 14) + dists = geometry.euclidean_distance(pred_positions[:, :, None], + pred_positions[:, None, :], 1e-10) + + # Compute the loss. + # shape (N, 14, 14) + dists_to_low_error = jax.nn.relu( + dists_lower_bound + tighten_bounds_for_loss - dists) + dists_to_high_error = jax.nn.relu( + dists + tighten_bounds_for_loss - dists_upper_bound) + loss = dists_masks * (dists_to_low_error + dists_to_high_error) + + # Compute the per atom loss sum. + # shape (N, 14) + per_atom_loss_sum = (jnp.sum(loss, axis=1) + + jnp.sum(loss, axis=2)) + + # Compute the violations mask. + # shape (N, 14, 14) + violations = dists_masks * ((dists < dists_lower_bound) | + (dists > dists_upper_bound)) + + # Compute the per atom violations. + # shape (N, 14) + per_atom_violations = jnp.maximum( + jnp.max(violations, axis=1), jnp.max(violations, axis=2)) + + return {'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14) + 'per_atom_violations': per_atom_violations # shape (N, 14) + } + + +def find_optimal_renaming( + gt_positions: geometry.Vec3Array, # (N, 14) + alt_gt_positions: geometry.Vec3Array, # (N, 14) + atom_is_ambiguous: jnp.ndarray, # (N, 14) + gt_exists: jnp.ndarray, # (N, 14) + pred_positions: geometry.Vec3Array, # (N, 14) +) -> jnp.ndarray: # (N): + """Find optimal renaming for ground truth that maximizes LDDT.""" + assert len(gt_positions.shape) == 2 + assert len(alt_gt_positions.shape) == 2 + assert len(atom_is_ambiguous.shape) == 2 + assert len(gt_exists.shape) == 2 + assert len(pred_positions.shape) == 2 + + # Create the pred distance matrix. + # shape (N, N, 14, 14) + pred_dists = geometry.euclidean_distance(pred_positions[:, None, :, None], + pred_positions[None, :, None, :], + 1e-10) + + # Compute distances for ground truth with original and alternative names. + # shape (N, N, 14, 14) + gt_dists = geometry.euclidean_distance(gt_positions[:, None, :, None], + gt_positions[None, :, None, :], 1e-10) + + alt_gt_dists = geometry.euclidean_distance(alt_gt_positions[:, None, :, None], + alt_gt_positions[None, :, None, :], + 1e-10) + + # Compute LDDT's. + # shape (N, N, 14, 14) + lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, gt_dists)) + alt_lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, alt_gt_dists)) + + # Create a mask for ambiguous atoms in rows vs. non-ambiguous atoms + # in cols. + # shape (N ,N, 14, 14) + mask = ( + gt_exists[:, None, :, None] * # rows + atom_is_ambiguous[:, None, :, None] * # rows + gt_exists[None, :, None, :] * # cols + (1. - atom_is_ambiguous[None, :, None, :])) # cols + + # Aggregate distances for each residue to the non-amibuguous atoms. + # shape (N) + per_res_lddt = jnp.sum(mask * lddt, axis=[1, 2, 3]) + alt_per_res_lddt = jnp.sum(mask * alt_lddt, axis=[1, 2, 3]) + + # Decide for each residue, whether alternative naming is better. + # shape (N) + alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).astype(jnp.float32) + + return alt_naming_is_better # shape (N) + + +def frame_aligned_point_error( + pred_frames: geometry.Rigid3Array, # shape (num_frames) + target_frames: geometry.Rigid3Array, # shape (num_frames) + frames_mask: jnp.ndarray, # shape (num_frames) + pred_positions: geometry.Vec3Array, # shape (num_positions) + target_positions: geometry.Vec3Array, # shape (num_positions) + positions_mask: jnp.ndarray, # shape (num_positions) + pair_mask: jnp.ndarray, # shape (num_frames, num_posiitons) + l1_clamp_distance: float, + length_scale=20., + epsilon=1e-4) -> jnp.ndarray: # shape () + """Measure point error under different alignements. + + Computes error between two structures with B points + under A alignments derived form the given pairs of frames. + Args: + pred_frames: num_frames reference frames for 'pred_positions'. + target_frames: num_frames reference frames for 'target_positions'. + frames_mask: Mask for frame pairs to use. + pred_positions: num_positions predicted positions of the structure. + target_positions: num_positions target positions of the structure. + positions_mask: Mask on which positions to score. + pair_mask: A (num_frames, num_positions) mask to use in the loss, useful + for separating intra from inter chain losses. + l1_clamp_distance: Distance cutoff on error beyond which gradients will + be zero. + length_scale: length scale to divide loss by. + epsilon: small value used to regularize denominator for masked average. + Returns: + Masked Frame aligned point error. + """ + # For now we do not allow any batch dimensions. + assert len(pred_frames.rotation.shape) == 1 + assert len(target_frames.rotation.shape) == 1 + assert frames_mask.ndim == 1 + assert pred_positions.x.ndim == 1 + assert target_positions.x.ndim == 1 + assert positions_mask.ndim == 1 + + # Compute array of predicted positions in the predicted frames. + # geometry.Vec3Array (num_frames, num_positions) + local_pred_pos = pred_frames[:, None].inverse().apply_to_point( + pred_positions[None, :]) + + # Compute array of target positions in the target frames. + # geometry.Vec3Array (num_frames, num_positions) + local_target_pos = target_frames[:, None].inverse().apply_to_point( + target_positions[None, :]) + + # Compute errors between the structures. + # jnp.ndarray (num_frames, num_positions) + error_dist = geometry.euclidean_distance(local_pred_pos, local_target_pos, + epsilon) + + clipped_error_dist = jnp.clip(error_dist, 0, l1_clamp_distance) + + normed_error = clipped_error_dist / length_scale + normed_error *= jnp.expand_dims(frames_mask, axis=-1) + normed_error *= jnp.expand_dims(positions_mask, axis=-2) + if pair_mask is not None: + normed_error *= pair_mask + + mask = (jnp.expand_dims(frames_mask, axis=-1) * + jnp.expand_dims(positions_mask, axis=-2)) + if pair_mask is not None: + mask *= pair_mask + normalization_factor = jnp.sum(mask, axis=(-1, -2)) + return (jnp.sum(normed_error, axis=(-2, -1)) / + (epsilon + normalization_factor)) + + +def get_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in residue_constants.restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in residue_constants.restypes: + residue_name = residue_constants.restype_1to3[residue_name] + residue_chi_angles = residue_constants.chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append( + [residue_constants.atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return jnp.asarray(chi_atom_indices) + + +def compute_chi_angles(positions: geometry.Vec3Array, + mask: geometry.Vec3Array, + aatype: geometry.Vec3Array): + """Computes the chi angles given all atom positions and the amino acid type. + + Args: + positions: A Vec3Array of shape + [num_res, residue_constants.atom_type_num], with positions of + atoms needed to calculate chi angles. Supports up to 1 batch dimension. + mask: An optional tensor of shape + [num_res, residue_constants.atom_type_num] that masks which atom + positions are set for each residue. If given, then the chi mask will be + set to 1 for a chi angle only if the amino acid has that chi angle and all + the chi atoms needed to calculate that chi angle are set. If not given + (set to None), the chi mask will be set to 1 for a chi angle if the amino + acid has that chi angle and whether the actual atoms needed to calculate + it were set will be ignored. + aatype: A tensor of shape [num_res] with amino acid type integer + code (0 to 21). Supports up to 1 batch dimension. + + Returns: + A tuple of tensors (chi_angles, mask), where both have shape + [num_res, 4]. The mask masks out unused chi angles for amino acid + types that have less than 4 chi angles. If atom_positions_mask is set, the + chi mask will also mask out uncomputable chi angles. + """ + + if not jnp.issubdtype(aatype.dtype, jnp.integer): + aatype = aatype.argmax(-1) + # Don't assert on the num_res and batch dimensions as they might be unknown. + + assert positions.shape[-1] == residue_constants.atom_type_num + assert mask.shape[-1] == residue_constants.atom_type_num + + # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4]. + chi_atom_indices = get_chi_atom_indices() + # Select atoms to compute chis. Shape: [num_res, chis=4, atoms=4]. + atom_indices = utils.batched_gather( + params=chi_atom_indices, indices=aatype, axis=0) + # Gather atom positions. Shape: [num_res, chis=4, atoms=4, xyz=3]. + chi_angle_atoms = jax.tree_map( + lambda x: utils.batched_gather( # pylint: disable=g-long-lambda + params=x, indices=atom_indices, axis=-1, batch_dims=1), positions) + a, b, c, d = [chi_angle_atoms[..., i] for i in range(4)] + + chi_angles = geometry.dihedral_angle(a, b, c, d) + + # Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4]. + chi_angles_mask = list(residue_constants.chi_angles_mask) + chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) + chi_angles_mask = jnp.asarray(chi_angles_mask) + # Compute the chi angle mask. Shape [num_res, chis=4]. + chi_mask = utils.batched_gather(params=chi_angles_mask, indices=aatype, + axis=0) + + # The chi_mask is set to 1 only when all necessary chi angle atoms were set. + # Gather the chi angle atoms mask. Shape: [num_res, chis=4, atoms=4]. + chi_angle_atoms_mask = utils.batched_gather( + params=mask, indices=atom_indices, axis=-1, batch_dims=1) + # Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4]. + chi_angle_atoms_mask = jnp.prod(chi_angle_atoms_mask, axis=[-1]) + chi_mask = chi_mask * chi_angle_atoms_mask.astype(jnp.float32) + + return chi_angles, chi_mask + +def make_transform_from_reference( + a_xyz: geometry.Vec3Array, + b_xyz: geometry.Vec3Array, + c_xyz: geometry.Vec3Array) -> geometry.Rigid3Array: + """Returns rotation and translation matrices to convert from reference. + + Note that this method does not take care of symmetries. If you provide the + coordinates in the non-standard way, the A atom will end up in the negative + y-axis rather than in the positive y-axis. You need to take care of such + cases in your code. + + Args: + a_xyz: A Vec3Array. + b_xyz: A Vec3Array. + c_xyz: A Vec3Array. + + Returns: + A Rigid3Array which, when applied to coordinates in a canonicalized + reference frame, will give coordinates approximately equal + the original coordinates (in the global frame). + """ + rotation = geometry.Rot3Array.from_two_vectors(c_xyz - b_xyz, + a_xyz - b_xyz) + return geometry.Rigid3Array(rotation, b_xyz) diff --git a/colabdesign/af/alphafold/model/common_modules.py b/colabdesign/af/alphafold/model/common_modules.py index f239c870..8c405eee 100644 --- a/colabdesign/af/alphafold/model/common_modules.py +++ b/colabdesign/af/alphafold/model/common_modules.py @@ -13,72 +13,113 @@ # limitations under the License. """A collection of common Haiku modules for use in protein folding.""" +import numbers +from typing import Union, Sequence + import haiku as hk import jax.numpy as jnp +import numpy as np -class Linear(hk.Module): - """Protein folding specific Linear Module. +# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) +TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(.87962566103423978, + dtype=np.float32) + + +def get_initializer_scale(initializer_name, input_shape): + """Get Initializer for weights and scale to multiply activations by.""" + + if initializer_name == 'zeros': + w_init = hk.initializers.Constant(0.0) + else: + # fan-in scaling + scale = 1. + for channel_dim in input_shape: + scale /= channel_dim + if initializer_name == 'relu': + scale *= 2 + + noise_scale = scale + + stddev = np.sqrt(noise_scale) + # Adjust stddev for truncation. + stddev = stddev / TRUNCATED_NORMAL_STDDEV_FACTOR + w_init = hk.initializers.TruncatedNormal(mean=0.0, stddev=stddev) + + return w_init + +class Linear(hk.Module): + """Protein folding specific Linear module. This differs from the standard Haiku Linear in a few ways: - * It supports inputs of arbitrary rank + * It supports inputs and outputs of arbitrary rank * Initializers are specified by strings """ def __init__(self, - num_output: int, + num_output: Union[int, Sequence[int]], initializer: str = 'linear', + num_input_dims: int = 1, use_bias: bool = True, bias_init: float = 0., + precision = None, name: str = 'linear'): """Constructs Linear Module. - Args: - num_output: number of output channels. + num_output: Number of output channels. Can be tuple when outputting + multiple dimensions. initializer: What initializer to use, should be one of {'linear', 'relu', 'zeros'} + num_input_dims: Number of dimensions from the end to project. use_bias: Whether to include trainable bias bias_init: Value used to initialize bias. - name: name of module, used for name scopes. + precision: What precision to use for matrix multiplication, defaults + to None. + name: Name of module, used for name scopes. """ - super().__init__(name=name) - self.num_output = num_output + if isinstance(num_output, numbers.Integral): + self.output_shape = (num_output,) + else: + self.output_shape = tuple(num_output) self.initializer = initializer self.use_bias = use_bias self.bias_init = bias_init + self.num_input_dims = num_input_dims + self.num_output_dims = len(self.output_shape) + self.precision = precision - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + def __call__(self, inputs): """Connects Module. - Args: - inputs: Tensor of shape [..., num_channel] - + inputs: Tensor with at least num_input_dims dimensions. Returns: - output of shape [..., num_output] + output of shape [...] + num_output. """ - n_channels = int(inputs.shape[-1]) - weight_shape = [n_channels, self.num_output] - if self.initializer == 'linear': - weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=1.) - elif self.initializer == 'relu': - weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=2.) - elif self.initializer == 'zeros': - weight_init = hk.initializers.Constant(0.0) + num_input_dims = self.num_input_dims + if self.num_input_dims > 0: + in_shape = inputs.shape[-self.num_input_dims:] + else: + in_shape = () + + weight_init = get_initializer_scale(self.initializer, in_shape) + + in_letters = 'abcde'[:self.num_input_dims] + out_letters = 'hijkl'[:self.num_output_dims] + + weight_shape = in_shape + self.output_shape weights = hk.get_parameter('weights', weight_shape, inputs.dtype, weight_init) - # this is equivalent to einsum('...c,cd->...d', inputs, weights) - # but turns out to be slightly faster - inputs = jnp.swapaxes(inputs, -1, -2) - output = jnp.einsum('...cb,cd->...db', inputs, weights) - output = jnp.swapaxes(output, -1, -2) + equation = f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}' + + output = jnp.einsum(equation, inputs, weights, precision=self.precision) if self.use_bias: - bias = hk.get_parameter('bias', [self.num_output], inputs.dtype, + bias = hk.get_parameter('bias', self.output_shape, inputs.dtype, hk.initializers.Constant(self.bias_init)) output += bias - return output + return output \ No newline at end of file diff --git a/colabdesign/af/alphafold/model/config.py b/colabdesign/af/alphafold/model/config.py index 03612cda..0076f803 100644 --- a/colabdesign/af/alphafold/model/config.py +++ b/colabdesign/af/alphafold/model/config.py @@ -23,40 +23,34 @@ NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES - def model_config(name: str) -> ml_collections.ConfigDict: """Get the ConfigDict of a CASP14 model.""" + if 'multimer' in name: + return CONFIG_MULTIMER + if name not in CONFIG_DIFFS: raise ValueError(f'Invalid model name {name}.') cfg = copy.deepcopy(CONFIG) cfg.update_from_flattened_dict(CONFIG_DIFFS[name]) return cfg - CONFIG_DIFFS = { 'model_1': { # Jumper et al. (2021) Suppl. Table 5, Model 1.1.1 - 'data.common.max_extra_msa': 5120, - 'data.common.reduce_msa_clusters_by_max_templates': True, - 'data.common.use_templates': True, 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, 'model.embeddings_and_evoformer.template.enabled': True }, 'model_2': { # Jumper et al. (2021) Suppl. Table 5, Model 1.1.2 - 'data.common.reduce_msa_clusters_by_max_templates': True, - 'data.common.use_templates': True, 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, 'model.embeddings_and_evoformer.template.enabled': True }, 'model_3': { # Jumper et al. (2021) Suppl. Table 5, Model 1.2.1 - 'data.common.max_extra_msa': 5120, }, 'model_4': { # Jumper et al. (2021) Suppl. Table 5, Model 1.2.2 - 'data.common.max_extra_msa': 5120, }, 'model_5': { # Jumper et al. (2021) Suppl. Table 5, Model 1.2.3 @@ -66,26 +60,19 @@ def model_config(name: str) -> ml_collections.ConfigDict: # with an additional predicted_aligned_error head that can produce # predicted TM-score (pTM) and predicted aligned errors. 'model_1_ptm': { - 'data.common.max_extra_msa': 5120, - 'data.common.reduce_msa_clusters_by_max_templates': True, - 'data.common.use_templates': True, 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, 'model.embeddings_and_evoformer.template.enabled': True, 'model.heads.predicted_aligned_error.weight': 0.1 }, 'model_2_ptm': { - 'data.common.reduce_msa_clusters_by_max_templates': True, - 'data.common.use_templates': True, 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, 'model.embeddings_and_evoformer.template.enabled': True, 'model.heads.predicted_aligned_error.weight': 0.1 }, 'model_3_ptm': { - 'data.common.max_extra_msa': 5120, 'model.heads.predicted_aligned_error.weight': 0.1 }, 'model_4_ptm': { - 'data.common.max_extra_msa': 5120, 'model.heads.predicted_aligned_error.weight': 0.1 }, 'model_5_ptm': { @@ -95,29 +82,6 @@ def model_config(name: str) -> ml_collections.ConfigDict: CONFIG = ml_collections.ConfigDict({ 'data': { - 'common': { - 'masked_msa': { - 'profile_prob': 0.1, - 'same_prob': 0.1, - 'uniform_prob': 0.1 - }, - 'max_extra_msa': 1024, - 'msa_cluster_features': True, - 'num_recycle': 3, - 'reduce_msa_clusters_by_max_templates': False, - 'resample_msa_in_recycling': True, - 'template_features': [ - 'template_all_atom_positions', 'template_sum_probs', - 'template_aatype', 'template_all_atom_masks', - 'template_domain_names' - ], - 'unsupervised_features': [ - 'aatype', 'residue_index', 'sequence', 'msa', 'domain_name', - 'num_alignments', 'seq_length', 'between_segment_residues', - 'deletion_matrix' - ], - 'use_templates': False, - }, 'eval': { 'feat': { 'aatype': [NUM_RES], @@ -161,7 +125,7 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'seq_mask': [NUM_RES], 'target_feat': [NUM_RES, None], 'template_aatype': [NUM_TEMPLATES, NUM_RES], - 'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None], + 'template_all_atom_mask': [NUM_TEMPLATES, NUM_RES, None], 'template_all_atom_positions': [ NUM_TEMPLATES, NUM_RES, None, None], 'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES], @@ -171,14 +135,11 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None], 'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES], 'template_sum_probs': [NUM_TEMPLATES, None], - 'true_msa': [NUM_MSA_SEQ, NUM_RES] + 'true_msa': [NUM_MSA_SEQ, NUM_RES], + 'asym_id': [NUM_RES], + 'sym_id': [NUM_RES], + 'entity_id': [NUM_RES] }, - 'fixed_size': True, - 'subsample_templates': False, # We want top templates. - 'masked_msa_replace_fraction': 0.15, - 'max_msa_clusters': 512, - 'max_templates': 4, - 'num_ensemble': 1, }, }, 'model': { @@ -206,6 +167,7 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'shared_dropout': True }, 'outer_product_mean': { + 'first': False, 'chunk_size': 128, 'dropout_rate': 0.0, 'num_outer_channel': 32, @@ -323,14 +285,13 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'shared_dropout': True } }, - 'max_templates': 4, 'subbatch_size': 128, 'use_template_unit_vector': False, } }, 'global_config': { - 'mixed_precision': False, 'deterministic': False, + 'multimer_mode': False, 'subbatch_size': 4, 'use_remat': False, 'zero_init': True @@ -406,9 +367,227 @@ def model_config(name: str) -> ml_collections.ConfigDict: }, }, 'num_recycle': 3, - 'backprop_recycle': False, - 'resample_msa_in_recycling': True, - 'add_prev': False, 'use_struct': True, }, }) + +CONFIG_MULTIMER = ml_collections.ConfigDict({ + 'model': { + 'embeddings_and_evoformer': { + 'evoformer_num_block': 48, + 'evoformer': { + 'msa_column_attention': { + 'dropout_rate': 0.0, + 'gating': True, + 'num_head': 8, + 'orientation': 'per_column', + 'shared_dropout': True + }, + 'msa_row_attention_with_pair_bias': { + 'dropout_rate': 0.15, + 'gating': True, + 'num_head': 8, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'msa_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'outer_product_mean': { + 'chunk_size': 128, + 'dropout_rate': 0.0, + 'first': True, + 'num_outer_channel': 32, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'pair_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_attention_ending_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_column', + 'shared_dropout': True + }, + 'triangle_attention_starting_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_multiplication_incoming': { + 'dropout_rate': 0.25, + 'equation': 'kjc,kic->ijc', + 'num_intermediate_channel': 128, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_multiplication_outgoing': { + 'dropout_rate': 0.25, + 'equation': 'ikc,jkc->ijc', + 'num_intermediate_channel': 128, + 'orientation': 'per_row', + 'shared_dropout': True + } + }, + 'extra_msa_channel': 64, + 'extra_msa_stack_num_block': 4, + 'num_extra_msa': 1, + 'masked_msa': { + 'profile_prob': 0.1, + 'replace_fraction': 0.15, + 'same_prob': 0.1, + 'uniform_prob': 0.1 + }, + 'use_chain_relative': True, + 'max_relative_chain': 2, + 'max_relative_idx': 32, + 'seq_channel': 384, + 'msa_channel': 256, + 'pair_channel': 128, + 'prev_pos': { + 'max_bin': 20.75, + 'min_bin': 3.25, + 'num_bins': 15 + }, + 'recycle_features': True, + 'recycle_pos': True, + 'template': { + 'attention': { + 'gating': False, + 'num_head': 4 + }, + 'dgram_features': { + 'max_bin': 50.75, + 'min_bin': 3.25, + 'num_bins': 39 + }, + 'enabled': True, + 'num_channels': 64, + 'subbatch_size': 128, + 'template_pair_stack': { + 'num_block': 2, + 'pair_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 2, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_attention_ending_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_column', + 'shared_dropout': True + }, + 'triangle_attention_starting_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_multiplication_incoming': { + 'dropout_rate': 0.25, + 'equation': 'kjc,kic->ijc', + 'num_intermediate_channel': 64, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_multiplication_outgoing': { + 'dropout_rate': 0.25, + 'equation': 'ikc,jkc->ijc', + 'num_intermediate_channel': 64, + 'orientation': 'per_row', + 'shared_dropout': True + } + } + }, + }, + 'global_config': { + 'deterministic': False, + 'multimer_mode': True, + 'subbatch_size': 4, + 'use_remat': False, + 'zero_init': True + }, + 'heads': { + 'distogram': { + 'first_break': 2.3125, + 'last_break': 21.6875, + 'num_bins': 64, + 'weight': 0.3 + }, + 'experimentally_resolved': { + 'filter_by_resolution': True, + 'max_resolution': 3.0, + 'min_resolution': 0.1, + 'weight': 0.01 + }, + 'masked_msa': { + 'weight': 2.0 + }, + 'predicted_aligned_error': { + 'filter_by_resolution': True, + 'max_error_bin': 31.0, + 'max_resolution': 3.0, + 'min_resolution': 0.1, + 'num_bins': 64, + 'num_channels': 128, + 'weight': 0.1 + }, + 'predicted_lddt': { + 'filter_by_resolution': True, + 'max_resolution': 3.0, + 'min_resolution': 0.1, + 'num_bins': 50, + 'num_channels': 128, + 'weight': 0.01 + }, + 'structure_module': { + 'angle_norm_weight': 0.01, + 'chi_weight': 0.5, + 'clash_overlap_tolerance': 1.5, + 'dropout': 0.1, + 'interface_fape': { + 'atom_clamp_distance': 1000.0, + 'loss_unit_distance': 20.0 + }, + 'intra_chain_fape': { + 'atom_clamp_distance': 10.0, + 'loss_unit_distance': 10.0 + }, + 'num_channel': 384, + 'num_head': 12, + 'num_layer': 8, + 'num_layer_in_transition': 3, + 'num_point_qk': 4, + 'num_point_v': 8, + 'num_scalar_qk': 16, + 'num_scalar_v': 16, + 'position_scale': 20.0, + 'sidechain': { + 'atom_clamp_distance': 10.0, + 'loss_unit_distance': 10.0, + 'num_channel': 128, + 'num_residual_block': 2, + 'weight_frac': 0.5 + }, + 'structural_violation_loss_weight': 1.0, + 'violation_tolerance_factor': 12.0, + 'weight': 1.0 + } + }, + 'num_recycle': 3, + 'use_struct': True, + } +}) \ No newline at end of file diff --git a/colabdesign/af/alphafold/model/folding.py b/colabdesign/af/alphafold/model/folding.py index 92c4d040..85a87b64 100644 --- a/colabdesign/af/alphafold/model/folding.py +++ b/colabdesign/af/alphafold/model/folding.py @@ -463,7 +463,7 @@ def fold_iter(act, key): class dummy(hk.Module): - def __init__(self, config, global_config, compute_loss=True): + def __init__(self, config, global_config): super().__init__(name="dummy") def __call__(self, representations, batch, is_training, safe_key=None): if safe_key is None: @@ -476,12 +476,11 @@ class StructureModule(hk.Module): Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" """ - def __init__(self, config, global_config, compute_loss=True, + def __init__(self, config, global_config, name='structure_module'): super().__init__(name=name) self.config = config self.global_config = global_config - self.compute_loss = compute_loss def __call__(self, representations, batch, is_training, safe_key=None): @@ -515,46 +514,6 @@ def __call__(self, representations, batch, is_training, return ret - def loss(self, value, batch): - ret = {'loss': 0.} - - ret['metrics'] = {} - # If requested, compute in-graph metrics. - if self.config.compute_in_graph_metrics: - atom14_pred_positions = value['final_atom14_positions'] - # Compute renaming and violations. - value.update(compute_renamed_ground_truth(batch, atom14_pred_positions)) - value['violations'] = find_structural_violations( - batch, atom14_pred_positions, self.config) - - # Several violation metrics: - violation_metrics = compute_violation_metrics( - batch=batch, - atom14_pred_positions=atom14_pred_positions, - violations=value['violations']) - ret['metrics'].update(violation_metrics) - - backbone_loss(ret, batch, value, self.config) - - if 'renamed_atom14_gt_positions' not in value: - value.update(compute_renamed_ground_truth( - batch, value['final_atom14_positions'])) - sc_loss = sidechain_loss(batch, value, self.config) - - ret['loss'] = ((1 - self.config.sidechain.weight_frac) * ret['loss'] + - self.config.sidechain.weight_frac * sc_loss['loss']) - ret['sidechain_fape'] = sc_loss['fape'] - - supervised_chi_loss(ret, batch, value, self.config) - - if self.config.structural_violation_loss_weight: - if 'violations' not in value: - value['violations'] = find_structural_violations( - batch, value['final_atom14_positions'], self.config) - structural_violation_loss(ret, batch, value, self.config) - - return ret - def compute_renamed_ground_truth( batch: Dict[str, jnp.ndarray], @@ -613,7 +572,7 @@ def compute_renamed_ground_truth( } -def backbone_loss(ret, batch, value, config): +def backbone_loss(batch, value, config): """Backbone FAPE Loss. Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" line 17 @@ -672,9 +631,8 @@ def backbone_loss(ret, batch, value, config): backbone_mask) fape_loss = (fape_loss * use_clamped_fape + fape_loss_unclamped * (1 - use_clamped_fape)) - - ret['fape'] = fape_loss[-1] - ret['loss'] += jnp.mean(fape_loss) + + return jnp.mean(fape_loss), fape_loss[-1] def sidechain_loss(batch, value, config): diff --git a/colabdesign/af/alphafold/model/folding_multimer.py b/colabdesign/af/alphafold/model/folding_multimer.py new file mode 100644 index 00000000..33a87b9b --- /dev/null +++ b/colabdesign/af/alphafold/model/folding_multimer.py @@ -0,0 +1,1031 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Modules and utilities for the structure module in the multimer system.""" + +import functools +import numbers +from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union + +from colabdesign.af.alphafold.common import residue_constants +from colabdesign.af.alphafold.model import all_atom_multimer +from colabdesign.af.alphafold.model import common_modules +from colabdesign.af.alphafold.model import geometry +from colabdesign.af.alphafold.model import modules +from colabdesign.af.alphafold.model import prng +from colabdesign.af.alphafold.model import utils +from colabdesign.af.alphafold.model.geometry import utils as geometry_utils +import haiku as hk +import jax +import jax.numpy as jnp +import ml_collections +import numpy as np + + +EPSILON = 1e-8 +Float = Union[float, jnp.ndarray] + + +def squared_difference(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + """Computes Squared difference between two arrays.""" + return jnp.square(x - y) + + +def make_backbone_affine( + positions: geometry.Vec3Array, + mask: jnp.ndarray, + aatype: jnp.ndarray, + ) -> Tuple[geometry.Rigid3Array, jnp.ndarray]: + """Make backbone Rigid3Array and mask.""" + del aatype + a = residue_constants.atom_order['N'] + b = residue_constants.atom_order['CA'] + c = residue_constants.atom_order['C'] + + rigid_mask = (mask[:, a] * mask[:, b] * mask[:, c]).astype( + jnp.float32) + + rigid = all_atom_multimer.make_transform_from_reference( + a_xyz=positions[:, a], b_xyz=positions[:, b], c_xyz=positions[:, c]) + + return rigid, rigid_mask + + +class QuatRigid(hk.Module): + """Module for projecting Rigids via a quaternion.""" + + def __init__(self, + global_config: ml_collections.ConfigDict, + rigid_shape: Union[int, Iterable[int]] = tuple(), + full_quat: bool = False, + init: str = 'zeros', + name: str = 'quat_rigid'): + """Module projecting a Rigid Object. + + For this Module the Rotation is parametrized as a quaternion, + If 'full_quat' is True a 4 vector is produced for the rotation which is + normalized and treated as a quaternion. + When 'full_quat' is False a 3 vector is produced and the 1st component of + the quaternion is set to 1. + + Args: + global_config: Global Config, used to set certain properties of underlying + Linear module, see common_modules.Linear for details. + rigid_shape: Shape of Rigids relative to shape of activations, e.g. when + activations have shape (n,) and this is (m,) output will be (n, m) + full_quat: Whether to parametrize rotation using full quaternion. + init: initializer to use, see common_modules.Linear for details + name: Name to use for module. + """ + self.init = init + self.global_config = global_config + if isinstance(rigid_shape, int): + self.rigid_shape = (rigid_shape,) + else: + self.rigid_shape = tuple(rigid_shape) + self.full_quat = full_quat + super(QuatRigid, self).__init__(name=name) + + def __call__(self, activations: jnp.ndarray) -> geometry.Rigid3Array: + """Executes Module. + + This returns a set of rigid with the same shape as activations, projecting + the channel dimension, rigid_shape controls the trailing dimensions. + For example when activations is shape (12, 5) and rigid_shape is (3, 2) + then the shape of the output rigids will be (12, 3, 2). + This also supports passing in an empty tuple for rigid shape, in that case + the example would produce a rigid of shape (12,). + + Args: + activations: Activations to use for projection, shape [..., num_channel] + Returns: + Rigid transformations with shape [...] + rigid_shape + """ + if self.full_quat: + rigid_dim = 7 + else: + rigid_dim = 6 + linear_dims = self.rigid_shape + (rigid_dim,) + rigid_flat = common_modules.Linear( + linear_dims, + initializer=self.init, + precision=jax.lax.Precision.HIGHEST, + name='rigid')( + activations) + rigid_flat = geometry_utils.unstack(rigid_flat) + if self.full_quat: + qw, qx, qy, qz = rigid_flat[:4] + translation = rigid_flat[4:] + else: + qx, qy, qz = rigid_flat[:3] + qw = jnp.ones_like(qx) + translation = rigid_flat[3:] + rotation = geometry.Rot3Array.from_quaternion( + qw, qx, qy, qz, normalize=True) + translation = geometry.Vec3Array(*translation) + return geometry.Rigid3Array(rotation, translation) + + +class PointProjection(hk.Module): + """Given input reprensentation and frame produces points in global frame.""" + + def __init__(self, + num_points: Union[Iterable[int], int], + global_config: ml_collections.ConfigDict, + return_local_points: bool = False, + name: str = 'point_projection'): + """Constructs Linear Module. + + Args: + num_points: number of points to project. Can be tuple when outputting + multiple dimensions + global_config: Global Config, passed through to underlying Linear + return_local_points: Whether to return points in local frame as well. + name: name of module, used for name scopes. + """ + if isinstance(num_points, numbers.Integral): + self.num_points = (num_points,) + else: + self.num_points = tuple(num_points) + + self.return_local_points = return_local_points + + self.global_config = global_config + + super().__init__(name=name) + + def __call__( + self, activations: jnp.ndarray, rigids: geometry.Rigid3Array + ) -> Union[geometry.Vec3Array, Tuple[geometry.Vec3Array, geometry.Vec3Array]]: + output_shape = self.num_points + output_shape = output_shape[:-1] + (3 * output_shape[-1],) + points_local = common_modules.Linear( + output_shape, + precision=jax.lax.Precision.HIGHEST, + name='point_projection')( + activations) + points_local = jnp.split(points_local, 3, axis=-1) + points_local = geometry.Vec3Array(*points_local) + rigids = rigids[(...,) + (None,) * len(output_shape)] + points_global = rigids.apply_to_point(points_local) + if self.return_local_points: + return points_global, points_local + else: + return points_global + + +class InvariantPointAttention(hk.Module): + """Invariant point attention module. + + The high-level idea is that this attention module works over a set of points + and associated orientations in 3D space (e.g. protein residues). + + Each residue outputs a set of queries and keys as points in their local + reference frame. The attention is then defined as the euclidean distance + between the queries and keys in the global frame. + """ + + def __init__(self, + config: ml_collections.ConfigDict, + global_config: ml_collections.ConfigDict, + dist_epsilon: float = 1e-8, + name: str = 'invariant_point_attention'): + """Initialize. + + Args: + config: iterative Fold Head Config + global_config: Global Config of Model. + dist_epsilon: Small value to avoid NaN in distance calculation. + name: Sonnet name. + """ + super().__init__(name=name) + + self._dist_epsilon = dist_epsilon + self._zero_initialize_last = global_config.zero_init + + self.config = config + + self.global_config = global_config + + def __call__( + self, + inputs_1d: jnp.ndarray, + inputs_2d: jnp.ndarray, + mask: jnp.ndarray, + rigid: geometry.Rigid3Array, + ) -> jnp.ndarray: + """Compute geometric aware attention. + + Given a set of query residues (defined by affines and associated scalar + features), this function computes geometric aware attention between the + query residues and target residues. + + The residues produce points in their local reference frame, which + are converted into the global frame to get attention via euclidean distance. + + Equivalently the target residues produce points in their local frame to be + used as attention values, which are converted into the query residues local + frames. + + Args: + inputs_1d: (N, C) 1D input embedding that is the basis for the + scalar queries. + inputs_2d: (N, M, C') 2D input embedding, used for biases values in the + attention between query_inputs_1d and target_inputs_1d. + mask: (N, 1) mask to indicate query_inputs_1d that participate in + the attention. + rigid: Rigid object describing the position and orientation of + every element in query_inputs_1d. + + Returns: + Transformation of the input embedding. + """ + + num_head = self.config.num_head + + attn_logits = 0. + + num_point_qk = self.config.num_point_qk + # Each point pair (q, k) contributes Var [0.5 ||q||^2 - ] = 9 / 2 + point_variance = max(num_point_qk, 1) * 9. / 2 + point_weights = np.sqrt(1.0 / point_variance) + + # This is equivalent to jax.nn.softplus, but avoids a bug in the test... + softplus = lambda x: jnp.logaddexp(x, jnp.zeros_like(x)) + raw_point_weights = hk.get_parameter( + 'trainable_point_weights', + shape=[num_head], + # softplus^{-1} (1) + init=hk.initializers.Constant(np.log(np.exp(1.) - 1.))) + + # Trainable per-head weights for points. + trainable_point_weights = softplus(raw_point_weights) + point_weights *= trainable_point_weights + q_point = PointProjection([num_head, num_point_qk], + self.global_config, + name='q_point_projection')(inputs_1d, + rigid) + + k_point = PointProjection([num_head, num_point_qk], + self.global_config, + name='k_point_projection')(inputs_1d, + rigid) + + dist2 = geometry.square_euclidean_distance( + q_point[:, None, :, :], k_point[None, :, :, :], epsilon=0.) + attn_qk_point = -0.5 * jnp.sum(point_weights[:, None] * dist2, axis=-1) + attn_logits += attn_qk_point + + num_scalar_qk = self.config.num_scalar_qk + # We assume that all queries and keys come iid from N(0, 1) distribution + # and compute the variances of the attention logits. + # Each scalar pair (q, k) contributes Var q*k = 1 + scalar_variance = max(num_scalar_qk, 1) * 1. + scalar_weights = np.sqrt(1.0 / scalar_variance) + q_scalar = common_modules.Linear([num_head, num_scalar_qk], + use_bias=False, + name='q_scalar_projection')( + inputs_1d) + + k_scalar = common_modules.Linear([num_head, num_scalar_qk], + use_bias=False, + name='k_scalar_projection')( + inputs_1d) + q_scalar *= scalar_weights + attn_logits += jnp.einsum('qhc,khc->qkh', q_scalar, k_scalar) + + attention_2d = common_modules.Linear( + num_head, name='attention_2d')(inputs_2d) + attn_logits += attention_2d + + mask_2d = mask * jnp.swapaxes(mask, -1, -2) + attn_logits -= 1e5 * (1. - mask_2d[..., None]) + + attn_logits *= np.sqrt(1. / 3) # Normalize by number of logit terms (3) + attn = jax.nn.softmax(attn_logits, axis=-2) + + num_scalar_v = self.config.num_scalar_v + + v_scalar = common_modules.Linear([num_head, num_scalar_v], + use_bias=False, + name='v_scalar_projection')( + inputs_1d) + + # [num_query_residues, num_head, num_scalar_v] + result_scalar = jnp.einsum('qkh, khc->qhc', attn, v_scalar) + + num_point_v = self.config.num_point_v + v_point = PointProjection([num_head, num_point_v], + self.global_config, + name='v_point_projection')(inputs_1d, + rigid) + + result_point_global = jax.tree_map( + lambda x: jnp.sum(attn[..., None] * x, axis=-3), v_point[None]) + + # Features used in the linear output projection. Should have the size + # [num_query_residues, ?] + output_features = [] + num_query_residues, _ = inputs_1d.shape + + flat_shape = [num_query_residues, -1] + + result_scalar = jnp.reshape(result_scalar, flat_shape) + output_features.append(result_scalar) + + result_point_global = jax.tree_map(lambda r: jnp.reshape(r, flat_shape), + result_point_global) + result_point_local = rigid[..., None].apply_inverse_to_point( + result_point_global) + output_features.extend( + [result_point_local.x, result_point_local.y, result_point_local.z]) + + point_norms = result_point_local.norm(self._dist_epsilon) + output_features.append(point_norms) + + # Dimensions: h = heads, i and j = residues, + # c = inputs_2d channels + # Contraction happens over the second residue dimension, similarly to how + # the usual attention is performed. + result_attention_over_2d = jnp.einsum('ijh, ijc->ihc', attn, inputs_2d) + output_features.append(jnp.reshape(result_attention_over_2d, flat_shape)) + + final_init = 'zeros' if self._zero_initialize_last else 'linear' + + final_act = jnp.concatenate(output_features, axis=-1) + + return common_modules.Linear( + self.config.num_channel, + initializer=final_init, + name='output_projection')(final_act) + + +class FoldIteration(hk.Module): + """A single iteration of iterative folding. + + First, each residue attends to all residues using InvariantPointAttention. + Then, we apply transition layers to update the hidden representations. + Finally, we use the hidden representations to produce an update to the + affine of each residue. + """ + + def __init__(self, + config: ml_collections.ConfigDict, + global_config: ml_collections.ConfigDict, + name: str = 'fold_iteration'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__( + self, + activations: Mapping[str, Any], + aatype: jnp.ndarray, + sequence_mask: jnp.ndarray, + update_rigid: bool, + is_training: bool, + initial_act: jnp.ndarray, + safe_key: Optional[prng.SafeKey] = None, + static_feat_2d: Optional[jnp.ndarray] = None, + dropout_scale=1.0, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + + c = self.config + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + def safe_dropout_fn(tensor, safe_key): + return modules.apply_dropout( + tensor=tensor, + safe_key=safe_key, + rate=0.0 if self.global_config.deterministic else (c.dropout * dropout_scale), + is_training=is_training) + + rigid = activations['rigid'] + + act = activations['act'] + attention_module = InvariantPointAttention( + self.config, self.global_config) + # Attention + act += attention_module( + inputs_1d=act, + inputs_2d=static_feat_2d, + mask=sequence_mask, + rigid=rigid) + + safe_key, *sub_keys = safe_key.split(3) + sub_keys = iter(sub_keys) + act = safe_dropout_fn(act, next(sub_keys)) + act = hk.LayerNorm( + axis=-1, + create_scale=True, + create_offset=True, + name='attention_layer_norm')( + act) + final_init = 'zeros' if self.global_config.zero_init else 'linear' + + # Transition + input_act = act + for i in range(c.num_layer_in_transition): + init = 'relu' if i < c.num_layer_in_transition - 1 else final_init + act = common_modules.Linear( + c.num_channel, + initializer=init, + name='transition')( + act) + if i < c.num_layer_in_transition - 1: + act = jax.nn.relu(act) + act += input_act + act = safe_dropout_fn(act, next(sub_keys)) + act = hk.LayerNorm( + axis=-1, + create_scale=True, + create_offset=True, + name='transition_layer_norm')(act) + if update_rigid: + # Rigid update + rigid_update = QuatRigid( + self.global_config, init=final_init)( + act) + rigid = rigid @ rigid_update + + sc = MultiRigidSidechain(c.sidechain, self.global_config)( + rigid.scale_translation(c.position_scale), [act, initial_act], aatype) + + outputs = {'rigid': rigid, 'sc': sc} + + rotation = rigid.rotation #jax.tree_map(jax.lax.stop_gradient, rigid.rotation) + rigid = geometry.Rigid3Array(rotation, rigid.translation) + + new_activations = { + 'act': act, + 'rigid': rigid + } + return new_activations, outputs + + +def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray], + batch: Mapping[str, jnp.ndarray], + config: ml_collections.ConfigDict, + global_config: ml_collections.ConfigDict, + is_training: bool, + safe_key: prng.SafeKey + ) -> Dict[str, Any]: + """Generate predicted Rigid's for a single chain. + + This is the main part of the iterative fold head - it iteratively applies + folding to produce a set of predicted residue positions. + + Args: + representations: Embeddings dictionary. + batch: Batch dictionary. + config: config for the iterative fold head. + global_config: global config. + is_training: is training. + safe_key: A prng.SafeKey object that wraps a PRNG key. + + Returns: + A dictionary containing residue Rigid's and sidechain positions. + """ + c = config + sequence_mask = batch['seq_mask'][:, None] + act = hk.LayerNorm( + axis=-1, create_scale=True, create_offset=True, name='single_layer_norm')( + representations['single']) + + initial_act = act + act = common_modules.Linear( + c.num_channel, name='initial_projection')(act) + + # Sequence Mask has extra 1 at the end. + rigid = geometry.Rigid3Array.identity(sequence_mask.shape[:-1]) + + fold_iteration = FoldIteration( + c, global_config, name='fold_iteration') + + assert len(batch['seq_mask'].shape) == 1 + + activations = { + 'act': + act, + 'rigid': + rigid + } + + act_2d = hk.LayerNorm( + axis=-1, + create_scale=True, + create_offset=True, + name='pair_layer_norm')( + representations['pair']) + + outputs = [] + def fold_iter(act, key): + act, out = fold_iteration( + act, + initial_act=initial_act, + static_feat_2d=act_2d, + aatype=batch['aatype'], + safe_key=prng.SafeKey(key), + sequence_mask=sequence_mask, + update_rigid=True, + is_training=is_training, + dropout_scale=batch["dropout_scale"]) + return act, out + + keys = jax.random.split(safe_key.get(), c.num_layer) + activations, output = hk.scan(fold_iter, activations, keys) + output['act'] = activations['act'] + + return output + + +class StructureModule(hk.Module): + """StructureModule as a network head. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" + """ + + def __init__(self, + config: ml_collections.ConfigDict, + global_config: ml_collections.ConfigDict, + name: str = 'structure_module'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + representations: Mapping[str, jnp.ndarray], + batch: Mapping[str, Any], + is_training: bool, + safe_key: Optional[prng.SafeKey] = None, + ) -> Dict[str, Any]: + c = self.config + ret = {} + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + output = generate_monomer_rigids( + representations=representations, + batch=batch, + config=self.config, + global_config=self.global_config, + is_training=is_training, + safe_key=safe_key) + + ret['traj'] = output['rigid'].scale_translation(c.position_scale).to_array() + ret['sidechains'] = output['sc'] + ret['sidechains']['atom_pos'] = ret['sidechains']['atom_pos'].to_array() + ret['sidechains']['frames'] = ret['sidechains']['frames'].to_array() + if 'local_atom_pos' in ret['sidechains']: + ret['sidechains']['local_atom_pos'] = ret['sidechains'][ + 'local_atom_pos'].to_array() + ret['sidechains']['local_frames'] = ret['sidechains'][ + 'local_frames'].to_array() + + aatype = batch['aatype'] + seq_mask = batch['seq_mask'] + + atom14_pred_mask = all_atom_multimer.get_atom14_mask( + aatype) * seq_mask[:, None] + atom14_pred_positions = output['sc']['atom_pos'][-1] + ret['final_atom14_positions'] = atom14_pred_positions # (N, 14, 3) + ret['final_atom14_mask'] = atom14_pred_mask # (N, 14) + + atom37_mask = all_atom_multimer.get_atom37_mask(aatype) * seq_mask[:, None] + atom37_pred_positions = all_atom_multimer.atom14_to_atom37( + atom14_pred_positions, aatype) + atom37_pred_positions *= atom37_mask[:, :, None] + ret['final_atom_positions'] = atom37_pred_positions # (N, 37, 3) + ret['final_atom_mask'] = atom37_mask # (N, 37) + ret['final_rigids'] = ret['traj'][-1] + + ret['act'] = output['act'] + + return ret + + +def compute_atom14_gt( + aatype: jnp.ndarray, + all_atom_positions: geometry.Vec3Array, + all_atom_mask: jnp.ndarray, + pred_pos: geometry.Vec3Array +) -> Tuple[geometry.Vec3Array, jnp.ndarray, jnp.ndarray]: + """Find atom14 positions, this includes finding the correct renaming.""" + gt_positions, gt_mask = all_atom_multimer.atom37_to_atom14( + aatype, all_atom_positions, + all_atom_mask) + alt_gt_positions, alt_gt_mask = all_atom_multimer.get_alt_atom14( + aatype, gt_positions, gt_mask) + atom_is_ambiguous = all_atom_multimer.get_atom14_is_ambiguous(aatype) + + alt_naming_is_better = all_atom_multimer.find_optimal_renaming( + gt_positions=gt_positions, + alt_gt_positions=alt_gt_positions, + atom_is_ambiguous=atom_is_ambiguous, + gt_exists=gt_mask, + pred_positions=pred_pos) + + use_alt = alt_naming_is_better[:, None] + + gt_mask = (1. - use_alt) * gt_mask + use_alt * alt_gt_mask + gt_positions = (1. - use_alt) * gt_positions + use_alt * alt_gt_positions + + return gt_positions, alt_gt_mask, alt_naming_is_better + + +def backbone_loss(gt_rigid: geometry.Rigid3Array, + gt_frames_mask: jnp.ndarray, + gt_positions_mask: jnp.ndarray, + target_rigid: geometry.Rigid3Array, + config: ml_collections.ConfigDict, + pair_mask: jnp.ndarray + ) -> Tuple[Float, jnp.ndarray]: + """Backbone FAPE Loss.""" + loss_fn = functools.partial( + all_atom_multimer.frame_aligned_point_error, + l1_clamp_distance=config.atom_clamp_distance, + length_scale=config.loss_unit_distance) + + loss_fn = jax.vmap(loss_fn, (0, None, None, 0, None, None, None)) + fape = loss_fn(target_rigid, gt_rigid, gt_frames_mask, + target_rigid.translation, gt_rigid.translation, + gt_positions_mask, pair_mask) + + return jnp.mean(fape), fape[-1] + + +def compute_frames( + aatype: jnp.ndarray, + all_atom_positions: geometry.Vec3Array, + all_atom_mask: jnp.ndarray, + use_alt: jnp.ndarray + ) -> Tuple[geometry.Rigid3Array, jnp.ndarray]: + """Compute Frames from all atom positions. + + Args: + aatype: array of aatypes, int of [N] + all_atom_positions: Vector of all atom positions, shape [N, 37] + all_atom_mask: mask, shape [N] + use_alt: whether to use alternative orientation for ambiguous aatypes + shape [N] + Returns: + Rigid corresponding to Frames w shape [N, 8], + mask which Rigids are present w shape [N, 8] + """ + frames_batch = all_atom_multimer.atom37_to_frames(aatype, all_atom_positions, + all_atom_mask) + gt_frames = frames_batch['rigidgroups_gt_frames'] + alt_gt_frames = frames_batch['rigidgroups_alt_gt_frames'] + use_alt = use_alt[:, None] + + renamed_gt_frames = jax.tree_map( + lambda x, y: (1. - use_alt) * x + use_alt * y, gt_frames, alt_gt_frames) + + return renamed_gt_frames, frames_batch['rigidgroups_gt_exists'] + + +def sidechain_loss(gt_frames: geometry.Rigid3Array, + gt_frames_mask: jnp.ndarray, + gt_positions: geometry.Vec3Array, + gt_mask: jnp.ndarray, + pred_frames: geometry.Rigid3Array, + pred_positions: geometry.Vec3Array, + config: ml_collections.ConfigDict + ) -> Dict[str, jnp.ndarray]: + """Sidechain Loss using cleaned up rigids.""" + + flat_gt_frames = jax.tree_map(jnp.ravel, gt_frames) + flat_frames_mask = jnp.ravel(gt_frames_mask) + + flat_gt_positions = jax.tree_map(jnp.ravel, gt_positions) + flat_positions_mask = jnp.ravel(gt_mask) + + # Compute frame_aligned_point_error score for the final layer. + def _slice_last_layer_and_flatten(x): + return jnp.ravel(x[-1]) + + flat_pred_frames = jax.tree_map(_slice_last_layer_and_flatten, pred_frames) + flat_pred_positions = jax.tree_map(_slice_last_layer_and_flatten, + pred_positions) + fape = all_atom_multimer.frame_aligned_point_error( + pred_frames=flat_pred_frames, + target_frames=flat_gt_frames, + frames_mask=flat_frames_mask, + pred_positions=flat_pred_positions, + target_positions=flat_gt_positions, + positions_mask=flat_positions_mask, + pair_mask=None, + length_scale=config.sidechain.loss_unit_distance, + l1_clamp_distance=config.sidechain.atom_clamp_distance) + + return { + 'fape': fape, + 'loss': fape} + + +def structural_violation_loss(mask: jnp.ndarray, + violations: Mapping[str, Float], + config: ml_collections.ConfigDict + ) -> Float: + """Computes Loss for structural Violations.""" + # Put all violation losses together to one large loss. + num_atoms = jnp.sum(mask).astype(jnp.float32) + 1e-6 + between_residues = violations['between_residues'] + within_residues = violations['within_residues'] + return (config.structural_violation_loss_weight * + (between_residues['bonds_c_n_loss_mean'] + + between_residues['angles_ca_c_n_loss_mean'] + + between_residues['angles_c_n_ca_loss_mean'] + + jnp.sum(between_residues['clashes_per_atom_loss_sum'] + + within_residues['per_atom_loss_sum']) / num_atoms + )) + + +def find_structural_violations( + aatype: jnp.ndarray, + residue_index: jnp.ndarray, + mask: jnp.ndarray, + pred_positions: geometry.Vec3Array, # (N, 14) + config: ml_collections.ConfigDict, + asym_id: jnp.ndarray, + ) -> Dict[str, Any]: + """Computes several checks for structural Violations.""" + + # Compute between residue backbone violations of bonds and angles. + connection_violations = all_atom_multimer.between_residue_bond_loss( + pred_atom_positions=pred_positions, + pred_atom_mask=mask.astype(jnp.float32), + residue_index=residue_index.astype(jnp.float32), + aatype=aatype, + tolerance_factor_soft=config.violation_tolerance_factor, + tolerance_factor_hard=config.violation_tolerance_factor) + + # Compute the van der Waals radius for every atom + # (the first letter of the atom name is the element type). + # shape (N, 14) + atomtype_radius = jnp.array([ + residue_constants.van_der_waals_radius[name[0]] + for name in residue_constants.atom_types + ]) + residx_atom14_to_atom37 = all_atom_multimer.get_atom14_to_atom37_map(aatype) + atom_radius = mask * utils.batched_gather(atomtype_radius, + residx_atom14_to_atom37) + + # Compute the between residue clash loss. + between_residue_clashes = all_atom_multimer.between_residue_clash_loss( + pred_positions=pred_positions, + atom_exists=mask, + atom_radius=atom_radius, + residue_index=residue_index, + overlap_tolerance_soft=config.clash_overlap_tolerance, + overlap_tolerance_hard=config.clash_overlap_tolerance, + asym_id=asym_id) + + # Compute all within-residue violations (clashes, + # bond length and angle violations). + restype_atom14_bounds = residue_constants.make_atom14_dists_bounds( + overlap_tolerance=config.clash_overlap_tolerance, + bond_length_tolerance_factor=config.violation_tolerance_factor) + dists_lower_bound = utils.batched_gather(restype_atom14_bounds['lower_bound'], + aatype) + dists_upper_bound = utils.batched_gather(restype_atom14_bounds['upper_bound'], + aatype) + within_residue_violations = all_atom_multimer.within_residue_violations( + pred_positions=pred_positions, + atom_exists=mask, + dists_lower_bound=dists_lower_bound, + dists_upper_bound=dists_upper_bound, + tighten_bounds_for_loss=0.0) + + # Combine them to a single per-residue violation mask (used later for LDDT). + per_residue_violations_mask = jnp.max(jnp.stack([ + connection_violations['per_residue_violation_mask'], + jnp.max(between_residue_clashes['per_atom_clash_mask'], axis=-1), + jnp.max(within_residue_violations['per_atom_violations'], + axis=-1)]), axis=0) + + return { + 'between_residues': { + 'bonds_c_n_loss_mean': + connection_violations['c_n_loss_mean'], # () + 'angles_ca_c_n_loss_mean': + connection_violations['ca_c_n_loss_mean'], # () + 'angles_c_n_ca_loss_mean': + connection_violations['c_n_ca_loss_mean'], # () + 'connections_per_residue_loss_sum': + connection_violations['per_residue_loss_sum'], # (N) + 'connections_per_residue_violation_mask': + connection_violations['per_residue_violation_mask'], # (N) + 'clashes_mean_loss': + between_residue_clashes['mean_loss'], # () + 'clashes_per_atom_loss_sum': + between_residue_clashes['per_atom_loss_sum'], # (N, 14) + 'clashes_per_atom_clash_mask': + between_residue_clashes['per_atom_clash_mask'], # (N, 14) + }, + 'within_residues': { + 'per_atom_loss_sum': + within_residue_violations['per_atom_loss_sum'], # (N, 14) + 'per_atom_violations': + within_residue_violations['per_atom_violations'], # (N, 14), + }, + 'total_per_residue_violations_mask': + per_residue_violations_mask, # (N) + } + + +def compute_violation_metrics( + residue_index: jnp.ndarray, + mask: jnp.ndarray, + seq_mask: jnp.ndarray, + pred_positions: geometry.Vec3Array, # (N, 14) + violations: Mapping[str, jnp.ndarray], +) -> Dict[str, jnp.ndarray]: + """Compute several metrics to assess the structural violations.""" + ret = {} + between_residues = violations['between_residues'] + within_residues = violations['within_residues'] + extreme_ca_ca_violations = all_atom_multimer.extreme_ca_ca_distance_violations( + positions=pred_positions, + mask=mask.astype(jnp.float32), + residue_index=residue_index.astype(jnp.float32)) + ret['violations_extreme_ca_ca_distance'] = extreme_ca_ca_violations + ret['violations_between_residue_bond'] = utils.mask_mean( + mask=seq_mask, + value=between_residues['connections_per_residue_violation_mask']) + ret['violations_between_residue_clash'] = utils.mask_mean( + mask=seq_mask, + value=jnp.max(between_residues['clashes_per_atom_clash_mask'], axis=-1)) + ret['violations_within_residue'] = utils.mask_mean( + mask=seq_mask, + value=jnp.max(within_residues['per_atom_violations'], axis=-1)) + ret['violations_per_residue'] = utils.mask_mean( + mask=seq_mask, value=violations['total_per_residue_violations_mask']) + return ret + + +def supervised_chi_loss( + sequence_mask: jnp.ndarray, + target_chi_mask: jnp.ndarray, + aatype: jnp.ndarray, + target_chi_angles: jnp.ndarray, + pred_angles: jnp.ndarray, + unnormed_angles: jnp.ndarray, + config: ml_collections.ConfigDict) -> Tuple[Float, Float, Float]: + """Computes loss for direct chi angle supervision.""" + eps = 1e-6 + chi_mask = target_chi_mask.astype(jnp.float32) + + pred_angles = pred_angles[:, :, 3:] + + residue_type_one_hot = jax.nn.one_hot( + aatype, residue_constants.restype_num + 1, dtype=jnp.float32)[None] + chi_pi_periodic = jnp.einsum('ijk, kl->ijl', residue_type_one_hot, + jnp.asarray(residue_constants.chi_pi_periodic)) + + true_chi = target_chi_angles[None] + sin_true_chi = jnp.sin(true_chi) + cos_true_chi = jnp.cos(true_chi) + sin_cos_true_chi = jnp.stack([sin_true_chi, cos_true_chi], axis=-1) + + # This is -1 if chi is pi periodic and +1 if it's 2 pi periodic + shifted_mask = (1 - 2 * chi_pi_periodic)[..., None] + sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi + + sq_chi_error = jnp.sum( + squared_difference(sin_cos_true_chi, pred_angles), -1) + sq_chi_error_shifted = jnp.sum( + squared_difference(sin_cos_true_chi_shifted, pred_angles), -1) + sq_chi_error = jnp.minimum(sq_chi_error, sq_chi_error_shifted) + + sq_chi_loss = utils.mask_mean(mask=chi_mask[None], value=sq_chi_error) + angle_norm = jnp.sqrt(jnp.sum(jnp.square(unnormed_angles), axis=-1) + eps) + norm_error = jnp.abs(angle_norm - 1.) + angle_norm_loss = utils.mask_mean(mask=sequence_mask[None, :, None], + value=norm_error) + loss = (config.chi_weight * sq_chi_loss + + config.angle_norm_weight * angle_norm_loss) + return loss, sq_chi_loss, angle_norm_loss + + +def l2_normalize(x: jnp.ndarray, + axis: int = -1, + epsilon: float = 1e-12 + ) -> jnp.ndarray: + return x / jnp.sqrt( + jnp.maximum(jnp.sum(x**2, axis=axis, keepdims=True), epsilon)) + + +def get_renamed_chi_angles(aatype: jnp.ndarray, + chi_angles: jnp.ndarray, + alt_is_better: jnp.ndarray + ) -> jnp.ndarray: + """Return renamed chi angles.""" + chi_angle_is_ambiguous = utils.batched_gather( + jnp.array(residue_constants.chi_pi_periodic, dtype=jnp.float32), aatype) + alt_chi_angles = chi_angles + np.pi * chi_angle_is_ambiguous + # Map back to [-pi, pi]. + alt_chi_angles = alt_chi_angles - 2 * np.pi * (alt_chi_angles > np.pi).astype( + jnp.float32) + alt_is_better = alt_is_better[:, None] + return (1. - alt_is_better) * chi_angles + alt_is_better * alt_chi_angles + + +class MultiRigidSidechain(hk.Module): + """Class to make side chain atoms.""" + + def __init__(self, + config: ml_collections.ConfigDict, + global_config: ml_collections.ConfigDict, + name: str = 'rigid_sidechain'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + rigid: geometry.Rigid3Array, + representations_list: Iterable[jnp.ndarray], + aatype: jnp.ndarray + ) -> Dict[str, Any]: + """Predict sidechains using multi-rigid representations. + + Args: + rigid: The Rigid's for each residue (translations in angstoms) + representations_list: A list of activations to predict sidechains from. + aatype: amino acid types. + + Returns: + dict containing atom positions and frames (in angstrom) + """ + act = [ + common_modules.Linear( # pylint: disable=g-complex-comprehension + self.config.num_channel, + name='input_projection')(jax.nn.relu(x)) + for x in representations_list] + # Sum the activation list (equivalent to concat then Conv1D) + act = sum(act) + + final_init = 'zeros' if self.global_config.zero_init else 'linear' + + # Mapping with some residual blocks. + for _ in range(self.config.num_residual_block): + old_act = act + act = common_modules.Linear( + self.config.num_channel, + initializer='relu', + name='resblock1')( + jax.nn.relu(act)) + act = common_modules.Linear( + self.config.num_channel, + initializer=final_init, + name='resblock2')( + jax.nn.relu(act)) + act += old_act + + # Map activations to torsion angles. + # [batch_size, num_res, 14] + num_res = act.shape[0] + unnormalized_angles = common_modules.Linear( + 14, name='unnormalized_angles')( + jax.nn.relu(act)) + unnormalized_angles = jnp.reshape( + unnormalized_angles, [num_res, 7, 2]) + angles = l2_normalize(unnormalized_angles, axis=-1) + + outputs = { + 'angles_sin_cos': angles, # jnp.ndarray (N, 7, 2) + 'unnormalized_angles_sin_cos': + unnormalized_angles, # jnp.ndarray (N, 7, 2) + } + + # Map torsion angles to frames. + # geometry.Rigid3Array with shape (N, 8) + all_frames_to_global = all_atom_multimer.torsion_angles_to_frames( + aatype, + rigid, + angles) + + # Use frames and literature positions to create the final atom coordinates. + # geometry.Vec3Array with shape (N, 14) + pred_positions = all_atom_multimer.frames_and_literature_positions_to_atom14_pos( + aatype, all_frames_to_global) + + outputs.update({ + 'atom_pos': pred_positions, # geometry.Vec3Array (N, 14) + 'frames': all_frames_to_global, # geometry.Rigid3Array (N, 8) + }) + return outputs diff --git a/colabdesign/af/alphafold/model/geometry/__init__.py b/colabdesign/af/alphafold/model/geometry/__init__.py new file mode 100644 index 00000000..761b886e --- /dev/null +++ b/colabdesign/af/alphafold/model/geometry/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Geometry Module.""" + +from colabdesign.af.alphafold.model.geometry import rigid_matrix_vector +from colabdesign.af.alphafold.model.geometry import rotation_matrix +from colabdesign.af.alphafold.model.geometry import struct_of_array +from colabdesign.af.alphafold.model.geometry import vector + +Rot3Array = rotation_matrix.Rot3Array +Rigid3Array = rigid_matrix_vector.Rigid3Array + +StructOfArray = struct_of_array.StructOfArray + +Vec3Array = vector.Vec3Array +square_euclidean_distance = vector.square_euclidean_distance +euclidean_distance = vector.euclidean_distance +dihedral_angle = vector.dihedral_angle +dot = vector.dot +cross = vector.cross diff --git a/colabdesign/af/alphafold/model/geometry/rigid_matrix_vector.py b/colabdesign/af/alphafold/model/geometry/rigid_matrix_vector.py new file mode 100644 index 00000000..4c7bb105 --- /dev/null +++ b/colabdesign/af/alphafold/model/geometry/rigid_matrix_vector.py @@ -0,0 +1,106 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rigid3Array Transformations represented by a Matrix and a Vector.""" + +from __future__ import annotations +from typing import Union + +from colabdesign.af.alphafold.model.geometry import rotation_matrix +from colabdesign.af.alphafold.model.geometry import struct_of_array +from colabdesign.af.alphafold.model.geometry import vector +import jax +import jax.numpy as jnp + +Float = Union[float, jnp.ndarray] + +VERSION = '0.1' + + +@struct_of_array.StructOfArray(same_dtype=True) +class Rigid3Array: + """Rigid Transformation, i.e. element of special euclidean group.""" + + rotation: rotation_matrix.Rot3Array + translation: vector.Vec3Array + + def __matmul__(self, other: Rigid3Array) -> Rigid3Array: + new_rotation = self.rotation @ other.rotation + new_translation = self.apply_to_point(other.translation) + return Rigid3Array(new_rotation, new_translation) + + def inverse(self) -> Rigid3Array: + """Return Rigid3Array corresponding to inverse transform.""" + inv_rotation = self.rotation.inverse() + inv_translation = inv_rotation.apply_to_point(-self.translation) + return Rigid3Array(inv_rotation, inv_translation) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply Rigid3Array transform to point.""" + return self.rotation.apply_to_point(point) + self.translation + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply inverse Rigid3Array transform to point.""" + new_point = point - self.translation + return self.rotation.apply_inverse_to_point(new_point) + + def compose_rotation(self, other_rotation): + rot = self.rotation @ other_rotation + trans = jax.tree_map(lambda x: jnp.broadcast_to(x, rot.shape), + self.translation) + return Rigid3Array(rot, trans) + + @classmethod + def identity(cls, shape, dtype=jnp.float32) -> Rigid3Array: + """Return identity Rigid3Array of given shape.""" + return cls( + rotation_matrix.Rot3Array.identity(shape, dtype=dtype), + vector.Vec3Array.zeros(shape, dtype=dtype)) # pytype: disable=wrong-arg-count # trace-all-classes + + def scale_translation(self, factor: Float) -> Rigid3Array: + """Scale translation in Rigid3Array by 'factor'.""" + return Rigid3Array(self.rotation, self.translation * factor) + + def to_array(self): + rot_array = self.rotation.to_array() + vec_array = self.translation.to_array() + return jnp.concatenate([rot_array, vec_array[..., None]], axis=-1) + + @classmethod + def from_array(cls, array): + rot = rotation_matrix.Rot3Array.from_array(array[..., :3]) + vec = vector.Vec3Array.from_array(array[..., -1]) + return cls(rot, vec) # pytype: disable=wrong-arg-count # trace-all-classes + + @classmethod + def from_array4x4(cls, array: jnp.ndarray) -> Rigid3Array: + """Construct Rigid3Array from homogeneous 4x4 array.""" + assert array.shape[-1] == 4 + assert array.shape[-2] == 4 + rotation = rotation_matrix.Rot3Array( + array[..., 0, 0], array[..., 0, 1], array[..., 0, 2], + array[..., 1, 0], array[..., 1, 1], array[..., 1, 2], + array[..., 2, 0], array[..., 2, 1], array[..., 2, 2] + ) + translation = vector.Vec3Array( + array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]) + return cls(rotation, translation) # pytype: disable=wrong-arg-count # trace-all-classes + + def __getstate__(self): + return (VERSION, (self.rotation, self.translation)) + + def __setstate__(self, state): + version, (rot, trans) = state + del version + object.__setattr__(self, 'rotation', rot) + object.__setattr__(self, 'translation', trans) diff --git a/colabdesign/af/alphafold/model/geometry/rotation_matrix.py b/colabdesign/af/alphafold/model/geometry/rotation_matrix.py new file mode 100644 index 00000000..846ea5d2 --- /dev/null +++ b/colabdesign/af/alphafold/model/geometry/rotation_matrix.py @@ -0,0 +1,157 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rot3Array Matrix Class.""" + +from __future__ import annotations +import dataclasses + +from colabdesign.af.alphafold.model.geometry import struct_of_array +from colabdesign.af.alphafold.model.geometry import utils +from colabdesign.af.alphafold.model.geometry import vector +import jax +import jax.numpy as jnp +import numpy as np + +COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz'] + +VERSION = '0.1' + + +@struct_of_array.StructOfArray(same_dtype=True) +class Rot3Array: + """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays.""" + + xx: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32}) + xy: jnp.ndarray + xz: jnp.ndarray + yx: jnp.ndarray + yy: jnp.ndarray + yz: jnp.ndarray + zx: jnp.ndarray + zy: jnp.ndarray + zz: jnp.ndarray + + __array_ufunc__ = None + + def inverse(self) -> Rot3Array: + """Returns inverse of Rot3Array.""" + return Rot3Array(self.xx, self.yx, self.zx, + self.xy, self.yy, self.zy, + self.xz, self.yz, self.zz) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies Rot3Array to point.""" + return vector.Vec3Array( + self.xx * point.x + self.xy * point.y + self.xz * point.z, + self.yx * point.x + self.yy * point.y + self.yz * point.z, + self.zx * point.x + self.zy * point.y + self.zz * point.z) + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies inverse Rot3Array to point.""" + return self.inverse().apply_to_point(point) + + def __matmul__(self, other: Rot3Array) -> Rot3Array: + """Composes two Rot3Arrays.""" + c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx)) + c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy)) + c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz)) + return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) + + @classmethod + def identity(cls, shape, dtype=jnp.float32) -> Rot3Array: + """Returns identity of given shape.""" + ones = jnp.ones(shape, dtype=dtype) + zeros = jnp.zeros(shape, dtype=dtype) + return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) # pytype: disable=wrong-arg-count # trace-all-classes + + @classmethod + def from_two_vectors(cls, e0: vector.Vec3Array, + e1: vector.Vec3Array) -> Rot3Array: + """Construct Rot3Array from two Vectors. + + Rot3Array is constructed such that in the corresponding frame 'e0' lies on + the positive x-Axis and 'e1' lies in the xy plane with positive sign of y. + + Args: + e0: Vector + e1: Vector + Returns: + Rot3Array + """ + # Normalize the unit vector for the x-axis, e0. + e0 = e0.normalized() + # make e1 perpendicular to e0. + c = e1.dot(e0) + e1 = (e1 - c * e0).normalized() + # Compute e2 as cross product of e0 and e1. + e2 = e0.cross(e1) + return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) # pytype: disable=wrong-arg-count # trace-all-classes + + @classmethod + def from_array(cls, array: jnp.ndarray) -> Rot3Array: + """Construct Rot3Array Matrix from array of shape. [..., 3, 3].""" + unstacked = utils.unstack(array, axis=-2) + unstacked = sum([utils.unstack(x, axis=-1) for x in unstacked], []) + return cls(*unstacked) + + def to_array(self) -> jnp.ndarray: + """Convert Rot3Array to array of shape [..., 3, 3].""" + return jnp.stack( + [jnp.stack([self.xx, self.xy, self.xz], axis=-1), + jnp.stack([self.yx, self.yy, self.yz], axis=-1), + jnp.stack([self.zx, self.zy, self.zz], axis=-1)], + axis=-2) + + @classmethod + def from_quaternion(cls, + w: jnp.ndarray, + x: jnp.ndarray, + y: jnp.ndarray, + z: jnp.ndarray, + normalize: bool = True, + epsilon: float = 1e-6) -> Rot3Array: + """Construct Rot3Array from components of quaternion.""" + if normalize: + inv_norm = jax.lax.rsqrt(jnp.maximum(epsilon, w**2 + x**2 + y**2 + z**2)) + w *= inv_norm + x *= inv_norm + y *= inv_norm + z *= inv_norm + xx = 1 - 2 * (jnp.square(y) + jnp.square(z)) + xy = 2 * (x * y - w * z) + xz = 2 * (x * z + w * y) + yx = 2 * (x * y + w * z) + yy = 1 - 2 * (jnp.square(x) + jnp.square(z)) + yz = 2 * (y * z - w * x) + zx = 2 * (x * z - w * y) + zy = 2 * (y * z + w * x) + zz = 1 - 2 * (jnp.square(x) + jnp.square(y)) + return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) # pytype: disable=wrong-arg-count # trace-all-classes + + @classmethod + def random_uniform(cls, key, shape, dtype=jnp.float32) -> Rot3Array: + """Samples uniform random Rot3Array according to Haar Measure.""" + quat_array = jax.random.normal(key, tuple(shape) + (4,), dtype=dtype) + quats = utils.unstack(quat_array) + return cls.from_quaternion(*quats) + + def __getstate__(self): + return (VERSION, + [np.asarray(getattr(self, field)) for field in COMPONENTS]) + + def __setstate__(self, state): + version, state = state + del version + for i, field in enumerate(COMPONENTS): + object.__setattr__(self, field, state[i]) diff --git a/colabdesign/af/alphafold/model/geometry/struct_of_array.py b/colabdesign/af/alphafold/model/geometry/struct_of_array.py new file mode 100644 index 00000000..97a89fd4 --- /dev/null +++ b/colabdesign/af/alphafold/model/geometry/struct_of_array.py @@ -0,0 +1,220 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Class decorator to represent (nested) struct of arrays.""" + +import dataclasses + +import jax + + +def get_item(instance, key): + sliced = {} + for field in get_array_fields(instance): + num_trailing_dims = field.metadata.get('num_trailing_dims', 0) + this_key = key + if isinstance(key, tuple) and Ellipsis in this_key: + this_key += (slice(None),) * num_trailing_dims + sliced[field.name] = getattr(instance, field.name)[this_key] + return dataclasses.replace(instance, **sliced) + + +@property +def get_shape(instance): + """Returns Shape for given instance of dataclass.""" + first_field = dataclasses.fields(instance)[0] + num_trailing_dims = first_field.metadata.get('num_trailing_dims', None) + value = getattr(instance, first_field.name) + if num_trailing_dims: + return value.shape[:-num_trailing_dims] + else: + return value.shape + + +def get_len(instance): + """Returns length for given instance of dataclass.""" + shape = instance.shape + if shape: + return shape[0] + else: + raise TypeError('len() of unsized object') # Match jax.numpy behavior. + + +@property +def get_dtype(instance): + """Returns Dtype for given instance of dataclass.""" + fields = dataclasses.fields(instance) + sets_dtype = [ + field.name for field in fields if field.metadata.get('sets_dtype', False) + ] + if sets_dtype: + assert len(sets_dtype) == 1, 'at most field can set dtype' + field_value = getattr(instance, sets_dtype[0]) + elif instance.same_dtype: + field_value = getattr(instance, fields[0].name) + else: + # Should this be Value Error? + raise AttributeError('Trying to access Dtype on Struct of Array without' + 'either "same_dtype" or field setting dtype') + + if hasattr(field_value, 'dtype'): + return field_value.dtype + else: + # Should this be Value Error? + raise AttributeError(f'field_value {field_value} does not have dtype') + + +def replace(instance, **kwargs): + return dataclasses.replace(instance, **kwargs) + + +def post_init(instance): + """Validate instance has same shapes & dtypes.""" + array_fields = get_array_fields(instance) + arrays = list(get_array_fields(instance, return_values=True).values()) + first_field = array_fields[0] + # These slightly weird constructions about checking whether the leaves are + # actual arrays is since e.g. vmap internally relies on being able to + # construct pytree's with object() as leaves, this would break the checking + # as such we are only validating the object when the entries in the dataclass + # Are arrays or other dataclasses of arrays. + try: + dtype = instance.dtype + except AttributeError: + dtype = None + if dtype is not None: + first_shape = instance.shape + for array, field in zip(arrays, array_fields): + field_shape = array.shape + num_trailing_dims = field.metadata.get('num_trailing_dims', None) + if num_trailing_dims: + array_shape = array.shape + field_shape = array_shape[:-num_trailing_dims] + msg = (f'field {field} should have number of trailing dims' + ' {num_trailing_dims}') + assert len(array_shape) == len(first_shape) + num_trailing_dims, msg + else: + field_shape = array.shape + + shape_msg = (f"Stripped Shape {field_shape} of field {field} doesn't " + f"match shape {first_shape} of field {first_field}") + assert field_shape == first_shape, shape_msg + + field_dtype = array.dtype + + allowed_metadata_dtypes = field.metadata.get('allowed_dtypes', []) + if allowed_metadata_dtypes: + msg = f'Dtype is {field_dtype} but must be in {allowed_metadata_dtypes}' + assert field_dtype in allowed_metadata_dtypes, msg + + if 'dtype' in field.metadata: + target_dtype = field.metadata['dtype'] + else: + target_dtype = dtype + + msg = f'Dtype is {field_dtype} but must be {target_dtype}' + assert field_dtype == target_dtype, msg + + +def flatten(instance): + """Flatten Struct of Array instance.""" + array_likes = list(get_array_fields(instance, return_values=True).values()) + flat_array_likes = [] + inner_treedefs = [] + num_arrays = [] + for array_like in array_likes: + flat_array_like, inner_treedef = jax.tree_flatten(array_like) + inner_treedefs.append(inner_treedef) + flat_array_likes += flat_array_like + num_arrays.append(len(flat_array_like)) + metadata = get_metadata_fields(instance, return_values=True) + metadata = type(instance).metadata_cls(**metadata) + return flat_array_likes, (inner_treedefs, metadata, num_arrays) + + +def make_metadata_class(cls): + metadata_fields = get_fields(cls, + lambda x: x.metadata.get('is_metadata', False)) + metadata_cls = dataclasses.make_dataclass( + cls_name='Meta' + cls.__name__, + fields=[(field.name, field.type, field) for field in metadata_fields], + frozen=True, + eq=True) + return metadata_cls + + +def get_fields(cls_or_instance, filterfn, return_values=False): + fields = dataclasses.fields(cls_or_instance) + fields = [field for field in fields if filterfn(field)] + if return_values: + return { + field.name: getattr(cls_or_instance, field.name) for field in fields + } + else: + return fields + + +def get_array_fields(cls, return_values=False): + return get_fields( + cls, + lambda x: not x.metadata.get('is_metadata', False), + return_values=return_values) + + +def get_metadata_fields(cls, return_values=False): + return get_fields( + cls, + lambda x: x.metadata.get('is_metadata', False), + return_values=return_values) + + +class StructOfArray: + """Class Decorator for Struct Of Arrays.""" + + def __init__(self, same_dtype=True): + self.same_dtype = same_dtype + + def __call__(self, cls): + cls.__array_ufunc__ = None + cls.replace = replace + cls.same_dtype = self.same_dtype + cls.dtype = get_dtype + cls.shape = get_shape + cls.__len__ = get_len + cls.__getitem__ = get_item + cls.__post_init__ = post_init + new_cls = dataclasses.dataclass(cls, frozen=True, eq=False) # pytype: disable=wrong-keyword-args + # pytree claims to require metadata to be hashable, not sure why, + # But making derived dataclass that can just hold metadata + new_cls.metadata_cls = make_metadata_class(new_cls) + + def unflatten(aux, data): + inner_treedefs, metadata, num_arrays = aux + array_fields = [field.name for field in get_array_fields(new_cls)] + value_dict = {} + array_start = 0 + for num_array, inner_treedef, array_field in zip(num_arrays, + inner_treedefs, + array_fields): + value_dict[array_field] = jax.tree_unflatten( + inner_treedef, data[array_start:array_start + num_array]) + array_start += num_array + metadata_fields = get_metadata_fields(new_cls) + for field in metadata_fields: + value_dict[field.name] = getattr(metadata, field.name) + + return new_cls(**value_dict) + + jax.tree_util.register_pytree_node( + nodetype=new_cls, flatten_func=flatten, unflatten_func=unflatten) + return new_cls diff --git a/colabdesign/af/alphafold/model/geometry/test_utils.py b/colabdesign/af/alphafold/model/geometry/test_utils.py new file mode 100644 index 00000000..18de0741 --- /dev/null +++ b/colabdesign/af/alphafold/model/geometry/test_utils.py @@ -0,0 +1,98 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared utils for tests.""" + +import dataclasses + +from colabdesign.af.alphafold.model.geometry import rigid_matrix_vector +from colabdesign.af.alphafold.model.geometry import rotation_matrix +from colabdesign.af.alphafold.model.geometry import vector +import jax.numpy as jnp +import numpy as np + + +def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array, + matrix2: rotation_matrix.Rot3Array): + for field in dataclasses.fields(rotation_matrix.Rot3Array): + field = field.name + np.testing.assert_array_equal( + getattr(matrix1, field), getattr(matrix2, field)) + + +def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array, + mat2: rotation_matrix.Rot3Array): + np.testing.assert_array_almost_equal(mat1.to_array(), mat2.to_array(), 6) + + +def assert_array_equal_to_rotation_matrix(array: jnp.ndarray, + matrix: rotation_matrix.Rot3Array): + """Check that array and Matrix match.""" + np.testing.assert_array_equal(matrix.xx, array[..., 0, 0]) + np.testing.assert_array_equal(matrix.xy, array[..., 0, 1]) + np.testing.assert_array_equal(matrix.xz, array[..., 0, 2]) + np.testing.assert_array_equal(matrix.yx, array[..., 1, 0]) + np.testing.assert_array_equal(matrix.yy, array[..., 1, 1]) + np.testing.assert_array_equal(matrix.yz, array[..., 1, 2]) + np.testing.assert_array_equal(matrix.zx, array[..., 2, 0]) + np.testing.assert_array_equal(matrix.zy, array[..., 2, 1]) + np.testing.assert_array_equal(matrix.zz, array[..., 2, 2]) + + +def assert_array_close_to_rotation_matrix(array: jnp.ndarray, + matrix: rotation_matrix.Rot3Array): + np.testing.assert_array_almost_equal(matrix.to_array(), array, 6) + + +def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array): + np.testing.assert_array_equal(vec1.x, vec2.x) + np.testing.assert_array_equal(vec1.y, vec2.y) + np.testing.assert_array_equal(vec1.z, vec2.z) + + +def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array): + np.testing.assert_allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.) + np.testing.assert_allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.) + np.testing.assert_allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.) + + +def assert_array_close_to_vector(array: jnp.ndarray, vec: vector.Vec3Array): + np.testing.assert_allclose(vec.to_array(), array, atol=1e-6, rtol=0.) + + +def assert_array_equal_to_vector(array: jnp.ndarray, vec: vector.Vec3Array): + np.testing.assert_array_equal(vec.to_array(), array) + + +def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, + rigid2: rigid_matrix_vector.Rigid3Array): + assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2) + + +def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, + rigid2: rigid_matrix_vector.Rigid3Array): + assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2) + + +def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array, + trans: vector.Vec3Array, + rigid: rigid_matrix_vector.Rigid3Array): + assert_rotation_matrix_equal(rot, rigid.rotation) + assert_vectors_equal(trans, rigid.translation) + + +def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array, + trans: vector.Vec3Array, + rigid: rigid_matrix_vector.Rigid3Array): + assert_rotation_matrix_close(rot, rigid.rotation) + assert_vectors_close(trans, rigid.translation) diff --git a/colabdesign/af/alphafold/model/geometry/utils.py b/colabdesign/af/alphafold/model/geometry/utils.py new file mode 100644 index 00000000..64c4a649 --- /dev/null +++ b/colabdesign/af/alphafold/model/geometry/utils.py @@ -0,0 +1,23 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for geometry library.""" + +from typing import List + +import jax.numpy as jnp + + +def unstack(value: jnp.ndarray, axis: int = -1) -> List[jnp.ndarray]: + return [jnp.squeeze(v, axis=axis) + for v in jnp.split(value, value.shape[axis], axis=axis)] diff --git a/colabdesign/af/alphafold/model/geometry/vector.py b/colabdesign/af/alphafold/model/geometry/vector.py new file mode 100644 index 00000000..8b5e653b --- /dev/null +++ b/colabdesign/af/alphafold/model/geometry/vector.py @@ -0,0 +1,217 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Vec3Array Class.""" + +from __future__ import annotations +import dataclasses +from typing import Union + +from colabdesign.af.alphafold.model.geometry import struct_of_array +from colabdesign.af.alphafold.model.geometry import utils +import jax +import jax.numpy as jnp +import numpy as np + +Float = Union[float, jnp.ndarray] + +VERSION = '0.1' + + +@struct_of_array.StructOfArray(same_dtype=True) +class Vec3Array: + """Vec3Array in 3 dimensional Space implemented as struct of arrays. + + This is done in order to improve performance and precision. + On TPU small matrix multiplications are very suboptimal and will waste large + compute ressources, furthermore any matrix multiplication on tpu happen in + mixed bfloat16/float32 precision, which is often undesirable when handling + physical coordinates. + In most cases this will also be faster on cpu's/gpu's since it allows for + easier use of vector instructions. + """ + + x: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32}) + y: jnp.ndarray + z: jnp.ndarray + + def __post_init__(self): + if hasattr(self.x, 'dtype'): + assert self.x.dtype == self.y.dtype + assert self.x.dtype == self.z.dtype + assert all([x == y for x, y in zip(self.x.shape, self.y.shape)]) + assert all([x == z for x, z in zip(self.x.shape, self.z.shape)]) + + def __add__(self, other: Vec3Array) -> Vec3Array: + return jax.tree_map(lambda x, y: x + y, self, other) + + def __sub__(self, other: Vec3Array) -> Vec3Array: + return jax.tree_map(lambda x, y: x - y, self, other) + + def __mul__(self, other: Float) -> Vec3Array: + return jax.tree_map(lambda x: x * other, self) + + def __rmul__(self, other: Float) -> Vec3Array: + return self * other + + def __truediv__(self, other: Float) -> Vec3Array: + return jax.tree_map(lambda x: x / other, self) + + def __neg__(self) -> Vec3Array: + return jax.tree_map(lambda x: -x, self) + + def __pos__(self) -> Vec3Array: + return jax.tree_map(lambda x: x, self) + + def cross(self, other: Vec3Array) -> Vec3Array: + """Compute cross product between 'self' and 'other'.""" + new_x = self.y * other.z - self.z * other.y + new_y = self.z * other.x - self.x * other.z + new_z = self.x * other.y - self.y * other.x + return Vec3Array(new_x, new_y, new_z) + + def dot(self, other: Vec3Array) -> Float: + """Compute dot product between 'self' and 'other'.""" + return self.x * other.x + self.y * other.y + self.z * other.z + + def norm(self, epsilon: float = 1e-6) -> Float: + """Compute Norm of Vec3Array, clipped to epsilon.""" + # To avoid NaN on the backward pass, we must use maximum before the sqrt + norm2 = self.dot(self) + if epsilon: + norm2 = jnp.maximum(norm2, epsilon**2) + return jnp.sqrt(norm2) + + def norm2(self): + return self.dot(self) + + def normalized(self, epsilon: float = 1e-6) -> Vec3Array: + """Return unit vector with optional clipping.""" + return self / self.norm(epsilon) + + @classmethod + def zeros(cls, shape, dtype=jnp.float32): + """Return Vec3Array corresponding to zeros of given shape.""" + return cls( + jnp.zeros(shape, dtype), jnp.zeros(shape, dtype), + jnp.zeros(shape, dtype)) # pytype: disable=wrong-arg-count # trace-all-classes + + def to_array(self) -> jnp.ndarray: + return jnp.stack([self.x, self.y, self.z], axis=-1) + + @classmethod + def from_array(cls, array): + return cls(*utils.unstack(array)) + + def __getstate__(self): + return (VERSION, + [np.asarray(self.x), + np.asarray(self.y), + np.asarray(self.z)]) + + def __setstate__(self, state): + version, state = state + del version + for i, letter in enumerate('xyz'): + object.__setattr__(self, letter, state[i]) + + +def square_euclidean_distance(vec1: Vec3Array, + vec2: Vec3Array, + epsilon: float = 1e-6) -> Float: + """Computes square of euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute distance to + vec2: Vec3Array to compute distance from, should be + broadcast compatible with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of square euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + difference = vec1 - vec2 + distance = difference.dot(difference) + if epsilon: + distance = jnp.maximum(distance, epsilon) + return distance + + +def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.dot(vector2) + + +def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.cross(vector2) + + +def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float: + return vector.norm(epsilon) + + +def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array: + return vector.normalized(epsilon) + + +def euclidean_distance(vec1: Vec3Array, + vec2: Vec3Array, + epsilon: float = 1e-6) -> Float: + """Computes euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute euclidean distance to + vec2: Vec3Array to compute euclidean distance from, should be + broadcast compatible with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2) + distance = jnp.sqrt(distance_sq) + return distance + + +def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array, + d: Vec3Array) -> Float: + """Computes torsion angle for a quadruple of points. + + For points (a, b, c, d), this is the angle between the planes defined by + points (a, b, c) and (b, c, d). It is also known as the dihedral angle. + + Arguments: + a: A Vec3Array of coordinates. + b: A Vec3Array of coordinates. + c: A Vec3Array of coordinates. + d: A Vec3Array of coordinates. + + Returns: + A tensor of angles in radians: [-pi, pi]. + """ + v1 = a - b + v2 = b - c + v3 = d - c + + c1 = v1.cross(v2) + c2 = v3.cross(v2) + c3 = c2.cross(c1) + + v2_mag = v2.norm() + return jnp.arctan2(c3.dot(v2), v2_mag * c1.dot(c2)) + + +def random_gaussian_vector(shape, key, dtype=jnp.float32): + vec_array = jax.random.normal(key, shape + (3,), dtype) + return Vec3Array.from_array(vec_array) diff --git a/colabdesign/af/alphafold/model/model.py b/colabdesign/af/alphafold/model/model.py index 7f0c401e..5ec1c027 100644 --- a/colabdesign/af/alphafold/model/model.py +++ b/colabdesign/af/alphafold/model/model.py @@ -16,35 +16,16 @@ from typing import Any, Mapping, Optional, Union from absl import logging -from colabdesign.af.alphafold.common import confidence from colabdesign.af.alphafold.model import features from colabdesign.af.alphafold.model import modules +from colabdesign.af.alphafold.model import modules_multimer + import haiku as hk import jax import ml_collections import numpy as np -import tensorflow.compat.v1 as tf import tree - -def get_confidence_metrics( - prediction_result: Mapping[str, Any]) -> Mapping[str, Any]: - """Post processes prediction_result to get confidence metrics.""" - - confidence_metrics = {} - confidence_metrics['plddt'] = confidence.compute_plddt( - prediction_result['predicted_lddt']['logits']) - if 'predicted_aligned_error' in prediction_result: - confidence_metrics.update(confidence.compute_predicted_aligned_error( - prediction_result['predicted_aligned_error']['logits'], - prediction_result['predicted_aligned_error']['breaks'])) - confidence_metrics['ptm'] = confidence.predicted_tm_score( - prediction_result['predicted_aligned_error']['logits'], - prediction_result['predicted_aligned_error']['breaks']) - - return confidence_metrics - - class RunModel: """Container for JAX model.""" @@ -53,27 +34,23 @@ def __init__(self, params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None, is_training=True, return_representations=True, - recycle_mode=None): + recycle_mode=None, + use_multimer=False): self.config = config self.params = params self.mode = recycle_mode - if self.mode is None: - self.mode = [] - # backward compatibility - if self.config.model.add_prev: - self.mode.append("add_prev") - if self.config.model.backprop_recycle: - self.mode.append("backprop") + if self.mode is None: self.mode = [] def _forward_fn(batch): - model = modules.AlphaFold(self.config.model) + if use_multimer: + model = modules_multimer.AlphaFold(self.config.model) + else: + model = modules.AlphaFold(self.config.model) return model( batch, is_training=is_training, - compute_loss=False, - ensemble_representations=False, return_representations=return_representations) self.init = jax.jit(hk.transform(_forward_fn).init) @@ -84,34 +61,19 @@ def apply(params, key, feat): if "prev" in feat: prev = feat["prev"] else: - L = feat['aatype'].shape[1] + L = feat['aatype'].shape[0] prev = {'prev_msa_first_row': np.zeros([L,256]), 'prev_pair': np.zeros([L,L,128])} if self.config.model.use_struct: prev['prev_pos'] = np.zeros([L,37,3]) else: prev['prev_dgram'] = np.zeros([L,L,64]) + feat["prev"] = prev ################################ # decide how to run recycles ################################ - if "num_iter_recycling" in feat: - # use while_loop() - num_recycles = feat.pop("num_iter_recycling")[0] - def body(x): - i,prev,key = x - key, sub_key = jax.random.split(key) - feat["prev"] = prev - prev = self.apply_fn(params, sub_key, feat)["prev"] - prev = jax.lax.stop_gradient(prev) - return (i+1, prev, key) - - init = (0,prev,key) - _, feat["prev"], key = jax.lax.while_loop(lambda x: x[0] < num_recycles, body, init) - key, sub_key = jax.random.split(key) - results = self.apply_fn(params, sub_key, feat) - - elif self.config.model.num_recycle: + if self.config.model.num_recycle: # use scan() def loop(prev, sub_key): feat["prev"] = prev @@ -136,75 +98,4 @@ def loop(prev, sub_key): return results - self.apply = jax.jit(apply) - - def init_params(self, feat: features.FeatureDict, random_seed: int = 0): - """Initializes the model parameters. - - If none were provided when this class was instantiated then the parameters - are randomly initialized. - - Args: - feat: A dictionary of NumPy feature arrays as output by - RunModel.process_features. - random_seed: A random seed to use to initialize the parameters if none - were set when this class was initialized. - """ - if not self.params: - # Init params randomly. - rng = jax.random.PRNGKey(random_seed) - self.params = hk.data_structures.to_mutable_dict( - self.init(rng, feat)) - logging.warning('Initialized parameters randomly') - - def process_features( - self, - raw_features: Union[tf.train.Example, features.FeatureDict], - random_seed: int) -> features.FeatureDict: - """Processes features to prepare for feeding them into the model. - - Args: - raw_features: The output of the data pipeline either as a dict of NumPy - arrays or as a tf.train.Example. - random_seed: The random seed to use when processing the features. - - Returns: - A dict of NumPy feature arrays suitable for feeding into the model. - """ - if isinstance(raw_features, dict): - return features.np_example_to_features( - np_example=raw_features, - config=self.config, - random_seed=random_seed) - else: - return features.tf_example_to_features( - tf_example=raw_features, - config=self.config, - random_seed=random_seed) - - def eval_shape(self, feat: features.FeatureDict) -> jax.ShapeDtypeStruct: - self.init_params(feat) - logging.info('Running eval_shape with shape(feat) = %s', tree.map_structure(lambda x: x.shape, feat)) - shape = jax.eval_shape(self.apply, self.params, jax.random.PRNGKey(0), feat) - logging.info('Output shape was %s', shape) - return shape - - def predict(self, feat: features.FeatureDict) -> Mapping[str, Any]: - """Makes a prediction by inferencing the model on the provided features. - - Args: - feat: A dictionary of NumPy feature arrays as output by - RunModel.process_features. - - Returns: - A dictionary of model outputs. - """ - self.init_params(feat) - logging.info('Running predict with shape(feat) = %s', tree.map_structure(lambda x: x.shape, feat)) - - result = self.apply(self.params, jax.random.PRNGKey(0), feat) - if self.config.model.use_struct: - result.update(get_confidence_metrics(result)) - - logging.info('Output shape was %s', tree.map_structure(lambda x: x.shape, result)) - return result + self.apply = jax.jit(apply) \ No newline at end of file diff --git a/colabdesign/af/alphafold/model/modules.py b/colabdesign/af/alphafold/model/modules.py index ebfc9ecb..5ed5306d 100644 --- a/colabdesign/af/alphafold/model/modules.py +++ b/colabdesign/af/alphafold/model/modules.py @@ -33,22 +33,6 @@ from colabdesign.af.alphafold.model.r3 import Rigids, Rots, Vecs - -def softmax_cross_entropy(logits, labels): - """Computes softmax cross entropy given logits and one-hot class labels.""" - loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1) - return jnp.asarray(loss) - - -def sigmoid_cross_entropy(logits, labels): - """Computes sigmoid cross entropy given logits and multiple class labels.""" - log_p = jax.nn.log_sigmoid(logits) - # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter is more numerically stable - log_not_p = jax.nn.log_sigmoid(-logits) - loss = -labels * log_p - (1. - labels) * log_not_p - return jnp.asarray(loss) - - def apply_dropout(*, tensor, safe_key, rate, is_training, broadcast_dim=None): """Applies dropout to a tensor.""" if is_training: # and rate != 0.0: @@ -128,8 +112,7 @@ class AlphaFoldIteration(hk.Module): Computes ensembled (averaged) representations from the provided features. These representations are then passed to the various heads - that have been requested by the configuration file. Each head also returns a - loss which is combined as a weighted sum to produce the total loss. + that have been requested by the configuration file. Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 3-22 """ @@ -140,17 +123,13 @@ def __init__(self, config, global_config, name='alphafold_iteration'): self.global_config = global_config def __call__(self, - ensembled_batch, - non_ensembled_batch, + batch, is_training, - compute_loss=False, - ensemble_representations=False, return_representations=False): # Compute representations for each batch element and average. evoformer_module = EmbeddingsAndEvoformer(self.config.embeddings_and_evoformer, self.global_config) - batch0 = {**ensembled_batch, **non_ensembled_batch} - representations = evoformer_module(batch0, is_training) + representations = evoformer_module(batch, is_training) # MSA representations are not ensembled so # we don't pass tensor into the loop. @@ -158,12 +137,11 @@ def __call__(self, del representations['msa'] representations['msa'] = msa_representation - batch = batch0 # We are not ensembled from here on. - if jnp.issubdtype(ensembled_batch['aatype'].dtype, jnp.integer): - num_residues = ensembled_batch['aatype'].shape + if jnp.issubdtype(batch['aatype'].dtype, jnp.integer): + num_residues = batch['aatype'].shape else: - num_residues, _ = ensembled_batch['aatype'].shape + num_residues, _ = batch['aatype'].shape if self.config.use_struct: struct_module = folding.StructureModule @@ -177,7 +155,7 @@ def __call__(self, head_factory = { 'masked_msa': MaskedMsaHead, 'distogram': DistogramHead, - 'structure_module': functools.partial(struct_module, compute_loss=compute_loss), + 'structure_module': struct_module, 'predicted_lddt': PredictedLDDTHead, 'predicted_aligned_error': PredictedAlignedErrorHead, 'experimentally_resolved': ExperimentallyResolvedHead, @@ -185,20 +163,9 @@ def __call__(self, heads[head_name] = (head_config, head_factory(head_config, self.global_config)) - total_loss = 0. ret = {} ret['representations'] = representations - def loss(module, head_config, ret, name, filter_ret=True): - if filter_ret: - value = ret[name] - else: - value = ret - loss_output = module.loss(value, batch) - ret[name].update(loss_output) - loss = head_config.weight * ret[name]['loss'] - return loss - for name, (head_config, module) in heads.items(): # Skip PredictedLDDTHead and PredictedAlignedErrorHead until # StructureModule is executed. @@ -210,8 +177,6 @@ def loss(module, head_config, ret, name, filter_ret=True): # Extra representations from the head. Used by the structure module # to provide activations for the PredictedLDDTHead. representations.update(ret[name].pop('representations')) - if compute_loss: - total_loss += loss(module, head_config, ret, name) if self.config.use_struct: if self.config.heads.get('predicted_lddt.weight', 0.0): @@ -220,8 +185,6 @@ def loss(module, head_config, ret, name, filter_ret=True): # Feed all previous results to give access to structure_module result. head_config, module = heads[name] ret[name] = module(representations, batch, is_training) - if compute_loss: - total_loss += loss(module, head_config, ret, name, filter_ret=False) if ('predicted_aligned_error' in self.config.heads and self.config.heads.get('predicted_aligned_error.weight', 0.0)): @@ -230,13 +193,8 @@ def loss(module, head_config, ret, name, filter_ret=True): # Feed all previous results to give access to structure_module result. head_config, module = heads[name] ret[name] = module(representations, batch, is_training) - if compute_loss: - total_loss += loss(module, head_config, ret, name, filter_ret=False) - if compute_loss: - return ret, total_loss - else: - return ret + return ret class AlphaFold(hk.Module): """AlphaFold model with recycling. @@ -253,36 +211,28 @@ def __call__( self, batch, is_training, - compute_loss=False, - ensemble_representations=False, return_representations=False): """Run the AlphaFold model. Arguments: batch: Dictionary with inputs to the AlphaFold model. is_training: Whether the system is in training or inference mode. - compute_loss: Whether to compute losses (requires extra features - to be present in the batch and knowing the true structure). - ensemble_representations: Whether to use ensembling of representations. return_representations: Whether to also return the intermediate representations. Returns: - When compute_loss is True: - a tuple of loss and output of AlphaFoldIteration. - When compute_loss is False: - just output of AlphaFoldIteration. + just output of AlphaFoldIteration. The output of AlphaFoldIteration is a nested dictionary containing predictions from the various heads. """ - if "dropout_scale" not in batch: batch["dropout_scale"] = jnp.ones((1,)) - impl = AlphaFoldIteration(self.config, self.global_config) - if jnp.issubdtype(batch['aatype'].dtype, jnp.integer): - batch_size, num_residues = batch['aatype'].shape + num_res = batch['aatype'].shape else: - batch_size, num_residues, _ = batch['aatype'].shape + num_res, _ = batch['aatype'].shape + + impl = AlphaFoldIteration(self.config, self.global_config) + def get_prev(ret): new_prev = { @@ -293,33 +243,12 @@ def get_prev(ret): new_prev['prev_pos'] = ret['structure_module']['final_atom_positions'] else: new_prev['prev_dgram'] = ret["distogram"]["logits"] - return new_prev emb_config = self.config.embeddings_and_evoformer - prev = { - 'prev_msa_first_row': jnp.zeros([num_residues, emb_config.msa_channel]), - 'prev_pair': jnp.zeros([num_residues, num_residues, emb_config.pair_channel]) - } - if self.config.use_struct: - prev['prev_pos'] = jnp.zeros([num_residues, residue_constants.atom_type_num, 3]) - else: - prev['prev_dgram'] = jnp.zeros([num_residues, num_residues, 64]) - - # copy previous from input batch (if defined) - if "prev" in batch: - prev.update(batch.pop("prev")) - - # backward compatibility - for k in ["pos","msa_first_row","pair","dgram"]: - if f"init_{k}" in batch: - prev[f"prev_{k}"] = batch.pop(f"init_{k}")[0] - - ret = impl(ensembled_batch=jax.tree_map(lambda x:x[0], batch), - non_ensembled_batch=prev, - is_training=is_training, - compute_loss=compute_loss, - ensemble_representations=ensemble_representations) + prev = batch.pop("prev") + ret = impl(batch={**batch, **prev}, + is_training=is_training) ret["prev"] = get_prev(ret) return ret @@ -899,6 +828,11 @@ def __init__(self, config, global_config, name='masked_msa_head'): self.config = config self.global_config = global_config + if global_config.multimer_mode: + self.num_output = len(residue_constants.restypes_with_x_and_gap) + else: + self.num_output = config.num_output + def __call__(self, representations, batch, is_training): """Builds MaskedMsaHead module. @@ -915,21 +849,12 @@ def __call__(self, representations, batch, is_training): """ del batch logits = common_modules.Linear( - self.config.num_output, + self.num_output, initializer=utils.final_init(self.global_config), name='logits')( representations['msa']) return dict(logits=logits) - def loss(self, value, batch): - errors = softmax_cross_entropy( - labels=jax.nn.one_hot(batch['true_msa'], num_classes=23), - logits=value['logits']) - loss = (jnp.sum(errors * batch['bert_mask'], axis=(-2, -1)) / - (1e-8 + jnp.sum(batch['bert_mask'], axis=(-2, -1)))) - return {'loss': loss} - - class PredictedLDDTHead(hk.Module): """Head to predict the per-residue LDDT to be used as a confidence measure. @@ -988,51 +913,6 @@ def __call__(self, representations, batch, is_training): # Shape (batch_size, num_res, num_bins) return dict(logits=logits) - def loss(self, value, batch): - # Shape (num_res, 37, 3) - pred_all_atom_pos = value['structure_module']['final_atom_positions'] - # Shape (num_res, 37, 3) - true_all_atom_pos = batch['all_atom_positions'] - # Shape (num_res, 37) - all_atom_mask = batch['all_atom_mask'] - - # Shape (num_res,) - lddt_ca = lddt.lddt( - # Shape (batch_size, num_res, 3) - predicted_points=pred_all_atom_pos[None, :, 1, :], - # Shape (batch_size, num_res, 3) - true_points=true_all_atom_pos[None, :, 1, :], - # Shape (batch_size, num_res, 1) - true_points_mask=all_atom_mask[None, :, 1:2].astype(jnp.float32), - cutoff=15., - per_residue=True)[0] - lddt_ca = jax.lax.stop_gradient(lddt_ca) - - num_bins = self.config.num_bins - bin_index = jnp.floor(lddt_ca * num_bins).astype(jnp.int32) - - # protect against out of range for lddt_ca == 1 - bin_index = jnp.minimum(bin_index, num_bins - 1) - lddt_ca_one_hot = jax.nn.one_hot(bin_index, num_classes=num_bins) - - # Shape (num_res, num_channel) - logits = value['predicted_lddt']['logits'] - errors = softmax_cross_entropy(labels=lddt_ca_one_hot, logits=logits) - - # Shape (num_res,) - mask_ca = all_atom_mask[:, residue_constants.atom_order['CA']] - mask_ca = mask_ca.astype(jnp.float32) - loss = jnp.sum(errors * mask_ca) / (jnp.sum(mask_ca) + 1e-8) - - if self.config.filter_by_resolution: - # NMR & distillation have resolution = 0 - loss *= ((batch['resolution'] >= self.config.min_resolution) - & (batch['resolution'] <= self.config.max_resolution)).astype( - jnp.float32) - - output = {'loss': loss} - return output - class PredictedAlignedErrorHead(hk.Module): """Head to predict the distance errors in the backbone alignment frames. @@ -1074,55 +954,6 @@ def __call__(self, representations, batch, is_training): 0., self.config.max_error_bin, self.config.num_bins - 1) return dict(logits=logits, breaks=breaks) - def loss(self, value, batch): - # Shape (num_res, 7) - predicted_affine = quat_affine.QuatAffine.from_tensor( - value['structure_module']['final_affines']) - # Shape (num_res, 7) - true_affine = quat_affine.QuatAffine.from_tensor( - batch['backbone_affine_tensor']) - # Shape (num_res) - mask = batch['backbone_affine_mask'] - # Shape (num_res, num_res) - square_mask = mask[:, None] * mask[None, :] - num_bins = self.config.num_bins - # (1, num_bins - 1) - breaks = value['predicted_aligned_error']['breaks'] - # (1, num_bins) - logits = value['predicted_aligned_error']['logits'] - - # Compute the squared error for each alignment. - def _local_frame_points(affine): - points = [jnp.expand_dims(x, axis=-2) for x in affine.translation] - return affine.invert_point(points, extra_dims=1) - error_dist2_xyz = [ - jnp.square(a - b) - for a, b in zip(_local_frame_points(predicted_affine), - _local_frame_points(true_affine))] - error_dist2 = sum(error_dist2_xyz) - # Shape (num_res, num_res) - # First num_res are alignment frames, second num_res are the residues. - error_dist2 = jax.lax.stop_gradient(error_dist2) - - sq_breaks = jnp.square(breaks) - true_bins = jnp.sum(( - error_dist2[..., None] > sq_breaks).astype(jnp.int32), axis=-1) - - errors = softmax_cross_entropy( - labels=jax.nn.one_hot(true_bins, num_bins, axis=-1), logits=logits) - - loss = (jnp.sum(errors * square_mask, axis=(-2, -1)) / - (1e-8 + jnp.sum(square_mask, axis=(-2, -1)))) - - if self.config.filter_by_resolution: - # NMR & distillation have resolution = 0 - loss *= ((batch['resolution'] >= self.config.min_resolution) - & (batch['resolution'] <= self.config.max_resolution)).astype( - jnp.float32) - - output = {'loss': loss} - return output - class ExperimentallyResolvedHead(hk.Module): """Predicts if an atom is experimentally resolved in a high-res structure. @@ -1158,28 +989,6 @@ def __call__(self, representations, batch, is_training): name='logits')(representations['single']) return dict(logits=logits) - def loss(self, value, batch): - logits = value['logits'] - assert len(logits.shape) == 2 - - # Does the atom appear in the amino acid? - atom_exists = batch['atom37_atom_exists'] - # Is the atom resolved in the experiment? Subset of atom_exists, - # *except for OXT* - all_atom_mask = batch['all_atom_mask'].astype(jnp.float32) - - xent = sigmoid_cross_entropy(labels=all_atom_mask, logits=logits) - loss = jnp.sum(xent * atom_exists) / (1e-8 + jnp.sum(atom_exists)) - - if self.config.filter_by_resolution: - # NMR & distillation examples have resolution = 0. - loss *= ((batch['resolution'] >= self.config.min_resolution) - & (batch['resolution'] <= self.config.max_resolution)).astype( - jnp.float32) - - output = {'loss': loss} - return output - class TriangleMultiplication(hk.Module): """Triangle multiplication layer ("outgoing" or "incoming"). @@ -1308,42 +1117,6 @@ def __call__(self, representations, batch, is_training): return dict(logits=logits, bin_edges=breaks) - def loss(self, value, batch): - return _distogram_log_loss(value['logits'], value['bin_edges'], - batch, self.config.num_bins) - - -def _distogram_log_loss(logits, bin_edges, batch, num_bins): - """Log loss of a distogram.""" - - assert len(logits.shape) == 3 - positions = batch['pseudo_beta'] - mask = batch['pseudo_beta_mask'] - - assert positions.shape[-1] == 3 - - sq_breaks = jnp.square(bin_edges) - - dist2 = jnp.sum( - jnp.square( - jnp.expand_dims(positions, axis=-2) - - jnp.expand_dims(positions, axis=-3)), - axis=-1, - keepdims=True) - - true_bins = jnp.sum(dist2 > sq_breaks, axis=-1) - - errors = softmax_cross_entropy( - labels=jax.nn.one_hot(true_bins, num_bins), logits=logits) - - square_mask = jnp.expand_dims(mask, axis=-2) * jnp.expand_dims(mask, axis=-1) - - avg_error = ( - jnp.sum(errors * square_mask, axis=(-2, -1)) / - (1e-6 + jnp.sum(square_mask, axis=(-2, -1)))) - dist2 = dist2[..., 0] - return dict(loss=avg_error, true_dist=jnp.sqrt(1e-6 + dist2)) - class OuterProductMean(hk.Module): """Computes mean outer product. @@ -1465,7 +1238,7 @@ def dgram_from_positions_soft(positions, num_bins, min_bin, max_bin, temp=2.0): o = o/(o.sum(-1,keepdims=True) + 1e-8) return o[...,1:] -def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask): """Create pseudo beta features.""" ca_idx = residue_constants.atom_order['CA'] @@ -1476,8 +1249,8 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): is_gly_tile = jnp.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]) pseudo_beta = jnp.where(is_gly_tile, all_atom_positions[..., ca_idx, :], all_atom_positions[..., cb_idx, :]) - if all_atom_masks is not None: - pseudo_beta_mask = jnp.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx]) + if all_atom_mask is not None: + pseudo_beta_mask = jnp.where(is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]) pseudo_beta_mask = pseudo_beta_mask.astype(jnp.float32) return pseudo_beta, pseudo_beta_mask else: @@ -1487,9 +1260,9 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): ca_pos = all_atom_positions[...,ca_idx,:] cb_pos = all_atom_positions[...,cb_idx,:] pseudo_beta = is_gly[...,None] * ca_pos + (1-is_gly[...,None]) * cb_pos - if all_atom_masks is not None: - ca_mask = all_atom_masks[...,ca_idx] - cb_mask = all_atom_masks[...,cb_idx] + if all_atom_mask is not None: + ca_mask = all_atom_mask[...,ca_idx] + cb_mask = all_atom_mask[...,cb_idx] pseudo_beta_mask = is_gly * ca_mask + (1-is_gly) * cb_mask return pseudo_beta, pseudo_beta_mask else: @@ -1638,22 +1411,15 @@ def __call__(self, batch, is_training, safe_key=None): # Embed clustered MSA. # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5 # Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder" - preprocess_1d = common_modules.Linear( - c.msa_channel, name='preprocess_1d')( - batch['target_feat']) - - preprocess_msa = common_modules.Linear( - c.msa_channel, name='preprocess_msa')( - batch['msa_feat']) - - msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa - - left_single = common_modules.Linear( - c.pair_channel, name='left_single')( - batch['target_feat']) - right_single = common_modules.Linear( - c.pair_channel, name='right_single')( - batch['target_feat']) + + target_feat = batch["msa_feat"][0,:,:21] + target_feat = jnp.pad(target_feat,[[0,0],[1,0]]) + preprocess_1d = common_modules.Linear(c.msa_channel, name='preprocess_1d')(target_feat) + preprocess_msa = common_modules.Linear(c.msa_channel, name='preprocess_msa')(batch['msa_feat']) + msa_activations = preprocess_1d[None] + preprocess_msa + + left_single = common_modules.Linear(c.pair_channel, name='left_single')(target_feat) + right_single = common_modules.Linear(c.pair_channel, name='right_single')(target_feat) pair_activations = left_single[:, None] + right_single[None] mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] @@ -1719,10 +1485,15 @@ def __call__(self, batch, is_training, safe_key=None): if c.template.enabled: template_batch = {k: batch[k] for k in batch if k.startswith('template_')} + + multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :] + multichain_mask = jnp.where(batch["mask_template_interchain"], multichain_mask, 1) + template_pair_representation = TemplateEmbedding(c.template, gc)( pair_activations, template_batch, mask_2d, + multichain_mask, is_training=is_training, dropout_scale=batch["dropout_scale"]) @@ -1774,7 +1545,7 @@ def extra_msa_stack_fn(x): ret = all_atom.atom37_to_torsion_angles( aatype=aatype, all_atom_pos=batch['template_all_atom_positions'], - all_atom_mask=batch['template_all_atom_masks'], + all_atom_mask=batch['template_all_atom_mask'], # Ensure consistent behaviour during testing: placeholder_for_undefined=not gc.zero_init) @@ -1849,7 +1620,7 @@ def __init__(self, config, global_config, name='single_template_embedding'): self.config = config self.global_config = global_config - def __call__(self, query_embedding, batch, mask_2d, is_training, dropout_scale=1.0): + def __call__(self, query_embedding, batch, mask_2d, multichain_mask_2d, is_training, dropout_scale=1.0): """Build the single template embedding. Arguments: query_embedding: Query pair representation, shape [N_res, N_res, c_z]. @@ -1868,6 +1639,7 @@ def __call__(self, query_embedding, batch, mask_2d, is_training, dropout_scale=1 .triangle_attention_ending_node.value_dim) template_mask = batch['template_pseudo_beta_mask'] template_mask_2d = template_mask[:, None] * template_mask[None, :] + template_mask_2d = template_mask_2d * multichain_mask_2d template_mask_2d = template_mask_2d.astype(dtype) if "template_dgram" in batch: @@ -1880,8 +1652,8 @@ def __call__(self, query_embedding, batch, mask_2d, is_training, dropout_scale=1 else: template_dgram = dgram_from_positions(batch['template_pseudo_beta'], **self.config.dgram_features) + template_dgram *= template_mask_2d[..., None] template_dgram = template_dgram.astype(dtype) - to_concat = [template_dgram, template_mask_2d[:, :, None]] if jnp.issubdtype(batch['template_aatype'].dtype, jnp.integer): @@ -1896,9 +1668,9 @@ def __call__(self, query_embedding, batch, mask_2d, is_training, dropout_scale=1 # (the template mask defined above only considers pseudo CB). n, ca, c = [residue_constants.atom_order[a] for a in ('N', 'CA', 'C')] template_mask = ( - batch['template_all_atom_masks'][..., n] * - batch['template_all_atom_masks'][..., ca] * - batch['template_all_atom_masks'][..., c]) + batch['template_all_atom_mask'][..., n] * + batch['template_all_atom_mask'][..., ca] * + batch['template_all_atom_mask'][..., c]) template_mask_2d = template_mask[:, None] * template_mask[None, :] # compute unit_vector (not used by default) @@ -1957,7 +1729,8 @@ def __init__(self, config, global_config, name='template_embedding'): self.config = config self.global_config = global_config - def __call__(self, query_embedding, template_batch, mask_2d, is_training, dropout_scale=1.0): + def __call__(self, query_embedding, template_batch, mask_2d, multichain_mask_2d, + is_training, dropout_scale=1.0): """Build TemplateEmbedding module. Arguments: query_embedding: Query pair representation, shape [N_res, N_res, c_z]. @@ -1986,7 +1759,8 @@ def __call__(self, query_embedding, template_batch, mask_2d, is_training, dropou template_embedder = SingleTemplateEmbedding(self.config, self.global_config) def map_fn(batch): - return template_embedder(query_embedding, batch, mask_2d, is_training, dropout_scale=dropout_scale) + return template_embedder(query_embedding, batch, mask_2d, multichain_mask_2d, + is_training, dropout_scale=dropout_scale) template_pair_representation = mapping.sharded_map(map_fn, in_axes=0)(template_batch) diff --git a/colabdesign/af/alphafold/model/modules_multimer.py b/colabdesign/af/alphafold/model/modules_multimer.py new file mode 100644 index 00000000..4f31d6ac --- /dev/null +++ b/colabdesign/af/alphafold/model/modules_multimer.py @@ -0,0 +1,827 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core modules, which have been refactored in AlphaFold-Multimer. + +The main difference is that MSA sampling pipeline is moved inside the JAX model +for easier implementation of recycling and ensembling. + +Lower-level modules up to EvoformerIteration are reused from modules.py. +""" + +import functools +from typing import Sequence + +from colabdesign.af.alphafold.common import residue_constants +from colabdesign.af.alphafold.model import all_atom_multimer +from colabdesign.af.alphafold.model import common_modules +from colabdesign.af.alphafold.model import folding_multimer +from colabdesign.af.alphafold.model import geometry +from colabdesign.af.alphafold.model import layer_stack +from colabdesign.af.alphafold.model import modules +from colabdesign.af.alphafold.model import prng +from colabdesign.af.alphafold.model import utils + +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + +def create_extra_msa_feature(batch, num_extra_msa): + """Expand extra_msa into 1hot and concat with other extra msa features. + We do this as late as possible as the one_hot extra msa can be very large. + Args: + batch: a dictionary with the following keys: + * 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster + centre. Note - This isn't one-hotted. + * 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given + position. + num_extra_msa: Number of extra msa to use. + Returns: + Concatenated tensor of extra MSA features. + """ + # 23 = 20 amino acids + 'X' for unknown + gap + bert mask + extra_msa = batch['extra_msa'][:num_extra_msa] + deletion_matrix = batch['extra_deletion_value'][:num_extra_msa] + msa_1hot = jax.nn.one_hot(extra_msa, 23) + has_deletion = jnp.clip(deletion_matrix, 0., 1.)[..., None] + deletion_value = (jnp.arctan(deletion_matrix / 3.) * (2. / jnp.pi))[..., None] + extra_msa_mask = batch['extra_msa_mask'][:num_extra_msa] + return jnp.concatenate([msa_1hot, has_deletion, deletion_value], + axis=-1), extra_msa_mask + +class AlphaFoldIteration(hk.Module): + """A single recycling iteration of AlphaFold architecture. + + Computes ensembled (averaged) representations from the provided features. + These representations are then passed to the various heads + that have been requested by the configuration file. + """ + + def __init__(self, config, global_config, name='alphafold_iteration'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + batch, + is_training, + return_representations=False, + safe_key=None): + + + # Compute representations for each MSA sample and average. + embedding_module = EmbeddingsAndEvoformer( + self.config.embeddings_and_evoformer, self.global_config) + + safe_key, safe_subkey = safe_key.split() + representations = embedding_module(batch, is_training, safe_key=safe_subkey) + + self.representations = representations + self.batch = batch + self.heads = {} + for head_name, head_config in sorted(self.config.heads.items()): + if not head_config.weight: + continue # Do not instantiate zero-weight heads. + + head_factory = { + 'masked_msa': + modules.MaskedMsaHead, + 'distogram': + modules.DistogramHead, + 'structure_module': + folding_multimer.StructureModule, + 'predicted_aligned_error': + modules.PredictedAlignedErrorHead, + 'predicted_lddt': + modules.PredictedLDDTHead, + 'experimentally_resolved': + modules.ExperimentallyResolvedHead, + }[head_name] + self.heads[head_name] = (head_config, + head_factory(head_config, self.global_config)) + + structure_module_output = None + if 'entity_id' in batch and 'all_atom_positions' in batch: + _, fold_module = self.heads['structure_module'] + structure_module_output = fold_module(representations, batch, is_training) + + + ret = {} + ret['representations'] = representations + + for name, (head_config, module) in self.heads.items(): + if name == 'structure_module' and structure_module_output is not None: + ret[name] = structure_module_output + representations['structure_module'] = structure_module_output.pop('act') + # Skip confidence heads until StructureModule is executed. + elif name in {'predicted_lddt', 'predicted_aligned_error', + 'experimentally_resolved'}: + continue + else: + ret[name] = module(representations, batch, is_training) + + + # Add confidence heads after StructureModule is executed. + if self.config.heads.get('predicted_lddt.weight', 0.0): + name = 'predicted_lddt' + head_config, module = self.heads[name] + ret[name] = module(representations, batch, is_training) + + if self.config.heads.experimentally_resolved.weight: + name = 'experimentally_resolved' + head_config, module = self.heads[name] + ret[name] = module(representations, batch, is_training) + + if self.config.heads.get('predicted_aligned_error.weight', 0.0): + name = 'predicted_aligned_error' + head_config, module = self.heads[name] + ret[name] = module(representations, batch, is_training) + # Will be used for ipTM computation. + ret[name]['asym_id'] = batch['asym_id'] + + return ret + +class AlphaFold(hk.Module): + """AlphaFold-Multimer model with recycling. + """ + + def __init__(self, config, name='alphafold'): + super().__init__(name=name) + self.config = config + self.global_config = config.global_config + + def __call__( + self, + batch, + is_training, + return_representations=False, + safe_key=None): + + c = self.config + impl = AlphaFoldIteration(c, self.global_config) + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + elif isinstance(safe_key, jnp.ndarray): + safe_key = prng.SafeKey(safe_key) + + assert isinstance(batch, dict) + num_res = batch['aatype'].shape[0] + + def get_prev(ret): + new_prev = { + 'prev_pos': ret['structure_module']['final_atom_positions'], + 'prev_msa_first_row': ret['representations']['msa_first_row'], + 'prev_pair': ret['representations']['pair'], + } + return new_prev + + def apply_network(prev, safe_key): + recycled_batch = {**batch, **prev} + return impl( + batch=recycled_batch, + is_training=is_training, + safe_key=safe_key) + + ret = apply_network(prev=batch.pop("prev"), safe_key=safe_key) + ret["prev"] = get_prev(ret) + + if not return_representations: + del ret['representations'] + return ret + +class EmbeddingsAndEvoformer(hk.Module): + """Embeds the input data and runs Evoformer. + + Produces the MSA, single and pair representations. + """ + + def __init__(self, config, global_config, name='evoformer'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def _relative_encoding(self, batch): + """Add relative position encodings. + + For position (i, j), the value is (i-j) clipped to [-k, k] and one-hotted. + + When not using 'use_chain_relative' the residue indices are used as is, e.g. + for heteromers relative positions will be computed using the positions in + the corresponding chains. + + When using 'use_chain_relative' we add an extra bin that denotes + 'different chain'. Furthermore we also provide the relative chain index + (i.e. sym_id) clipped and one-hotted to the network. And an extra feature + which denotes whether they belong to the same chain type, i.e. it's 0 if + they are in different heteromer chains and 1 otherwise. + + Args: + batch: batch. + Returns: + Feature embedding using the features as described before. + """ + c = self.config + rel_feats = [] + asym_id = batch['asym_id'] + asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :]) + + if "offset" in batch: + offset = batch['offset'] + else: + pos = batch['residue_index'] + offset = pos[:, None] - pos[None, :] + + clipped_offset = jnp.clip( + offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx) + + if c.use_chain_relative: + + final_offset = jnp.where(asym_id_same, clipped_offset, + (2 * c.max_relative_idx + 1) * + jnp.ones_like(clipped_offset)) + + rel_pos = jax.nn.one_hot(final_offset, 2 * c.max_relative_idx + 2) + + rel_feats.append(rel_pos) + + entity_id = batch['entity_id'] + entity_id_same = jnp.equal(entity_id[:, None], entity_id[None, :]) + rel_feats.append(entity_id_same.astype(rel_pos.dtype)[..., None]) + + sym_id = batch['sym_id'] + rel_sym_id = sym_id[:, None] - sym_id[None, :] + + max_rel_chain = c.max_relative_chain + + clipped_rel_chain = jnp.clip( + rel_sym_id + max_rel_chain, a_min=0, a_max=2 * max_rel_chain) + + final_rel_chain = jnp.where(entity_id_same, clipped_rel_chain, + (2 * max_rel_chain + 1) * + jnp.ones_like(clipped_rel_chain)) + rel_chain = jax.nn.one_hot(final_rel_chain, 2 * c.max_relative_chain + 2) + + rel_feats.append(rel_chain) + + else: + rel_pos = jax.nn.one_hot(clipped_offset, 2 * c.max_relative_idx + 1) + rel_feats.append(rel_pos) + + rel_feat = jnp.concatenate(rel_feats, axis=-1) + + return common_modules.Linear( + c.pair_channel, + name='position_activations')( + rel_feat) + + def __call__(self, batch, is_training, safe_key=None): + + c = self.config + gc = self.global_config + + batch = dict(batch) + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + output = {} + + target_feat = batch['msa_feat'][0,:,:21] + msa_feat = batch['msa_feat'] + preprocess_1d = common_modules.Linear(c.msa_channel, name='preprocess_1d')(target_feat) + preprocess_msa = common_modules.Linear(c.msa_channel, name='preprocess_msa')(msa_feat) + msa_activations = preprocess_1d[None] + preprocess_msa + + left_single = common_modules.Linear(c.pair_channel, name='left_single')(target_feat) + right_single = common_modules.Linear(c.pair_channel, name='right_single')(target_feat) + pair_activations = left_single[:, None] + right_single[None] + mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] + mask_2d = mask_2d.astype(jnp.float32) + + if c.recycle_pos: + prev_pseudo_beta = modules.pseudo_beta_fn( + batch['aatype'], batch['prev_pos'], None) + + dgram = modules.dgram_from_positions( + prev_pseudo_beta, **self.config.prev_pos) + pair_activations += common_modules.Linear( + c.pair_channel, name='prev_pos_linear')( + dgram) + + if c.recycle_features: + prev_msa_first_row = hk.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='prev_msa_first_row_norm')( + batch['prev_msa_first_row']) + msa_activations = msa_activations.at[0].add(prev_msa_first_row) + + pair_activations += hk.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='prev_pair_norm')( + batch['prev_pair']) + + if c.max_relative_idx: + pair_activations += self._relative_encoding(batch) + + if c.template.enabled: + template_module = TemplateEmbedding(c.template, gc) + template_batch = { + 'template_aatype': batch['template_aatype'], + 'template_all_atom_positions': batch['template_all_atom_positions'], + 'template_all_atom_mask': batch['template_all_atom_mask'] + } + # Construct a mask such that only intra-chain template features are + # computed, since all templates are for each chain individually. + multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :] + multichain_mask = jnp.where(batch["mask_template_interchain"], multichain_mask, 1) + + safe_key, safe_subkey = safe_key.split() + template_act = template_module( + query_embedding=pair_activations, + template_batch=template_batch, + padding_mask_2d=mask_2d, + multichain_mask_2d=multichain_mask, + is_training=is_training, + dropout_scale=batch["dropout_scale"], + safe_key=safe_subkey) + pair_activations += template_act + + # Extra MSA stack. + (extra_msa_feat, + extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa) + extra_msa_activations = common_modules.Linear( + c.extra_msa_channel, + name='extra_msa_activations')( + extra_msa_feat) + extra_msa_mask = extra_msa_mask.astype(jnp.float32) + + extra_evoformer_input = { + 'msa': extra_msa_activations, + 'pair': pair_activations, + } + extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d} + + extra_evoformer_iteration = modules.EvoformerIteration( + c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack') + + def extra_evoformer_fn(x): + act, safe_key = x + safe_key, safe_subkey = safe_key.split() + extra_evoformer_output = extra_evoformer_iteration( + activations=act, + masks=extra_masks, + is_training=is_training, + dropout_scale=batch["dropout_scale"], + safe_key=safe_subkey) + return (extra_evoformer_output, safe_key) + + if gc.use_remat: + extra_evoformer_fn = hk.remat(extra_evoformer_fn) + + safe_key, safe_subkey = safe_key.split() + extra_evoformer_stack = layer_stack.layer_stack( + c.extra_msa_stack_num_block)( + extra_evoformer_fn) + extra_evoformer_output, safe_key = extra_evoformer_stack( + (extra_evoformer_input, safe_subkey)) + + pair_activations = extra_evoformer_output['pair'] + + # Get the size of the MSA before potentially adding templates, so we + # can crop out the templates later. + num_msa_sequences = msa_activations.shape[0] + evoformer_input = { + 'msa': msa_activations, + 'pair': pair_activations, + } + evoformer_masks = {'msa': batch['msa_mask'].astype(jnp.float32), + 'pair': mask_2d} + + if c.template.enabled: + template_features, template_masks = ( + template_embedding_1d(batch=batch, num_channel=c.msa_channel)) + + evoformer_input['msa'] = jnp.concatenate( + [evoformer_input['msa'], template_features], axis=0) + evoformer_masks['msa'] = jnp.concatenate( + [evoformer_masks['msa'], template_masks], axis=0) + + evoformer_iteration = modules.EvoformerIteration( + c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration') + + def evoformer_fn(x): + act, safe_key = x + safe_key, safe_subkey = safe_key.split() + evoformer_output = evoformer_iteration( + activations=act, + masks=evoformer_masks, + is_training=is_training, + dropout_scale=batch["dropout_scale"], + safe_key=safe_subkey) + return (evoformer_output, safe_key) + + if gc.use_remat: + evoformer_fn = hk.remat(evoformer_fn) + + safe_key, safe_subkey = safe_key.split() + evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)( + evoformer_fn) + + def run_evoformer(evoformer_input): + evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey)) + return evoformer_output + + evoformer_output = run_evoformer(evoformer_input) + + msa_activations = evoformer_output['msa'] + pair_activations = evoformer_output['pair'] + + single_activations = common_modules.Linear( + c.seq_channel, name='single_activations')( + msa_activations[0]) + + output.update({ + 'single': + single_activations, + 'pair': + pair_activations, + # Crop away template rows such that they are not used in MaskedMsaHead. + 'msa': + msa_activations[:num_msa_sequences, :, :], + 'msa_first_row': + msa_activations[0], + }) + + return output + + +class TemplateEmbedding(hk.Module): + """Embed a set of templates.""" + + def __init__(self, config, global_config, name='template_embedding'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, query_embedding, template_batch, padding_mask_2d, + multichain_mask_2d, is_training, dropout_scale, + safe_key=None): + """Generate an embedding for a set of templates. + + Args: + query_embedding: [num_res, num_res, num_channel] a query tensor that will + be used to attend over the templates to remove the num_templates + dimension. + template_batch: A dictionary containing: + `template_aatype`: [num_templates, num_res] aatype for each template. + `template_all_atom_positions`: [num_templates, num_res, 37, 3] atom + positions for all templates. + `template_all_atom_mask`: [num_templates, num_res, 37] mask for each + template. + padding_mask_2d: [num_res, num_res] Pair mask for attention operations. + multichain_mask_2d: [num_res, num_res] Mask indicating which residue pairs + are intra-chain, used to mask out residue distance based features + between chains. + is_training: bool indicating where we are running in training mode. + safe_key: random key generator. + + Returns: + An embedding of size [num_res, num_res, num_channels] + """ + c = self.config + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + num_templates = template_batch['template_aatype'].shape[0] + num_res, _, query_num_channels = query_embedding.shape + + # Embed each template separately. + template_embedder = SingleTemplateEmbedding(self.config, self.global_config) + def partial_template_embedder(template_aatype, + template_all_atom_positions, + template_all_atom_mask, + unsafe_key): + safe_key = prng.SafeKey(unsafe_key) + return template_embedder(query_embedding, + template_aatype, + template_all_atom_positions, + template_all_atom_mask, + padding_mask_2d, + multichain_mask_2d, + is_training, + dropout_scale, + safe_key) + + safe_key, unsafe_key = safe_key.split() + unsafe_keys = jax.random.split(unsafe_key._key, num_templates) + + def scan_fn(carry, x): + return carry + partial_template_embedder(*x), None + + scan_init = jnp.zeros((num_res, num_res, c.num_channels), + dtype=query_embedding.dtype) + summed_template_embeddings, _ = hk.scan( + scan_fn, scan_init, + (template_batch['template_aatype'], + template_batch['template_all_atom_positions'], + template_batch['template_all_atom_mask'], unsafe_keys)) + + embedding = summed_template_embeddings / num_templates + embedding = jax.nn.relu(embedding) + embedding = common_modules.Linear( + query_num_channels, + initializer='relu', + name='output_linear')(embedding) + + return embedding + + +class SingleTemplateEmbedding(hk.Module): + """Embed a single template.""" + + def __init__(self, config, global_config, name='single_template_embedding'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, query_embedding, template_aatype, + template_all_atom_positions, template_all_atom_mask, + padding_mask_2d, multichain_mask_2d, is_training, dropout_scale, + safe_key): + """Build the single template embedding graph. + + Args: + query_embedding: (num_res, num_res, num_channels) - embedding of the + query sequence/msa. + template_aatype: [num_res] aatype for each template. + template_all_atom_positions: [num_res, 37, 3] atom positions for all + templates. + template_all_atom_mask: [num_res, 37] mask for each template. + padding_mask_2d: Padding mask (Note: this doesn't care if a template + exists, unlike the template_pseudo_beta_mask). + multichain_mask_2d: A mask indicating intra-chain residue pairs, used + to mask out between chain distances/features when templates are for + single chains. + is_training: Are we in training mode. + safe_key: Random key generator. + + Returns: + A template embedding (num_res, num_res, num_channels). + """ + gc = self.global_config + c = self.config + assert padding_mask_2d.dtype == query_embedding.dtype + dtype = query_embedding.dtype + num_channels = self.config.num_channels + + def construct_input(query_embedding, template_aatype, + template_all_atom_positions, template_all_atom_mask, + multichain_mask_2d): + + # Compute distogram feature for the template. + template_positions, pseudo_beta_mask = modules.pseudo_beta_fn( + template_aatype, template_all_atom_positions, template_all_atom_mask) + pseudo_beta_mask_2d = (pseudo_beta_mask[:, None] * + pseudo_beta_mask[None, :]) + pseudo_beta_mask_2d *= multichain_mask_2d + template_dgram = modules.dgram_from_positions( + template_positions, **self.config.dgram_features) + template_dgram *= pseudo_beta_mask_2d[..., None] + template_dgram = template_dgram.astype(dtype) + pseudo_beta_mask_2d = pseudo_beta_mask_2d.astype(dtype) + to_concat = [(template_dgram, 1), (pseudo_beta_mask_2d, 0)] + + aatype = jax.nn.one_hot(template_aatype, 22, axis=-1, dtype=dtype) + to_concat.append((aatype[None, :, :], 1)) + to_concat.append((aatype[:, None, :], 1)) + + # Compute a feature representing the normalized vector between each + # backbone affine - i.e. in each residues local frame, what direction are + # each of the other residues. + raw_atom_pos = template_all_atom_positions + + atom_pos = geometry.Vec3Array.from_array(raw_atom_pos) + rigid, backbone_mask = folding_multimer.make_backbone_affine( + atom_pos, + template_all_atom_mask, + template_aatype) + points = rigid.translation + rigid_vec = rigid[:, None].inverse().apply_to_point(points) + unit_vector = rigid_vec.normalized() + unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z] + + backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :] + backbone_mask_2d *= multichain_mask_2d + unit_vector = [x*backbone_mask_2d for x in unit_vector] + + # Note that the backbone_mask takes into account C, CA and N (unlike + # pseudo beta mask which just needs CB) so we add both masks as features. + to_concat.extend([(x, 0) for x in unit_vector]) + to_concat.append((backbone_mask_2d, 0)) + + query_embedding = hk.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='query_embedding_norm')( + query_embedding) + # Allow the template embedder to see the query embedding. Note this + # contains the position relative feature, so this is how the network knows + # which residues are next to each other. + to_concat.append((query_embedding, 1)) + + act = 0 + + for i, (x, n_input_dims) in enumerate(to_concat): + + act += common_modules.Linear( + num_channels, + num_input_dims=n_input_dims, + initializer='relu', + name=f'template_pair_embedding_{i}')(x) + return act + + act = construct_input(query_embedding, template_aatype, + template_all_atom_positions, template_all_atom_mask, + multichain_mask_2d) + + template_iteration = TemplateEmbeddingIteration( + c.template_pair_stack, gc, name='template_embedding_iteration') + + def template_iteration_fn(x): + act, safe_key = x + + safe_key, safe_subkey = safe_key.split() + act = template_iteration( + act=act, + pair_mask=padding_mask_2d, + is_training=is_training, + dropout_scale=dropout_scale, + safe_key=safe_subkey) + return (act, safe_key) + + if gc.use_remat: + template_iteration_fn = hk.remat(template_iteration_fn) + + safe_key, safe_subkey = safe_key.split() + template_stack = layer_stack.layer_stack( + c.template_pair_stack.num_block)( + template_iteration_fn) + act, safe_key = template_stack((act, safe_subkey)) + + act = hk.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='output_layer_norm')( + act) + return act + + +class TemplateEmbeddingIteration(hk.Module): + """Single Iteration of Template Embedding.""" + + def __init__(self, config, global_config, + name='template_embedding_iteration'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, act, pair_mask, is_training=True, dropout_scale=1.0, + safe_key=None): + """Build a single iteration of the template embedder. + + Args: + act: [num_res, num_res, num_channel] Input pairwise activations. + pair_mask: [num_res, num_res] padding mask. + is_training: Whether to run in training mode. + safe_key: Safe pseudo-random generator key. + + Returns: + [num_res, num_res, num_channel] tensor of activations. + """ + c = self.config + gc = self.global_config + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + dropout_wrapper_fn = functools.partial( + modules.dropout_wrapper, + is_training=is_training, + dropout_scale=dropout_scale, + global_config=gc) + + safe_key, *sub_keys = safe_key.split(20) + sub_keys = iter(sub_keys) + + act = dropout_wrapper_fn( + modules.TriangleMultiplication(c.triangle_multiplication_outgoing, gc, + name='triangle_multiplication_outgoing'), + act, + pair_mask, + safe_key=next(sub_keys)) + + act = dropout_wrapper_fn( + modules.TriangleMultiplication(c.triangle_multiplication_incoming, gc, + name='triangle_multiplication_incoming'), + act, + pair_mask, + safe_key=next(sub_keys)) + + act = dropout_wrapper_fn( + modules.TriangleAttention(c.triangle_attention_starting_node, gc, + name='triangle_attention_starting_node'), + act, + pair_mask, + safe_key=next(sub_keys)) + + act = dropout_wrapper_fn( + modules.TriangleAttention(c.triangle_attention_ending_node, gc, + name='triangle_attention_ending_node'), + act, + pair_mask, + safe_key=next(sub_keys)) + + act = dropout_wrapper_fn( + modules.Transition(c.pair_transition, gc, + name='pair_transition'), + act, + pair_mask, + safe_key=next(sub_keys)) + + return act + + +def template_embedding_1d(batch, num_channel): + """Embed templates into an (num_res, num_templates, num_channels) embedding. + + Args: + batch: A batch containing: + template_aatype, (num_templates, num_res) aatype for the templates. + template_all_atom_positions, (num_templates, num_residues, 37, 3) atom + positions for the templates. + template_all_atom_mask, (num_templates, num_residues, 37) atom mask for + each template. + num_channel: The number of channels in the output. + + Returns: + An embedding of shape (num_templates, num_res, num_channels) and a mask of + shape (num_templates, num_res). + """ + + # Embed the templates aatypes. + aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1) + + num_templates = batch['template_aatype'].shape[0] + all_chi_angles = [] + all_chi_masks = [] + for i in range(num_templates): + atom_pos = geometry.Vec3Array.from_array( + batch['template_all_atom_positions'][i, :, :, :]) + template_chi_angles, template_chi_mask = all_atom_multimer.compute_chi_angles( + atom_pos, + batch['template_all_atom_mask'][i, :, :], + batch['template_aatype'][i, :]) + all_chi_angles.append(template_chi_angles) + all_chi_masks.append(template_chi_mask) + chi_angles = jnp.stack(all_chi_angles, axis=0) + chi_mask = jnp.stack(all_chi_masks, axis=0) + + template_features = jnp.concatenate([ + aatype_one_hot, + jnp.sin(chi_angles) * chi_mask, + jnp.cos(chi_angles) * chi_mask, + chi_mask], axis=-1) + + template_mask = chi_mask[:, :, 0] + + template_activations = common_modules.Linear( + num_channel, + initializer='relu', + name='template_single_embedding')( + template_features) + template_activations = jax.nn.relu(template_activations) + template_activations = common_modules.Linear( + num_channel, + initializer='relu', + name='template_projection')( + template_activations) + return template_activations, template_mask diff --git a/colabdesign/af/alphafold/model/tf/data_transforms.py b/colabdesign/af/alphafold/model/tf/data_transforms.py index 699e4ccf..3e9c7e31 100644 --- a/colabdesign/af/alphafold/model/tf/data_transforms.py +++ b/colabdesign/af/alphafold/model/tf/data_transforms.py @@ -119,7 +119,7 @@ def squeeze_features(protein): for k in [ 'domain_name', 'msa', 'num_alignments', 'seq_length', 'sequence', 'superfamily', 'deletion_matrix', 'resolution', - 'between_segment_residues', 'residue_index', 'template_all_atom_masks']: + 'between_segment_residues', 'residue_index', 'template_all_atom_mask']: if k in protein: final_dim = shape_helpers.shape_list(protein[k])[-1] if isinstance(final_dim, int) and final_dim == 1: @@ -318,7 +318,7 @@ def make_msa_mask(protein): return protein -def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask): """Create pseudo beta features.""" is_gly = tf.equal(aatype, residue_constants.restype_order['G']) ca_idx = residue_constants.atom_order['CA'] @@ -328,9 +328,9 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): all_atom_positions[..., ca_idx, :], all_atom_positions[..., cb_idx, :]) - if all_atom_masks is not None: + if all_atom_mask is not None: pseudo_beta_mask = tf.where( - is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx]) + is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]) pseudo_beta_mask = tf.cast(pseudo_beta_mask, tf.float32) return pseudo_beta, pseudo_beta_mask else: @@ -345,7 +345,7 @@ def make_pseudo_beta(protein, prefix=''): pseudo_beta_fn( protein['template_aatype' if prefix else 'all_atom_aatype'], protein[prefix + 'all_atom_positions'], - protein['template_all_atom_masks' if prefix else 'all_atom_mask'])) + protein['template_all_atom_mask' if prefix else 'all_atom_mask'])) return protein diff --git a/colabdesign/af/alphafold/model/tf/protein_features.py b/colabdesign/af/alphafold/model/tf/protein_features.py index eb58bd0c..6498e578 100644 --- a/colabdesign/af/alphafold/model/tf/protein_features.py +++ b/colabdesign/af/alphafold/model/tf/protein_features.py @@ -59,7 +59,7 @@ class FeatureType(enum.Enum): "template_all_atom_positions": (tf.float32, [ NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 3 ]), - "template_all_atom_masks": (tf.float32, [ + "template_all_atom_mask": (tf.float32, [ NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 1 ]), } diff --git a/colabdesign/af/alphafold/model/utils.py b/colabdesign/af/alphafold/model/utils.py index 8ed5361e..ae83ba06 100644 --- a/colabdesign/af/alphafold/model/utils.py +++ b/colabdesign/af/alphafold/model/utils.py @@ -30,10 +30,9 @@ def final_init(config): else: return 'linear' - def batched_gather(params, indices, axis=0, batch_dims=0): """Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`.""" - take_fn = lambda p, i: jnp.take(p, i, axis=axis) + take_fn = lambda p, i: jnp.take(p, i, axis=axis, mode="clip") for _ in range(batch_dims): take_fn = jax.vmap(take_fn) return take_fn(params, indices) diff --git a/colabdesign/af/crop.py b/colabdesign/af/crop.py new file mode 100644 index 00000000..6e8050e1 --- /dev/null +++ b/colabdesign/af/crop.py @@ -0,0 +1,98 @@ +import jax +import jax.numpy as jnp +import numpy as np + +from colabdesign.shared.utils import copy_dict +from colabdesign.af.alphafold.model import config + +class _af_crop: + def _crop(self): + ''' determine positions to crop ''' + (L, max_L, mode) = (sum(self._lengths), self._args["crop_len"], self._args["crop_mode"]) + + if max_L is None or max_L >= L: crop = False + elif self._args["copies"] > 1 and not self._args["repeat"]: crop = False + elif self.protocol in ["partial","binder"]: crop = False + elif mode == "dist" and not hasattr(self,"_dist"): crop = False + else: crop = True + + if crop: + if self.protocol == "fixbb": + self._tmp["cmap"] = self._dist < self.opt["cmap_cutoff"] + + if mode == "slide": + i = jax.random.randint(self.key(),[],0,(L-max_L)+1) + p = np.arange(i,i+max_L) + + if mode == "roll": + i = jax.random.randint(self.key(),[],0,L) + p = np.sort(np.roll(np.arange(L),L-i)[:max_L]) + + if mode == "dist": + i = jax.random.randint(self.key(),[],0,(L-max_L)+1) + p = np.sort(self._dist[i].argsort()[1:][:max_L]) + + if mode == "pair": + # pick random pair of interactig crops + max_L = max_L // 2 + + # pick first crop + i_range = np.append(np.arange(0,(L-2*max_L)+1),np.arange(max_L,(L-max_L)+1)) + i = jax.random.choice(self.key(),i_range,[]) + + # pick second crop + j_range = np.append(np.arange(0,(i-max_L)+1),np.arange(i+max_L,(L-max_L)+1)) + if "cmap" in self._tmp: + # if contact map defined, bias to interacting pairs + w = np.array([self._tmp["cmap"][i:i+max_L,j:j+max_L].sum() for j in j_range]) + 1e-8 + j = jax.random.choice(self.key(), j_range, [], p=w/w.sum()) + else: + j = jax.random.choice(self.key(), j_range, []) + + p = np.sort(np.append(np.arange(i,i+max_L),np.arange(j,j+max_L))) + + def callback(self): + # function to apply after run + cmap, pae = (np.array(self.aux[k]) for k in ["cmap","pae"]) + mask = np.isnan(pae) + + b = 0.9 + _pae = self._tmp.get("pae",np.full_like(pae, 31.0)) + self._tmp["pae"] = np.where(mask, _pae, (1-b)*pae + b*_pae) + + if self.protocol == "hallucination": + _cmap = self._tmp.get("cmap",np.zeros_like(cmap)) + self._tmp["cmap"] = np.where(mask, _cmap, (1-b)*cmap + b*_cmap) + + self.aux.update(self._tmp) + + else: + callback = None + p = np.arange(sum(self._lengths)) + + self.opt["crop_pos"] = p + return callback + +def crop_feat(feat, pos): + ''' + crop features to specified [pos]itions + ''' + if feat is None: return None + + def find(x,k): + i = [] + for j,y in enumerate(x): + if y == k: i.append(j) + return i + + shapes = config.CONFIG.data.eval.feat + NUM_RES = "num residues placeholder" + idx = {k:find(v,NUM_RES) for k,v in shapes.items()} + new_feat = copy_dict(feat) + for k in new_feat.keys(): + if k == "batch": + new_feat[k] = crop_feat(feat[k], pos) + if k in idx: + for i in idx[k]: new_feat[k] = jnp.take(new_feat[k], pos, i) + + return new_feat \ No newline at end of file diff --git a/colabdesign/af/design.py b/colabdesign/af/design.py index c23f2819..c1cd76b3 100644 --- a/colabdesign/af/design.py +++ b/colabdesign/af/design.py @@ -67,21 +67,15 @@ def restart(self, seed=None, optimizer="sgd", opt=None, weights=None, self._traj = {"log":[],"seq":[],"xyz":[],"plddt":[],"pae":[]} self._best, self._tmp = {}, {} - def run(self, backprop=True, callback=None): + def run(self, num_recycles=None, backprop=True, callback=None): '''run model to get outputs, losses and gradients''' - callbacks = [self._crop(), callback] + callbacks = [callback] + if self._args["use_crop"]: callbacks.append(self._crop()) # decide which model params to use - ns,ns_name = [],[] - count = {"openfold":0,"alphafold":0} - for n,name in enumerate(self._model_names): - if "openfold" in name: - if self._args["use_openfold"]: ns.append(n); ns_name.append(name); count["openfold"] += 1 - else: - if self._args["use_alphafold"]: ns.append(n); ns_name.append(name); count["alphafold"] += 1 - for k in count: - if self._args[f"use_{k}"] and count[k] == 0: print(f"ERROR: {k} params not found") + ns_name = self._model_names.copy() + ns = list(range(len(ns_name))) # sub select number of model params if self._args["models"] is not None: @@ -101,7 +95,7 @@ def run(self, backprop=True, callback=None): aux = [] for n in model_num: p = self._model_params[n] - aux.append(self._recycle(p, backprop=backprop)) + aux.append(self._recycle(p, num_recycles=num_recycles, backprop=backprop)) aux = jax.tree_map(lambda *x: jnp.stack(x), *aux) # update aux @@ -118,16 +112,21 @@ def run(self, backprop=True, callback=None): if callback is not None: callback(self) # update log - self.aux["log"] = {**self.aux["losses"], "loss":self.aux["loss"], "ptm":self.aux["ptm"]} + self.aux["log"] = {**self.aux["losses"], "loss":self.aux["loss"], + "ptm":self.aux["ptm"], "i_ptm":self.aux["i_ptm"]} self.aux["log"].update({k:self.opt[k] for k in ["hard","soft","temp"]}) # compute sequence recovery if self.protocol in ["fixbb","partial"] or (self.protocol == "binder" and self._args["redesign"]): - if self.protocol == "partial" and "pos" in self.opt: + if self.protocol == "partial": aatype = self.aux["aatype"].argmax(-1)[...,self.opt["pos"]] else: aatype = self.aux["seq"]["pseudo"].argmax(-1) - self.aux["log"]["seqid"] = (aatype == self._wt_aatype).mean() + + mask = self._wt_aatype != -1 + true = self._wt_aatype[mask] + pred = aatype[...,mask] + self.aux["log"]["seqid"] = (true == pred).mean() self.aux["log"] = to_float(self.aux["log"]) self.aux["log"].update({"recycles":int(self.aux["num_recycles"]), @@ -144,66 +143,72 @@ def _single(self, model_params, backprop=True): aux.update({"loss":loss,"grad":grad}) return aux - def _recycle(self, model_params, backprop=True): + def _recycle(self, model_params, num_recycles=None, backprop=True): '''multiple passes through the model (aka recycle)''' - mode = self._args["recycle_mode"] + if num_recycles is None: + num_recycles = self.opt["num_recycles"] + if mode in ["backprop","add_prev"]: - # recycles compiled into model, only need single-pass - num_recycles = self.opt["num_recycles"] = self._cfg.model.num_recycle aux = self._single(model_params, backprop) else: - - # configure number of recycle to run - num_recycles = self.opt["num_recycles"] - if mode == "average": - # run recycles manually, average gradients - if "crop_pos" in self.opt: L = self.opt["crop_pos"].shape[0] - else: L = self._inputs["residue_index"].shape[-1] - self._inputs["prev"] = {'prev_msa_first_row': np.zeros([L,256]), - 'prev_pair': np.zeros([L,L,128]), - 'prev_pos': np.zeros([L,37,3])} - grad = [] - for _ in range(num_recycles+1): - aux = self._single(model_params, backprop) - grad.append(aux["grad"]) - self._inputs["prev"] = aux["prev"] - # average gradients across - aux["grad"] = jax.tree_map(lambda *x: jnp.stack(x).mean(0), *grad) + L = self._inputs["residue_index"].shape[0] + if self._args["use_crop"]: L = self.opt["crop_pos"].shape[0] - elif mode == "sample": - # randomly select number of recycles to run - self.set_opt(num_recycles=jax.random.randint(self.key(),[],0,num_recycles+1)) - aux = self._single(model_params, backprop) - (self.opt["num_recycles"],num_recycles) = (num_recycles,self.opt["num_recycles"]) + # intialize previous + self._inputs["prev"] = {'prev_msa_first_row': np.zeros([L,256]), + 'prev_pair': np.zeros([L,L,128]), + 'prev_pos': np.zeros([L,37,3])} + + # decide which layers to compute gradients for + cycles = (num_recycles + 1) + mask = [0] * cycles + if mode == "sample": mask[jax.random.randint(self.key(),[],0,cycles)] = 1 + if mode == "average": mask = [1/cycles] * cycles + if mode == "last": mask[-1] = 1 + if mode == "first": mask[0] = 1 - else: - aux = self._single(model_params, backprop) + # gather gradients across recycles + grad = [] + for m in mask: + if m == 0: + aux = self._single(model_params, backprop=False) + else: + aux = self._single(model_params, backprop) + grad.append(jax.tree_map(lambda x:x*m, aux["grad"])) + self._inputs["prev"] = aux["prev"] + + aux["grad"] = jax.tree_map(lambda *x: jnp.stack(x).sum(0), *grad) aux["num_recycles"] = num_recycles return aux - def step(self, lr_scale=1.0, backprop=True, repredict=False, - callback=None, save_best=False, verbose=1): + def step(self, lr_scale=1.0, num_recycles=None, backprop=True, + callback=None, stats_correct=False, save_best=False, verbose=1): '''do one step of gradient descent''' # run - self.run(backprop=backprop, callback=callback) + self.run(num_recycles=num_recycles, backprop=backprop, callback=callback) - # normalize gradient + # apply gradient g = self.aux["grad"]["seq"] - gn = jnp.linalg.norm(g,axis=(-1,-2),keepdims=True) - eff_len = (jnp.square(g).sum(-1,keepdims=True) > 0).sum(-2,keepdims=True) - self.aux["grad"]["seq"] *= jnp.sqrt(eff_len)/(gn+1e-7) + + # statistical correction - doi:10.1101/2022.04.29.490102 + if stats_correct: + g = g - g.sum(-2,keepdims=True) / eff_len + + # normalize gradient + gn = jnp.linalg.norm(g,axis=(-1,-2),keepdims=True) + self.aux["grad"]["seq"] = g * jnp.sqrt(eff_len)/(gn+1e-7) # set learning rate lr = self.opt["lr"] * lr_scale self.aux["grad"] = jax.tree_map(lambda x:x*lr, self.aux["grad"]) - # apply gradient + # update state/params self._state = self._update_fun(self._k, self.aux["grad"], self._state) self._params = self._get_params(self._state) @@ -211,7 +216,6 @@ def step(self, lr_scale=1.0, backprop=True, repredict=False, self._k += 1 # save results - if repredict: self.predict(models=None, verbose=False) self._save_results(save_best=save_best, verbose=verbose) def _update_traj(self): @@ -225,8 +229,10 @@ def _update_traj(self): def _print_log(self, print_str=None): keys = ["models","recycles","hard","soft","temp","seqid","loss", - "msa_ent","plddt","pae","helix","con","i_pae","i_con", - "sc_fape","sc_rmsd","dgram_cce","fape","ptm","rmsd"] + "seq_ent","mlm","plddt","pae","exp_res","con","i_con", + "sc_fape","sc_rmsd","dgram_cce","fape","ptm"] + if sum(self._lengths) > 1: keys.append("i_ptm") + keys.append("rmsd") print(dict_to_str(self.aux["log"], filt=self.opt["weights"], print_str=print_str, keys=keys, ok="rmsd")) @@ -235,92 +241,29 @@ def _save_best(self): if "metric" not in self._best or metric < self._best["metric"]: self._best.update({"metric":metric, "aux":self.aux}) - def clear_best(self): self._best = {} - def _save_results(self, save_best=False, verbose=True): self._update_traj() if save_best: self._save_best() if verbose and (self._k % verbose) == 0: self._print_log(f"{self._k}") - def _crop(self): - ''' determine positions to crop ''' - (L, max_L, mode) = (sum(self._lengths), self._args["crop_len"], self._args["crop_mode"]) - - if max_L is None or max_L >= L: crop = False - elif self._args["copies"] > 1 and not self._args["repeat"]: crop = False - elif self.protocol in ["partial","binder"]: crop = False - elif mode == "dist" and not hasattr(self,"_dist"): crop = False - else: crop = True + def predict(self, seq=None, num_recycles=None, num_models=None, models=None, verbose=True): + '''predict structure for input sequence (if provided)''' - if crop: - if self.protocol == "fixbb": - self._tmp["cmap"] = self._dist < self.opt["cmap_cutoff"] - - if mode == "slide": - i = jax.random.randint(self.key(),[],0,(L-max_L)+1) - p = np.arange(i,i+max_L) - - if mode == "roll": - i = jax.random.randint(self.key(),[],0,L) - p = np.sort(np.roll(np.arange(L),L-i)[:max_L]) - - if mode == "dist": - i = jax.random.randint(self.key(),[],0,(L-max_L)+1) - p = np.sort(self._dist[i].argsort()[1:][:max_L]) - - if mode == "pair": - # pick random pair of interactig crops - max_L = max_L // 2 - - # pick first crop - i_range = np.append(np.arange(0,(L-2*max_L)+1),np.arange(max_L,(L-max_L)+1)) - i = jax.random.choice(self.key(),i_range,[]) - - # pick second crop - j_range = np.append(np.arange(0,(i-max_L)+1),np.arange(i+max_L,(L-max_L)+1)) - if "cmap" in self._tmp: - # if contact map defined, bias to interacting pairs - w = np.array([self._tmp["cmap"][i:i+max_L,j:j+max_L].sum() for j in j_range]) + 1e-8 - j = jax.random.choice(self.key(), j_range, [], p=w/w.sum()) - else: - j = jax.random.choice(self.key(), j_range, []) - - p = np.sort(np.append(np.arange(i,i+max_L),np.arange(j,j+max_L))) - - def callback(self): - # function to apply after run - cmap, pae = (np.array(self.aux[k]) for k in ["cmap","pae"]) - mask = np.isnan(pae) - - b = 0.9 - _pae = self._tmp.get("pae",np.full_like(pae, 31.0)) - self._tmp["pae"] = np.where(mask, _pae, (1-b)*pae + b*_pae) - - if self.protocol == "hallucination": - _cmap = self._tmp.get("cmap",np.zeros_like(cmap)) - self._tmp["cmap"] = np.where(mask, _cmap, (1-b)*cmap + b*_cmap) - - self.aux.update(self._tmp) - - else: - callback = None - p = np.arange(sum(self._lengths)) - - self.opt["crop_pos"] = p - return callback - - def predict(self, seq=None, models=None, verbose=True): # save settings (opt, args, params) = (copy_dict(x) for x in [self.opt, self._args, self._params]) # set settings if seq is not None: self.set_seq(seq=seq, set_state=False) - if models is not None: self.set_opt(num_models=len(models) if isinstance(models,list) else 1) - self.set_opt(hard=True, dropout=False, crop=False, sample_models=False, models=models) + if models is None: + models = self._model_names if num_models is None else self._model_names[:num_models] + num_models = len(models) if isinstance(models,list) else 1 + self.set_opt(hard=True, dropout=False, sample_models=False, + models=models, num_models=num_models, + mlm_dropout=0.0, use_crop=False) # run - self.run(backprop=False) + self.run(num_recycles=num_recycles, backprop=False) if verbose: self._print_log("predict") # reset settings @@ -335,11 +278,12 @@ def design(self, iters=100, hard=0.0, e_hard=None, step=1.0, e_step=None, dropout=True, opt=None, weights=None, - repredict=False, backprop=True, callback=None, + mlm_dropout=0.05, num_recycles=None, + backprop=True, callback=None, save_best=False, verbose=1): # update options/settings (if defined) - self.set_opt(opt, dropout=dropout) + self.set_opt(opt, dropout=dropout, mlm_dropout=mlm_dropout) self.set_weights(weights) m = {"soft":[soft,e_soft],"temp":[temp,e_temp], @@ -358,7 +302,7 @@ def design(self, iters=100, # decay learning rate based on temperature lr_scale = step * ((1 - self.opt["soft"]) + (self.opt["soft"] * self.opt["temp"])) - self.step(lr_scale=lr_scale, backprop=backprop, repredict=repredict, + self.step(lr_scale=lr_scale, num_recycles=num_recycles, backprop=backprop, callback=callback, save_best=save_best, verbose=verbose) def design_logits(self, iters=100, **kwargs): @@ -376,34 +320,41 @@ def design_hard(self, iters=100, **kwargs): # --------------------------------------------------------------------------------- # experimental # --------------------------------------------------------------------------------- - - def design_2stage(self, soft_iters=100, temp_iters=100, hard_iters=10, - num_models=1, **kwargs): - '''two stage design (soft→hard)''' - self.set_opt(num_models=num_models, sample_models=True) # sample models - self.design_soft(soft_iters, **kwargs) - self.design_soft(temp_iters, e_temp=1e-2, **kwargs) - self.set_opt(num_models=len(self._model_params)) # use all models - self.design_hard(hard_iters, temp=1e-2, dropout=False, save_best=True, **kwargs) - def design_3stage(self, soft_iters=300, temp_iters=100, hard_iters=10, - num_models=1, **kwargs): + ramp_recycles=True, num_recycles=None, num_models=1, **kwargs): '''three stage design (logits→soft→hard)''' - self.set_opt(num_models=num_models, sample_models=True) # sample models - self.design_logits(soft_iters, e_soft=1, **kwargs) + + # set starting options + kwargs["num_recycles"] = num_recycles + self.set_opt(num_models=num_models, sample_models=True) + + # logits -> softmax(logits/1.0) + if ramp_recycles and self._args["recycle_mode"] not in ["add_prev","backprop"]: + R = self.opt["num_recycles"] if num_recycles is None else num_recycles + p = 1.0 / (R + 1) + iters = soft_iters // (R + 1) + for r in range(R + 1): + kwargs["num_recycles"] = r + self.design_logits(iters, soft=r*p, e_soft=(r+1)*p, **kwargs) + else: + self.design_logits(soft_iters, e_soft=1, **kwargs) + + # softmax(logits/1.0) -> softmax(logits/0.01) self.design_soft(temp_iters, e_temp=1e-2, **kwargs) + self.set_opt(num_models=len(self._model_params)) # use all models - self.design_hard(hard_iters, temp=1e-2, dropout=False, save_best=True, **kwargs) + self.design_hard(hard_iters, temp=1e-2, dropout=False, mlm_dropout=0.0, save_best=True, + num_recycles=kwargs["num_recycles"], verbose=kwargs.get("verbose",1)) - def design_semigreedy(self, iters=100, tries=20, num_models=1, + def design_semigreedy(self, iters=100, tries=20, num_recycles=None, num_models=1, use_plddt=True, save_best=True, verbose=1): '''semigreedy search''' - self.set_opt(hard=True, dropout=False, crop=False, + self.set_opt(hard=True, dropout=False, use_crop=False, num_models=num_models, sample_models=False) if self._k == 0: - self.run(backprop=False) + self.run(num_recycles=num_recycles, backprop=False) def mut(seq, plddt=None): '''mutate random position''' @@ -438,7 +389,7 @@ def get_seq(): buff = [] for _ in range(tries): self.set_seq(seq=mut(seq, plddt), set_state=False) - self.run(backprop=False) + self.run(num_recycles=num_recycles, backprop=False) buff.append({"aux":self.aux, "seq":self._params["seq"]}) # accept best diff --git a/colabdesign/af/inputs.py b/colabdesign/af/inputs.py index 0292991e..39a07096 100644 --- a/colabdesign/af/inputs.py +++ b/colabdesign/af/inputs.py @@ -5,7 +5,7 @@ from colabdesign.shared.utils import copy_dict from colabdesign.shared.model import soft_seq from colabdesign.af.alphafold.common import residue_constants -from colabdesign.af.alphafold.model import model +from colabdesign.af.alphafold.model import model, config ############################################################################ # AF_INPUTS - functions for modifying inputs before passing to alphafold @@ -14,14 +14,16 @@ class _af_inputs: def _get_seq(self, inputs, params, opt, aux, key): '''get sequence features''' seq = soft_seq(params["seq"], opt, key) - if "pos" in opt and "fix_seq" in opt: - seq_ref = jax.nn.one_hot(self._wt_aatype,20) - p = opt["pos"] - if self.protocol == "partial": - fix_seq = lambda x:jnp.where(opt["fix_seq"],x.at[...,p,:].set(seq_ref),x) + if "fix_pos" in opt: + if "pos" in self.opt: + seq_ref = jax.nn.one_hot(self._wt_aatype_sub,20) + p = opt["pos"][opt["fix_pos"]] + fix_seq = lambda x:x.at[...,p,:].set(seq_ref) else: - fix_seq = lambda x:jnp.where(opt["fix_seq"],x.at[...,p,:].set(seq_ref[...,p,:]),x) - seq = jax.tree_map(fix_seq,seq) + seq_ref = jax.nn.one_hot(self._wt_aatype,20) + p = opt["fix_pos"] + fix_seq = lambda x:x.at[...,p,:].set(seq_ref[...,p,:]) + seq = jax.tree_map(fix_seq, seq) aux.update({"seq":seq, "seq_pseudo":seq["pseudo"]}) # protocol specific modifications to seq features @@ -38,66 +40,68 @@ def _get_seq(self, inputs, params, opt, aux, key): def _update_template(self, inputs, opt, key): ''''dynamically update template features''' + + o = opt["template"] - # aatype = is used to define template's CB coordinates (CA in case of glycine) - # template_aatype = is used as template's sequence + # enable templates + inputs["template_mask"] = inputs["template_mask"].at[:].set(1) batch = inputs["batch"] - if self.protocol in ["partial","fixbb","binder"]: - L = batch["aatype"].shape[0] - if self.protocol in ["partial","fixbb"]: - rt = opt["rm_template_seq"] - aatype = jnp.where(rt,0,batch["aatype"]) - template_aatype = jnp.where(rt,opt["template"]["aatype"],batch["aatype"]) + if self.protocol in ["partial","fixbb","binder"]: + + L = batch["aatype"].shape[0] + # decide which position to remove sequence and/or sidechains + rm = jnp.logical_or(o["rm_seq"],o["rm_sc"]) + rm_seq = jnp.full(L,o["rm_seq"]) + rm_sc = jnp.full(L,rm) if self.protocol == "binder": - if self._args["redesign"]: - rt = opt["rm_template_seq"] - aatype = jnp.where(rt,batch["aatype"].at[self._target_len:].set(0),batch["aatype"]) - template_aatype = jnp.where(rt,batch["aatype"].at[self._target_len:].set(opt["template"]["aatype"]),batch["aatype"]) - else: - aatype = template_aatype = batch["aatype"] - + rm_seq = rm_seq.at[:self._target_len].set(False) + rm_sc = rm_sc.at[:self._target_len].set(False) + + # aatype = is used to define template's CB coordinates (CA in case of glycine) + # template_aatype = is used as template's sequence + aatype = jnp.where(rm_seq,0,batch["aatype"]) + template_aatype = jnp.where(rm_seq,21,batch["aatype"]) + # get pseudo-carbon-beta coordinates (carbon-alpha for glycine) - pb, pb_mask = model.modules.pseudo_beta_fn(aatype, - batch["all_atom_positions"], - batch["all_atom_mask"]) + cb, cb_mask = model.modules.pseudo_beta_fn(aatype, batch["all_atom_positions"], batch["all_atom_mask"]) # define template features template_feats = {"template_aatype": template_aatype, "template_all_atom_positions": batch["all_atom_positions"], - "template_all_atom_masks": batch["all_atom_mask"], - "template_pseudo_beta": pb, - "template_pseudo_beta_mask": pb_mask} + "template_all_atom_mask": batch["all_atom_mask"], + "template_pseudo_beta": cb, + "template_pseudo_beta_mask": cb_mask} - # protocol specific template injection - for k,v in template_feats.items(): - if self.protocol == "binder": - n = self._target_len - inputs[k] = inputs[k].at[:,0,:n].set(v[:n]) - inputs[k] = inputs[k].at[:,-1,n:].set(v[n:]) - - if self.protocol == "fixbb": - inputs[k] = inputs[k].at[:,0].set(v) + # inject template features + if self.protocol == "partial": + pos = opt["pos"] + if self._args["repeat"] or self._args["homooligomer"]: + C,L = self._args["copies"], self._len + pos = (jnp.repeat(pos,C).reshape(-1,C) + jnp.arange(C) * L).T.flatten() + for k,v in template_feats.items(): if self.protocol == "partial": - inputs[k] = inputs[k].at[:,0,opt["pos"]].set(v) + inputs[k] = inputs[k].at[0,pos].set(v) + else: + inputs[k] = inputs[k].at[0].set(v) - if k == "template_all_atom_masks": - rt = jnp.logical_or(opt["rm_template_seq"],opt["rm_template_sc"]) - if self.protocol == "binder": - inputs[k] = jnp.where(rt,inputs[k].at[:,-1,n:,5:].set(0),inputs[k]) + # remove sidechains (mask anything beyond CB) + if k == "template_all_atom_mask": + if self.protocol == "partial": + inputs[k] = inputs[k].at[:,pos,5:].set(jnp.where(rm_sc[:,None],0,inputs[k][:,pos,5:])) else: - inputs[k] = jnp.where(rt,inputs[k].at[:,0,:,5:].set(0),inputs[k]) + inputs[k] = inputs[k].at[:,:,5:].set(jnp.where(rm_sc[:,None],0,inputs[k][:,:,5:])) # dropout template input features - L = inputs["template_aatype"].shape[2] + L = inputs["template_aatype"].shape[1] n = self._target_len if self.protocol == "binder" else 0 - pos_mask = jax.random.bernoulli(key, 1-opt["template"]["dropout"],(L,)) - inputs["template_all_atom_masks"] = inputs["template_all_atom_masks"].at[:,:,n:].multiply(pos_mask[n:,None]) - inputs["template_pseudo_beta_mask"] = inputs["template_pseudo_beta_mask"].at[:,:,n:].multiply(pos_mask[n:]) + pos_mask = jax.random.bernoulli(key, 1-o["dropout"],(L,)) + inputs["template_all_atom_mask"] = inputs["template_all_atom_mask"].at[:,n:].multiply(pos_mask[n:,None]) + inputs["template_pseudo_beta_mask"] = inputs["template_pseudo_beta_mask"].at[:,n:].multiply(pos_mask[n:]) -def update_seq(seq, inputs, seq_1hot=None, seq_pssm=None, msa_input=None): +def update_seq(seq, inputs, seq_1hot=None, seq_pssm=None, mlm=None): '''update the sequence features''' if seq_1hot is None: seq_1hot = seq @@ -107,12 +111,13 @@ def update_seq(seq, inputs, seq_1hot=None, seq_pssm=None, msa_input=None): seq_pssm = jnp.pad(seq_pssm,[[0,0],[0,0],[0,22-seq_pssm.shape[-1]]]) msa_feat = jnp.zeros_like(inputs["msa_feat"]).at[...,0:22].set(seq_1hot).at[...,25:47].set(seq_pssm) - if seq.ndim == 3: - target_feat = jnp.zeros_like(inputs["target_feat"]).at[...,1:21].set(seq[0,...,:20]) - else: - target_feat = jnp.zeros_like(inputs["target_feat"]).at[...,1:21].set(seq[...,:20]) + + if mlm is not None: + X = jax.nn.one_hot(22,23) + X = jnp.zeros(msa_feat.shape[-1]).at[...,:23].set(X).at[...,25:48].set(X) + msa_feat = jnp.where(mlm[None,:,None],X,msa_feat) - inputs.update({"target_feat":target_feat,"msa_feat":msa_feat}) + inputs.update({"msa_feat":msa_feat}) def update_aatype(aatype, inputs): if jnp.issubdtype(aatype.dtype, jnp.integer): @@ -147,28 +152,4 @@ def expand_copies(x, copies, block_diag=True): y = (seq + gap_seq).swapaxes(0,1).reshape(-1,L,22) return jnp.concatenate([x[:1],y],0) else: - return x - -def crop_feat(feat, pos, cfg, add_batch=True): - ''' - crop features to specified [pos]itions - ''' - if feat is None: return None - - def find(x,k): - i = [] - for j,y in enumerate(x): - if y == k: i.append(j) - return i - - shapes = cfg.data.eval.feat - NUM_RES = "num residues placeholder" - idx = {k:find(v,NUM_RES) for k,v in shapes.items()} - new_feat = copy_dict(feat) - for k in new_feat.keys(): - if k == "batch": - new_feat[k] = crop_feat(feat[k], pos, cfg, add_batch=False) - if k in idx: - for i in idx[k]: new_feat[k] = jnp.take(new_feat[k], pos, i + add_batch) - - return new_feat \ No newline at end of file + return x \ No newline at end of file diff --git a/colabdesign/af/loss.py b/colabdesign/af/loss.py index 7f98e789..f95b232d 100644 --- a/colabdesign/af/loss.py +++ b/colabdesign/af/loss.py @@ -2,61 +2,73 @@ import jax.numpy as jnp import numpy as np -from colabdesign.shared.utils import Key +from colabdesign.shared.utils import Key, copy_dict from colabdesign.shared.protein import jnp_rmsd_w, _np_kabsch, _np_rmsd, _np_get_6D_loss from colabdesign.af.alphafold.model import model, folding, all_atom -from colabdesign.af.alphafold.common import confidence_jax +from colabdesign.af.alphafold.common import confidence_jax, residue_constants #################################################### # AF_LOSS - setup loss function #################################################### + class _af_loss: # protocol specific loss functions - def _loss_hallucination(self, inputs, outputs, opt, aux): - plddt_prob = jax.nn.softmax(outputs["predicted_lddt"]["logits"]) - plddt_loss = (plddt_prob * jnp.arange(plddt_prob.shape[-1])[::-1]).mean(-1) - aux["losses"]["plddt"] = plddt_loss.mean() - self._get_pairwise_loss(inputs, outputs, opt, aux) - def _loss_fixbb(self, inputs, outputs, opt, aux): '''get losses''' - plddt_prob = jax.nn.softmax(outputs["predicted_lddt"]["logits"]) - plddt_loss = (plddt_prob * jnp.arange(plddt_prob.shape[-1])[::-1]).mean(-1) - self._get_pairwise_loss(inputs, outputs, opt, aux) - - copies = self._args["copies"] - if self._args["repeat"] or not self._args["homooligomer"]: copies = 1 - + copies = self._args["copies"] if self._args["homooligomer"] else 1 # rmsd loss aln = get_rmsd_loss(inputs, outputs, copies=copies) - rmsd, aux["atom_positions"] = aln["rmsd"], aln["align"](aux["atom_positions"]) - - # dgram loss - aatype = inputs["aatype"][0] - dgram_cce = get_dgram_loss(inputs, outputs, copies=copies, aatype=aatype) + aux["atom_positions"] = aln["align"](aux["atom_positions"]) - aux["losses"].update({"fape": get_fape_loss(inputs, outputs, model_config=self._cfg), - "rmsd": rmsd, "dgram_cce": dgram_cce, "plddt":plddt_loss.mean()}) + # supervised losses + aux["losses"].update({ + "fape": get_fape_loss(inputs, outputs, copies=copies, clamp=opt["fape_cutoff"]), + "dgram_cce": get_dgram_loss(inputs, outputs, copies=copies, aatype=inputs["aatype"]), + "rmsd": aln["rmsd"], + }) + + # unsupervised losses + self._loss_unsupervised(inputs, outputs, opt, aux) def _loss_binder(self, inputs, outputs, opt, aux): '''get losses''' - plddt_prob = jax.nn.softmax(outputs["predicted_lddt"]["logits"]) - plddt_loss = (plddt_prob * jnp.arange(plddt_prob.shape[-1])[::-1]).mean(-1) - aux["losses"]["plddt"] = plddt_loss[...,self._target_len:].mean() - self._get_pairwise_loss(inputs, outputs, opt, aux, interface=True) - + zeros = jnp.zeros(sum(self._lengths)) + binder_id = zeros.at[self._target_len:].set(1) + if "hotspot" in opt: + target_id = zeros.at[opt["hotspot"]].set(1) + else: + target_id = zeros.at[:self._target_len].set(1) + + # unsupervised losses + aux["losses"].update({ + "plddt": get_plddt_loss(outputs, mask_1d=binder_id), # plddt over binder + "exp_res": get_exp_res_loss(outputs, mask_1d=binder_id), + "pae": get_pae_loss(outputs, mask_1d=binder_id), # pae over binder + interface + "con": get_con_loss(inputs, outputs, opt["con"], mask_1d=binder_id, mask_1b=binder_id), + # interface + "i_con": get_con_loss(inputs, outputs, opt["i_con"], mask_1d=binder_id, mask_1b=target_id), + "i_pae": get_pae_loss(outputs, mask_1d=binder_id, mask_1b=target_id), + }) + + # supervised losses if self._args["redesign"]: + aln = get_rmsd_loss(inputs, outputs, L=self._target_len, include_L=False) align_fn = aln["align"] - - fape = get_fape_loss(inputs, outputs, model_config=self._cfg) # compute cce of binder + interface - aatype = inputs["aatype"][0] - cce = get_dgram_loss(inputs, outputs, aatype=aatype, return_cce=True) + aatype = inputs["aatype"] + cce = get_dgram_loss(inputs, outputs, aatype=aatype, return_mtx=True) + + # compute fape + fape = get_fape_loss(inputs, outputs, clamp=opt["fape_cutoff"], return_mtx=True) + + aux["losses"].update({ + "rmsd": aln["rmsd"], + "dgram_cce": cce[self._target_len:,:].mean(), + "fape": fape[self._target_len:,:].mean() + }) - aux["losses"].update({"rmsd":aln["rmsd"], "fape":fape, "dgram_cce":cce[self._target_len:,:].mean()}) - else: align_fn = get_rmsd_loss(inputs, outputs, L=self._target_len)["align"] @@ -64,157 +76,100 @@ def _loss_binder(self, inputs, outputs, opt, aux): def _loss_partial(self, inputs, outputs, opt, aux): '''get losses''' - batch = inputs["batch"] - plddt_prob = jax.nn.softmax(outputs["predicted_lddt"]["logits"]) - plddt_loss = (plddt_prob * jnp.arange(plddt_prob.shape[-1])[::-1]).mean(-1) - aux["losses"]["plddt"] = plddt_loss.mean() - self._get_pairwise_loss(inputs, outputs, opt, aux) - - def sub(x, p, axis=0): - fn = lambda y:jnp.take(y,p,axis) - # fn = lambda y:jnp.tensordot(p,y,(-1,axis)).swapaxes(axis,0) - return jax.tree_map(fn, x) - pos = opt["pos"] - aatype = inputs["aatype"][0] - _config = self._cfg.model.heads.structure_module - - # dgram - dgram = sub(sub(outputs["distogram"]["logits"],pos),pos,1) - if aatype is not None: aatype = sub(aatype,pos,0) - aux["losses"]["dgram_cce"] = get_dgram_loss(inputs, pred=dgram, copies=1, aatype=aatype) - - # rmsd - true = batch["all_atom_positions"] - pred = sub(outputs["structure_module"]["final_atom_positions"],pos) - aln = _get_rmsd_loss(true[:,1], pred[:,1]) - aux["losses"]["rmsd"] = aln["rmsd"] - - # fape - fape_loss = {"loss":0.0} - struct = outputs["structure_module"] - traj = {"traj":sub(struct["traj"],pos,-2)} - folding.backbone_loss(fape_loss, batch, traj, _config) - aux["losses"]["fape"] = fape_loss["loss"] + if self._args["repeat"] or self._args["homooligomer"]: + C,L = self._args["copies"], self._len + pos = (jnp.repeat(pos,C).reshape(-1,C) + jnp.arange(C) * L).T.flatten() + + def sub(x, axis=0): return jnp.take(x,pos,axis) + + copies = self._args["copies"] if self._args["homooligomer"] else 1 + aatype = sub(inputs["aatype"]) + dgram = {"logits":sub(sub(outputs["distogram"]["logits"]),1), + "bin_edges":outputs["distogram"]["bin_edges"]} + atoms = sub(outputs["structure_module"]["final_atom_positions"]) + + I = {"aatype": aatype, "batch": inputs["batch"]} + O = {"distogram": dgram, "structure_module": {"final_atom_positions": atoms}} + aln = get_rmsd_loss(I, O, copies=copies) + + # supervised losses + aux["losses"].update({ + "dgram_cce": get_dgram_loss(I, O, copies=copies, aatype=I["aatype"]), + "fape": get_fape_loss(I, O, copies=copies, clamp=opt["fape_cutoff"]), + "rmsd": aln["rmsd"], + }) + + # unsupervised losses + self._loss_unsupervised(inputs, outputs, opt, aux) # sidechain specific losses - if self._args["use_sidechains"]: - # sc_fape - pred_pos = sub(struct["final_atom14_positions"],pos) - sc_struct = {**folding.compute_renamed_ground_truth(self._sc["batch"], pred_pos), - "sidechains":{k: sub(struct["sidechains"][k],pos,1) for k in ["frames","atom_pos"]}} - - aux["losses"]["sc_fape"] = folding.sidechain_loss(batch, sc_struct, _config)["loss"] + if self._args["use_sidechains"] and copies == 1: + + struct = outputs["structure_module"] + pred_pos = sub(struct["final_atom14_positions"]) + true_pos = all_atom.atom37_to_atom14(inputs["batch"]["all_atom_positions"], self._sc["batch"]) # sc_rmsd - true_pos = all_atom.atom37_to_atom14(batch["all_atom_positions"], self._sc["batch"]) - aln = get_sc_rmsd(true_pos, pred_pos, self._sc["pos"]) + aln = _get_sc_rmsd_loss(true_pos, pred_pos, self._sc["pos"]) aux["losses"]["sc_rmsd"] = aln["rmsd"] + + # sc_fape + if not self._args["use_multimer"]: + sc_struct = {**folding.compute_renamed_ground_truth(self._sc["batch"], pred_pos), + "sidechains":{k: sub(struct["sidechains"][k],1) for k in ["frames","atom_pos"]}} + batch = {**inputs["batch"], + **all_atom.atom37_to_frames(**inputs["batch"])} + aux["losses"]["sc_fape"] = folding.sidechain_loss(batch, sc_struct, + self._cfg.model.heads.structure_module)["loss"] + + else: + # TODO + print("ERROR: 'sc_fape' not currently supported for 'multimer' mode") + aux["losses"]["sc_fape"] = 0.0 # align final atoms aux["atom_positions"] = aln["align"](aux["atom_positions"]) - def _get_pairwise_loss(self, inputs, outputs, opt, aux, interface=False): - '''get pairwise loss features''' - - # decide on what offset to use - if "offset" in inputs: - offset = inputs["offset"][0] - else: - idx = inputs["residue_index"][0] - offset = idx[:,None] - idx[None,:] - - # pae loss - pae_prob = jax.nn.softmax(outputs["predicted_aligned_error"]["logits"]) - pae = (pae_prob * jnp.arange(pae_prob.shape[-1])).mean(-1) + def _loss_hallucination(self, inputs, outputs, opt, aux): + # unsupervised losses + self._loss_unsupervised(inputs, outputs, opt, aux) + + def _loss_unsupervised(self, inputs, outputs, opt, aux): + + # define masks + mask_1d = jnp.ones_like(inputs["asym_id"]) + if "pos" in opt: + C,L = self._args["copies"], self._len + pos = opt["pos"] + if C > 1: pos = (jnp.repeat(pos,C).reshape(-1,C) + jnp.arange(C) * L).T.flatten() + mask_1d = mask_1d.at[pos].set(0) - # define distogram - dgram = outputs["distogram"]["logits"] - dgram_bins = jnp.append(0,outputs["distogram"]["bin_edges"]) - if not interface: - aux["losses"].update({"con":get_con_loss(dgram, dgram_bins, offset=offset, **opt["con"]).mean(), - "helix":get_helix_loss(dgram, dgram_bins, offset=offset, **opt["con"]), - "pae":pae.mean()}) - else: - # split pae/con into inter/intra - if self.protocol == "binder": - (L,H) = (self._target_len, opt.get("pos",None)) - else: - (L,H) = (self._len, None) - - def split_feats(v): - '''split pairwise features into intra and inter features''' - if v is None: - return None,None - - (aa,bb) = (v[:L,:L],v[L:,L:]) - if H is None: - (ab,ba) = (v[:L,L:],v[L:,:L]) - else: - (ab,ba) = (v[H, L:],v[L:, H]) - - abba = (ab + ba.swapaxes(0,1)) / 2 - if self.protocol == "binder": - return bb,abba.swapaxes(0,1) - else: - return aa,abba - - x_offset, ix_offset = split_feats(jnp.abs(offset)) - for k,v in zip(["pae","con"], [pae,dgram]): - x, ix = split_feats(v) - if k == "con": - aux["losses"]["helix"] = get_helix_loss(x, dgram_bins, x_offset) - x = get_con_loss(x, dgram_bins, offset=x_offset, **opt["con"]) - ix = get_con_loss(ix, dgram_bins, offset=ix_offset, **opt["i_con"]) - - aux["losses"].update({k:x.mean(),f"i_{k}":ix.mean()}) + mask_2d = inputs["asym_id"][:,None] == inputs["asym_id"][None,:] + masks = {"mask_1d":mask_1d, + "mask_2d":mask_2d} + + # define losses + losses = { + "exp_res": get_exp_res_loss(outputs, mask_1d=mask_1d), + "plddt": get_plddt_loss(outputs, mask_1d=mask_1d), + "pae": get_pae_loss(outputs, **masks), + "con": get_con_loss(inputs, outputs, opt["con"], **masks) + } + + # define losses at interface + if self._args["copies"] > 1 and not self._args["repeat"]: + masks = {"mask_1d": mask_1d if self._args["homoligomer"] else jnp.ones_like(mask_1d), + "mask_2d": mask_2d == False} + losses.update({ + "i_pae": get_pae_loss(outputs, **masks), + "i_con": get_con_loss(inputs, outputs, opt["i_con"], **masks), + }) + + aux["losses"].update(losses) ##################################################################################### -##################################################################################### -def get_pw_con_loss(dgram, dgram_bins, cutoff=None, binary=True): - '''dgram to contacts''' - if cutoff is None: cutoff = dgram_bins[-1] - bins = dgram_bins < cutoff - px = jax.nn.softmax(dgram) - px_ = jax.nn.softmax(dgram - 1e7 * (1-bins)) - # binary/cateogorical cross-entropy - con_loss_cat_ent = -(px_ * jax.nn.log_softmax(dgram)).sum(-1) - con_loss_bin_ent = -jnp.log((bins * px + 1e-8).sum(-1)) - return jnp.where(binary, con_loss_bin_ent, con_loss_cat_ent) - -def get_con_loss(dgram, dgram_bins, cutoff=None, binary=True, - num=1, seqsep=0, offset=None): - '''convert distogram into contact loss''' - x = get_pw_con_loss(dgram, dgram_bins, cutoff, binary) - a,b = x.shape - if offset is None: - mask = jnp.abs(jnp.arange(a)[:,None] - jnp.arange(b)[None,:]) >= seqsep - else: - mask = jnp.abs(offset) >= seqsep - x = jnp.sort(jnp.where(mask,x,jnp.nan)) - k_mask = (jnp.arange(b) < num) * (jnp.isnan(x) == False) - return jnp.where(k_mask,x,0.0).sum(-1) / (k_mask.sum(-1) + 1e-8) - -def get_helix_loss(dgram, dgram_bins, offset=None, **kwargs): - '''helix bias loss''' - x = get_pw_con_loss(dgram, dgram_bins, cutoff=6.0, binary=True) - if offset is None: - return jnp.diagonal(x,3).mean() - else: - mask = offset == 3 - return jnp.where(mask,x,0.0).sum() / (mask.sum() + 1e-8) - -def get_contact_map(outputs, dist=8.0): - '''get contact map from distogram''' - dist_logits = outputs["distogram"]["logits"] - dist_bins = jax.numpy.append(0,outputs["distogram"]["bin_edges"]) - dist_mtx = dist_bins[dist_logits.argmax(-1)] - return (jax.nn.softmax(dist_logits) * (dist_bins < dist)).sum(-1) - -#################### -# confidence metrics -#################### def get_plddt(outputs): logits = outputs["predicted_lddt"]["logits"] num_bins = logits.shape[-1] @@ -231,52 +186,190 @@ def get_pae(outputs): bin_centers = jnp.append(bin_centers,bin_centers[-1]+step) return (prob*bin_centers).sum(-1) -def get_ptm(outputs): +def get_ptm(inputs, outputs, interface=False): pae = outputs["predicted_aligned_error"] - return confidence_jax.predicted_tm_score_jax(**pae) + if "asym_id" not in pae: + pae["asym_id"] = inputs["asym_id"] + return confidence_jax.predicted_tm_score_jax(**pae, interface=interface) + +def get_contact_map(outputs, dist=8.0): + '''get contact map from distogram''' + dist_logits = outputs["distogram"]["logits"] + dist_bins = jax.numpy.append(0,outputs["distogram"]["bin_edges"]) + dist_mtx = dist_bins[dist_logits.argmax(-1)] + return (jax.nn.softmax(dist_logits) * (dist_bins < dist)).sum(-1) + +#################### +# confidence metrics +#################### +def mask_loss(x, mask=None, mask_grad=False): + if mask is None: + return x.mean() + else: + x_masked = (x * mask).sum() / (1e-8 + mask.sum()) + if mask_grad: + return jax.lax.stop_gradient(x.mean() - x_masked) + x_masked + else: + return x_masked + +def get_exp_res_loss(outputs, mask_1d=None): + p = jax.nn.sigmoid(outputs["experimentally_resolved"]["logits"]) + p = 1 - p[...,residue_constants.atom_order["CA"]] + return mask_loss(p, mask_1d) + +def get_plddt_loss(outputs, mask_1d=None): + p = jax.nn.softmax(outputs["predicted_lddt"]["logits"]) + p = (p * jnp.arange(p.shape[-1])[::-1]).mean(-1) + return mask_loss(p, mask_1d) + +def get_pae_loss(outputs, mask_1d=None, mask_1b=None, mask_2d=None): + p = jax.nn.softmax(outputs["predicted_aligned_error"]["logits"]) + p = (p * jnp.arange(p.shape[-1])).mean(-1) + p = (p + p.T)/2 + L = p.shape[0] + if mask_1d is None: mask_1d = jnp.ones(L) + if mask_1b is None: mask_1b = jnp.ones(L) + if mask_2d is None: mask_2d = jnp.ones((L,L)) + mask_2d = mask_2d * mask_1d[:,None] * mask_1b[None,:] + return mask_loss(p, mask_2d) + +def get_con_loss(inputs, outputs, opt, + mask_1d=None, mask_1b=None, mask_2d=None): + + # get top k + def min_k(x, k=1, mask=None): + y = jnp.sort(x if mask is None else jnp.where(mask,x,jnp.nan)) + k_mask = jnp.logical_and(jnp.arange(y.shape[-1]) < k, jnp.isnan(y) == False) + return jnp.where(k_mask,y,0).sum(-1) / (k_mask.sum(-1) + 1e-8) + + # decide on what offset to use + if "offset" in inputs: + offset = inputs["offset"] + else: + idx = inputs["residue_index"].flatten() + offset = idx[:,None] - idx[None,:] + + # define distogram + dgram = outputs["distogram"]["logits"] + dgram_bins = jnp.append(0,outputs["distogram"]["bin_edges"]) + + p = _get_con_loss(dgram, dgram_bins, cutoff=opt["cutoff"], binary=opt["binary"]) + if "seqsep" in opt: + m = jnp.abs(offset) >= opt["seqsep"] + else: + m = jnp.ones_like(offset) + + # mask results + if mask_1d is None: mask_1d = jnp.ones(m.shape[0]) + if mask_1b is None: mask_1b = jnp.ones(m.shape[0]) + + if mask_2d is None: + m = jnp.logical_and(m, mask_1b) + else: + m = jnp.logical_and(m, mask_2d) + + p = min_k(p, opt["num"], m) + return min_k(p, opt["num_pos"], mask_1d) + +def _get_con_loss(dgram, dgram_bins, cutoff=None, binary=True): + '''dgram to contacts''' + if cutoff is None: cutoff = dgram_bins[-1] + bins = dgram_bins < cutoff + px = jax.nn.softmax(dgram) + px_ = jax.nn.softmax(dgram - 1e7 * (1-bins)) + # binary/cateogorical cross-entropy + con_loss_cat_ent = -(px_ * jax.nn.log_softmax(dgram)).sum(-1) + con_loss_bin_ent = -jnp.log((bins * px + 1e-8).sum(-1)) + return jnp.where(binary, con_loss_bin_ent, con_loss_cat_ent) + +def _get_helix_loss(dgram, dgram_bins, offset=None, **kwargs): + '''helix bias loss''' + x = _get_con_loss(dgram, dgram_bins, cutoff=6.0, binary=True) + if offset is None: + return jnp.diagonal(x,3).mean() + else: + mask = offset == 3 + return jnp.where(mask,x,0.0).sum() / (mask.sum() + 1e-8) #################### # loss functions #################### -def get_dgram_loss(inputs, outputs=None, copies=1, aatype=None, pred=None, return_cce=False): +def get_dgram_loss(inputs, outputs, copies=1, aatype=None, return_mtx=False): batch = inputs["batch"] # gather features if aatype is None: aatype = batch["aatype"] - if pred is None: pred = outputs["distogram"]["logits"] + pred = outputs["distogram"]["logits"] # get true features x, weights = model.modules.pseudo_beta_fn(aatype=aatype, all_atom_positions=batch["all_atom_positions"], - all_atom_masks=batch["all_atom_mask"]) + all_atom_mask=batch["all_atom_mask"]) dm = jnp.square(x[:,None]-x[None,:]).sum(-1,keepdims=True) bin_edges = jnp.linspace(2.3125, 21.6875, pred.shape[-1] - 1) true = jax.nn.one_hot((dm > jnp.square(bin_edges)).sum(-1), pred.shape[-1]) - return _get_dgram_loss(true, pred, weights, copies, return_cce=return_cce) - -def _get_dgram_loss(true, pred, weights=None, copies=1, return_cce=False): + def loss_fn(t,p,m): + cce = -(t*jax.nn.log_softmax(p)).sum(-1) + return cce, (cce*m).sum((-1,-2))/(m.sum((-1,-2))+1e-8) + return _get_pw_loss(true, pred, loss_fn, weights=weights, copies=copies, return_mtx=return_mtx) + +def get_fape_loss(inputs, outputs, copies=1, clamp=10.0, return_mtx=False): + + def robust_norm(x, axis=-1, keepdims=False, eps=1e-8): + return jnp.sqrt(jnp.square(x).sum(axis=axis, keepdims=keepdims) + eps) + + def get_R(N, CA, C): + (v1,v2) = (C-CA, N-CA) + e1 = v1 / robust_norm(v1, axis=-1, keepdims=True) + c = jnp.einsum('li, li -> l', e1, v2)[:,None] + e2 = v2 - c * e1 + e2 = e2 / robust_norm(e2, axis=-1, keepdims=True) + e3 = jnp.cross(e1, e2, axis=-1) + return jnp.concatenate([e1[:,:,None], e2[:,:,None], e3[:,:,None]], axis=-1) + + def get_ij(R,T): + return jnp.einsum('rji,rsj->rsi',R,T[None,:]-T[:,None]) + + def loss_fn(t,p,m): + fape = robust_norm(t-p) + fape = jnp.clip(fape, 0, clamp) / 10.0 + return fape, (fape*m).sum((-1,-2))/(m.sum((-1,-2)) + 1e-8) + + true = inputs["batch"]["all_atom_positions"] + pred = outputs["structure_module"]["final_atom_positions"] + + N,CA,C = (residue_constants.atom_order[k] for k in ["N","CA","C"]) + + true_mask = inputs["batch"]["all_atom_mask"] + weights = true_mask[:,N] * true_mask[:,CA] * true_mask[:,C] + + true = get_ij(get_R(true[:,N],true[:,CA],true[:,C]),true[:,CA]) + pred = get_ij(get_R(pred[:,N],pred[:,CA],pred[:,C]),pred[:,CA]) + + return _get_pw_loss(true, pred, loss_fn, weights=weights, copies=copies, return_mtx=return_mtx) + +def _get_pw_loss(true, pred, loss_fn, weights=None, copies=1, return_mtx=False): length = true.shape[0] - if weights is None: weights = jnp.ones(length) + + if weights is None: + weights = jnp.ones(length) + F = {"t":true, "p":pred, "m":weights[:,None] * weights[None,:]} - def cce_fn(t,p,m): - cce = -(t*jax.nn.log_softmax(p)).sum(-1) - return cce, (cce*m).sum((-1,-2))/(m.sum((-1,-2))+1e-8) - if copies > 1: (L,C) = (length//copies, copies-1) # intra (L,L,F) intra = jax.tree_map(lambda x:x[:L,:L], F) - cce, cce_loss = cce_fn(**intra) + mtx, loss = loss_fn(**intra) # inter (C*L,L,F) inter = jax.tree_map(lambda x:x[L:,:L], F) if C == 0: - i_cce, i_cce_loss = cce_fn(**inter) + i_mtx, i_loss = loss_fn(**inter) else: # (C,L,L,F) @@ -286,20 +379,20 @@ def cce_fn(t,p,m): "m":inter["m"][:,None,:,:,0]} # (C,1,L,L) # (C,C,L,L,F) → (C,C,L,L) → (C,C) → (C) → () - i_cce, i_cce_loss = cce_fn(**inter) - i_cce_loss = sum([i_cce_loss.min(i).sum() for i in [0,1]]) / 2 + i_mtx, i_loss = loss_fn(**inter) + i_loss = sum([i_loss.min(i).sum() for i in [0,1]]) / 2 - total_loss = (cce_loss + i_cce_loss) / copies - return (cce, i_cce) if return_cce else total_loss + total_loss = (loss + i_loss) / copies + return (mtx, i_mtx) if return_mtx else total_loss else: - cce, cce_loss = cce_fn(**F) - return cce if return_cce else cce_loss - + mtx, loss = loss_fn(**F) + return mtx if return_mtx else loss + def get_rmsd_loss(inputs, outputs, L=None, include_L=True, copies=1): batch = inputs["batch"] - true = batch["all_atom_positions"][:,1,:] - pred = outputs["structure_module"]["final_atom_positions"][:,1,:] + true = batch["all_atom_positions"][:,1] + pred = outputs["structure_module"]["final_atom_positions"][:,1] weights = batch["all_atom_mask"][:,1] return _get_rmsd_loss(true, pred, weights=weights, L=L, include_L=include_L, copies=copies) @@ -355,7 +448,7 @@ def _get_rmsd_loss(true, pred, weights=None, L=None, include_L=True, copies=1): return {"rmsd":rmsd, "align":align_fn} -def get_sc_rmsd(true, pred, sc): +def _get_sc_rmsd_loss(true, pred, sc): '''get sidechain rmsd + alignment function''' # select atoms @@ -387,21 +480,22 @@ def get_sc_rmsd(true, pred, sc): rmsd = jnp.sqrt(msd + 1e-8) return {"rmsd":rmsd, "align":align_fn} -#-------------------------------------- -# TODO (make copies friendly) -#-------------------------------------- -def get_fape_loss(inputs, outputs, model_config, use_clamped_fape=False): - batch = inputs["batch"] - sub_batch = jax.tree_map(lambda x: x, batch) - sub_batch["use_clamped_fape"] = use_clamped_fape - loss = {"loss":0.0} - _config = model_config.model.heads.structure_module - folding.backbone_loss(loss, sub_batch, outputs["structure_module"], _config) - return loss["loss"] - -def get_6D_loss(inputs, outputs, **kwargs): - batch = inputs["batch"] - true = batch["all_atom_positions"] - pred = outputs["structure_module"]["final_atom_positions"] - mask = batch["all_atom_mask"] - return _np_get_6D_loss(true, pred, mask, **kwargs) \ No newline at end of file +def get_seq_ent_loss(inputs, outputs, opt): + x = inputs["seq"]["logits"] / opt["temp"] + ent = -(jax.nn.softmax(x) * jax.nn.log_softmax(x)).sum(-1) + mask = jnp.ones(ent.shape[-1]) + if "fix_pos" in opt: + if "pos" in opt: + p = opt["pos"][opt["fix_pos"]] + else: + p = opt["fix_pos"] + mask = mask.at[p].set(0) + ent = (ent * mask).sum() / (mask.sum() + 1e-8) + return {"seq_ent":ent.mean()} + +def get_mlm_loss(outputs, mask, truth=None): + x = outputs["masked_msa"]["logits"] + if truth is None: truth = jax.nn.softmax(x[...,:20]) + ent = -(truth[...,:20] * jax.nn.log_softmax(x)[...,:20]).sum(-1) + ent = (ent * mask).sum(-1) / (mask.sum() + 1e-8) + return {"mlm":ent.mean()} \ No newline at end of file diff --git a/colabdesign/af/model.py b/colabdesign/af/model.py index fd46276c..c843cb4e 100644 --- a/colabdesign/af/model.py +++ b/colabdesign/af/model.py @@ -9,26 +9,30 @@ from colabdesign.shared.utils import Key from colabdesign.af.prep import _af_prep -from colabdesign.af.loss import _af_loss, get_plddt, get_pae, get_contact_map, get_ptm +from colabdesign.af.loss import _af_loss, get_plddt, get_pae, get_contact_map, get_ptm, get_seq_ent_loss, get_mlm_loss from colabdesign.af.utils import _af_utils from colabdesign.af.design import _af_design -from colabdesign.af.inputs import _af_inputs, update_seq, update_aatype, crop_feat +from colabdesign.af.inputs import _af_inputs, update_seq, update_aatype +from colabdesign.af.crop import _af_crop, crop_feat ################################################################ # MK_DESIGN_MODEL - initialize model, and put it all together ################################################################ -class mk_af_model(design_model, _af_inputs, _af_loss, _af_prep, _af_design, _af_utils): +class mk_af_model(design_model, _af_inputs, _af_loss, _af_prep, _af_design, _af_utils, _af_crop): def __init__(self, protocol="fixbb", num_seq=1, num_models=1, sample_models=True, - recycle_mode="average", num_recycles=0, + recycle_mode="last", num_recycles=0, use_templates=False, best_metric="loss", - crop_len=None, crop_mode="slide", - debug=False, use_alphafold=True, use_openfold=False, - loss_callback=None, data_dir="."): + model_names=None, + use_openfold=False, use_alphafold=True, + use_multimer=False, + use_mlm=False, + use_crop=False, crop_len=None, crop_mode="slide", + debug=False, loss_callback=None, data_dir="."): assert protocol in ["fixbb","hallucination","binder","partial"] - assert recycle_mode in ["average","add_prev","backprop","last","sample"] + assert recycle_mode in ["average","first","last","sample","add_prev","backprop"] assert crop_mode in ["slide","roll","pair","dist"] # decide if templates should be used @@ -37,23 +41,25 @@ def __init__(self, protocol="fixbb", num_seq=1, self.protocol = protocol self._loss_callback = loss_callback self._num = num_seq - self._args = {"use_templates":use_templates, - "recycle_mode":recycle_mode, + self._args = {"use_templates":use_templates, "use_multimer":use_multimer, + "recycle_mode":recycle_mode, "use_mlm": use_mlm, "debug":debug, "repeat":False, "homooligomer":False, "copies":1, "best_metric":best_metric, - 'use_alphafold':use_alphafold, 'use_openfold':use_openfold, - "crop":False, "crop_len":crop_len,"crop_mode":crop_mode, + "use_crop":use_crop, "crop_len":crop_len, "crop_mode":crop_mode, "models":None} - self.opt = {"dropout":True, "lr":1.0, "use_pssm":False, + self.opt = {"dropout":True, "lr":1.0, "use_pssm":False, "mlm_dropout":0.05, "num_recycles":num_recycles, "num_models":num_models, "sample_models":sample_models, "temp":1.0, "soft":0.0, "hard":0.0, "bias":0.0, "alpha":2.0, - "con": {"num":2, "cutoff":14.0, "binary":False, "seqsep":9}, - "i_con": {"num":1, "cutoff":20.0, "binary":False}, - "template": {"aatype":21, "dropout":0.0}, - "weights": {"helix":0.0, "plddt":0.01, "pae":0.01}, - "cmap_cutoff": 10.0} + "con": {"num":2, "cutoff":14.0, "binary":False, "seqsep":9, "num_pos":float("inf")}, + "i_con": {"num":1, "cutoff":21.6875, "binary":False, "num_pos":float("inf")}, + "template": {"dropout":0.0, "rm_ic":False, "rm_seq":True, "rm_sc":True}, + "weights": {"seq_ent":0.0, "plddt":0.0, "pae":0.0, "exp_res":0.0}, + "cmap_cutoff": 10.0, "fape_cutoff":10.0} + + if self._args["use_mlm"]: + self.opt["weights"]["mlm"] = 0.1 self._params = {} self._inputs = {} @@ -61,43 +67,40 @@ def __init__(self, protocol="fixbb", num_seq=1, ############################# # configure AlphaFold ############################# - cfg = config.model_config("model_1_ptm" if use_templates else "model_3_ptm") - cfg.model.global_config.use_remat = True - # number of sequences - if use_templates: - cfg.data.eval.max_templates = 1 - cfg.data.eval.max_msa_clusters = num_seq + 1 + if use_multimer: + cfg = config.model_config("model_1_multimer") else: - cfg.data.eval.max_templates = 0 - cfg.data.eval.max_msa_clusters = num_seq - cfg.data.common.max_extra_msa = 1 - cfg.data.eval.masked_msa_replace_fraction = 0 - - # number of recycles - if recycle_mode == "average": num_recycles = 0 - cfg.data.common.num_recycle = 0 # for feature processing - cfg.model.num_recycle = num_recycles # for model configuration + cfg = config.model_config("model_1_ptm" if use_templates else "model_3_ptm") + if recycle_mode in ["average","first","last","sample"]: num_recycles = 0 + cfg.model.num_recycle = num_recycles + cfg.model.global_config.use_remat = True # setup model self._cfg = cfg # load model_params - model_names = [] - if use_templates: - model_names += [f"model_{k}_ptm" for k in [1,2]] - model_names += [f"openfold_model_ptm_{k}" for k in [1,2]] - else: - model_names += [f"model_{k}_ptm" for k in [1,2,3,4,5]] - model_names += [f"openfold_model_ptm_{k}" for k in [1,2]] + ["openfold_model_no_templ_ptm_1"] + if model_names is None: + model_names = [] + if use_multimer: + model_names += [f"model_{k}_multimer_v2" for k in [1,2,3,4,5]] + else: + if use_templates: + if use_alphafold: model_names += [f"model_{k}_ptm" for k in [1,2]] + if use_openfold: model_names += [f"openfold_model_ptm_{k}" for k in [1,2]] + else: + if use_alphafold: model_names += [f"model_{k}_ptm" for k in [1,2,3,4,5]] + if use_openfold: model_names += [f"openfold_model_ptm_{k}" for k in [1,2]] + ["openfold_model_no_templ_ptm_1"] self._model_params, self._model_names = [],[] for model_name in model_names: params = data.get_model_haiku_params(model_name=model_name, data_dir=data_dir) if params is not None: - if not use_templates: + if not use_multimer and not use_templates: params = {k:v for k,v in params.items() if "template" not in k} self._model_params.append(params) self._model_names.append(model_name) + else: + print(f"WARNING: '{model_name}' not found") ##################################### # set protocol specific functions @@ -108,7 +111,10 @@ def __init__(self, protocol="fixbb", num_seq=1, def _get_model(self, cfg, callback=None): - runner = model.RunModel(cfg, is_training=True, recycle_mode=self._args["recycle_mode"]) + a = self._args + runner = model.RunModel(cfg, is_training=True, + recycle_mode=a["recycle_mode"], + use_multimer=a["use_multimer"]) # setup function to get gradients def _model(params, model_params, inputs, key, opt): @@ -120,41 +126,37 @@ def _model(params, model_params, inputs, key, opt): # INPUTS ####################################################################### + L = inputs["aatype"].shape[0] + # get sequence seq = self._get_seq(inputs, params, opt, aux, key()) # update sequence features pssm = jnp.where(opt["use_pssm"], seq["pssm"], seq["pseudo"]) - update_seq(seq["pseudo"], inputs, seq_pssm=pssm) + if a["use_mlm"]: + mlm = jax.random.bernoulli(key(), opt["mlm_dropout"], (L,)) + update_seq(seq["pseudo"], inputs, seq_pssm=pssm, mlm=mlm) + else: + update_seq(seq["pseudo"], inputs, seq_pssm=pssm) # update amino acid sidechain identity - B,L = inputs["aatype"].shape[:2] aatype = jax.nn.one_hot(seq["pseudo"][0].argmax(-1),21) - update_aatype(jnp.broadcast_to(aatype,(B,L,21)), inputs) + update_aatype(jnp.broadcast_to(aatype,(L,21)), inputs) # update template features - if self._args["use_templates"]: + if a["use_templates"]: self._update_template(inputs, opt, key()) + inputs["mask_template_interchain"] = opt["template"]["rm_ic"] # set dropout - inputs["dropout_scale"] = jnp.array([opt["dropout"]]).astype(float) + inputs["dropout_scale"] = jnp.array(opt["dropout"], dtype=float) - # decide number of recycles to do - if self._args["recycle_mode"] in ["last","sample"]: - inputs["num_iter_recycling"] = jnp.array([opt["num_recycles"]]) - - # crop inputs - if opt["crop_pos"].shape[0] < L: - inputs = crop_feat(inputs, opt["crop_pos"], self._cfg) - - if "batch" in inputs: - # need frames for fape - batch = inputs.pop("batch") - batch.update(all_atom.atom37_to_frames(**batch)) - else: - batch = None - - inputs["batch"] = batch + # experimental - crop inputs + if a["use_crop"] and opt["crop_pos"].shape[0] < L: + inputs = crop_feat(inputs, opt["crop_pos"]) + + if "batch" not in inputs: + inputs["batch"] = None ####################################################################### # OUTPUTS @@ -165,34 +167,45 @@ def _model(params, model_params, inputs, key, opt): # add aux outputs aux.update({"atom_positions":outputs["structure_module"]["final_atom_positions"], "atom_mask":outputs["structure_module"]["final_atom_mask"], - "residue_index":inputs["residue_index"][0], "aatype":inputs["aatype"][0], - "plddt":get_plddt(outputs),"pae":get_pae(outputs), "ptm":get_ptm(outputs), - "cmap":get_contact_map(outputs, opt["cmap_cutoff"])}) - - # experimental - # crop outputs (TODO) - if opt["crop_pos"].shape[0] < L: + "residue_index":inputs["residue_index"], + "aatype":inputs["aatype"], + "plddt": get_plddt(outputs), + "pae": get_pae(outputs), + "ptm": get_ptm(inputs, outputs), + "i_ptm": get_ptm(inputs, outputs, interface=True), + "cmap": get_contact_map(outputs, opt["cmap_cutoff"]), + "prev": outputs["prev"]}) + + # experimental - uncrop outputs + if a["use_crop"] and opt["crop_pos"].shape[0] < L: p = opt["crop_pos"] aux["cmap"] = jnp.zeros((L,L)).at[p[:,None],p[None,:]].set(aux["cmap"]) aux["pae"] = jnp.full((L,L),jnp.nan).at[p[:,None],p[None,:]].set(aux["pae"]) - - if self._args["recycle_mode"] == "average": aux["prev"] = outputs["prev"] ####################################################################### # LOSS ####################################################################### - aux["losses"] = {} + + # add protocol specific losses self._get_loss(inputs=inputs, outputs=outputs, opt=opt, aux=aux) + # add user defined losses inputs["seq"] = aux["seq"] if self._loss_callback is not None: loss_fns = self._loss_callback if isinstance(self._loss_callback,list) else [self._loss_callback] for loss_fn in loss_fns: aux["losses"].update(loss_fn(inputs, outputs, opt)) - if self._args["debug"]: + if a["debug"]: aux["debug"] = {"inputs":inputs, "outputs":outputs, "opt":opt} + + # sequence entropy loss + aux["losses"].update(get_seq_ent_loss(inputs, outputs, opt)) + + # experimental masked-language-modeling + if a["use_mlm"]: + aux["losses"].update(get_mlm_loss(outputs, mask=mlm, truth=seq["pssm"])) # weighted loss w = opt["weights"] diff --git a/colabdesign/af/prep.py b/colabdesign/af/prep.py index 4d81b469..1f9da429 100644 --- a/colabdesign/af/prep.py +++ b/colabdesign/af/prep.py @@ -9,6 +9,8 @@ from colabdesign.af.alphafold.data import pipeline, prep_inputs from colabdesign.af.alphafold.common import protein, residue_constants from colabdesign.af.alphafold.model.tf import shape_placeholders +from colabdesign.af.alphafold.model import config + from colabdesign.shared.protein import _np_get_cb, pdb_to_string from colabdesign.shared.prep import prep_pos @@ -35,182 +37,109 @@ def _prep_model(self, **kwargs): self._opt = copy_dict(self.opt) self.restart(**kwargs) - def _prep_features(self, length, num_seq=None, num_templates=1, template_features=None): + def _prep_features(self, num_res, num_seq=None, num_templates=1): '''process features''' if num_seq is None: num_seq = self._num - return prep_input_features(L=length, N=num_seq, T=num_templates, - use_templates=self._args["use_templates"]) - - # prep functions specific to protocol - def _prep_binder(self, pdb_filename, chain="A", - binder_len=50, binder_chain=None, - use_binder_template=False, split_templates=False, - hotspot=None, rm_template_seq=True, rm_template_sc=True, **kwargs): - ''' - prep inputs for binder design - --------------------------------------------------- - -binder_len = length of binder to hallucinate (option ignored if binder_chain is defined) - -binder_chain = chain of binder to redesign - -use_binder_template = use binder coordinates as template input - -split_templates = use target and binder coordinates as seperate template inputs - -hotspot = define position/hotspots on target - -rm_template_seq = for binder redesign protocol, remove sequence info from binder template - --------------------------------------------------- - ''' - - redesign = binder_chain is not None - - self.opt.update({"rm_template_seq":rm_template_seq,"rm_template_sc":rm_template_sc}) - self._args.update({"redesign":redesign}) - - self.opt["template"]["dropout"] = 0.0 if use_binder_template else 1.0 - num_templates = 1 - - # get pdb info - chains = f"{chain},{binder_chain}" if redesign else chain - pdb = prep_pdb(pdb_filename, chain=chains) - - if redesign: - target_len = sum([(pdb["idx"]["chain"] == c).sum() for c in chain.split(",")]) - binder_len = sum([(pdb["idx"]["chain"] == c).sum() for c in binder_chain.split(",")]) - if split_templates: num_templates = 2 - # get input features - self._inputs = self._prep_features(target_len + binder_len, num_templates=num_templates) - - else: - target_len = pdb["residue_index"].shape[0] - self._inputs = self._prep_features(target_len) - - self._inputs["residue_index"][...,:] = pdb["residue_index"] - - # gather hotspot info - hotspot = kwargs.pop("pos", hotspot) - if hotspot is not None: - self.opt["pos"] = prep_pos(hotspot, **pdb["idx"])["pos"] - - # add batch - self._inputs["batch"] = pdb["batch"] - - if redesign: - self._wt_aatype = self._inputs["batch"]["aatype"][target_len:] - self.opt["weights"].update({"dgram_cce":1.0, "fape":0.0, "rmsd":0.0, - "con":0.0, "i_pae":0.01, "i_con":0.0}) - else: # binder hallucination - # pad inputs - total_len = target_len + binder_len - self._inputs = make_fixed_size(self._inputs, self._cfg, total_len) - - # offset residue index for binder - self._inputs["residue_index"] = self._inputs["residue_index"].copy() - self._inputs["residue_index"][:,target_len:] = pdb["residue_index"][-1] + np.arange(binder_len) + 50 - for k in ["seq_mask","msa_mask"]: self._inputs[k] = np.ones_like(self._inputs[k]) - self.opt["weights"].update({"con":0.5, "i_pae":0.01, "i_con":0.5}) - - self._target_len = target_len - self._binder_len = self._len = binder_len - self._lengths = [self._target_len, self._binder_len] + return prep_input_features(L=num_res, N=num_seq, T=num_templates) - self._prep_model(**kwargs) - - def _prep_fixbb(self, pdb_filename, chain=None, copies=1, homooligomer=False, - repeat=False, block_diag=True, rm_template_seq=True, rm_template_sc=True, - pos=None, fix_seq=True, **kwargs): + def _prep_fixbb(self, pdb_filename, chain=None, + copies=1, repeat=False, homooligomer=False, + rm_template_seq=True, rm_template_sc=True, rm_template_ic=False, + fix_pos=None, **kwargs): ''' prep inputs for fixed backbone design --------------------------------------------------- if copies > 1: - -homooligomer=True - input pdb chains are parsed as homo-olgiomeric units - -block_diag=True - each copy is it's own sequence in the MSA - -repeat=True - tie the repeating sequence within single chain - -rm_template_seq - if template is defined, remove information about template sequence - if fix_seq: - -pos="1,2-10" - specify which positions to keep fixed in the sequence + -homooligomer=True - input pdb chains are parsed as homo-oligomeric units + -repeat=True - tie the repeating sequence within single chain + -rm_template_seq - if template is defined, remove information about template sequence + -fix_pos="1,2-10" - specify which positions to keep fixed in the sequence note: supervised loss is applied to all positions, use "partial" protocol to apply supervised loss to only subset of positions --------------------------------------------------- - ''' - self.opt.update({"rm_template_seq":rm_template_seq,"rm_template_sc":rm_template_sc}) - # block_diag the msa features - if block_diag and not repeat and copies > 1: - max_msa_clusters = 1 + self._num * copies - self._cfg.data.eval.max_msa_clusters = max_msa_clusters - else: - max_msa_clusters = self._num - block_diag = False + ''' + self.opt["template"].update({"rm_seq":rm_template_seq,"rm_sc":rm_template_sc, "rm_ic":rm_template_ic}) - pdb = prep_pdb(pdb_filename, chain=chain) - if chain is not None and homooligomer and copies == 1: - copies = len(chain.split(",")) + # prep features + pdb = prep_pdb(pdb_filename, chain=chain, + lengths=kwargs.pop("pdb_lengths",None), + offsets=kwargs.pop("pdb_offsets",None)) self._len = pdb["residue_index"].shape[0] - self._inputs = self._prep_features(self._len, num_seq=max_msa_clusters) - self._inputs["batch"] = pdb["batch"] - - self._args.update({"repeat":repeat, - "block_diag":block_diag, - "homooligomer":homooligomer, - "copies":copies}) + self._lengths = [self._len] - # set weights - self.opt["weights"].update({"dgram_cce":1.0, "rmsd":0.0, "con":0.0, "fape":0.0}) + # feat dims + num_seq = self._num + res_idx = pdb["residue_index"] + + # get [pos]itions of interests + if fix_pos is not None: + self._pos_info = prep_pos(fix_pos, **pdb["idx"]) + self.opt["fix_pos"] = self._pos_info["pos"] - # update residue index from pdb + # repeat/homo-oligomeric support + if chain is not None and copies == 1: copies = len(chain.split(",")) if copies > 1: - if repeat: + + if repeat or homooligomer: self._len = self._len // copies - block_diag = False + if "fix_pos" in self.opt: + self.opt["fix_pos"] = self.opt["fix_pos"][self.opt["fix_pos"] < self._len] + + if repeat: self._lengths = [self._len * copies] + block_diag = False + else: - if homooligomer: - self._len = self._len // copies - self._inputs["residue_index"] = repeat_idx(pdb["residue_index"][:self._len], copies)[None] - else: - self._inputs = make_fixed_size(self._inputs, self._cfg, self._len * copies) - self._inputs["residue_index"] = repeat_idx(pdb["residue_index"], copies)[None] - for k in ["seq_mask","msa_mask"]: self._inputs[k] = np.ones_like(self._inputs[k]) self._lengths = [self._len] * copies + block_diag = not self._args["use_multimer"] + + res_idx = repeat_idx(res_idx[:self._len], copies) + num_seq = (self._num * copies + 1) if block_diag else self._num + self.opt["weights"].update({"i_pae":0.0, "i_con":0.0}) + + self._args.update({"copies":copies, "repeat":repeat, "homooligomer":homooligomer, "block_diag":block_diag}) + homooligomer = not repeat else: - self._inputs["residue_index"] = pdb["residue_index"][None] - self._lengths = [self._len] + self._lengths = pdb["lengths"] - # fix certain positions - self.opt["fix_seq"] = fix_seq - if pos is not None: - self._pos_info = prep_pos(pos, **pdb["idx"]) - self.opt["pos"] = self._pos_info["pos"] + # configure input features + self._inputs = self._prep_features(num_res=sum(self._lengths), num_seq=num_seq) + self._inputs["residue_index"] = res_idx + self._inputs["batch"] = make_fixed_size(pdb["batch"], num_res=sum(self._lengths)) + self._inputs.update(get_multi_id(self._lengths, homooligomer=homooligomer)) + # configure options/weights + self.opt["weights"].update({"dgram_cce":1.0, "rmsd":0.0, "fape":0.0, "con":0.0}) self._wt_aatype = self._inputs["batch"]["aatype"][:self._len] + self._prep_model(**kwargs) - # undocumented: for dist cropping (for Shihao) - cb_atoms = pdb["cb_feat"]["atoms"] - cb_atoms[pdb["cb_feat"]["mask"] == 0,:] = np.nan - self._dist = np.sqrt(np.square(cb_atoms[:,None] - cb_atoms[None,:]).sum(-1)) + # undocumented: dist cropping (for Shihao) + if self._args["use_crop"]: + cb_atoms = pdb["cb_feat"]["atoms"] + cb_atoms[pdb["cb_feat"]["mask"] == 0,:] = np.nan + self._dist = np.sqrt(np.square(cb_atoms[:,None] - cb_atoms[None,:]).sum(-1)) - def _prep_hallucination(self, length=100, copies=1, - repeat=False, block_diag=True, **kwargs): + def _prep_hallucination(self, length=100, copies=1, repeat=False, **kwargs): ''' prep inputs for hallucination --------------------------------------------------- if copies > 1: - -homooligomer=True - input pdb chains are parsed as homo-olgiomeric units - -block_diag=True - each copy is it's own sequence in the MSA -repeat=True - tie the repeating sequence within single chain --------------------------------------------------- ''' - # set [arg]uments - if block_diag and not repeat and copies > 1: - max_msa_clusters = 1 + self._num * copies - self._cfg.data.eval.max_msa_clusters = max_msa_clusters + # define num copies (for repeats/ homo-oligomers) + if not repeat and copies > 1 and not self._args["use_multimer"]: + (num_seq, block_diag) = (self._num * copies + 1, True) else: - max_msa_clusters = self._num - block_diag = False - self._args.update({"block_diag":block_diag, "repeat":repeat, "copies":copies}) + (num_seq, block_diag) = (self._num, False) + + self._args.update({"repeat":repeat,"block_diag":block_diag,"copies":copies}) # prep features self._len = length - self._inputs = self._prep_features(length * copies, num_seq=max_msa_clusters) # set weights self.opt["weights"].update({"con":1.0}) @@ -220,65 +149,179 @@ def _prep_hallucination(self, length=100, copies=1, self._lengths = [self._len * copies] else: offset = 50 - self.opt["weights"].update({"i_pae":0.01, "i_con":0.1}) self._lengths = [self._len] * copies - self._inputs["residue_index"] = repeat_idx(np.arange(length), copies, offset=offset)[None] + self.opt["weights"].update({"i_pae":0.0, "i_con":1.0}) + res_idx = repeat_idx(np.arange(length), copies, offset=offset) else: self._lengths = [self._len] + res_idx = np.arange(length) + + # configure input features + self._inputs = self._prep_features(num_res=sum(self._lengths), num_seq=num_seq) + self._inputs["residue_index"] = res_idx + self._inputs.update(get_multi_id(self._lengths, homooligomer=True)) + + self._prep_model(**kwargs) + + def _prep_binder(self, pdb_filename, chain="A", + binder_len=50, binder_chain=None, + use_binder_template=False, + rm_template_ic=False,rm_template_seq=True, rm_template_sc=True, + hotspot=None, **kwargs): + ''' + prep inputs for binder design + --------------------------------------------------- + -binder_len = length of binder to hallucinate (option ignored if binder_chain is defined) + -binder_chain = chain of binder to redesign + -use_binder_template = use binder coordinates as template input + -rm_template_ic = use target and binder coordinates as seperate template inputs + -hotspot = define position/hotspots on target + -rm_template_seq = for binder redesign protocol, remove sequence info from binder template + --------------------------------------------------- + ''' + + redesign = binder_chain is not None + + self.opt["template"].update({"rm_seq":rm_template_seq, "rm_sc":rm_template_sc, "rm_ic":rm_template_ic, + "dropout":(0.0 if use_binder_template else 1.0)}) + self._args.update({"redesign":redesign}) + + # get pdb info + chains = f"{chain},{binder_chain}" if redesign else chain + pdb = prep_pdb(pdb_filename, chain=chains, ignore_missing=True) + res_idx = pdb["residue_index"] + + if redesign: + self._target_len = sum([(pdb["idx"]["chain"] == c).sum() for c in chain.split(",")]) + self._binder_len = self._len = sum([(pdb["idx"]["chain"] == c).sum() for c in binder_chain.split(",")]) + else: + self._target_len = pdb["residue_index"].shape[0] + self._binder_len = self._len = binder_len + res_idx = np.append(res_idx, res_idx[-1] + np.arange(binder_len) + 50) + self._lengths = [self._target_len, self._binder_len] + + # gather hotspot info + if hotspot is not None: + self.opt["hotspot"] = prep_pos(hotspot, **pdb["idx"])["pos"] + + if redesign: + # binder redesign + self._wt_aatype = pdb["batch"]["aatype"][target_len:] + self.opt["weights"].update({"dgram_cce":1.0, "rmsd":0.0, "fape":0.0, + "con":0.0, "i_con":0.0, "i_pae":0.0}) + else: + # binder hallucination + pdb["batch"] = make_fixed_size(pdb["batch"], num_res=sum(self._lengths)) + self.opt["weights"].update({"plddt":0.1, "con":0.0, "i_con":1.0, "i_pae":0.0}) + + # configure input features + self._inputs = self._prep_features(num_res=sum(self._lengths), num_seq=1) + self._inputs["residue_index"] = res_idx + self._inputs["batch"] = pdb["batch"] + self._inputs.update(get_multi_id(self._lengths)) + self._prep_model(**kwargs) def _prep_partial(self, pdb_filename, chain=None, length=None, - pos=None, fix_seq=True, use_sidechains=False, atoms_to_exclude=None, - rm_template_seq=False, rm_template_sc=False, **kwargs): + copies=1, repeat=False, homooligomer=False, + pos=None, fix_pos=None, use_sidechains=False, atoms_to_exclude=None, + rm_template_seq=False, rm_template_sc=False, rm_template_ic=False, **kwargs): ''' prep input for partial hallucination --------------------------------------------------- -length=100 - total length of protein (if different from input PDB) -pos="1,2-10" - specify which positions to apply supervised loss to - -fix_seq=True - keep sequence fixed in the specified positions -use_sidechains=True - add a sidechain supervised loss to the specified positions -atoms_to_exclude=["N","C","O"] (for sc_rmsd loss, specify which atoms to exclude) -rm_template_seq - if template is defined, remove information about template sequence --------------------------------------------------- ''' - self.opt.update({"rm_template_seq":rm_template_seq,"rm_template_sc":rm_template_sc}) + self.opt["template"].update({"rm_seq":rm_template_seq,"rm_sc":rm_template_sc,"rm_ic":rm_template_ic}) # prep features - pdb = prep_pdb(pdb_filename, chain=chain) - - self._len = pdb["residue_index"].shape[0] if length is None else length + pdb = prep_pdb(pdb_filename, chain=chain, + lengths=kwargs.pop("pdb_lengths",None), + offsets=kwargs.pop("pdb_offsets",None)) + + pdb["len"] = sum(pdb["lengths"]) + + self._len = pdb["len"] if length is None else length self._lengths = [self._len] - self._inputs = self._prep_features(self._len) - self._inputs["batch"] = pdb["batch"] - # undocumented: experimental repeat support - if kwargs.pop("repeat",False): - copies = kwargs.pop("copies",1) - if copies > 1: + # feat dims + num_seq = self._num + res_idx = np.arange(self._len) + + # get [pos]itions of interests + if pos is None: + self.opt["pos"] = pdb["pos"] = np.arange(pdb["len"]) + self._pos_info = {"length":np.array([pdb["len"]]), "pos":pdb["pos"]} + else: + self._pos_info = prep_pos(pos, **pdb["idx"]) + self.opt["pos"] = pdb["pos"] = self._pos_info["pos"] + + # repeat/homo-oligomeric support + if chain is not None and copies == 1: copies = len(chain.split(",")) + if copies > 1: + + if repeat or homooligomer: self._len = self._len // copies + pdb["len"] = pdb["len"] // copies + self.opt["pos"] = pdb["pos"][pdb["pos"] < pdb["len"]] + + # repeat positions across copies + pdb["pos"] = repeat_pos(self.opt["pos"], copies, pdb["len"]) + + if repeat: self._lengths = [self._len * copies] - self._args.update({"copies":copies, "repeat":True, "block_diag":False}) + block_diag = False - # configure options/weights - self.opt["pos"] = np.arange(pdb["residue_index"].shape[0]) - self.opt["weights"].update({"dgram_cce":1.0,"con":1.0, "fape":0.0, "rmsd":0.0}) - self.opt["fix_seq"] = fix_seq + else: + self._lengths = [self._len] * copies + block_diag = not self._args["use_multimer"] - # get [pos]itions of interests - if pos is not None: - self._pos_info = prep_pos(pos, **pdb["idx"]) - self.opt["pos"] = self._pos_info["pos"] - self._inputs["batch"] = jax.tree_map(lambda x:x[self.opt["pos"]], pdb["batch"]) - self._wt_aatype = self._inputs["batch"]["aatype"] + num_seq = (self._num * copies + 1) if block_diag else self._num + res_idx = repeat_idx(np.arange(self._len), copies) + + self.opt["weights"].update({"i_pae":0.0, "i_con":1.0}) + + self._args.update({"copies":copies, "repeat":repeat, "homooligomer":homooligomer, "block_diag":block_diag}) + homooligomer = not repeat + + # configure input features + self._inputs = self._prep_features(num_res=sum(self._lengths), num_seq=num_seq) + self._inputs["residue_index"] = res_idx + self._inputs["batch"] = jax.tree_map(lambda x:x[pdb["pos"]], pdb["batch"]) + self._inputs.update(get_multi_id(self._lengths, homooligomer=homooligomer)) + + # configure options/weights + self.opt["weights"].update({"dgram_cce":1.0, "rmsd":0.0, "fape":0.0, "con":1.0}) + self._wt_aatype = pdb["batch"]["aatype"][self.opt["pos"]] # configure sidechains - self._args["use_sidechains"] = kwargs.pop("sidechain", use_sidechains) - if self._args["use_sidechains"]: + self._args["use_sidechains"] = use_sidechains + if use_sidechains: self._sc = {"batch":prep_inputs.make_atom14_positions(self._inputs["batch"]), "pos":get_sc_pos(self._wt_aatype, atoms_to_exclude)} self.opt["weights"].update({"sc_rmsd":0.1, "sc_fape":0.1}) - self.opt["fix_seq"] = True + self.opt["fix_pos"] = np.arange(self.opt["pos"].shape[0]) + self._wt_aatype_sub = self._wt_aatype + + elif fix_pos is not None: + sub_fix_pos = [] + sub_i = [] + pos = self.opt["pos"].tolist() + for i in prep_pos(fix_pos, **pdb["idx"])["pos"]: + if i in pos: + sub_i.append(i) + sub_fix_pos.append(pos.index(i)) + self.opt["fix_pos"] = np.array(sub_fix_pos) + self._wt_aatype_sub = pdb["batch"]["aatype"][sub_i] + + elif kwargs.pop("fix_seq",False): + self.opt["fix_pos"] = np.arange(self.opt["pos"].shape[0]) + self._wt_aatype_sub = self._wt_aatype self._prep_model(**kwargs) @@ -289,12 +332,17 @@ def repeat_idx(idx, copies=1, offset=50): idx_offset = np.repeat(np.cumsum([0]+[idx[-1]+offset]*(copies-1)),len(idx)) return np.tile(idx,copies) + idx_offset -def prep_pdb(pdb_filename, chain=None, for_alphafold=True): +def repeat_pos(pos, copies, length): + return (np.repeat(pos,copies).reshape(-1,copies) + np.arange(copies) * length).T.flatten() + +def prep_pdb(pdb_filename, chain=None, + offsets=None, lengths=None, + ignore_missing=False): '''extract features from pdb''' def add_cb(batch): '''add missing CB atoms based on N,CA,C''' - p,m = batch["all_atom_positions"],batch["all_atom_mask"] + p,m = batch["all_atom_positions"], batch["all_atom_mask"] atom_idx = residue_constants.atom_order atoms = {k:p[...,atom_idx[k],:] for k in ["N","CA","C"]} cb = atom_idx["CB"] @@ -308,64 +356,70 @@ def add_cb(batch): chains = [None] if chain is None else chain.split(",") o,last = [],0 residue_idx, chain_idx = [],[] - for chain in chains: + full_lengths = [] + + for n,chain in enumerate(chains): protein_obj = protein.from_pdb_string(pdb_to_string(pdb_filename), chain_id=chain) batch = {'aatype': protein_obj.aatype, 'all_atom_positions': protein_obj.atom_positions, - 'all_atom_mask': protein_obj.atom_mask} + 'all_atom_mask': protein_obj.atom_mask, + 'residue_index': protein_obj.residue_index} cb_feat = add_cb(batch) # add in missing cb (in the case of glycine) + + if ignore_missing: + r = batch["all_atom_mask"][:,0] == 1 + batch = jax.tree_map(lambda x:x[r], batch) + residue_index = batch["residue_index"] + last - has_ca = batch["all_atom_mask"][:,0] == 1 - batch = jax.tree_map(lambda x:x[has_ca], batch) - seq = "".join([order_aa[a] for a in batch["aatype"]]) - residue_index = protein_obj.residue_index[has_ca] + last - last = residue_index[-1] + 50 + else: + # pad values + offset = 0 if offsets is None else (offsets[n] if isinstance(offsets,list) else offsets) + r = offset + (protein_obj.residue_index - protein_obj.residue_index.min()) + length = (r.max()+1) if lengths is None else (lengths[n] if isinstance(lengths,list) else lengths) + def scatter(x, value=0): + shape = (length,) + x.shape[1:] + y = np.full(shape, value, dtype=x.dtype) + y[r] = x + return y + + batch = {"aatype":scatter(batch["aatype"],-1), + "all_atom_positions":scatter(batch["all_atom_positions"]), + "all_atom_mask":scatter(batch["all_atom_mask"]), + "residue_index":scatter(batch["residue_index"],-1)} + + residue_index = np.arange(length) + last - if for_alphafold: - template_aatype = residue_constants.sequence_to_onehot(seq, residue_constants.HHBLITS_AA_TO_ID) - template_features = {"template_aatype":template_aatype, - "template_all_atom_masks":batch["all_atom_mask"], - "template_all_atom_positions":batch["all_atom_positions"]} - o.append({"batch":batch, - "template_features":template_features, - "residue_index": residue_index, - "cb_feat":cb_feat}) - else: - o.append({"batch":batch, - "residue_index": residue_index, - "cb_feat":cb_feat}) + last = residue_index[-1] + 50 + o.append({"batch":batch, + "residue_index": residue_index, + "cb_feat":cb_feat}) - residue_idx.append(protein_obj.residue_index[has_ca]) + residue_idx.append(batch.pop("residue_index")) chain_idx.append([chain] * len(residue_idx[-1])) + full_lengths.append(len(residue_index)) # concatenate chains o = jax.tree_util.tree_map(lambda *x:np.concatenate(x,0),*o) - if for_alphafold: - o["template_features"] = jax.tree_map(lambda x:x[None],o["template_features"]) - o["template_features"]["template_domain_names"] = np.asarray(["None"]) - # save original residue and chain index o["idx"] = {"residue":np.concatenate(residue_idx), "chain":np.concatenate(chain_idx)} + o["lengths"] = full_lengths return o -def make_fixed_size(feat, cfg, length, batch_axis=True): +def make_fixed_size(feat, num_res, num_seq=1, num_templates=1): '''pad input features''' - if batch_axis: - shape_schema = {k:[None]+v for k,v in dict(cfg.data.eval.feat).items()} - else: - shape_schema = {k:v for k,v in dict(cfg.data.eval.feat).items()} - num_msa_seq = cfg.data.eval.max_msa_clusters - cfg.data.eval.max_templates + shape_schema = {k:v for k,v in config.CONFIG.data.eval.feat.items()} + pad_size_map = { - shape_placeholders.NUM_RES: length, - shape_placeholders.NUM_MSA_SEQ: num_msa_seq, - shape_placeholders.NUM_EXTRA_SEQ: cfg.data.common.max_extra_msa, - shape_placeholders.NUM_TEMPLATES: cfg.data.eval.max_templates - } + shape_placeholders.NUM_RES: num_res, + shape_placeholders.NUM_MSA_SEQ: num_seq, + shape_placeholders.NUM_EXTRA_SEQ: 1, + shape_placeholders.NUM_TEMPLATES: num_templates + } for k,v in feat.items(): if k == "batch": - feat[k] = make_fixed_size(v, cfg, length, batch_axis=False) + feat[k] = make_fixed_size(v, num_res) else: shape = list(v.shape) schema = shape_schema[k] @@ -415,14 +469,19 @@ def get_sc_pos(aa_ident, atoms_to_exclude=None): return {"pos":pos, "pos_alt":pos_alt, "non_amb":non_amb, "weight":w, "weight_non_amb":w_na[:,None]} -def prep_input_features(L, N=1, T=1, use_templates=False, eN=1): +def prep_input_features(L, N=1, T=1, eN=1): ''' given [L]ength, [N]umber of sequences and number of [T]emplates return dictionary of blank features ''' inputs = {'aatype': np.zeros(L,int), - 'target_feat': np.zeros((L,22)), 'msa_feat': np.zeros((N,L,49)), + # 23 = one_hot -> (20, UNK, GAP, MASK) + # 1 = has deletion + # 1 = deletion_value + # 23 = profile + # 1 = deletion_mean_value + 'seq_mask': np.ones(L), 'msa_mask': np.ones((N,L)), 'msa_row_mask': np.ones(N), @@ -435,14 +494,27 @@ def prep_input_features(L, N=1, T=1, use_templates=False, eN=1): 'extra_has_deletion': np.zeros((eN,L)), 'extra_msa': np.zeros((eN,L),int), 'extra_msa_mask': np.zeros((eN,L)), - 'extra_msa_row_mask': np.zeros(eN)} - - if use_templates: - inputs.update({'template_aatype': np.zeros((T,L),int), - 'template_all_atom_masks': np.zeros((T,L,37)), - 'template_all_atom_positions': np.zeros((T,L,37,3)), - 'template_mask': np.ones(T), - 'template_pseudo_beta': np.zeros((T,L,3)), - 'template_pseudo_beta_mask': np.zeros((T,L))}) - - return jax.tree_map(lambda x:x[None], inputs) \ No newline at end of file + 'extra_msa_row_mask': np.zeros(eN), + + # for template inputs + 'template_aatype': np.zeros((T,L),int), + 'template_all_atom_mask': np.zeros((T,L,37)), + 'template_all_atom_positions': np.zeros((T,L,37,3)), + 'template_mask': np.zeros(T), + 'template_pseudo_beta': np.zeros((T,L,3)), + 'template_pseudo_beta_mask': np.zeros((T,L)), + + # for alphafold-multimer + 'asym_id': np.zeros(L), + 'sym_id': np.zeros(L), + 'entity_id': np.zeros(L), + 'all_atom_positions': np.zeros((N,37,3))} + return inputs + +def get_multi_id(lengths, homooligomer=False): + '''set info for alphafold-multimer''' + i = np.concatenate([[n]*l for n,l in enumerate(lengths)]) + if homooligomer: + return {"asym_id":i, "sym_id":i, "entity_id":np.zeros_like(i)} + else: + return {"asym_id":i, "sym_id":i, "entity_id":i} \ No newline at end of file diff --git a/colabdesign/af/utils.py b/colabdesign/af/utils.py index 6d7b6324..657c7d4c 100644 --- a/colabdesign/af/utils.py +++ b/colabdesign/af/utils.py @@ -39,15 +39,13 @@ def set_args(self, **kwargs): ''' set [arg]uments ''' - for k in ["best_metric","crop","crop_mode","crop_len", - "use_openfold","use_alphafold","models"]: + for k in ["best_metric","use_crop","crop_mode","crop_len","models"]: if k in kwargs: self._args[k] = kwargs.pop(k) - if k == "crop" and not self._args[k]: - self._args["crop_len"] = None if "recycle_mode" in kwargs: - if kwargs["recycle_mode"] in ["sample","last"] and self._args["recycle_mode"] in ["sample","last"]: + ok_recycle_mode_swap = ["average","sample","first","last"] + if kwargs["recycle_mode"] in ok_recycle_mode_swap and self._args["recycle_mode"] in ok_recycle_mode_swap: self._args["recycle_mode"] = kwargs.pop("recycle_mode") else: print(f"ERROR: use {self.__class__.__name__}(recycle_mode=...) to set the recycle_mode") @@ -56,7 +54,6 @@ def set_args(self, **kwargs): if len(ks) > 0: print(f"ERROR: the following args were not set: {ks}") - def get_loss(self, x="loss"): '''output the loss (for entire trajectory)''' return np.array([float(loss[x]) for loss in self._traj["log"]]) @@ -109,11 +106,10 @@ def animate(self, s=0, e=None, dpi=100, get_best=True): if self.protocol == "hallucination": return make_animation(**sub_traj, pos_ref=pos_ref, length=self._lengths, dpi=dpi) else: - return make_animation(**sub_traj, pos_ref=pos_ref, length=self._lengths, align_xyz=False, dpi=dpi) + return make_animation(**sub_traj, pos_ref=pos_ref, length=self._lengths, align_xyz=False, dpi=dpi) def plot_pdb(self, show_sidechains=False, show_mainchains=False, - color="pLDDT", color_HP=False, size=(800,480), - animate=False, get_best=True): + color="pLDDT", color_HP=False, size=(800,480), animate=False, get_best=True): ''' use py3Dmol to plot pdb coordinates - color=["pLDDT","chain","rainbow"] @@ -165,4 +161,19 @@ def plot_traj(self, dpi=100): ax2.legend(loc='center left') else: print("TODO") - plt.show() \ No newline at end of file + plt.show() + + def clear_best(self): + self._best = {} + + def save_current_pdb(self, filename=None): + '''save pdb coordinates (if filename provided, otherwise return as string)''' + self.save_pdb(filename=filename, get_best=False) + + def plot_current_pdb(self, show_sidechains=False, show_mainchains=False, + color="pLDDT", color_HP=False, size=(800,480), animate=False): + '''use py3Dmol to plot pdb coordinates + - color=["pLDDT","chain","rainbow"] + ''' + self.plot_pdb(show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color, + color_HP=color_HP, size=size, animate=animate, get_best=False) \ No newline at end of file diff --git a/colabdesign/shared/model.py b/colabdesign/shared/model.py index 49b95c18..3358f021 100644 --- a/colabdesign/shared/model.py +++ b/colabdesign/shared/model.py @@ -139,14 +139,16 @@ def rewire(self, order=None, offset=0, loops=0): ''' self.opt["pos"] = rewire(length=self._pos_info["length"], order=order, offset=offset, loops=loops) - if hasattr(self,"_opt"): - self._opt["pos"] = self.opt["pos"] + + # make default + if hasattr(self,"_opt"): self._opt["pos"] = self.opt["pos"] def soft_seq(x, opt, key=None): seq = {"input":x} # shuffle msa (randomly pick which sequence is query) if x.ndim == 3 and x.shape[0] > 1 and key is not None: - n = jax.random.randint(key,[],0,x.shape[0]) + key, sub_key = jax.random.split(key) + n = jax.random.randint(sub_key,[],0,x.shape[0]) seq["input"] = seq["input"].at[0].set(seq["input"][n]).at[n].set(seq["input"][0]) # straight-through/reparameterization @@ -158,5 +160,9 @@ def soft_seq(x, opt, key=None): # create pseudo sequence seq["pseudo"] = opt["soft"] * seq["soft"] + (1-opt["soft"]) * seq["input"] - seq["pseudo"] = opt["hard"] * seq["hard"] + (1-opt["hard"]) * seq["pseudo"] + + # key, sub_key = jax.random.split(key) + # hard_mask = jax.random.bernoulli(sub_key, opt["hard"], seq["hard"].shape[:-1] + (1,)) + hard_mask = opt["hard"] + seq["pseudo"] = hard_mask * seq["hard"] + (1-hard_mask) * seq["pseudo"] return seq \ No newline at end of file diff --git a/colabdesign/shared/utils.py b/colabdesign/shared/utils.py index 17eefd2e..b3ebca17 100644 --- a/colabdesign/shared/utils.py +++ b/colabdesign/shared/utils.py @@ -10,27 +10,28 @@ def clear_mem(): def update_dict(D, *args, **kwargs): '''robust function for updating dictionary''' - def set_dict(d, x): + def set_dict(d, x, override=False): for k,v in x.items(): if v is not None: if k in d: if isinstance(v, dict): - set_dict(d[k], x[k]) + set_dict(d[k], x[k], override=override) + elif override or d[k] is None: + d[k] = v elif isinstance(d[k],(np.ndarray,jnp.ndarray)): d[k] = np.asarray(v) - elif d[k] is None: - d[k] = v elif isinstance(d[k], dict): d[k] = jax.tree_map(lambda x: type(x)(v), d[k]) else: d[k] = type(d[k])(v) else: print(f"ERROR: '{k}' not found in {list(d.keys())}") + override = kwargs.pop("override", False) while len(args) > 0 and isinstance(args[0],str): D,args = D[args[0]],args[1:] for a in args: - if isinstance(a, dict): set_dict(D, a) - set_dict(D, kwargs) + if isinstance(a, dict): set_dict(D, a, override=override) + set_dict(D, kwargs, override=override) def copy_dict(x): '''deepcopy dictionary''' diff --git a/colabdesign/tr/joint_model.py b/colabdesign/tr/joint_model.py index b8cfc801..592feafa 100644 --- a/colabdesign/tr/joint_model.py +++ b/colabdesign/tr/joint_model.py @@ -22,17 +22,17 @@ def _prep_inputs(pdb_filename, chain, binder_len=50, binder_chain=None, **kwargs self.tr = mk_tr_model(protocol=protocol) if protocol == "fixbb": - def _prep_inputs(pdb_filename, chain, pos=None, fix_seq=False, **kwargs): - flags = dict(pdb_filename=pdb_filename, chain=chain, pos=pos, fix_seq=fix_seq) + def _prep_inputs(pdb_filename, chain, fix_pos=None, **kwargs): + flags = dict(pdb_filename=pdb_filename, chain=chain, fix_pos=fix_pos) self.af.prep_inputs(**flags, **kwargs) self.tr.prep_inputs(**flags, chain=chain) if protocol == "partial": def _prep_inputs(pdb_filename, chain, pos=None, length=None, - fix_seq=True, use_sidechains=False, atoms_to_exclude=None, **kwargs): + fix_pos=None, use_sidechains=False, atoms_to_exclude=None, **kwargs): if use_sidechains: fix_seq = True flags = dict(pdb_filename=pdb_filename, chain=chain, - length=length, pos=pos, fix_seq=fix_seq) + length=length, pos=pos, fix_pos=fix_pos) af_a2e = kwargs.pop("af_atoms_to_exclude",atoms_to_exclude) tr_a2e = kwargs.pop("tr_atoms_to_exclude",atoms_to_exclude) self.af.prep_inputs(**flags, use_sidechains=use_sidechains, atoms_to_exclude=af_a2e, **kwargs) diff --git a/colabdesign/tr/model.py b/colabdesign/tr/model.py index 5acd8087..81fd5cd2 100644 --- a/colabdesign/tr/model.py +++ b/colabdesign/tr/model.py @@ -63,7 +63,7 @@ def _get_loss(inputs, outputs, opt): # cce loss if self.protocol in ["fixbb","partial"]: - if self.protocol in ["partial"] and "pos" in opt: + if "pos" in opt: pos = opt["pos"] log_p = jax.tree_map(lambda x:x[:,pos][pos,:], log_p) @@ -84,14 +84,17 @@ def _get_loss(inputs, outputs, opt): def _model(params, model_params, opt, inputs, key): seq = soft_seq(params["seq"], opt) - if "pos" in opt: - seq_ref = jax.nn.one_hot(inputs["batch"]["aatype"],20) - p = opt["pos"] - if self.protocol == "partial": - fix_seq = lambda x:jnp.where(opt["fix_seq"],x.at[...,p,:].set(seq_ref),x) + if "fix_pos" in opt: + if "pos" in self.opt: + seq_ref = jax.nn.one_hot(inputs["batch"]["aatype_sub"],20) + p = opt["pos"][opt["fix_pos"]] + fix_seq = lambda x:x.at[...,p,:].set(seq_ref) else: - fix_seq = lambda x:jnp.where(opt["fix_seq"],x.at[...,p,:].set(seq_ref[...,p,:]),x) - seq = jax.tree_map(fix_seq,seq) + seq_ref = jax.nn.one_hot(inputs["batch"]["aatype"],20) + p = opt["fix_pos"] + fix_seq = lambda x:x.at[...,p,:].set(seq_ref[...,p,:]) + seq = jax.tree_map(fix_seq, seq) + inputs.update({"seq":seq["pseudo"][0], "prf":jnp.where(opt["use_pssm"],seq["pssm"],seq["pseudo"])[0]}) rate = jnp.where(opt["dropout"],0.15,0.0) @@ -104,23 +107,33 @@ def _model(params, model_params, opt, inputs, key): "fn":jax.jit(_model)} def prep_inputs(self, pdb_filename=None, chain=None, length=None, - pos=None, fix_seq=True, atoms_to_exclude=None, **kwargs): + pos=None, fix_pos=None, atoms_to_exclude=None, **kwargs): ''' prep inputs for TrDesign ''' if self.protocol in ["fixbb", "partial"]: # parse PDB file and return features compatible with TrRosetta - pdb = prep_pdb(pdb_filename, chain, for_alphafold=False) + pdb = prep_pdb(pdb_filename, chain) self._inputs["batch"] = pdb["batch"] - if pos is not None: + if fix_pos is not None: + self.opt["fix_pos"] = prep_pos(fix_pos, **pdb["idx"])["pos"] + + if self.protocol == "partial" and pos is not None: self._pos_info = prep_pos(pos, **pdb["idx"]) p = self._pos_info["pos"] - if self.protocol == "partial": - self._inputs["batch"] = jax.tree_map(lambda x:x[p], self._inputs["batch"]) - + aatype = self._inputs["batch"]["aatype"] + self._inputs["batch"] = jax.tree_map(lambda x:x[p], self._inputs["batch"]) self.opt["pos"] = p - self.opt["fix_seq"] = fix_seq + if "fix_pos" in self.opt: + sub_i,sub_p = [],[] + p = p.tolist() + for i in self.opt["fix_pos"].tolist(): + if i in p: + sub_i.append(i) + sub_p.append(p.index(i)) + self.opt["fix_pos"] = np.array(sub_p) + self._inputs["batch"]["aatype_sub"] = aatype[sub_i] self._inputs["6D"] = _np_get_6D_binned(self._inputs["batch"]["all_atom_positions"], self._inputs["batch"]["all_atom_mask"]) @@ -236,9 +249,9 @@ def step(self, models=None, backprop=True, g = self.aux["grad"]["seq"] gn = jnp.linalg.norm(g,axis=(-1,-2),keepdims=True) - if "pos" in self.opt and self.opt.get("fix_seq",False): + if "fix_pos" in self.opt: # note: gradients only exist in unconstrained positions - eff_len = self._len - self.opt["pos"].shape[0] + eff_len = self._len - self.opt["fix_pos"].shape[0] else: eff_len = self._len self.aux["grad"]["seq"] *= jnp.sqrt(eff_len)/(gn+1e-7) diff --git a/setup.py b/setup.py index c95d7206..a0863749 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup, find_packages setup( name='colabdesign', - version='1.0.5', + version='1.0.6', packages=find_packages(include=['colabdesign*']), install_requires=['py3Dmol','absl-py','biopython', 'chex','dm-haiku','dm-tree', diff --git a/tr/design.ipynb b/tr/design.ipynb index 21e08ddf..344f15ba 100644 --- a/tr/design.ipynb +++ b/tr/design.ipynb @@ -214,7 +214,7 @@ "if [ ! -d params/af ]; then\n", " # download alphafold weights\n", " mkdir -p params/af/params\n", - " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params/af/params\n", + " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params/af/params\n", "fi" ], "metadata": {