## Basic Setup

Run the cells below for the basic setup of this notebook.

In [1]:
try:
    from google.colab import drive # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False
    print('No colab environment, assuming local setup.')

if IN_COLAB:
    drive.mount('/content/drive')

    # TODO: Enter the foldername in your Drive where you have saved the unzipped
    # turorials folder, e.g. 'alphafold-decoded/tutorials'
    FOLDERNAME = None
    assert FOLDERNAME is not None, "[!] Enter the foldername."

    # Now that we've mounted your Drive, this ensures that
    # the Python interpreter of the Colab VM can load
    # python files from within it.
    import sys
    sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))
    %cd /content/drive/My\ Drive/$FOLDERNAME

    print('Connected COLAB to Google Drive.')

import os

base_folder = '../feature_extraction'
control_folder = f'{base_folder}/control_values'

assert os.path.isdir(control_folder), 'Folder "control_values" not found, make sure that FOLDERNAME is set correctly.' if IN_COLAB else 'Folder "control_values" not found, make sure that your root folder is set correctly.'

No colab environment, assuming local setup.


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from torch import nn
import math
import os

In [4]:
from control_values.obfuscated_solution import create_control_values

# Some of the assert statements depend on the result of random operations,
# which aren't always the same over different versions and operating systems.
# Therefore, the control values are created directly on your system.
create_control_values(".")

# Feature Extraction

In this Notebook, we will implement the feature extraction pipeline for AlphaFold. The pipeline consists of the following steps:

- Parse the a3m file
- Count and remove deletions (residues that are present in the aligned sequences, but aren't present in the query sequence)
- Randomly select cluster centers
- Randomly change some residues from the cluster center (this is called masking)
- Assign non-cluster sequences to their closest cluster center by Hamming Distance
- Summarize the features of all sequences assigned to the cluster
- Select a fixed number of non-cluster sequences as extra sequences
- Assemble the features

Most of the work is done to create the `msa_feat` and the `extra_msa_feat`. The input features additionally consist of the `target_feat` and the `residue_index`, but these are easy to implement.

For an overview of the features, you can read Section 1.2.9 from [AlphaFold's Supplement](https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf). 

We will skip over the template features in this series, but predictions often work great without them (ColabFold has the option to use templates disabled by default).

## File parsing

Take a look at the file `alignment_tautomerase.a3m`. In it, you will find the alignment data of the 2-Hydroxymuconate Tautomerase in a3m format. The alignment was generated by ColabFold, which uses the mmseqs algorithm to create alignments.

Note the format of the file: It consists of lines starting with '>' containing an identifier and some values from the alignment, followed by a sequence. The first of these sequences is the query sequence.

In the sequence string, there are upper-case and lower-case letters. Upper-case letters denote aligned residues, while lower-case letters denote residues that are present in the aligned sequence, but not in the query sequence (in other formats, this might be denoted by a dash in the query sequence). These are called deletions.

With this knowledge on the file format, implement the method `load_a3m_file` in the file `feature_extractor.py` and test your implementation with the following code cell.

In [5]:
from feature_extraction import load_a3m_file

seqs = load_a3m_file(f'{base_folder}/alignment_tautomerase.a3m')

first_expected = ['PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK', 'PVVTIELWEGRTPEQKRELVRAVSSAISRVLGCPEEAVHVILHEVPKANWGIGGRLASE', 'PVVTIEMWEGRTPEQKKALVEAVTSAVAGAIGCPPEAVEVIIHEVPKVNWGIGGQIASE', 'PIIQVQMLKGRSPELKKQLISEITDTISRTLGSPPEAVRVILTEVPEENWGVGGVPINE', 'PFVQIHMLEGRTPEQKKAVIEKVTQALVQAVGVPASAVRVLIQEVPKEHWGIGGVSARE']

assert len(seqs) == 8361 and seqs[:5] == first_expected

Now, we will parse the individual sequence to remove deletions and encode them as one-hot encoding. For one-hot encoding, the classes must have a predetermined order. The usual way to order the residues is to alphabetically sort the 3-letter codes and to then use this order for the 1-letter codes. 

The order of the amino acids is provided at the top of `feature_extractor.py`. Initialize the two dictionaries as maps from the letter to the index. Then, implement `onehot_encode_aa_type` and check your implementation with the following cell.

In [6]:
from feature_extraction import onehot_encode_aa_type

test_seq = "ARNDCQEGHILKMFPSTWYV"

enc1 = onehot_encode_aa_type(test_seq, include_gap_token=False)
enc2 = onehot_encode_aa_type(test_seq, include_gap_token=True)
enc3 = onehot_encode_aa_type(test_seq+'-', include_gap_token=True)

assert torch.allclose(enc1, nn.functional.one_hot(torch.arange(20), num_classes=21))
assert torch.allclose(enc2, nn.functional.one_hot(torch.arange(20), num_classes=22))
enc3_exp = nn.functional.one_hot(torch.cat((torch.arange(20),torch.tensor([21]))), num_classes=22)
assert torch.allclose(enc3, enc3_exp)

Now implement `initial_data_from_seqs`. The method counts and removes deletions, and removes sequences that are duplicates (duplicates after removal of deletions). After that, the method uses one-hot encoding to encode the residues and calculates the distribution of the residues at each position.

Test your code by running the following cell:

In [7]:
from feature_extraction import initial_data_from_seqs
seqs = load_a3m_file(f'{base_folder}/alignment_tautomerase.a3m')

features = initial_data_from_seqs(seqs)

expected_features = torch.load(f'{control_folder}/initial_data.pt')

for key, param in features.items():
    assert torch.allclose(param, expected_features[key]), f'Error in computation of feature {key}.'

  expected_features = torch.load(f'{control_folder}/initial_data.pt')


## Clustering

A subset of the sequences is randomly selected as cluster centers, always including the query sequence as the first cluster center. Implement `select_cluster_centers` and test your implementation by running the following cell.

In [8]:
from feature_extraction import select_cluster_centers

inp = torch.load(f'{control_folder}/initial_data.pt')

features = select_cluster_centers(inp, seed=0)

expected_features = torch.load(f'{control_folder}/clusters_selected.pt')

for key, param in features.items():
    assert torch.allclose(param, expected_features[key]), f'Error in computation of feature {key}.'


  inp = torch.load(f'{control_folder}/initial_data.pt')
  expected_features = torch.load(f'{control_folder}/clusters_selected.pt')


AlphaFold uses a regularization strategy they call 'masking' on the cluster centers (it is used during training and inference). 

This operation replaces some of the residues randomly according to the following rules:

- 15% of the residues are selected. Of these,
    - 10% are replaced with a uniformly sampled random amino acid.
    - 10% are replaced with an amino acid sampled from the MSA distribution for this position.
    - 10% are not replaced.
    - 70% are replaced with a special token (masked_msa_token).

For this, we will create a probability distribution of the replacement options for each amino acid, given that it falls under the first 15% (the first option for example contributes to the whole distribution by $[0.1\cdot 1/20,\, 0.1\cdot 1/20, \,0.1.\cdot 1/20,...]$, the third option contributes to the whole distribution by [0, 0, 0.1, 0, 0, ...] given that the initial residue was N).

After creation of the distribution, we will sample from it and create a mask with probability 15%. All amino acids in the mask are replaced by the amino acid sampled from the distribution.

Implement the method `mask_cluster_centers`. You will find step-by-step instructions in the method body. After you're done, test your implementation by running the following cell.


In [9]:
from feature_extraction import mask_cluster_centers

inp = torch.load(f'{control_folder}/clusters_selected.pt')

features = mask_cluster_centers(inp, seed=1)

expected_features = torch.load(f'{control_folder}/clusters_masked.pt')

for key, param in features.items():
    assert torch.allclose(param, expected_features[key]), f'Error in computation of feature {key}.'

  inp = torch.load(f'{control_folder}/clusters_selected.pt')
  expected_features = torch.load(f'{control_folder}/clusters_masked.pt')


Every sequence from the ExtraMSA is assigned to the cluster center it shares the most residues with. This is called Hamming-Distance. Implement the method `cluster_assignment` and check your implementation with the following cell.

In [10]:
from feature_extraction import cluster_assignment

inp = torch.load(f'{control_folder}/clusters_masked.pt')

features = cluster_assignment(inp)

expected_features = torch.load(f'{control_folder}/clusters_assigned.pt')

for key, param in features.items():
    assert torch.allclose(param, expected_features[key]), f'Error in computation of feature {key}.'

  inp = torch.load(f'{control_folder}/clusters_masked.pt')
  expected_features = torch.load(f'{control_folder}/clusters_assigned.pt')


Clustering is used to reduce the computational cost of running the prediction to a manageable level. However, the information of the non-selected sequences should still contribute to the prediction. To do so, we calculate averages of the deletion_counts and the one-hot-encoded residues over the clusters.

As the first step, implement the method `cluster_average`, that takes a feature, an extra_feature, the assignment indices and the assignment counts (how many extra sequences are assigned to each of the centers). You will find a step-by-step guide for the method in the method body. 

After you're done, check your implementation by running the following cell.

In [11]:
from feature_extraction import cluster_average

# Check for cluster_average

N_clust = 10
N_res = 3
N_extra = 20
dim1 = 5
dim2 = 7
assignment = torch.tensor([7, 1, 1, 8, 3, 4, 7, 1, 4, 4, 9, 8, 4, 8, 1, 5, 8, 8, 8, 5])
assignment_count = torch.tensor([0, 4, 0, 1, 4, 2, 0, 2, 6, 1])

ft1_shape = (N_clust, N_res, dim1)
eft1_shape = (N_extra, N_res, dim1)
ft2_shape = (N_clust, N_res, dim1, dim2)
eft2_shape = (N_extra, N_res, dim1, dim2)

ft1 = torch.linspace(-2, 2, math.prod(ft1_shape)).reshape(ft1_shape)
eft1 = torch.linspace(-2, 2, math.prod(eft1_shape)).reshape(eft1_shape)
ft2 = torch.linspace(-2, 2, math.prod(ft2_shape)).reshape(ft2_shape)
eft2 = torch.linspace(-2, 2, math.prod(eft2_shape)).reshape(eft2_shape)

res1 = cluster_average(ft1, eft1, assignment, assignment_count)
res2 = cluster_average(ft2, eft2, assignment, assignment_count)

expected_res1 = torch.load(f'{control_folder}/cluster_average_res1.pt')
expected_res2 = torch.load(f'{control_folder}/cluster_average_res2.pt')

assert torch.allclose(res1, expected_res1)
assert torch.allclose(res2, expected_res2)

  expected_res1 = torch.load(f'{control_folder}/cluster_average_res1.pt')
  expected_res2 = torch.load(f'{control_folder}/cluster_average_res2.pt')


Now, we'll use the method we just created to compute the average deletion counts and the cluster profiles. Implement the method `summarize_clusters` and check your code with the following cell.

In [12]:
from feature_extraction import summarize_clusters

inp = torch.load(f'{control_folder}/clusters_assigned.pt')

features = summarize_clusters(inp)

expected_features = torch.load(f'{control_folder}/clusters_summarized.pt')

for key, param in features.items():
    assert torch.allclose(param, expected_features[key]), f'Error in computation of feature {key}.'

  inp = torch.load(f'{control_folder}/clusters_assigned.pt')
  expected_features = torch.load(f'{control_folder}/clusters_summarized.pt')


Up to this point, the number of extra sequences was determined by the MSA algorithm (in our case mmseqs) and is, therefore, arbitrary. Of these extra sequences, we will randomly select a fixed number to directly contribute to the inference in form of the extra_msa_feat. This is done by the method `crop_extra_msa`. Implement it and check your code by running the following cell.

In [13]:
from feature_extraction import crop_extra_msa

inp = torch.load(f'{control_folder}/clusters_summarized.pt')

features = crop_extra_msa(inp, seed=2)

expected_features = torch.load(f'{control_folder}/extra_msa_cropped.pt')

for key, param in features.items():
    assert torch.allclose(param, expected_features[key]), f'Error in computation of feature {key}.'

  inp = torch.load(f'{control_folder}/clusters_summarized.pt')
  expected_features = torch.load(f'{control_folder}/extra_msa_cropped.pt')


## MSA Feat and ExtraMSA Feat

We've successfully populated the feature dict with all the features of the msa_feat. In the function `calculate_msa_feat`, we do some last processing on the features and concat them to get the finished msa_feat. The method description contains a step-by-step guide to assemble the feature. After you're done, check your implementation by running the following cell:

In [14]:
from feature_extraction import calculate_msa_feat

inp = torch.load(f'{control_folder}/extra_msa_cropped.pt')

msa_feat = calculate_msa_feat(inp)

expected_feat = torch.load(f'{control_folder}/msa_feat.pt')

assert torch.allclose(msa_feat, expected_feat)

  inp = torch.load(f'{control_folder}/extra_msa_cropped.pt')
  expected_feat = torch.load(f'{control_folder}/msa_feat.pt')


The ExtraMSA feat is basically a simpler version of the MSA feat, as it doesn't include cluster averages. Implement `calculate_extra_msa_feat` and check your code with the following cell:

In [15]:
from feature_extraction import calculate_extra_msa_feat

inp = torch.load(f'{control_folder}/extra_msa_cropped.pt')

msa_feat = calculate_extra_msa_feat(inp)

expected_feat = torch.load(f'{control_folder}/extra_msa_feat.pt')

assert torch.allclose(msa_feat, expected_feat)

  inp = torch.load(f'{control_folder}/extra_msa_cropped.pt')
  expected_feat = torch.load(f'{control_folder}/extra_msa_feat.pt')


## Putting it all together

We've got all the methods to build the complete input for AlphaFold. `create_features_from_a3m` walks you through putting together all the methods we defined so far. Note that we'll construct two new features:
- target_feat: One-hot encoded query sequence
- residue_index: Index of each residue, simply [0,...,N_res-1] for our prediction.

After you are done with your implementation, check your code by running the following cell:

In [16]:
from feature_extraction import create_features_from_a3m

inp = torch.load(f'{control_folder}/extra_msa_cropped.pt')

batch = create_features_from_a3m(f'{base_folder}/alignment_tautomerase.a3m', seed=0)

expected_batch = torch.load(f'{control_folder}/full_batch.pt')

for key, param in batch.items():
    assert torch.allclose(param, expected_batch[key]), f'Error in computation of feature {key}.'

assert batch['target_feat'].dtype == torch.float32, f"Target feat isn't a float, but {batch['target_feat'].dtype}."

  inp = torch.load(f'{control_folder}/extra_msa_cropped.pt')
  expected_batch = torch.load(f'{control_folder}/full_batch.pt')


## Conclusion

You've completed the feature extraction for AlphaFold – nice work! That step can get a bit tricky. Next up is the Evoformer. Since we've already implemented the MultiHeadAttention module, this part should feel more straightforward. It's where your features start turning into the insights that drive the Structure Module.  See you there!

## My notes

### mask_cluster_centers

- Todo says `Create a copy of the original 'msa_aatype' data under the key 'true_msa_atype'.`. But this leads to a failure in the cell that asserts the method.
- I feel uneasy about adding the padding category to `msa_aatype`. I have changed the meaning of the one-hot encoding that it originally had, meaning callers have to be aware of when this method can be called, and that it cannot be called multiple times, etc.. I could be worth considering to make the mask a part of the input `msa_aatype` instead of changing it during the call.

### cluster_average

- It was actually difficult to get my head around what the inputs to the method were. I think if I had seen an example usage in the tests, I would have understood it better. In general, loading stuff from the control-folder and asserting against them in the tests, makes it hard to use the tests as documentation, or in this instance, a guide to understand the code to be implemented.