## Take Home Assessment

**Disclaimer**: This assessment is work in progress, so we apologise in advance for any hiccup. Any feedback is valuable!

**Setup**: You are provided with some training code for a model that takes protein 3D structure and predicts the associated amino acid sequence. This notebook provides the required steps to download the code repository and training data (a subset of the Protein Data Bank), alongside minimal code to call the training loop. Please fork the repository that you can find below and edit your own version.

**Compute**: You will be provided a [Lambda](https://cloud.lambdalabs.com/) instance with a A10 GPU on an agreed day. For this we need your public key and we will share an IP address to access the compute instance.

**Evaluation**: The following questions are on purpose quite open-ended. No specific answer is expected. The aim is to provide a semi-realistic setup that you may encounter if you were to join our team. We want to assess your ability to probe deep learning models and to come up with solutions to alleviate potential identified limitations. Please write down your answers (e.g. with plots, tables etc) in your copy of the repository (e.g. in this notebook or in any other format of your choice) and push them to your fork. Do include any documentation of what all you did to arrive at your answers. We will discuss during the onsite interview. Please keep the time commitment under 4h.

**Questions**:
1. Log and profile the training loop.  What would you recommend if we wanted to train more quickly? Implement some of your proposals.
2. What kinds of issues will arise as model size increases? How could these be partially alleviated? Implement some of your proposed solutions.
3. The way the dataloader is organized in this project is unusual.  What will happen as we increase the size of the training dataset (e.g. using the AlphaFold database)?  How would you re-organize the code to avoid these issues?  What techniques would you consider using to ensure training scales efficiently with the dataset size?
4. Log the average norm of the weights & activations through training. How would you organize this information to help diagnose training dynamics?  How would you characterize the values you observe here?

# Initial observations

- Hey, this code looks strangely familiar... looks like ProteinMPNN!
- In the cells below, I will answer the questions out-of-order but in a way that I think makes sense.

# Data download

In [1]:
# Download subset of training data
!wget https://files.ipd.uw.edu/pub/training_sets/pdb_2021aug02_sample.tar.gz
!tar xvf "pdb_2021aug02_sample.tar.gz"
!rm pdb_2021aug02_sample.tar.gz

--2024-10-04 21:56:24--  https://files.ipd.uw.edu/pub/training_sets/pdb_2021aug02_sample.tar.gz
Resolving files.ipd.uw.edu (files.ipd.uw.edu)... 128.95.160.134, 128.95.160.135, 2607:4000:406::160:135, ...
Connecting to files.ipd.uw.edu (files.ipd.uw.edu)|128.95.160.134|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 49690915 (47M) [application/octet-stream]
Saving to: ‘pdb_2021aug02_sample.tar.gz’


2024-10-04 21:56:28 (16.3 MB/s) - ‘pdb_2021aug02_sample.tar.gz’ saved [49690915/49690915]

./pdb_2021aug02_sample/
./pdb_2021aug02_sample/README
./pdb_2021aug02_sample/list.csv
./pdb_2021aug02_sample/pdb/
./pdb_2021aug02_sample/pdb/l3/
./pdb_2021aug02_sample/pdb/l3/5l3p.pt
./pdb_2021aug02_sample/pdb/l3/5l3g_A.pt
./pdb_2021aug02_sample/pdb/l3/5l3f.pt
./pdb_2021aug02_sample/pdb/l3/5l3r_B.pt
./pdb_2021aug02_sample/pdb/l3/4l3o_G.pt
./pdb_2021aug02_sample/pdb/l3/1l3b_E.pt
./pdb_2021aug02_sample/pdb/l3/3l3t_C.pt
./pdb_2021aug02_sample/pdb/l3/6l3y_A.pt
./pdb_2021aug02_sam

# 3. Data loading
The way the dataloader is organized in this project is unusual.  What will happen as we increase the size of the training dataset (e.g. using the AlphaFold database)?  How would you re-organize the code to avoid these issues?  What techniques would you consider using to ensure training scales efficiently with the dataset size?

- The dataloader is already well-organized to minimize the I/O time spent loading from disk.
  - PDBs are processed so that the smallest atom of training (a single chain from a single pdb id) is in its own file.
  - Each of these has a corresponding lightweight metadata file that allows for fast selection of training instances without loading the full example.
  - The dataloader also performs processing to generate assemblies on-the-fly. The auxiliary DataLoader, `train_loader`, provides randomly transformed assemblies every `args.reload_data_every_n_epochs` to the primary DataLoader, `loader_train`. This could be a bottleneck, but it is alleviated by running asynchronously on cpu during model training, so it doesn't interfere with training.
- However, if the auxiliary dataloader were used on a huge dataset like the AlphaFold dataset, the training would outpace the auxiliary dataloader and you would be left waiting forever at `pdb_dict_train = q.get().result()`. 
  - You could alleviate this by increasing the `args.reload_data_every_n_epochs` parameter, but that probably still would be insufficient. 
  - I don't understand exactly what the transforms here do, but I think they generate the bioassembly from the asymmetric unit. Since the model is SE(3)-invariant, I don't know exactly why this is necessary, since these are rotations and translations of the same asymmetric unit, offering no additional diversity to a model that is agnostic to SE(3) transforms. AlphaFold also doesn't provide transforms in its predicted PDBs, so I don't think this information would be present. I would probably dispense with the auxiliary dataloader and use only the primary dataloader. 
  - If these transforms are necessary, I would push them into the normal dataloader as long as I didn't see a data loading bottleneck.

# 1. Profiling the training loop
Log and profile the training loop.  What would you recommend if we wanted to train more quickly? Implement some of your proposals.

## Profiling existing training loop

In [2]:
from training.training import main as run_training
import random
import numpy as np
import torch

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

class MyArgs(object):
  def __init__(self):
    self.path_for_training_data = "./pdb_2021aug02_sample"
    self.path_for_outputs = "./content/test"
    self.previous_checkpoint = ""
    self.num_epochs = 2
    self.save_model_every_n_epochs = 5
    self.reload_data_every_n_epochs = 4
    self.num_examples_per_epoch = 200
    self.batch_size = 2000
    self.max_protein_length = 2000
    self.hidden_dim = 128
    self.num_encoder_layers = 3
    self.num_decoder_layers = 3
    self.num_neighbors = 32
    self.dropout = 0.1
    self.backbone_noise = 0.1
    self.rescut = 3.5
    self.debug = True
    self.gradient_norm = -1.0 #no norm
    self.decoder_use_full_cross_attention = True
    self.cross_attention_num_heads = 4
    self.mixed_precision = False
    self.compute_categorical_jacobian = False

args = MyArgs()
run_training(args)

  scaler = torch.cuda.amp.GradScaler()
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load(

epoch: 1, step: 7, time: 1.5, train: 58.141, valid: 50.003, train_acc: 0.016, valid_acc: 0.013
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
    autograd::engine::evaluate_function: AddmmBackward0         0.47%       5.231ms         5.12%      57.143ms     181.407us       0.000us         0.00%      97.829ms     310.567us           315  
                                               aten::mm         1.23%      13.739ms         2.89%      32.194ms 

100%|██████████| 7/7 [00:00<00:00, 17.12it/s]


epoch: 2, step: 14, time: 0.4, train: 44.589, valid: 34.875, train_acc: 0.026, valid_acc: 0.024
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
    autograd::engine::evaluate_function: AddmmBackward0         1.52%       5.035ms        13.07%      43.259ms     137.331us       0.000us         0.00%      99.454ms     315.726us           315  
                                               aten::mm         4.19%      13.870ms         6.12%      20.263ms

## Observations
CPU time is about equal with GPU time, so without doing extensive profiling, it looks like the dataloader is doing a decent job of keeping the GPUs fed. This means that model performance improvements are likely to help. I implemented AMP, with the caveat that I was directly inspired by the ProteinMPNN implementation.

## Profile with AMP

In [1]:
from training.training import main as run_training
import random
import numpy as np
import torch

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

class MyArgs(object):
  def __init__(self):
    self.path_for_training_data = "./pdb_2021aug02_sample"
    self.path_for_outputs = "./content/test"
    self.previous_checkpoint = ""
    self.num_epochs = 2
    self.save_model_every_n_epochs = 5
    self.reload_data_every_n_epochs = 4
    self.num_examples_per_epoch = 200
    self.batch_size = 2000
    self.max_protein_length = 2000
    self.hidden_dim = 128
    self.num_encoder_layers = 3
    self.num_decoder_layers = 3
    self.num_neighbors = 32
    self.dropout = 0.1
    self.backbone_noise = 0.1
    self.rescut = 3.5
    self.debug = True
    self.gradient_norm = -1.0 #no norm
    self.decoder_use_full_cross_attention = True
    self.cross_attention_num_heads = 4
    self.mixed_precision = True
    self.compute_categorical_jacobian = False

args = MyArgs()
run_training(args)

  scaler = torch.cuda.amp.GradScaler()
  scaler = torch.cuda.amp.GradScaler()
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(

epoch: 1, step: 6, time: 1.3, train: 58.912, valid: 51.906, train_acc: 0.018, valid_acc: 0.009
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                               Optimizer.step#Adam.step         0.00%       0.000us         0.00%       0.000us       0.000us      79.188ms        27.98%      79.188ms      13.198ms             6  
                                            aten::copy_         3.81%      45.533ms        10.45%     124.743ms 

100%|██████████| 6/6 [00:00<00:00, 14.97it/s]


epoch: 2, step: 12, time: 0.4, train: 47.865, valid: 38.174, train_acc: 0.021, valid_acc: 0.018
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::copy_         5.45%      17.711ms        18.41%      59.862ms      25.670us      37.024ms        17.82%      39.670ms      17.011us          2332  
                                           aten::linear         1.13%       3.681ms        15.60%      50.725ms

## Observations
Training with AMP offered a noticeable reduction in CUDA time. It looks like we may be dataloader bottlenecked at this point - I would like to do more profiling to see if we are stalling at the dataloader but I ran out of time. I would also like to train longer to assess the effect of AMP on model performance.

# 2. Model improvement
What kinds of issues will arise as model size increases? How could these be partially alleviated? Implement some of your proposed solutions.

You could hit memory errors and slower training if the model were to get big enough, since the spatial and computational complexity of updating nodes by all other nodes scales quadratically with sequence length. However, the k_neighbors argument, which limits the complexity to O(N * k), alleviates that issue.

Possible improvements:
- Train on spatial crops. Similar to k_neighbors, but would allow full self-attention. Training algorithm:
  - Sample a pdb
  - Sample a residue in that pdb
  - Condition on all structure within a 50 A radius of that residue.
  - Supervise all residues within a 20 A crop of that residue (to avoid training residues near the edge of the spatial crop).
  - At inference time, you can scan over residues and perform a spatial crop for each when performing inference for that residue. This has the downside of needing to re-encode structure at each step. Really, the k_neighbors solution where you encode the full structure once and then select node embeddings based on neighbors is a much more elegant way to solve this problem.
- Use a Mamba-like architecture to perform linear attention.
- Perform an ablation of encoder size vs. decoder size - for a fixed FLOPS budget, what is the optimal tradeoff between encoder size and decoder size?
- This is not necessarily a model scaling improvement, but I noticed that this model does not use true attention. 
  - I made changes to model_utils.py to replace the message-passing "attention" with true cross attention to see if that offered improvement.

## Model with message passing

In [2]:
from training.training import main as run_training
import random
import numpy as np
import torch

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

class MyArgs(object):
  def __init__(self):
    self.path_for_training_data = "./pdb_2021aug02_sample"
    self.path_for_outputs = "./content/test"
    self.previous_checkpoint = ""
    self.num_epochs = 50
    self.save_model_every_n_epochs = 5
    self.reload_data_every_n_epochs = 4
    self.num_examples_per_epoch = 200
    self.batch_size = 2000
    self.max_protein_length = 2000
    self.hidden_dim = 128
    self.num_encoder_layers = 3
    self.num_decoder_layers = 3
    self.num_neighbors = 32
    self.dropout = 0.1
    self.backbone_noise = 0.1
    self.rescut = 3.5
    self.debug = False
    self.gradient_norm = -1.0 #no norm
    self.decoder_use_full_cross_attention = False
    self.cross_attention_num_heads = 4
    self.mixed_precision = True
    self.compute_categorical_jacobian = False

args = MyArgs()
run_training(args)

  scaler = torch.cuda.amp.GradScaler()
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load(

epoch: 1, step: 4, time: 1.4, train: 52.935, valid: 51.814, train_acc: 0.046, valid_acc: 0.071
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                               Optimizer.step#Adam.step         0.00%       0.000us         0.00%       0.000us       0.000us      76.432ms        33.52%      76.432ms      19.108ms             4  
                                           aten::linear         0.22%       2.779ms        20.34%     253.854ms 

100%|██████████| 4/4 [00:00<00:00, 16.09it/s]


epoch: 2, step: 8, time: 0.3, train: 48.275, valid: 41.828, train_acc: 0.043, valid_acc: 0.069
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::linear         1.12%       2.427ms        15.27%      32.983ms      85.228us       0.000us         0.00%      31.219ms      80.669us           387  
                                            aten::copy_         5.22%      11.263ms        23.00%      49.663ms 

100%|██████████| 4/4 [00:00<00:00, 16.40it/s]


epoch: 3, step: 12, time: 0.3, train: 40.520, valid: 32.893, train_acc: 0.045, valid_acc: 0.064
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::linear         1.16%       2.484ms        15.66%      33.504ms      86.575us       0.000us         0.00%      31.260ms      80.775us           387  
                                            aten::copy_         5.44%      11.635ms        23.58%      50.436ms

100%|██████████| 4/4 [00:00<00:00, 14.91it/s]


epoch: 4, step: 16, time: 0.3, train: 33.371, valid: 26.795, train_acc: 0.048, valid_acc: 0.063
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::linear         1.01%       2.436ms        13.76%      33.150ms      85.660us       0.000us         0.00%      28.346ms      73.245us           387  
                                            aten::copy_        18.34%      44.185ms        46.40%     111.787ms

  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))

  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))


epoch: 5, step: 21, time: 0.7, train: 28.368, valid: 23.078, train_acc: 0.062, valid_acc: 0.095


  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::linear         0.95%       3.077ms        23.33%      75.506ms     159.633us       0.000us         0.00%      34.767ms      73.503us           473  
                                            aten::copy_        10.75%      34.773ms        32.36%     104.717ms      50.661us      28.587ms        15.19%      30.783ms      14.892us          2067  
    autog

100%|██████████| 5/5 [00:00<00:00, 13.90it/s]


epoch: 6, step: 26, time: 0.4, train: 25.421, valid: 20.574, train_acc: 0.080, valid_acc: 0.099


  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::linear         1.00%       3.098ms        13.85%      42.983ms      90.874us       0.000us         0.00%      35.062ms      74.127us           473  
                                            aten::copy_        20.52%      63.671ms        50.47%     156.600ms      91.418us      27.124ms        14.46%      29.391ms      17.158us          1713  
    autog

100%|██████████| 5/5 [00:00<00:00, 15.21it/s]


epoch: 7, step: 31, time: 0.4, train: 22.342, valid: 19.257, train_acc: 0.091, valid_acc: 0.078
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::linear         1.07%       3.038ms        21.52%      61.008ms     128.981us       0.000us         0.00%      33.860ms      71.586us           473  
                                            aten::copy_         7.82%      22.177ms        25.14%      71.267ms

100%|██████████| 5/5 [00:00<00:00, 14.79it/s]


epoch: 8, step: 36, time: 0.4, train: 21.554, valid: 18.507, train_acc: 0.083, valid_acc: 0.074
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::linear         1.06%       3.079ms        14.64%      42.376ms      89.590us       0.000us         0.00%      32.414ms      68.529us           473  
                                            aten::copy_        10.99%      31.790ms        35.72%     103.354ms

  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
100%|██████████| 6/6 [00:00<00:00, 12.36it/s]
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))


epoch: 9, step: 42, time: 0.8, train: 20.659, valid: 19.063, train_acc: 0.081, valid_acc: 0.082


  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))


RuntimeError: !stack.empty() INTERNAL ASSERT FAILED at "../torch/csrc/autograd/profiler_python.cpp":969, please report a bug to PyTorch. Python replay stack is empty.

## Model with full cross attention

In [None]:
from training.training import main as run_training
import random
import numpy as np
import torch

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

class MyArgs(object):
  def __init__(self):
    self.path_for_training_data = "./pdb_2021aug02_sample"
    self.path_for_outputs = "./content/test"
    self.previous_checkpoint = ""
    self.num_epochs = 2
    self.save_model_every_n_epochs = 5
    self.reload_data_every_n_epochs = 4
    self.num_examples_per_epoch = 200
    self.batch_size = 2000
    self.max_protein_length = 2000
    self.hidden_dim = 128
    self.num_encoder_layers = 3
    self.num_decoder_layers = 3
    self.num_neighbors = 32
    self.dropout = 0.1
    self.backbone_noise = 0.1
    self.rescut = 3.5
    self.debug = False
    self.gradient_norm = -1.0 #no norm
    self.decoder_use_full_cross_attention = True
    self.cross_attention_num_heads = 4
    self.mixed_precision = False
    self.compute_categorical_jacobian = False

args = MyArgs()
run_training(args)

# 4. Logging activations and weights

Activation visualizations can be helpful to assess whether the model is learning the expected relationships. 

One interesting activation visualization would be to see how perturbations in some atoms affect learned representations of other atoms. You would expect that for a trained model, atoms that are closer in proximity should have a greater effect on each other that atoms that are distant. For true attention-based models, you can visualize the attention maps directly. For this model, full attention maps are not available, but you can achieve a similar effect by computing the categorical Jacobian. This is helpful for a sanity check at the beginning of training - it should be mostly sparse, with non-zero values for neighbors and zeros for all interactions with non-neighbors - it should already start to look like a contact map. It also can help you understand how well the model is training - the categorical Jacobian should start to recapitulate the contacts as training progresses.

I've added code to compute the categorical Jacobian of the encoder and visualized it in reference to the true contact map as training progresses.

## Categorical Jacobian of message passing model

In [None]:
from training.training import main as run_training
import random
import numpy as np
import torch

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

class MyArgs(object):
  def __init__(self):
    self.path_for_training_data = "./pdb_2021aug02_sample"
    self.path_for_outputs = "./content/test"
    self.previous_checkpoint = ""
    self.num_epochs = 2
    self.save_model_every_n_epochs = 5
    self.reload_data_every_n_epochs = 4
    self.num_examples_per_epoch = 200
    self.batch_size = 2000
    self.max_protein_length = 2000
    self.hidden_dim = 128
    self.num_encoder_layers = 3
    self.num_decoder_layers = 3
    self.num_neighbors = 32
    self.dropout = 0.1
    self.backbone_noise = 0.1
    self.rescut = 3.5
    self.debug = False
    self.gradient_norm = -1.0 #no norm
    self.decoder_use_full_cross_attention = False
    self.cross_attention_num_heads = 4
    self.mixed_precision = False
    self.compute_categorical_jacobian = True

args = MyArgs()
run_training(args)

## Categorical Jacobian of cross-attention model

In [None]:
from training.training import main as run_training
import random
import numpy as np
import torch

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

class MyArgs(object):
  def __init__(self):
    self.path_for_training_data = "./pdb_2021aug02_sample"
    self.path_for_outputs = "./content/test"
    self.previous_checkpoint = ""
    self.num_epochs = 2
    self.save_model_every_n_epochs = 5
    self.reload_data_every_n_epochs = 4
    self.num_examples_per_epoch = 200
    self.batch_size = 2000
    self.max_protein_length = 2000
    self.hidden_dim = 128
    self.num_encoder_layers = 3
    self.num_decoder_layers = 3
    self.num_neighbors = 32
    self.dropout = 0.1
    self.backbone_noise = 0.1
    self.rescut = 3.5
    self.debug = False
    self.gradient_norm = -1.0 #no norm
    self.decoder_use_full_cross_attention = True
    self.cross_attention_num_heads = 4
    self.mixed_precision = False
    self.compute_categorical_jacobian = True

args = MyArgs()
run_training(args)

## Observations
1. The categorical Jacobian looks a bit like a contact map. This is promising!
2. The categorical Jacobian is not sparse as I expected it to be. I'd have to do some more debugging here.

# 5. Feedback

- This coding assessment is too big. 4 hours to read/understand the assessment, set up an environment, read over 1000 lines of code, come up with multiple model improvement proposals and multiple training loop improvement proposals, implement all of them, and answer the questions is a lot. I'd suggest asking the interviewee to answer each of these questions but only implement one model improvement.

- Also, this challenge seems pretty much impossible to do in time if you don't know ProteinMPNN - you would have to back out the model architecture from reading the code, which is possible but would take half of the allotted time. Providing a diagram of the architecture would set candidates on somewhat more equal footing.

- I might leave some low-hanging fruit for model improvements, rather than giving the user a model that is so good that it is still more-or-less SOTA 2 years after it was released. For example, remove the k_neighbors and allow the person doing the challenge to recognize that you could take some spatial shortcuts to avoid full self-attention.

- Provide a ready-built workspace for the user, ideally with conda/mamba installed and a conda environment that can run this code out-of-the-box.