In [1]:
from pathlib import Path

import pandas as pd
import numpy as np

data_path = Path("data")

### Utility Funcs

In [2]:
def open_diversity_sets(data_path, file_prefix, observations = "sequence", label = "consensus_stability_score"):
    sets = dict()
    
    for csv_f in data_path.glob(f"stability_diversity_train_{file_prefix}*"):
        parts = csv_f.stem.split("_")
        size = parts[-1]
        
        a_set = pd.read_csv(csv_f)
        
        to_drop_idx = []
        to_rename_label = None
        to_rename_obs = None
        for i, col in enumerate(a_set.columns):
            if "Unnamed" in col:
                to_drop_idx.append(i)
                
            elif label == col:
                to_rename_label = i
                
            elif observations == col:
                to_rename_obs = i
        
        a_set = a_set.drop(a_set.columns[to_drop_idx], axis=1)
        a_set = a_set.rename(columns={label: "label"})
        
        new_cols = ["label", observations, "diversity"]
        
        sets["train"] = a_set[new_cols]
        
    return sets

## What's the minimum amount of data to achieve r2 > 0.7

In [3]:
diversity_sets = open_diversity_sets(data_path, "1000")
diversity_sets["train"]

Unnamed: 0,label,sequence,diversity
0,-0.03,TELKKKLEEALKKGEEVRVKFNGIEIRNTSEDAARKAVELLEK,0.879509
1,1.15,GSSGSLSDEDFKAVFGMTRSAFAMLPLWKQQNLKKEKGLFGSS,0.879126
2,0.74,TELKKKLEEALKKGEEVRVKFNGIEIRITSEDTARKAVELLEK,0.879500
3,0.73,GMADEEKLPPGWEKRMSRSSGRVYYTNHITNASQWERPSGGSS,0.879761
4,1.35,GMADEEKLPPGWEKRMSYSSGRVYYFNHITNASQWERPSGGSS,0.879780
...,...,...,...
996,0.84,GSSGSLSDNDFKAVFGMTRSAFANLPLWKQQNLKKEKGLFGSS,0.881958
997,0.80,TELKKKLEEALKKGEEVRVKFNGIEIRIESEDAARKAVELLEK,0.879496
998,0.86,GSSGSLSDESFKAVFGMTRSAFANLPLWKQQNLKKEKGLFGSS,0.880492
999,0.95,TELKKKLEEALKKGEEVRVKFNGIEIRITSEDAWRKAVELLEK,0.879480


### Load Protein Embeddings

In [4]:
from utils import load_dataset
X, y, dset = load_dataset(data_path, to_torch=True)

Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master


In [5]:
is_in_diversity =  lambda seq_bytes: seq_bytes.decode("utf-8") in diversity_sets["train"]["sequence"].values

idxes = [idx for idx, seq in enumerate(dset["sequences"]) if is_in_diversity(seq)]
dset.close()
X, y = X[idxes], y[idxes]

### Load model

In [6]:
from skorch.callbacks import LoadInitState, Checkpoint, LRScheduler, EarlyStopping

from skorch import NeuralNetRegressor

from models import ProteinMLP

cp = Checkpoint(dirname='models')
load_state = LoadInitState(cp)

net = NeuralNetRegressor(ProteinMLP)

net.initialize()
net.load_params(checkpoint=cp)

y_pred = net.predict(X)
y[:10]

tensor([[-2.7473],
        [ 0.2262],
        [-0.8070],
        [-0.8322],
        [ 0.7302],
        [-2.8481],
        [-2.0417],
        [-0.8322],
        [-1.0086],
        [-0.8574]])

## Experiments

In [26]:
def setup(data_path, s, epochs=int(1e4), ckpt_dir = "", load_model = False):
    diversity_sets = open_diversity_sets(data_path, s)

    is_in_diversity =  lambda seq_bytes: seq_bytes.decode("utf-8") in diversity_sets["train"]["sequence"].values
    
    X, y, dset = load_dataset(data_path, to_torch=True)
    idxes = [idx for idx, seq in enumerate(dset["sequences"]) if is_in_diversity(seq)]
    X, y = X[idxes], y[idxes]
    dset.close()

    cb = Checkpoint(dirname=f"models/{ckpt_dir}")
    sched = LRScheduler(step_every="batch")
    stopper = EarlyStopping(patience=15)
    
    net = NeuralNetRegressor(
        ProteinMLP,
        max_epochs=epochs,
        lr=3e-3,
        iterator_train__shuffle=True,
        device='cuda',
        callbacks=[cb, sched, stopper],
    )
    
    if load_model:
        load_state = LoadInitState(cp)
        
        net.initialize()
        net.load_params(checkpoint=cp)
    
    return {"model": net, "X": X, "y": y}

In [20]:
exp = setup(data_path, "1000", epochs=1000, ckpt_dir=f"full")
        
X_train, y_train, dset = load_dataset(data_path, kind = 'train', reduce = False, to_torch = True)
dset.close()

exp["model"].fit(X_train, y_train)

Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master
Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master


  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.7158[0m        [32m0.5840[0m     +  0.0763
      2        [36m0.6507[0m        0.6294        0.0754
      3        [36m0.6342[0m        [32m0.5737[0m     +  0.0741
      4        [36m0.6217[0m        0.5768        0.0743
      5        [36m0.6125[0m        [32m0.5583[0m     +  0.0749
      6        [36m0.6083[0m        [32m0.5519[0m     +  0.0746
      7        [36m0.6063[0m        [32m0.5433[0m     +  0.0740
      8        [36m0.5974[0m        0.5988        0.0743
      9        [36m0.5902[0m        [32m0.5290[0m     +  0.0742
     10        [36m0.5668[0m        [32m0.5119[0m     +  0.0740
     11        [36m0.5584[0m        [32m0.4996[0m     +  0.0741
     12        [36m0.5501[0m        0.5010        0.0745
     13        [36m0.5500[0m        [32m0.4981[0m     +  0.0744
     14        [36m0.5416[0m        [32m0.4691[0

    139        [36m0.1973[0m        [32m0.1848[0m     +  0.0746
    140        [36m0.1924[0m        0.1885        0.0750
    141        0.1936        [32m0.1846[0m     +  0.0747
    142        [36m0.1922[0m        0.1853        0.0769
    143        [36m0.1917[0m        0.1924        0.0741
    144        [36m0.1903[0m        [32m0.1835[0m     +  0.0746
    145        [36m0.1892[0m        [32m0.1821[0m     +  0.0758
    146        0.1921        0.1879        0.0768
    147        0.1921        [32m0.1783[0m     +  0.0755
    148        0.1911        0.1810        0.0762
    149        [36m0.1885[0m        0.1907        0.0756
    150        [36m0.1881[0m        0.1784        0.0747
    151        [36m0.1833[0m        0.1848        0.0745
    152        0.1871        0.1785        0.0745
    153        0.1843        [32m0.1778[0m     +  0.0741
    154        [36m0.1824[0m        0.1836        0.0761
    155        0.1860        0.1879        0.0747
    15

<class 'skorch.regressor.NeuralNetRegressor'>[initialized](
  module_=ProteinMLP(
    (fc1): Linear(in_features=1280, out_features=1024, bias=True)
    (fc2): Linear(in_features=1024, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=1, bias=True)
    (drop): Dropout(p=0.7, inplace=False)
    (act): LeakyReLU(negative_slope=0.01)
  ),
)

In [21]:
X_test, y_test, dset = load_dataset(data_path, kind = 'test', reduce = False, to_torch = True)
dset.close()

exp["model"].score(X_test, y_test)

Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master


0.8260408952518665

In [27]:
from tqdm.autonotebook import tqdm
import pickle

def run_experiment(data_path, epochs = int(1e4), sizes = ["450", "700", "1000", "1500", "2000", "3000", "4000", "5000", "6000"]):
    for s in tqdm(sizes):
        exp = setup(data_path, s, epochs=epochs, ckpt_dir=f"diversity_{s}")
                
        exp["model"].fit(exp["X"], exp["y"])
        X_test, y_test, dset = load_dataset(data_path, kind = 'test', reduce = False, to_torch = True)
        dset.close()
        
        score = net.score(X_test, y_test)
        
        print(f"TEST R2: {round(score, 3)}")
        
        pickle.dump(exp["model"].history, open(Path("logs") / f"ProteinMLP_diversity_{s}_history.pkl", "wb"))
        pickle.dump(score, open(Path("logs") / f"ProteinMLP_diversity_{s}_test.pkl", "wb"))

In [None]:
run_experiment(data_path=data_path)

  0%|          | 0/9 [00:00<?, ?it/s]

Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master


  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.7029[0m        [32m0.6411[0m     +  0.0060
      2        [36m0.6405[0m        [32m0.6407[0m     +  0.0055
      3        0.6406        [32m0.6406[0m     +  0.0054
      4        [36m0.6402[0m        [32m0.6405[0m     +  0.0058
      5        [36m0.6397[0m        [32m0.6399[0m     +  0.0049
      6        [36m0.6392[0m        [32m0.6393[0m     +  0.0049
      7        [36m0.6386[0m        [32m0.6389[0m     +  0.0048
      8        [36m0.6379[0m        [32m0.6386[0m     +  0.0049
      9        0.6388        [32m0.6385[0m     +  0.0051
     10        [36m0.6377[0m        [32m0.6384[0m     +  0.0049
     11        [36m0.6374[0m        [32m0.6383[0m     +  0.0049
     12        0.6377        [32m0.6376[0m     +  0.0049
     13        [36m0.6365[0m        [32m0.6368[0m     +  0.0049
     14        [36m0.6357[0m        [32m

    126        [36m0.4297[0m        [32m0.4026[0m     +  0.0058
    127        0.4411        [32m0.3987[0m     +  0.0048
    128        0.4315        [32m0.3945[0m     +  0.0050
    129        [36m0.4277[0m        [32m0.3911[0m     +  0.0049
    130        [36m0.4227[0m        [32m0.3876[0m     +  0.0055
    131        0.4263        [32m0.3841[0m     +  0.0049
    132        [36m0.4202[0m        [32m0.3809[0m     +  0.0049
    133        [36m0.4195[0m        [32m0.3774[0m     +  0.0048
    134        [36m0.4120[0m        [32m0.3741[0m     +  0.0053
    135        0.4153        [32m0.3713[0m     +  0.0048
    136        0.4147        [32m0.3683[0m     +  0.0050
    137        0.4149        [32m0.3656[0m     +  0.0049
    138        [36m0.4071[0m        [32m0.3630[0m     +  0.0050
    139        [36m0.3989[0m        [32m0.3604[0m     +  0.0052
    140        0.4109        [32m0.3583[0m     +  0.0049
    141        0.4006        [32m0.3561[0

    263        0.3728        [32m0.3047[0m     +  0.0058
    264        0.3746        [32m0.3044[0m     +  0.0051
    265        0.3659        [32m0.3043[0m     +  0.0051
    266        0.3626        [32m0.3042[0m     +  0.0059
    267        0.3827        0.3044        0.0050
    268        0.3765        [32m0.3040[0m     +  0.0048
    269        0.3624        [32m0.3036[0m     +  0.0057
    270        0.3761        [32m0.3034[0m     +  0.0051
    271        0.3670        [32m0.3032[0m     +  0.0053
    272        0.3723        [32m0.3030[0m     +  0.0063
    273        0.3646        [32m0.3027[0m     +  0.0058
    274        0.3726        [32m0.3024[0m     +  0.0052
    275        0.3647        [32m0.3020[0m     +  0.0055
    276        0.3718        [32m0.3018[0m     +  0.0058
    277        0.3666        [32m0.3014[0m     +  0.0050
    278        0.3652        [32m0.3014[0m     +  0.0049
    279        0.3712        [32m0.3014[0m     +  0.0049
    28

Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master


TEST R2: 0.001


Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master


  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.8313[0m        [32m0.7918[0m     +  0.0080
      2        [36m0.7910[0m        [32m0.7913[0m     +  0.0081
      3        [36m0.7904[0m        [32m0.7897[0m     +  0.0088
      4        [36m0.7882[0m        [32m0.7879[0m     +  0.0089
      5        [36m0.7867[0m        [32m0.7870[0m     +  0.0078
      6        0.7869        [32m0.7867[0m     +  0.0077
      7        [36m0.7856[0m        [32m0.7851[0m     +  0.0080
      8        [36m0.7834[0m        [32m0.7826[0m     +  0.0088
      9        [36m0.7811[0m        [32m0.7802[0m     +  0.0077
     10        [36m0.7793[0m        [32m0.7780[0m     +  0.0078
     11        [36m0.7765[0m        [32m0.7762[0m     +  0.0079
     12        [36m0.7749[0m        [32m0.7751[0m     +  0.0086
     13        0.7750        [32m0.7746[0m     +  0.0079
     14        [36m0.7745[0m    

    131        0.3926        [32m0.3438[0m     +  0.0091
    132        0.3891        [32m0.3428[0m     +  0.0088
    133        0.3976        [32m0.3427[0m     +  0.0098
    134        0.3924        [32m0.3423[0m     +  0.0093
    135        0.3868        [32m0.3421[0m     +  0.0088
    136        0.3900        [32m0.3407[0m     +  0.0081
    137        0.3906        [32m0.3404[0m     +  0.0094
    138        0.3946        [32m0.3400[0m     +  0.0084
    139        0.3825        [32m0.3392[0m     +  0.0079
    140        0.3916        [32m0.3390[0m     +  0.0084
    141        0.3958        [32m0.3388[0m     +  0.0080
    142        0.3957        0.3391        0.0080
    143        0.3954        0.3388        0.0078
    144        0.3965        [32m0.3382[0m     +  0.0078
    145        0.3888        [32m0.3370[0m     +  0.0082
    146        0.3939        [32m0.3370[0m     +  0.0082
    147        0.3885        [32m0.3358[0m     +  0.0081
    148        

    272        [36m0.3497[0m        [32m0.3051[0m     +  0.0080
    273        0.3548        [32m0.3050[0m     +  0.0084
    274        0.3675        0.3057        0.0079
    275        0.3552        [32m0.3045[0m     +  0.0078
    276        0.3621        [32m0.3032[0m     +  0.0076
    277        0.3611        [32m0.3029[0m     +  0.0079
    278        0.3534        [32m0.3025[0m     +  0.0080
    279        0.3648        [32m0.3013[0m     +  0.0079
    280        0.3550        0.3017        0.0078
    281        [36m0.3472[0m        [32m0.3003[0m     +  0.0087
    282        [36m0.3429[0m        0.3005        0.0081
    283        0.3484        [32m0.2996[0m     +  0.0080
    284        0.3636        [32m0.2993[0m     +  0.0082
    285        0.3505        0.2995        0.0091
    286        0.3516        [32m0.2992[0m     +  0.0080
    287        0.3472        [32m0.2973[0m     +  0.0080
    288        0.3491        [32m0.2963[0m     +  0.0090
    28

    409        0.2697        [32m0.2271[0m     +  0.0080
    410        0.2794        [32m0.2267[0m     +  0.0080
    411        0.2730        [32m0.2264[0m     +  0.0088
    412        0.2764        [32m0.2263[0m     +  0.0080
    413        0.2744        [32m0.2257[0m     +  0.0087
    414        0.2805        [32m0.2257[0m     +  0.0079
    415        0.2722        0.2259        0.0080
    416        0.2769        0.2262        0.0078
    417        0.2770        [32m0.2257[0m     +  0.0077
    418        0.2718        [32m0.2254[0m     +  0.0079
    419        0.2712        [32m0.2252[0m     +  0.0084
    420        0.2727        [32m0.2252[0m     +  0.0083
    421        [36m0.2634[0m        [32m0.2250[0m     +  0.0080
    422        0.2646        [32m0.2243[0m     +  0.0081
    423        0.2666        0.2243        0.0088
    424        [36m0.2552[0m        [32m0.2240[0m     +  0.0078
    425        0.2768        [32m0.2239[0m     +  0.0081
    42

    554        [36m0.2353[0m        [32m0.1934[0m     +  0.0082
    555        [36m0.2323[0m        0.1952        0.0081
    556        0.2396        0.1972        0.0078
    557        0.2357        0.2055        0.0078
    558        [36m0.2322[0m        0.2121        0.0078
    559        0.2422        0.2326        0.0078
    560        0.2371        0.2079        0.0078
    561        0.2453        0.2078        0.0086
    562        0.2399        0.2006        0.0077
    563        0.2415        0.2018        0.0077
    564        0.2449        0.1964        0.0077
    565        0.2441        0.1969        0.0077
    566        0.2422        0.1951        0.0078
    567        0.2510        0.1967        0.0087
    568        0.2383        0.1962        0.0078
Stopping since valid_loss has not improved in the last 15 epochs.


Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master


TEST R2: 0.001


Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master


  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.8466[0m        [32m0.8512[0m     +  0.0110
      2        [36m0.8047[0m        [32m0.8500[0m     +  0.0124
      3        [36m0.8033[0m        [32m0.8479[0m     +  0.0108
      4        [36m0.8019[0m        [32m0.8471[0m     +  0.0105
      5        [36m0.8012[0m        [32m0.8458[0m     +  0.0118
      6        [36m0.7992[0m        [32m0.8424[0m     +  0.0115
      7        [36m0.7962[0m        [32m0.8395[0m     +  0.0107
      8        [36m0.7931[0m        [32m0.8376[0m     +  0.0105
      9        [36m0.7920[0m        [32m0.8365[0m     +  0.0107
     10        [36m0.7907[0m        [32m0.8362[0m     +  0.0109
     11        0.7909        [32m0.8333[0m     +  0.0107
     12        [36m0.7876[0m        [32m0.8284[0m     +  0.0106
     13        [36m0.7831[0m        [32m0.8232[0m     +  0.0108
     14        [36m0.778

    135        0.4362        0.3973        0.0109
    136        [36m0.4315[0m        0.3955        0.0106
    137        0.4442        0.4015        0.0116
    138        0.4407        0.4000        0.0106
    139        0.4384        0.3973        0.0106
    140        0.4442        [32m0.3928[0m     +  0.0109
    141        0.4365        [32m0.3928[0m     +  0.0119
    142        0.4407        [32m0.3926[0m     +  0.0132
    143        0.4382        [32m0.3917[0m     +  0.0148
    144        0.4386        0.3939        0.0109
    145        0.4391        0.3993        0.0146
    146        0.4364        0.4099        0.0108
    147        0.4388        0.4126        0.0106
    148        0.4338        0.4167        0.0113
    149        0.4357        0.4272        0.0106
    150        0.4368        0.4319        0.0106
    151        0.4369        0.4218        0.0106
    152        0.4419        [32m0.3904[0m     +  0.0106
    153        0.4361        [32m0.3900[0m  

Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master


TEST R2: 0.001


Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master


  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.8831[0m        [32m0.9133[0m     +  0.0159
      2        [36m0.8646[0m        [32m0.9112[0m     +  0.0168
      3        [36m0.8624[0m        [32m0.9105[0m     +  0.0153
      4        [36m0.8620[0m        [32m0.9083[0m     +  0.0152
      5        [36m0.8591[0m        [32m0.9058[0m     +  0.0167
      6        [36m0.8568[0m        [32m0.9043[0m     +  0.0154
      7        [36m0.8558[0m        [32m0.9039[0m     +  0.0151
      8        [36m0.8547[0m        [32m0.9012[0m     +  0.0155
      9        [36m0.8509[0m        [32m0.8970[0m     +  0.0153
     10        [36m0.8469[0m        [32m0.8926[0m     +  0.0161
     11        [36m0.8420[0m        [32m0.8885[0m     +  0.0154
     12        [36m0.8395[0m        [32m0.8853[0m     +  0.0152
     13        [36m0.8358[0m        [32m0.8833[0m     +  0.0152
     14        

Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master


TEST R2: 0.001


Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master


  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.9897[0m        [32m0.8176[0m     +  0.0206
      2        [36m0.8588[0m        [32m0.7955[0m     +  0.0214
      3        [36m0.8436[0m        0.8052        0.0209
      4        0.8456        0.7961        0.0207
      5        0.8440        [32m0.7842[0m     +  0.0204
      6        [36m0.8387[0m        0.7853        0.0215
      7        [36m0.8322[0m        [32m0.7763[0m     +  0.0200
      8        [36m0.8247[0m        [32m0.7735[0m     +  0.0204
      9        0.8249        0.7756        0.0205
     10        [36m0.8246[0m        [32m0.7696[0m     +  0.0203
     11        [36m0.8193[0m        0.7699        0.0206
     12        [36m0.8172[0m        0.7912        0.0210
     13        0.8223        [32m0.7630[0m     +  0.0202
     14        0.8188        [32m0.7618[0m     +  0.0205
     15        [36m0.8153[0m        [32m0.75

    138        [36m0.3394[0m        [32m0.3155[0m     +  0.0197
    139        [36m0.3355[0m        [32m0.3023[0m     +  0.0211
    140        0.3459        0.3069        0.0206
    141        0.3467        [32m0.2987[0m     +  0.0208
    142        0.3453        0.3155        0.0217
    143        0.3469        0.3112        0.0216
    144        0.3450        0.3117        0.0211
    145        0.3380        0.3100        0.0208
    146        0.3357        0.3094        0.0214
    147        0.3508        0.3116        0.0202
    148        0.3376        0.3144        0.0209
    149        0.3438        0.3096        0.0199
    150        0.3405        0.3086        0.0209
    151        [36m0.3321[0m        0.3075        0.0212
    152        [36m0.3275[0m        [32m0.2959[0m     +  0.0207
    153        [36m0.3231[0m        [32m0.2892[0m     +  0.0203
    154        [36m0.3186[0m        0.2973        0.0202
    155        0.3262        [32m0.2880[0m     + 

Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master


TEST R2: 0.001


Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master


  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.8658[0m        [32m0.7321[0m     +  0.0301
      2        [36m0.7699[0m        [32m0.6942[0m     +  0.0312
      3        0.7811        [32m0.6884[0m     +  0.0290
      4        [36m0.7687[0m        [32m0.6878[0m     +  0.0290
      5        0.7712        [32m0.6836[0m     +  0.0291
      6        [36m0.7649[0m        [32m0.6809[0m     +  0.0300
      7        [36m0.7594[0m        [32m0.6792[0m     +  0.0290
      8        [36m0.7575[0m        [32m0.6783[0m     +  0.0297
      9        [36m0.7563[0m        0.6786        0.0292
     10        [36m0.7517[0m        [32m0.6694[0m     +  0.0298
     11        [36m0.7450[0m        [32m0.6634[0m     +  0.0293
     12        [36m0.7409[0m        0.6638        0.0304
     13        [36m0.7326[0m        [32m0.6564[0m     +  0.0288
     14        [36m0.7301[0m        [32m0.6531[0

Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master


In [25]:
pickle.load(open("logs/ProteinMLP_diversity_6000_test.pkl", "rb"))

0.0010775796472866084