In [1]:
from pathlib import Path

import pandas as pd
import numpy as np

data_path = Path("data")

### Utility Funcs

In [2]:
def open_diversity_sets(data_path, file_prefix, observations = "sequence", label = "consensus_stability_score"):
    sets = dict()
    
    for csv_f in data_path.glob(f"stability_diversity_train_{file_prefix}*"):
        parts = csv_f.stem.split("_")
        size = parts[-1]
        
        a_set = pd.read_csv(csv_f)
        
        to_drop_idx = []
        to_rename_label = None
        to_rename_obs = None
        for i, col in enumerate(a_set.columns):
            if "Unnamed" in col:
                to_drop_idx.append(i)
                
            elif label == col:
                to_rename_label = i
                
            elif observations == col:
                to_rename_obs = i
        
        a_set = a_set.drop(a_set.columns[to_drop_idx], axis=1)
        a_set = a_set.rename(columns={label: "label"})
        
        new_cols = ["label", observations, "diversity"]
        
        sets["train"] = a_set[new_cols]
        
    return sets

## What's the minimum amount of data to achieve r2 > 0.7

In [3]:
diversity_sets = open_diversity_sets(data_path, "1000")
diversity_sets["train"]

Unnamed: 0,label,sequence,diversity
0,-0.03,TELKKKLEEALKKGEEVRVKFNGIEIRNTSEDAARKAVELLEK,0.879509
1,1.15,GSSGSLSDEDFKAVFGMTRSAFAMLPLWKQQNLKKEKGLFGSS,0.879126
2,0.74,TELKKKLEEALKKGEEVRVKFNGIEIRITSEDTARKAVELLEK,0.879500
3,0.73,GMADEEKLPPGWEKRMSRSSGRVYYTNHITNASQWERPSGGSS,0.879761
4,1.35,GMADEEKLPPGWEKRMSYSSGRVYYFNHITNASQWERPSGGSS,0.879780
...,...,...,...
996,0.84,GSSGSLSDNDFKAVFGMTRSAFANLPLWKQQNLKKEKGLFGSS,0.881958
997,0.80,TELKKKLEEALKKGEEVRVKFNGIEIRIESEDAARKAVELLEK,0.879496
998,0.86,GSSGSLSDESFKAVFGMTRSAFANLPLWKQQNLKKEKGLFGSS,0.880492
999,0.95,TELKKKLEEALKKGEEVRVKFNGIEIRITSEDAWRKAVELLEK,0.879480


In [4]:
pd.read_csv(data_path / "stability_diversity_full.csv", index_col=0)

Unnamed: 0,sequence,consensus_stability_score,diversity
0,GSSQETIEVEDEEEARRVAKELRKKGYEVKDERRGNKWHVHRT,0.37,0.839413
1,TLDEARELVERAKKEGTGMDVNGQRFEDWREAERWVREQEKNK,0.62,0.846967
2,TELKKKLEEALKKGEEVRVKFNGIEIRNTSEDAARKAVELLEK,-0.03,0.879509
3,GSSQETIEVEDEEEARRVAKELRKTGYEVKIERRGNKWHVHRT,1.41,0.836507
4,TTIHVGDLTLKYDNPKKAYEIAKKLAKKYNLQVTIKNGKITVT,1.11,0.829514
...,...,...,...
10276,GSSKTQYEYDTKEEAQKAYEKFKKQGIPVTITQKNGKWFVQVE,1.59,0.833643
10277,TELKKALEEALKKGEEVRVKFNGIEIRITSEDAARKAVELLEK,0.78,0.875211
10278,SKDEAQREAERAIRSGNKEEARRILEEAGYSPEQAERIARKLG,1.26,0.816401
10279,GSSKTQYEYDTKEEAQPAYEKFKKQGIPVTITQKNGKWFVQVE,1.48,0.839330


### Load Protein Embeddings Dataset and Model

In [5]:
import pytorch_lightning as pl
import torchmetrics
from torch.nn import functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau


class LitProteins(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.r2 = torchmetrics.R2Score()
        
    def forward(self, x):
        pred_stability = self.model(x)
        return pred_stability

    def do_step(self, batch, stage):
        X, y = batch
        
        y_hat = self.model(X)
        loss = F.mse_loss(y_hat, y)
        
        self.log(f'{stage}_r2_step', self.r2(y_hat, y))
        return y_hat, loss
    
    def training_step(self, batch, batch_idx):
        y_hat, loss = self.do_step(batch, "train")

        self.log("train_loss_step", loss, prog_bar=False)
        return loss
    
    def validation_step(self, batch, batch_idx):
        y_hat, loss = self.do_step(batch, "valid")
        
        self.log("valid_loss_step", loss, prog_bar=False)
        return loss
    
    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([out["loss"] for out in outputs]).mean()
        self.log('train_r2_epoch', self.r2.compute(), prog_bar=True)
        self.log("train_loss_epoch", avg_loss, on_epoch=True, prog_bar=True)
    
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([out for out in outputs]).mean()
        self.log('valid_r2_epoch', self.r2.compute(), prog_bar=True)
        self.log("valid_loss_epoch", avg_loss, on_epoch=True, prog_bar=True)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        scheduler = ReduceLROnPlateau(optimizer, patience=3)
        return {
            "optimizer": optimizer,
#             "lr_scheduler": {
#                 "scheduler": scheduler,
#                 "monitor": "val_loss",
#             },
        }

In [10]:
from proteins import ProteinStabilityDataset, SubsetDiversitySampler
from torch.utils.data import SubsetRandomSampler

from models import ProteinMLP
import torch
import random

model = ProteinMLP()
mlp = LitProteins(model)
trainer = pl.Trainer(gpus = 0,  max_epochs=1, check_val_every_n_epoch=1, log_every_n_steps=10)

dataset = ProteinStabilityDataset(data_path / "stability.h5", ret_dict = False)
sampler = SubsetRandomSampler(random.sample(dataset.indices, 4000))
train_loader = torch.utils.data.DataLoader(dataset, batch_size = 128, sampler=sampler, pin_memory=True)

trainer.fit(mlp, train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop")

  | Name  | Type       | Params
-------------------------------------
0 | model | ProteinMLP | 1.8 M 
1 | r2    | R2Score    | 0     
-------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.348     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Training: -1it [00:00, ?it/s]

## Experiments

In [11]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import loggers as pl_loggers

In [12]:
def setup(data_path, 
          epochs=int(1e4), 
          diversity_sampling = False, 
          diversity_cutoff = 0.82,
          max_percent = 0.7,
          random_percent = 0.2, 
          ckpt_dir = ''):
    
    model = ProteinMLP()
    net = LitProteins(model)
    
    tb_logger = pl_loggers.TensorBoardLogger("logs/")
    trainer = pl.Trainer(
        logger=pl_loggers.TensorBoardLogger(f"logs/{ckpt_dir}"),
        gpus=0, 
        max_epochs=epochs, 
        log_every_n_steps=3,
        callbacks = [
            ModelCheckpoint(f"logs/{ckpt_dir}"), 
#             EarlyStopping(monitor="valid_r2_epoch", patience=5)
        ]
    )

    train_set = ProteinStabilityDataset(data_path / "stability_train.h5", ret_dict = False)
    val_set = ProteinStabilityDataset(data_path / "stability_test.h5", ret_dict = False)
    
    if diversity_sampling:
        sampler = SubsetDiversitySampler(
            valid_indices=train_set.indices,
            diversity_path=data_path / "stability_diversity_full.csv",
            diversity_cutoff=diversity_cutoff,
            max_size = int(len(train_set) * max_percent)
        )
    else:
        sampler = SubsetRandomSampler(random.sample(train_set.indices, int(len(train_set) * random_percent)))
        
    train_loader = torch.utils.data.DataLoader(train_set, batch_size = 64, sampler=sampler, num_workers=8, pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(val_set, batch_size = 64, num_workers=8, pin_memory=True)
    
    return {
        "model": net, 
        "trainer": trainer,
        "train_loader": train_loader, 
        "valid_loader": valid_loader
    }

In [13]:
exp = setup(data_path, epochs=10, diversity_sampling=True)
exp["trainer"].fit(exp["model"], exp["train_loader"], exp["valid_loader"])

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | ProteinMLP | 1.8 M 
1 | r2    | R2Score    | 0     
-------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.348     Total estimated model params size (MB)


=== USING 5397 out of 7710 samples ===


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [18]:
from tqdm.autonotebook import tqdm
import pickle

def run_experiment(data_path, epochs = int(1e4), cutoffs = [0.75, 0.83, 0.84, 0.86]):
    for cut in tqdm(cutoffs):
        exp = setup(data_path, epochs=30, diversity_sampling=True, diversity_cutoff=cut, ckpt_dir=f'stability_cut_{cut}')
        exp["trainer"].fit(exp["model"], exp["train_loader"], exp["valid_loader"])

In [19]:
run_experiment(data_path=data_path)

  0%|          | 0/4 [00:00<?, ?it/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | ProteinMLP | 1.8 M 
1 | r2    | R2Score    | 0     
-------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.348     Total estimated model params size (MB)


=== USING 5397 out of 7710 samples ===


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | ProteinMLP | 1.8 M 
1 | r2    | R2Score    | 0     
-------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.348     Total estimated model params size (MB)


=== USING 4091 out of 7710 samples ===


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | ProteinMLP | 1.8 M 
1 | r2    | R2Score    | 0     
-------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.348     Total estimated model params size (MB)


=== USING 3152 out of 7710 samples ===


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | ProteinMLP | 1.8 M 
1 | r2    | R2Score    | 0     
-------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.348     Total estimated model params size (MB)


=== USING 1657 out of 7710 samples ===


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]