## Basic Setup

Run the cells below for the basic setup of this notebook.

In [1]:
try:
    from google.colab import drive # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False
    print('No colab environment, assuming local setup.')

if IN_COLAB:
    drive.mount('/content/drive')

    # TODO: Enter the foldername in your Drive where you have saved the unzipped
    # turorials folder, e.g. 'alphafold-decoded/tutorials'
    FOLDERNAME = None
    assert FOLDERNAME is not None, "[!] Enter the foldername."

    # Now that we've mounted your Drive, this ensures that
    # the Python interpreter of the Colab VM can load
    # python files from within it.
    import sys
    sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))
    %cd /content/drive/My\ Drive/$FOLDERNAME

    print('Connected COLAB to Google Drive.')

import os

base_folder = 'evoformer'
control_folder = f'{base_folder}/control_values'

# VSCode specific
os.chdir(os.path.join(os.getcwd(), '..'))

assert os.path.isdir(control_folder), 'Folder "control_values" not found, make sure that FOLDERNAME is set correctly.' if IN_COLAB else 'Folder "control_values" not found, make sure that your root folder is set correctly.'

No colab environment, assuming local setup.


In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import math
import torch
import os

# The Evoformer

The Evoformer is one of the two major building blocks of AlphaFold. It consists of 48 identical blocks (each with their own, different trained parameters) and uses a Transformer-like architecture to create semantically meaningful features to be used by the Structure Module.

The Evoformer consists of quite a number of steps, but each of them is pretty straightforward to implement, given the flexible MultiHeadAttention module we already implemented.

We will actually implement the Evoformer before going into input embedding, which creates the initial MSA representation m and pair representation z for the Evoformer, as the ExtraMSAStack involved in input embedding is way easier to implement after understanding the Evoformer. Starting here, we will often refer to [AlphaFold's supplement](https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf) for the implementation, as it's good practice to translate the extensive pseudocode instructions into real code yourself. Still, we'll clarify as needed, and you can always refer to the solutions if you're stuck.

With that out of the way, let's get started!

## The MSA Stack
Take a look at Algorithm 6. Lines 2 to 10 correspond to one block of the Evoformer. It is structured in two parts, the MSA stack and the pair stack. The MSA stack works mostly on the MSA representation, while the pair stack works mostly on the pair representation. Communication between the two stacks happens in line 2 (z is used as bias for the RowAttention) and in line 5 (m is added to z via OuterProductMean). In this section, we will implement the MSA stack, i.e. line 2 to 5.

We will start with MSARowAttentionWithPairBias. It is described in Algorithm 7. Note that, as always when working with attention in this context, the algorithm explicitly describes the attention mechanism, while we don't need to worry to much about it, given that we already implemented it. 

Specifically, lines 2, 4, 5, 6 and 7 are already contained in our MultiHeadAttention module: All we need to check is along which dimension the attention is computed, if it is gated and if it uses a bias.

Regarding the attention dimension, the input feature m for the attention mechanism has shape (*, N_seq, N_res, c_m), where N_seq are the different rows and N_res are the different columns. The attention is computed row-wise, which means that the index that is actually iterated over in the attention mechanism is the column index. All other dimensions are just being broadcasted. As the * dimensions are unknown, we will use negative indexing to specify the attention dimension.

With all of that in mind, head over to `msa_stack.py` and implement `MSARowAttentionWithPairBias`. After you are done, check your code by running the following cell.

In [4]:
from evoformer.msa_stack import MSARowAttentionWithPairBias
from evoformer.control_values.evoformer_checks import c_m, c_z, c, N_head
from evoformer.control_values.evoformer_checks import test_module_shape, test_module

msa_row_att = MSARowAttentionWithPairBias(c_m, c_z, c, N_head)

test_module_shape(msa_row_att, 'msa_row_att', control_folder)

test_module(msa_row_att, 'msa_row_att', ('m', 'z'), 'out', control_folder)


  expected_shapes = torch.load(shapes_path)
  expected_out = torch.load(out_file_name)
  expected_out = torch.load(out_file_name)


Next up is MSAColumnAttention, which is just like the row attention we just implemented with a different attention dimension and without bias. It is described in Algorithm 8. Read through it carefully and try to identify where the difference between it and Algorithm 7 lies. Can you identify from the pseudocode that Algorithm 7 uses row-wise attention, while Algorithm 8 uses column-wise attention?

Implement `MSAColumnAttention` and check your implementation with the following cell.

In [5]:
from evoformer.msa_stack import MSAColumnAttention
from evoformer.control_values.evoformer_checks import c_m, c, N_head
from evoformer.control_values.evoformer_checks import test_module_shape, test_module

msa_col_att = MSAColumnAttention(c_m, c, N_head)

test_module_shape(msa_col_att, 'msa_col_att', control_folder)

test_module(msa_col_att, 'msa_col_att', 'm', 'out', control_folder)

The MSATransition is a two-layer feed forward neural network, just as we implemented in the intro to ML for handwritten digit recognition. 

There is a major difference, however: When doing handwritten digit recognition, we flattened the whole image and fed it into the feed-forward network. This way, the network could create information by comparing the different values at different positions to each other. Here, the different positions are being processed fully separately from each other. For the shape (*, N_seq, N_res, c_m) of m, the dimensions N_seq and N_res are just broadcasted, and only the embedding c_m is changed. This is important to understand: The only cross-talk between different residues so far happened through the attention mechanism.

Implement the `MSATransition` as described in Algorithm 9 and check your implementation by running the following cell.

In [6]:
from evoformer.msa_stack import MSATransition
from evoformer.control_values.evoformer_checks import c_m
from evoformer.control_values.evoformer_checks import test_module_shape, test_module

n = 3
msa_trans = MSATransition(c_m, n)

test_module_shape(msa_trans, 'msa_transition', control_folder)

test_module(msa_trans, 'msa_transition', 'm', 'out', control_folder)

Next up is OuterProductMean. It serves as information flow from the MSA representation to the pair representation. The MSA representation is of shape (\*, N_seq, N_res, c_m), and the pair representation is of shape (\*, N_res, N_res, c_z). The difference in the embedding dimensions c_m and c_z is an easy fix, switching embedding dimensions is what we always do with the Linear modules. If we think back to the tensor introduction, to duplicate a dimension (N_res) we need to broadcast it in an outer, each-with-each fashion. To lose N_seq, we will need to contract along this dimension, which is done by computing the mean over the sequences. This is exactly the name of the function: OuterProductMean. 

This will also be the first point, where our implementation differs from the pseudocode from the paper. This happens, when the actual open-source implementation differs from the paper. We need to stick with the code rather than the supplement, to be able to use their parameters for our model. The pseudocode denotes the mean computation in line 3, but it is actually just summed at this point. The output of line 4 is instead divided by N_seq. 

That is a real difference:
- Paper: W*(a/n) + b
- Code: (W\*a+b)/n = W\*(a/n) + b/n

Implement `OuterProductMean` and check your code by running the following cell.

In [7]:
from evoformer.msa_stack import OuterProductMean
from evoformer.control_values.evoformer_checks import c_m, c_z, c
from evoformer.control_values.evoformer_checks import test_module_shape, test_module

opm = OuterProductMean(c_m, c_z, c)
test_module_shape(opm, 'outer_product_mean', control_folder)

test_module(opm, 'outer_product_mean', 'm', 'z_out', control_folder)


## The Pair Stack

The pair stack consists at the core of the TriangleMultiplication and TriangleAttention updates. These have an interesting justification which, in the end, essentially boils down to "compute along the rows" or "compute along the columns" again. Still, given the huge success of AlphaFold, it seems smart to pay a closer look to their interpretation, as it might as well help in a different scenario.

The pair representation describes the relationship between the residues. For each residue pair i and j, it has two different values: z_ij, and z_ji. The authors of AlphaFold describe the different residues as nodes in the graph, where the elements in the pair representation correspond to the directed edges in this graph. 

<figure align=center style="padding: 30px">
<img src='images/graph_representation_alphafold.png' height=300px>
<figcaption>Source: Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure prediction with AlphaFold. Nature 596, 583–589 (2021).</figcaption>
</figure>

If you wanted to, you could imagine z_ij as information send from residue i to residue j, maybe a glutamate telling a glycine that it is negatively charged, while z_ji is information traveling from j to i, like the glycin responding that it doesn't care as it is neutral. This would obviously be an overinterpretation, as the values aren't any concrete information, but just numbers that get the job done.

One point the paper stresses is that for the computation of an edge update (let's say z_ij), other edges z_ik aren't considered on their own, but jointly with the missing edge z_jk. We will see how this is concretely implemented in the TriangleMultiplication and TriangleAttention sections. The authors suggest that considering this third edge as well might prod the network into learning consistencies when creating features for this geometric problem, like the triangle inequality (if i is close to k and k close to j, than i can't be far from j).

### Triangle Multiplication

Triangle Multiplication is algorithmically simple. There are two different versions, using "outgoing" edges and using "incoming" edges.

For "outgoing" edges, for the computation of z_ij, the i-th row is multiplied against the j-th row and then contracted over the the column dimension.

For "incoming" edges, for the computation of z_ij, the i-th column is multiplied against the j-th column and then contracted over the row dimension.

Take a look at Algorithm 11 and Algorithm 12 and see if you can confirm that this is indeed the case. Going back to our graph interpretation, the following picture from the AlphaFold paper illustrates how this refers to incoming and outgoing edges.

<figure align=center style="padding: 30px">
<img src='images/graph_representation_multiplicative.png' height=300px>
<figcaption>Source: Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure prediction with AlphaFold. Nature 596, 583–589 (2021).</figcaption>
</figure>

We will go through it for outgoing edges. The i-th row are the elements z_ik for every k. The algorithm multiplies the edge z_ik against z_jk and sums these values up over all k. This way, the outgoing edges z_ik and z_jk are always used jointly for the computation of the third edge z_ij.

In the file `pair_stack.py`, implement the `__init__` and `forward` method of `TriangleMultiplication`. After you're done, check your code by running the following cell.

In [8]:
from evoformer.pair_stack import TriangleMultiplication
from evoformer.control_values.evoformer_checks import c_z, c
from evoformer.control_values.evoformer_checks import test_module_shape, test_module

tri_mul_in = TriangleMultiplication(c_z, 'incoming', c)
tri_mul_out = TriangleMultiplication(c_z, 'outgoing', c)

test_module_shape(tri_mul_in, 'tri_mul_in', control_folder)
test_module_shape(tri_mul_out, 'tri_mul_out', control_folder)

test_module(tri_mul_in, 'tri_mul_in', 'z', 'z_out', control_folder)
test_module(tri_mul_out, 'tri_mul_out', 'z', 'z_out', control_folder)


### Triangle Attention

Just like for TriangleMultiplication, there are two different versions for TriangleAttention: Triangle self-attention around the starting node, and triangle self-attention around the ending node. Attention around the starting node is row-wise attention using the (embedded) pair representation as bias, while attention around the ending-node is column-wise attention using the transposed pair representation as bias.

The following image shows how this relates to the starting and ending nodes.

<figure align=center style="padding: 30px">
<img src='images/graph_representation_tri_attention.png' height=300px>
<figcaption>Source: Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure prediction with AlphaFold. Nature 596, 583–589 (2021).</figcaption>
</figure>

For attention around the starting node, we update the edge z_ij with an attention mechanism that looks at all the other edges z_ik starting from i. This is row-wise attention. The amount of the update is determined by the query-key similarity of z_ij and z_ik. Given the structure of the query-key matrix, this is the entry jk (as our order for the attention weights is (*, q, k) in MultiHeadAttention, meaning that the index of the query determines the first index and the key the second). This means that z_jk influences the update from z_ik to z_ij, which is the third edge in the graph.

For attention around the ending node, the update of z_ij looks at all the other edges z_kj ending in j. This is column-wise attention, and the relevant entry in the attention matrix is ik (as z_ij is the query and z_kj is the key). Since we transpose the pair  bias, z_ki additionaly influences the amount that z_kj contributes to the update of z_ij. As can be seen in the picture, z_ki is the third edge of the  triangle.

The choice of transposing the bias for the ending node and not for the starting node is somewhat arbitrary, it is just about the direction of the gray edge in the image. You could say that attention around the starting node is more about outgoing edges and z_jk is the outgoing edge of the triangle (when we focus on i and j), while for attention aronud the ending node, we focus on incoming edges, and z_ki is an incoming edge. In any case, it seems sensible to transpose the bias for one of the operations, since the direction of information flow is "inverted" for the two versions.

Implement the `__init__` and `forward` method of `TriangleAttention`. After you're done, check your code with the following cell.

In [9]:
from evoformer.pair_stack import TriangleAttention
from evoformer.control_values.evoformer_checks import c_z, c, N_head
from evoformer.control_values.evoformer_checks import test_module_shape, test_module

tri_att_start = TriangleAttention(c_z, 'starting_node', c, N_head)
tri_att_end = TriangleAttention(c_z, 'ending_node', c, N_head)

test_module_shape(tri_att_start, 'tri_att_start', control_folder)
test_module_shape(tri_att_end, 'tri_att_end', control_folder)

test_module(tri_att_start, 'tri_att_start', 'z', 'z_out', control_folder)
test_module(tri_att_end, 'tri_att_end', 'z', 'z_out', control_folder)


### Pair Transition
Just like the MSA stack, the pair stack ends in a 2-layer feed-forward network, the pair transition.

Implement the `__init__` and `forward` method in `PairTransition` and test your code with the following cell.

In [10]:
from evoformer.pair_stack import PairTransition
from evoformer.control_values.evoformer_checks import c_z
from evoformer.control_values.evoformer_checks import test_module_shape, test_module

n = 3
pair_trans = PairTransition(c_z, n)

test_module_shape(pair_trans, 'pair_transition', control_folder)

test_module(pair_trans, 'pair_transition', 'z', 'z_out', control_folder)


### Assembling the Pair Stack

Put together the TriangleMultiplication, TriangleAttention and PairTransition modules according to Algorithm 6 by implementing the `__init__` and `forward` methods in `PairStack`. You can leave the dropout layers out, as they are only active during training and not during inference. Test your code by running the following cell.

In [11]:
from evoformer.pair_stack import PairStack
from evoformer.control_values.evoformer_checks import c_z
from evoformer.control_values.evoformer_checks import test_module_shape, test_module

pair_stack = PairStack(c_z)

test_module_shape(pair_stack, 'pair_stack', control_folder)

test_module(pair_stack, 'pair_stack', 'z', 'z_out', control_folder)


## Assembling the Evoformer

We are close to finishing off the Evoformer. First, implement the `__init__` and `forward` method for `EvoformerBlock` in the file `evoformer.py`. These correspond to the lines 2 to 10 from Algorithm 6. Check your code by running the following cell.

In [12]:
from evoformer.evoformer import EvoformerBlock
from evoformer.control_values.evoformer_checks import c_m, c_z
from evoformer.control_values.evoformer_checks import test_module_shape, test_module

evo_block = EvoformerBlock(c_m, c_z)

test_module_shape(evo_block, 'evo_block', control_folder)

test_module(evo_block, 'evo_block', ('m', 'z'), ('m_out', 'z_out'), control_folder)


Last, implement the `__init__` and `forward` methods for `Evoformer`. The Evoformer is basically just a list of EvoformerBlocks, with an additional embedding for the single representation in the and (line 12 in Algorithm 6). 

After you're done, check your implementation by running the following cell.

In [13]:
from evoformer.evoformer import EvoformerStack
from evoformer.control_values.evoformer_checks import c_m, c_z
from evoformer.control_values.evoformer_checks import test_module_shape, test_module

num_blocks = 3
c_s = 5

evoformer = EvoformerStack(c_m, c_z, num_blocks, c_s)

test_module_shape(evoformer, 'evoformer', control_folder)

test_module(evoformer, 'evoformer', ('m', 'z'), ('m_out', 'z_out', 's_out'), control_folder)


### Optional: Dropout

We are only using AlphaFold for inference. During inference, dropout layers are replaced by identity mappings, so they don't affect the results. They are only active during training, where they set a random subset of the feature vectors to zero (they also scale all other values by 1/p, where p is the dropout probability, so that the expected values of every feature are the same as without dropout).

AlphaFold specifically uses shared dropout, where whole rows or columns from the feature are set to zero. If you want to try implementing this feature yourself, go to `dropout.py` and implement the modules in there. When you are done, set `test_dropout` in the following cell to `True` and run it to test your implementation.

In [14]:
# Set this to `True` if you want to test your dropout implementation.
test_dropout = False

if test_dropout:
    from evoformer.dropout import DropoutRowwise, DropoutColumnwise
    test_shape = (8, 25, 30, 4)
    dropout_rowwise = DropoutRowwise(p=0.2)
    dropout_columnwise = DropoutColumnwise(p=0.3)
    dropout_rowwise.train()
    dropout_columnwise.train()

    test_inp = torch.ones(test_shape)
    rows_dropped = dropout_rowwise(test_inp)
    cols_dropped = dropout_columnwise(test_inp)

    p_nonzero_rows = torch.count_nonzero(rows_dropped).item()/rows_dropped.numel()
    p_nonzero_cols = torch.count_nonzero(cols_dropped).item()/cols_dropped.numel()

    assert abs(p_nonzero_rows - 0.8) < 0.1
    assert abs(p_nonzero_cols - 0.7) < 0.1

    assert torch.std(rows_dropped, dim=-2).sum() == 0
    assert torch.std(cols_dropped, dim=-3).sum() == 0


## Conclusion

We are through with the Evoformer - well done! It consists of a lot of pieces, but with the MultiHeadAttention we already implemented, each single element can be formulated with a few lines. 

Next up is a quick chapter on feature embedding, which is the conversion of the features we extracted from the MSA to the MSA representation m and the pair representation z we've just seen so much when implementing the Evoformer. Structurally, feature embedding is done before the Evoformer, but since the ExtraMSAStack of feature embedding is basically a modified MSAStack, we've put it after it.

Keep up the good work!