## Take Home Assessment

**Disclaimer**: This assessment is work in progress, so we apologise in advance for any hiccup. Any feedback is valuable!

**Setup**: You are provided with some training code for a model that takes protein 3D structure and predicts the associated amino acid sequence. This notebook provides the required steps to download the code repository and training data (a subset of the Protein Data Bank), alongside minimal code to call the training loop. Please fork the repository that you can find below and edit your own version.

**Compute**: You will be provided a [Lambda](https://cloud.lambdalabs.com/) instance with a A10 GPU on an agreed day. For this we need your public key and we will share an IP address to access the compute instance.

**Evaluation**: The following questions are on purpose quite open-ended. No specific answer is expected. The aim is to provide a semi-realistic setup that you may encounter if you were to join our team. We want to assess your ability to probe deep learning models and to come up with solutions to alleviate potential identified limitations. Please write down your answers (e.g. with plots, tables etc) in your copy of the repository (e.g. in this notebook or in any other format of your choice) and push them to your fork. Do include any documentation of what all you did to arrive at your answers. We will discuss during the onsite interview. Please keep the time commitment under 4h.

**Questions**:
1. Log and profile the training loop.  What would you recommend if we wanted to train more quickly? Implement some of your proposals.
2. What kinds of issues will arise as model size increases? How could these be partially alleviated? Implement some of your proposed solutions.
3. The way the dataloader is organized in this project is unusual.  What will happen as we increase the size of the training dataset (e.g. using the AlphaFold database)?  How would you re-organize the code to avoid these issues?  What techniques would you consider using to ensure training scales efficiently with the dataset size?
4. Log the average norm of the weights & activations through training. How would you organize this information to help diagnose training dynamics?  How would you characterize the values you observe here?

Initial notes from reading the code:
- Hey, this code looks strangely familiar... looks like ProteinMPNN!

Questions:
1. 
2. You could hit memory errors and slower training if the model were to get big enough, since the spatial and computational complexity of updating nodes by all other nodes scales quadratically with sequence length. However, the k_neighbors argument that limits the complexity to O(N * k) alleviates that.
    a. Train on spatial crops. Similar to k_neighbors, but would allow full self-attention.
        - Training algorithm
            - Sample a pdb
            - Sample a residue in that pdb
            - Condition on all structure within a 50 A radius of that residue.
            - Supervise all residues within a 20 A crop of that residue (to avoid training residues near the edge of the spatial crop).
        - At inference time, you can scan over residues and perform a spatial crop for each when performing inference for that residue. This has the downside of needing to re-encode structure at each step. Really, the k_neighbors solution where you encode the full structure once and then select node embeddings based on neighbors is a much more elegant way to solve this problem.
    b. Use a Mamba-like architecture to perform linear attention.
    c. Perform an ablation of encoder size vs. decoder size - for a fixed FLOPS budget, what is the optimal tradeoff between encoder size and decoder size?
3. The dataloader is already well-organized to minimize the I/O time spent loading from disk.
    a. PDBs are processed so that the smallest atom of training (a single chain from a single pdb id) is in its own file.
    b. Each of these has a corresponding lightweight metadata file that allows for fast selection of training instances without loading the full example.
    c. The dataloader also performs processing to generate assemblies on-the-fly. The auxiliary DataLoader, `train_loader`, provides randomly transformed assemblies every `args.reload_data_every_n_epochs` to the primary DataLoader, `loader_train`. This could be a bottleneck, but it is alleviated by running asynchronously on cpu during model training, so it doesn't interfere with training.
    d. However, if this were done on a huge dataset like the AlphaFold dataset, the training would outpace the auxiliary dataloader and you would be left waiting forever at `pdb_dict_train = q.get().result()`. You could alleviate this my increasing the `args.reload_data_every_n_epochs` parameter, but that probably still would be insufficient. I don't understand exactly what the transforms here do, but I think they generate the bioassembly from the asymmetric unit. Since the model is SE(3)-invariant, I don't know exactly why this is necessary, since these are rotations and translations of the same asymmetric unit, offering no additional diversity to a model that is agnostic to SE(3) transforms. AlphaFold also doesn't provide transforms in its predicted PDBs, so I don't think this information would be present. I would probably just dispense with the auxiliary dataloader and use only the primary dataloader.
4. Activation visualizations can be helpful to assess whether the model is learning the expected relationships. One interesting activation visualization would be to see how perturbations in some atoms affect learned representations of other atoms. You would expect that for a trained model, atoms that are closer proximity should have a greater effect on each other that atoms that are distant. For true self-attention models, you can visualize the attention maps directly. For this model, full attention maps are not available, but you can achieve a similar effect by computing the categorical Jacobian. This is helpful for a sanity check - it should be mostly sparse, with zeros for all interactions with non-neighbors. It also can help you understand how well the model is training - the categorical Jacobian should start to recapitulate the contacts as training progresses. I've done this for the encoder and visualized the categorical Jacobian in reference to the true contact map as training progresses.

Feedback: 
- This coding assessment is too big. 4 hours to read/understand the assessment, set up an environment, read over 1000 lines of code, come up with multiple model improvement proposals and multiple training loop improvement proposals, implement all of them, and answer the questions is kinda insane. I'd suggest asking the interviewee to answer each of these questions but only implement one model improvement.

- Also, this challenge seems pretty much impossible if you don't know ProteinMPNN - you would have to back out the model architecture from reading the code, which is possible but would take half of the allotted time. Same thing with the non-standard training loop. Providing a diagram of the architecture would set candidates on somewhat more equal footing.

- I might leave some low-hanging fruit for model improvements, rather than giving the user a model that is so good that it is still more-or-less SOTA 2 years after it was released. For example, remove the k_neighbors and allow the person doing the challenge to recognize that you could take some spatial shortcuts to avoid full self-attention.

- Provide an environment.yaml for the user to install.

In [1]:
# Download subset of training data
!wget https://files.ipd.uw.edu/pub/training_sets/pdb_2021aug02_sample.tar.gz
!tar xvf "pdb_2021aug02_sample.tar.gz"
!rm pdb_2021aug02_sample.tar.gz

--2024-10-04 00:45:44--  https://files.ipd.uw.edu/pub/training_sets/pdb_2021aug02_sample.tar.gz
Resolving files.ipd.uw.edu (files.ipd.uw.edu)... 2607:4000:406::160:134, 2607:4000:406::160:135, 128.95.160.135, ...
Connecting to files.ipd.uw.edu (files.ipd.uw.edu)|2607:4000:406::160:134|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 49690915 (47M) [application/octet-stream]
Saving to: ‘pdb_2021aug02_sample.tar.gz’


2024-10-04 00:45:55 (4.61 MB/s) - ‘pdb_2021aug02_sample.tar.gz’ saved [49690915/49690915]

./pdb_2021aug02_sample/
./pdb_2021aug02_sample/README
./pdb_2021aug02_sample/list.csv
./pdb_2021aug02_sample/pdb/
./pdb_2021aug02_sample/pdb/l3/
./pdb_2021aug02_sample/pdb/l3/5l3p.pt
./pdb_2021aug02_sample/pdb/l3/5l3g_A.pt
./pdb_2021aug02_sample/pdb/l3/5l3f.pt
./pdb_2021aug02_sample/pdb/l3/5l3r_B.pt
./pdb_2021aug02_sample/pdb/l3/4l3o_G.pt
./pdb_2021aug02_sample/pdb/l3/1l3b_E.pt
./pdb_2021aug02_sample/pdb/l3/3l3t_C.pt
./pdb_2021aug02_sample/pdb/l3/6l3y_A.pt
./p

In [1]:
from training.training import main as run_training
import random
import numpy as np
import torch

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

class MyArgs(object):
  def __init__(self):
    self.path_for_training_data = "./pdb_2021aug02_sample"
    self.path_for_outputs = "./outputs"
    self.previous_checkpoint = ""
    self.num_epochs = 2
    self.save_model_every_n_epochs = 5
    self.reload_data_every_n_epochs = 4
    self.num_examples_per_epoch = 200
    self.batch_size = 2000
    self.max_protein_length = 2000
    ######################################
    # Jacob
    # self.hidden_dim = 128
    self.hidden_dim = 64
    ######################################
    self.num_encoder_layers = 3
    self.num_decoder_layers = 3
    self.num_neighbors = 32
    self.dropout = 0.1
    self.backbone_noise = 0.1
    self.rescut = 3.5
    self.debug = False
    self.gradient_norm = -1.0 #no norm

args = MyArgs()
run_training(args)


  scaler = torch.cuda.amp.GradScaler()
  return torch._C._cuda_getDeviceCount() > 0
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
  meta = torch

epoch: 1, step: 6, time: 39.1, train: 50.177, valid: 47.244, train_acc: 0.049, valid_acc: 0.069
epoch: 2, step: 12, time: 32.0, train: 41.552, valid: 33.108, train_acc: 0.045, valid_acc: 0.066
