In [1]:
PROTGPS_PARENT_DIR = "/home/protgps" # point to the protgps local repo

In [None]:
import sys
import os
sys.path.append(PROTGPS_PARENT_DIR) # append the path of protgps
from argparse import Namespace
import pickle
from tqdm import tqdm
import pandas as pd
import torch 
from protgps.utils.loading import get_object

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
COMPARTMENT_CLASSES = [
    "nuclear_speckle",
    "p-body",
    "pml-bdoy",
    "post_synaptic_density",
    "stress_granule",
    "chromosome",
    "nucleolus",
    "nuclear_pore_complex",
    "cajal_body",
    "rna_granule",
    "cell_junction",
    "transcriptional"
]

def load_model(snargs):
    """
    Loads classifier model from args file
    """
    modelpath = snargs.model_path
    model = get_object(snargs.lightning_name, "lightning")(snargs)
    model = model.load_from_checkpoint(
        checkpoint_path = modelpath,
        strict=not snargs.relax_checkpoint_matching,
        **{"args": snargs},
    )
    return model

@torch.no_grad()
def predict_condensates(model, sequences, batch_size=1, round=True):
    scores = []
    for i in tqdm(range(0, len(sequences), batch_size), ncols=100):
        batch = sequences[ i : (i + batch_size)]
        out = model.model({"x": batch})    
        s = torch.sigmoid(out['logit']).to("cpu")
        scores.append(s)
    scores = torch.vstack(scores)
    if round:
        scores = torch.round(scores, decimals=3)
    return scores

In [None]:
args = Namespace(**pickle.load(open(os.path.join(PROTGPS_PARENT_DIR, 'checkpoints/protgps/32bf44b16a4e770a674896b81dfb3729.args'),'rb'))) # assumes args file has been extracted in checkpoints/protgps
args.model_path = os.path.join(PROTGPS_PARENT_DIR, 'checkpoints/protgps/32bf44b16a4e770a674896b81dfb3729epoch=26.ckpt') # assumes checkpoint has been extracted in checkpoints/protgps
args.pretrained_hub_dir = "/home/protgps/esm_models/esm2" # should point to folder with ESM2 facebookresearch_esm_main directory
model = load_model(args)
model.eval()
model = model.to(device)

In [3]:
sequences = [
    # UniProt O15116
    "MNYMPGTASLIEDIDKKHLVLLRDGRTLIGFLRSIDQFANLVLHQTVERIHVGKKYGDIPRGIFVVRGENVVLLGEIDLEKESDTPLQQVSIEEILEEQRVEQQTKLEAEKLKVQALKDRGLSIPRADTLDEY", 
    # Uniprot P38432
    "MAASETVRLRLQFDYPPPATPHCTAFWLLVDLNRCRVVTDLISLIRQRFGFSSGAFLGLYLEGGLLPPAESARLVRDNDCLRVKLEERGVAENSVVISNGDINLSLRKAKKRAFQLEEGEETEPDCKYSKKHWKSRENNNNNEKVLDLEPKAVTDQTVSKKNKRKNKATCGTVGDDNEEAKRKSPKKKEKCEYKKKAKNPKSPKVQAVKDWANQRCSSPKGSARNSLVKAKRKGSVSVCSKESPSSSSESESCDESISDGPSKVTLEARNSSEKLPTELSKEEPSTKNTTADKLAIKLGFSLTPSKGKTSGTTSSSSDSSAESDDQCLMSSSTPECAAGFLKTVGLFAGRGRPGPGLSSQTAGAAGWRRSGSNGGGQAPGASPSVSLPASLGRGWGREENLFSWKGAKGRGMRGRGRGRGHPVSCVVNRSTDNQRQQQLNDVVKNSSTIIQNPVETPKKDYSLLPLLAAAPQVGEKIAFKLLELTSSYSPDVSDYKEGRILSHNPETQQVDIEILSSLPALREPGKFDLVYHNENGAEVVEYAVTQESKITVFWKELIDPRLIIESPSNTSSTEPA" 
]

In [None]:
scores = predict_condensates(model, sequences, batch_size=1)

In [None]:
data = {"sequences": sequences}
for j,condensate in enumerate(COMPARTMENT_CLASSES):
    data[f"{condensate.upper()}_Score"] = scores[:, j].tolist()

In [None]:
pd.DataFrame(data)