In [1]:
from pathlib import Path
data_path = Path("../data") / "Protera"
per_protein_csv = [pth for pth in (data_path / "prism").iterdir()]

print(per_protein_csv[:10])

[PosixPath('../data/Protera/prism/prism_merged_037_UBI4_E1_binding_limiting_E1.csv'), PosixPath('../data/Protera/prism/prism_merged_999_IF-1_DMS.csv'), PosixPath('../data/Protera/prism/prism_merged_021_PAB1_doxycyclin_sensitivity.csv'), PosixPath('../data/Protera/prism/prism_merged_006_CBS_high_B6_activity.csv'), PosixPath('../data/Protera/prism/prism_merged_003_PTEN_abundance.csv'), PosixPath('../data/Protera/prism/prism_merged_999_GmR_DMS.csv'), PosixPath('../data/Protera/prism/prism_merged_027_Src_kinase_activity_catalytic_domain_reversed.csv'), PosixPath('../data/Protera/prism/prism_merged_999_ccdB_DMS.csv'), PosixPath('../data/Protera/prism/prism_merged_026_BRCA1_E3_ubiquitination_activity.csv'), PosixPath('../data/Protera/prism/prism_merged_999_HAh1n1_DMS.csv')]


In [2]:
import pandas as pd

max_ = 0
for path in per_protein_csv:
    df = pd.read_csv(path)
    df = df[df["variant"].str.len() < 1000]
    size = len(df)
    if size > max_:
        max_ = size
        path_ = path

path_, max_

(PosixPath('../data/Protera/prism/prism_merged_030_HMGCR_yeast_complementation_control_medium.csv'),
 16872)

In [3]:
from protera_stability.data import ProteinStabilityDataset, EmbeddingGetter
from protera_stability.proteins import EmbeddingExtractor1D

# from torch.multiprocessing import set_start_method
# try:
#      set_start_method('spawn')
# except RuntimeError:
#     pass


df = pd.read_csv(path_)
df = df.drop_duplicates().dropna()
df = df[["variant", "Rosetta_ddg_score_02"]]
df.columns = ["sequences", "labels"]
df = df[df["sequences"].str.len() < 1000]
df = df[df.columns[::-1]]

args_dict = {
    "model_name": "esm1b_t33_650M_UR50S",
    "base_path": data_path,
    "gpu": True,
}
emb_extractor = EmbeddingExtractor1D(**args_dict)
dset = emb_extractor.generate_datasets(
    [""],
    data=df,
    h5_stem=f"stability_{path_.stem}",  # data_path / "stability_train".h5
    bs=1,
    target_name="stability_scores"
)

dataset = ProteinStabilityDataset(data_path / f"stability_{path_.stem}.h5",  ret_dict=False)
len(dataset)

Using cache found in /home/roberto/.cache/torch/hub/facebookresearch_esm_master
100%|██████████| 7771/7771 [31:25<00:00,  4.12it/s]


7771

In [4]:
import torch
from protera_stability.config.lazy import LazyCall as L
from protera_stability.config.common.mlp import mlp_esm
from protera_stability.train import get_cfg, setup_diversity, setup_data

exp_params = {
    "diversity_cutoff": 0.,
    "random_percent": 1.,
    "sampling_method": "",
    "experiment_name": "per_prot",
}

def create_cfg(exp_params):
    cfg = get_cfg(args={})
    cfg = setup_diversity(cfg, **exp_params)
    mlp_esm.n_units = 1024
    mlp_esm.n_layers = 3
    mlp_esm.act = L(torch.nn.RReLU)()
    cfg.model = mlp_esm

    cfg = setup_data(cfg, dataset=dataset)
    return cfg

In [5]:
cfg = create_cfg(exp_params)
cfg.trainer_params.gpus = 1
cfg.keys()

dict_keys(['trainer_params', 'output_dir', 'random_split', 'experiment', 'model', 'dataloader'])

In [6]:
cfg.dataloader

{'train': {'dataset': {'proteins_path': '../data/stability_train.h5', 'ret_dict': False, '_target_': <class 'protera_stability.data.dataset.ProteinStabilityDataset'>}, 'batch_size': 256, 'num_workers': 12, 'pin_memory': True, '_target_': <class 'torch.utils.data.dataloader.DataLoader'>, 'sampler': {'indices': {'dataset': {'proteins_path': '../data/stability_train.h5', 'ret_dict': False, '_target_': <class 'protera_stability.data.dataset.ProteinStabilityDataset'>}, 'set_indices': [2984, 2404, 2027, 7246, 2042, 2964, 3768, 3080, 6586, 6778, 7326, 1630, 1342, 2607, 440, 7310, 2563, 6514, 1379, 4757, 5844, 2, 2104, 3712, 1611, 7172, 5542, 6607, 3350, 3030, 7195, 555, 7335, 4585, 4375, 5091, 814, 797, 5142, 864, 2758, 2761, 3104, 5601, 3514, 1079, 6898, 5587, 1993, 5520, 42, 5305, 918, 3304, 4415, 1243, 4403, 2606, 3622, 6860, 2089, 5888, 2497, 4949, 6351, 3467, 6219, 5344, 2565, 5970, 6226, 1465, 3527, 6902, 2200, 2734, 1969, 2735, 7483, 1907, 5018, 775, 3405, 2000, 5899, 7036, 1273, 6184,

In [7]:
from protera_stability.config.common import SGD
cfg.optim = SGD
cfg.optim.lr = 2e-2
cfg.optim.weight_decay=0.01

In [8]:
from protera_stability.engine.default import DefaultTrainer

cfg.dataloader.train.dataset = dataset
# cfg.dataloader.test.dataset = dataset

trainer = DefaultTrainer(cfg)
trainer.fit()

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name     | Type       | Params
----------------------------------------
0 | model    | ProteinMLP | 1.8 M 
1 | train_r2 | R2Score    | 0     
2 | valid_r2 | R2Score    | 0     
3 | test_r2  | R2Score    | 0     
----------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.348     Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/roberto/anaconda3/envs/protera-stability/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/roberto/anaconda3/envs/protera-stability/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/roberto/anaconda3/envs/protera-stability/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/roberto/anaconda3/envs/protera-stability/lib/python3.8/site-packages/protera_stability/data/dataset.py", line 56, in __getitem__
    sequences = self.sequences[idx]
IndexError: index 7419 is out of bounds for axis 0 with size 7201


In [None]:
idx = 5

for cb in trainer.trainer.callbacks:
    try:
        print(f"{cb.monitor}, {cb.mode}, {cb.patience}")
    except AttributeError:
        continue

In [None]:
trainer.trainer.callbacks