[PosixPath('../data/Protera/prism/prism_merged_037_UBI4_E1_binding_limiting_E1.csv'), PosixPath('../data/Protera/prism/prism_merged_999_IF-1_DMS.csv'), PosixPath('../data/Protera/prism/prism_merged_021_PAB1_doxycyclin_sensitivity.csv'), PosixPath('../data/Protera/prism/prism_merged_006_CBS_high_B6_activity.csv'), PosixPath('../data/Protera/prism/prism_merged_003_PTEN_abundance.csv'), PosixPath('../data/Protera/prism/prism_merged_999_GmR_DMS.csv'), PosixPath('../data/Protera/prism/prism_merged_027_Src_kinase_activity_catalytic_domain_reversed.csv'), PosixPath('../data/Protera/prism/prism_merged_999_ccdB_DMS.csv'), PosixPath('../data/Protera/prism/prism_merged_026_BRCA1_E3_ubiquitination_activity.csv'), PosixPath('../data/Protera/prism/prism_merged_999_HAh1n1_DMS.csv')]


In [7]:
from protera_stability.proteins import EmbeddingExtractor1D

from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np
from tqdm import tqdm

from pathlib import Path

data_path = Path("../data") / "Protera"
per_protein_csv = [pth for pth in (data_path / "prism").iterdir()]

print(per_protein_csv[:10])

class AsSubscriptable:
    def __init__(self, values, get_func):
        self.values = values
        self.get_func = get_func

    def __getitem__(self, idx):
        return self.get_func(self.values, idx)


class EmbeddingGetter:
    def __init__(self, protein_csv_path, cuda, data_path):
        self.extractor_kwargs = {
            "model_name": "esm1b_t33_650M_UR50S",
            "base_path": data_path,
            "gpu": cuda,
        }
        self.extractor = EmbeddingExtractor1D(**self.extractor_kwargs)
        self.df = self.process_csv(protein_csv_path)
        self.sequences = self.df.sequences.values
        self.labels = self.df.labels.values

        self.initialize()

    def initialize(self):
        print("Initializing on the fly scalers...")
        self.x_scaler = StandardScaler()
        self.y_scaler = StandardScaler()

        # use 10% of the data to fit scaler
        seqs = np.random.choice(self.sequences, size=min(int(len(self.sequences) * 0.1), 500), replace=False)
        embeddings = []
        for seq in tqdm(seqs):
            embeddings.append(self.extractor.get_embedding(seq))

        y = self.labels

        self.x_scaler.fit(embeddings)
        self.y_scaler.fit(y.reshape(-1, 1))        

    def process_csv(self, path):
        df = pd.read_csv(path)
        df = df.drop_duplicates().dropna()
        df = df[["variant", "Rosetta_ddg_score_02"]]
        df.columns = ["sequences", "labels"]
        df = df[df.columns[::-1]]

        return df

    def X(self):
        return AsSubscriptable(
            self.sequences, 
            lambda x, idx: self.x_scaler.transform(self.extractor.get_embedding(x[idx]).reshape(1, -1))[0]
        )
    
    def y(self):
        return AsSubscriptable(
            self.labels, 
            lambda x, idx: self.y_scaler.transform(x[idx].reshape(1, -1))[0]
        )

[PosixPath('../data/Protera/prism/prism_merged_037_UBI4_E1_binding_limiting_E1.csv'), PosixPath('../data/Protera/prism/prism_merged_999_IF-1_DMS.csv'), PosixPath('../data/Protera/prism/prism_merged_021_PAB1_doxycyclin_sensitivity.csv'), PosixPath('../data/Protera/prism/prism_merged_006_CBS_high_B6_activity.csv'), PosixPath('../data/Protera/prism/prism_merged_003_PTEN_abundance.csv'), PosixPath('../data/Protera/prism/prism_merged_999_GmR_DMS.csv'), PosixPath('../data/Protera/prism/prism_merged_027_Src_kinase_activity_catalytic_domain_reversed.csv'), PosixPath('../data/Protera/prism/prism_merged_999_ccdB_DMS.csv'), PosixPath('../data/Protera/prism/prism_merged_026_BRCA1_E3_ubiquitination_activity.csv'), PosixPath('../data/Protera/prism/prism_merged_999_HAh1n1_DMS.csv')]


In [8]:
getter = EmbeddingGetter(protein_csv_path=data_path / "prism/prism_merged_999_ADRB2_DMS.csv", cuda="cuda:1", data_path=data_path)

Using cache found in /home/roberto/.cache/torch/hub/facebookresearch_esm_master


Initializing on the fly scalers...


100%|██████████| 409/409 [00:31<00:00, 12.86it/s]


In [9]:
getter.X()[1], getter.y()[1]

(array([-0.02848428,  2.1652691 , -0.02952591, ..., -2.1547403 ,
        -1.4675167 ,  1.2362475 ], dtype=float32),
 array([0.05850698]))

In [14]:
import torch
from protera_stability.config.lazy import LazyCall as L
from protera_stability.config.common.mlp import mlp_esm
from protera_stability.train import get_cfg, setup_diversity, setup_data

exp_params = {
    "diversity_cutoff": 0.866,
    "random_percent": 0.15,
    "sampling_method": "diversity",
    "experiment_name": "base",
}

def setup_data(cfg):
    raise NotImplementedError

def create_cfg(exp_params):
    cfg = get_cfg(args={})
    cfg = setup_diversity(cfg, **exp_params)
    mlp_esm.n_units = 2048
    mlp_esm.act = L(torch.nn.GELU)()
    cfg.model = mlp_esm

    cfg = setup_data(cfg)
    return cfg

cfg = create_cfg(exp_params)
cfg.keys()

NotImplementedError: 