In [19]:
import pandas as pd
from lobster.model._utils import model_typer
from lobster.model import LobsterPCLM
from typing import List, Optional, Union, Tuple, Literal
from lobster.transforms.functional import GraftGermline
from lobster.tokenization import PmlmTokenizerTransform
import importlib


import pandas as pd
from lobster.model._utils import model_typer
from typing import Literal
from lobster.data import DataFrameDatasetInMemory
from torch.utils.data import DataLoader
from prescient.metrics.functional import naturalness

import torch
# import Levenshtein
import seaborn as sns
import matplotlib.pyplot as plt
import os

def get_lobster_model(                    
                    model_name: str = None,
                    ckpt_path: str = None,
                    model_type: Literal["LobsterPMLM", "LobsterPCLM"] = "LobsterPCLM",
                    ):
    model_cls = model_typer[model_type]
    if model_name is not None:
        model = model_cls(
                model_name=model_name,
            )  # load a specific model, e.g. ESM2-8M
    if ckpt_path is not None:
        model = model_cls.load_from_checkpoint(
                ckpt_path,
                strict=False
            )  # load specific pre-trained chkpt
    model.eval()
    return model

def preprocess_df(df: pd.DataFrame) -> (pd.DataFrame, pd.DataFrame):
    print("Preprocessing dataframe")
    df["fv_heavy"] = df["fv_heavy_aho"].apply(lambda x: x.replace("-", ""))
    df["fv_light"] = df["fv_light_aho"].apply(lambda x: x.replace("-", ""))

    unique_fv_heavy_grafts = df.drop_duplicates(subset="fv_heavy")
    unique_fv_heavy_grafts = pd.concat([unique_fv_heavy_grafts.filter(like="heavy"), unique_fv_heavy_grafts.filter(like="H")], axis=1)
    unique_fv_heavy_grafts["chain"] = "heavy"
    
    unique_fv_light_grafts = df.drop_duplicates(subset="fv_light")
    unique_fv_light_grafts = pd.concat([unique_fv_light_grafts.filter(like="light"), unique_fv_light_grafts.filter(like="L")], axis=1)
    unique_fv_light_grafts["chain"] = "light"
    
    return unique_fv_heavy_grafts, unique_fv_light_grafts

def get_pclm_likelihoods(model_name, ckpt_path, unique_fv_heavy_grafts: pd.DataFrame, unique_fv_light_grafts: pd.DataFrame) -> (pd.DataFrame, pd.DataFrame):

    pclm = get_lobster_model(model_name=None, ckpt_path=ckpt_path, model_type="LobsterPCLM")

    unique_fv_heavy_grafts_dataset = DataFrameDatasetInMemory(unique_fv_heavy_grafts, transform_fn=pclm._transform_fn, columns=["fv_heavy"])
    unique_fv_light_grafts_dataset = DataFrameDatasetInMemory(unique_fv_light_grafts, transform_fn=pclm._transform_fn, columns=["fv_light"])

    fv_heavy_batch_likelihoods = []
    fv_light_batch_likelihoods = []

    for batch in DataLoader(unique_fv_heavy_grafts_dataset, batch_size=64, shuffle=False):
        fv_heavy_batch_likelihoods.extend(pclm.batch_to_log_likelihoods(batch).to('cpu'))

    for batch in DataLoader(unique_fv_light_grafts_dataset, batch_size=64, shuffle=False):
        fv_light_batch_likelihoods.extend(pclm.batch_to_log_likelihoods(batch).to('cpu'))
        
    fv_heavy_batch_likelihoods = torch.stack(fv_heavy_batch_likelihoods)
    fv_light_batch_likelihoods = torch.stack(fv_light_batch_likelihoods)

    unique_fv_heavy_grafts[f"{model_name}_fv_heavy_likelihood"] = fv_heavy_batch_likelihoods
    unique_fv_light_grafts[f"{model_name}_fv_light_likelihood"] = fv_light_batch_likelihoods
    
    return unique_fv_heavy_grafts, unique_fv_light_grafts

def get_pmlm_likelihoods(
                    unique_fv_heavy_grafts: pd.DataFrame, 
                    unique_fv_light_grafts: pd.DataFrame, 
                    ) -> (pd.DataFrame, pd.DataFrame):
                    
    unique_fv_heavy_grafts["antiberty_fv_heavy_naturalness"] = naturalness(unique_fv_heavy_grafts["fv_heavy"].tolist())
    unique_fv_light_grafts["antiberty_fv_light_naturalness"] = naturalness(unique_fv_light_grafts["fv_light"].tolist())

    model_name = "esm2_t6_8M_UR50D"
    pmlm = get_lobster_model(model_name=model_name, ckpt_path=None, model_type="LobsterPMLM")
    unique_fv_heavy_grafts[f"{model_name}_fv_heavy_naturalness"] = pmlm.naturalness(unique_fv_heavy_grafts["fv_heavy"].tolist())
    unique_fv_light_grafts[f"{model_name}_fv_light_naturalness"] = pmlm.naturalness(unique_fv_light_grafts["fv_light"].tolist())

    return unique_fv_heavy_grafts, unique_fv_light_grafts

def score_grafts(seed_id: str, df: pd.DataFrame, seeds_df: pd.DataFrame, out_root_dir: None):
    unique_fv_heavy_grafts, unique_fv_light_grafts = preprocess_df(df)
    fv_heavy_seed = seeds_df[seeds_df["seed_id"] == seed_id].fv_heavy.values[0]
    fv_light_seed = seeds_df[seeds_df["seed_id"] == seed_id].fv_light.values[0]
    unique_fv_heavy_grafts["ed_seed_fv_heavy"] = unique_fv_heavy_grafts["fv_heavy"].apply(lambda x: Levenshtein.distance(x, fv_heavy_seed))
    unique_fv_light_grafts["ed_seed_fv_light"] = unique_fv_light_grafts["fv_light"].apply(lambda x: Levenshtein.distance(x, fv_light_seed))

    print("Computing pCLM likelihoods")
    model_name = "clm_11M_uniref50"
    ckpt_path = f"s3://prescient-data-dev/sandbox/freyn6/models/prod/lobster/clm/{model_name}.ckpt"
    unique_fv_heavy_grafts, unique_fv_light_grafts = get_pclm_likelihoods(model_name, ckpt_path, unique_fv_heavy_grafts, unique_fv_light_grafts)

    model_name = "clm_11M_uniref50_oas_pdb"
    ckpt_path = f"s3://prescient-data-dev/sandbox/freyn6/models/prod/lobster/clm/{model_name}.ckpt"
    unique_fv_heavy_grafts, unique_fv_light_grafts = get_pclm_likelihoods(model_name, ckpt_path, unique_fv_heavy_grafts, unique_fv_light_grafts)
    print("Computing pMLM likelihoods")
    unique_fv_heavy_grafts, unique_fv_light_grafts = get_pmlm_likelihoods(unique_fv_heavy_grafts, unique_fv_light_grafts)
    
    if out_root_dir is not None:
        outpath = f"{out_root_dir}/{seed_id}"
        os.makedirs(outpath, exist_ok=True)
        unique_fv_heavy_grafts.to_csv(f"{outpath}/{seed_id}_unique_fv_heavy_grafts_pCLM_scored.csv", index=False)
        unique_fv_light_grafts.to_csv(f"{outpath}/{seed_id}_unique_fv_light_grafts_pCLM_scored.csv", index=False)
        
    return unique_fv_heavy_grafts, unique_fv_light_grafts

path = importlib.resources.files("lobster") / "assets" / "pmlm_tokenizer"

tokenization_transform = PmlmTokenizerTransform(
                                                path,
                                                padding="max_length",
                                                truncation=True,
                                                max_length=512,
                                                mlm=True,
                                            )


In [14]:
%timeit
germline_df = pd.read_pickle("s3://prescient-data-dev/sandbox/wanga84/humanization/preferred_absolve_germlines_aho_tokenized.pkl")

In [13]:
germline_df['fv_heavy_aho_tok'].iloc[0]

tensor([ 0, 16,  7, 16,  4,  7, 16,  8, 30,  6,  5,  9,  7, 15, 15, 14,  6,  5,
         8,  7, 15,  7,  8, 23, 15,  5,  8,  6, 30, 19, 11, 18, 11,  8, 30, 30,
        30, 30, 30, 19,  6, 12,  8, 22,  7, 10, 16,  5, 14,  6, 16,  6,  4,  9,
        22, 20,  6, 22, 12,  8,  5, 19, 30, 30, 30, 17,  6, 17, 11, 17, 19,  5,
        16, 15,  4, 16,  6, 10,  7, 11, 20, 11, 11, 13, 11,  8, 11,  8, 11,  5,
        19, 20,  9,  4, 10,  8,  4, 10,  8, 13, 13, 11,  5,  7, 19, 19, 23,  5,
        10,  5,  9, 19, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
        30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 18, 16, 21, 22,  6, 16,  6, 11,
         4,  7, 11,  7,  8,  8,  2,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1, 

In [7]:
germline_df.to_parquet('s3://prescient-data-dev/sandbox/freyn6/humanization/preferred_absolve_germlines_aho_tokenized.parquet')

ArrowInvalid: ('Could not convert tensor([ 0, 16,  7, 16,  4,  7, 16,  8, 30,  6,  5,  9,  7, 15, 15, 14,  6,  5,\n         8,  7, 15,  7,  8, 23, 15,  5,  8,  6, 30, 19, 11, 18, 11,  8, 30, 30,\n        30, 30, 30, 19,  6, 12,  8, 22,  7, 10, 16,  5, 14,  6, 16,  6,  4,  9,\n        22, 20,  6, 22, 12,  8,  5, 19, 30, 30, 30, 17,  6, 17, 11, 17, 19,  5,\n        16, 15,  4, 16,  6, 10,  7, 11, 20, 11, 11, 13, 11,  8, 11,  8, 11,  5,\n        19, 20,  9,  4, 10,  8,  4, 10,  8, 13, 13, 11,  5,  7, 19, 19, 23,  5,\n        10,  5,  9, 19, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,\n        30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 18, 16, 21, 22,  6, 16,  6, 11,\n         4,  7, 11,  7,  8,  8,  2,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n         1,  1,  1,  1,  1,  1,  1,  1]) with type Tensor: did not recognize Python value type when inferring an Arrow data type', 'Conversion failed for column fv_heavy_aho_tok with type object')

In [6]:
germline_df.head()

Unnamed: 0.1,Unnamed: 0,fv_heavy,fv_heavy_aho,species_x,heavy_nonconserved_cys,heavy_unpaired_cys,heavy_ng_motif,heavy_glycosylation_motif,heavy_dg_in_h3,heavy_q_h1,...,LFR3b,LFR4,L1,L2,L3,L4,light_v_gene,light_j_gene,fv_heavy_aho_tok,fv_light_aho_tok
0,0,QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYGISWVRQAPGQGLE...,QVQLVQS-GAEVKKPGASVKVSCKASG-YTFTS-----YGISWVRQ...,human,False,False,True,False,False,True,...,TLTISSLQPEDVATYYC,FGQGTKVEIK,RASQGISNYLA,YAASTLQS,QKYNSAPWT,GSGTDF,IGKV1-27*01,IGKJ1*01,"[tensor(0), tensor(16), tensor(7), tensor(16),...","[tensor(0), tensor(13), tensor(12), tensor(16)..."
1,1,QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYGISWVRQAPGQGLE...,QVQLVQS-GAEVKKPGASVKVSCKASG-YTFTS-----YGISWVRQ...,human,False,False,True,False,False,True,...,TLTISSLQPEDVATYYC,FGQGTKLEIK,RASQGISNYLA,YAASTLQS,QKYNSAPYT,GSGTDF,IGKV1-27*01,IGKJ2*01,"[tensor(0), tensor(16), tensor(7), tensor(16),...","[tensor(0), tensor(13), tensor(12), tensor(16)..."
2,2,QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYGISWVRQAPGQGLE...,QVQLVQS-GAEVKKPGASVKVSCKASG-YTFTS-----YGISWVRQ...,human,False,False,True,False,False,True,...,TLTISSLQPEDVATYYC,FGQGTKLEIK,RASQGISNYLA,YAASTLQS,QKYNSAPCT,GSGTDF,IGKV1-27*01,IGKJ2*02,"[tensor(0), tensor(16), tensor(7), tensor(16),...","[tensor(0), tensor(13), tensor(12), tensor(16)..."
3,3,QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYGISWVRQAPGQGLE...,QVQLVQS-GAEVKKPGASVKVSCKASG-YTFTS-----YGISWVRQ...,human,False,False,True,False,False,True,...,TLTISSLQPEDVATYYC,FGQGTKLEIK,RASQGISNYLA,YAASTLQS,QKYNSAPYS,GSGTDF,IGKV1-27*01,IGKJ2*03,"[tensor(0), tensor(16), tensor(7), tensor(16),...","[tensor(0), tensor(13), tensor(12), tensor(16)..."
4,4,QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYGISWVRQAPGQGLE...,QVQLVQS-GAEVKKPGASVKVSCKASG-YTFTS-----YGISWVRQ...,human,False,False,True,False,False,True,...,TLTISSLQPEDVATYYC,FGQGTKLEIK,RASQGISNYLA,YAASTLQS,QKYNSAPCS,GSGTDF,IGKV1-27*01,IGKJ2*04,"[tensor(0), tensor(16), tensor(7), tensor(16),...","[tensor(0), tensor(13), tensor(12), tensor(16)..."


In [20]:
model = LobsterPCLM(model_name='CLM_mini')
model.eval();

LobsterPCLM(
  (_transform_fn): PmlmTokenizerTransform()
  (model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(33, 32, padding_idx=1)
      (layers): ModuleList(
        (0-2): 3 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): Linear(in_features=32, out_features=32, bias=False)
            (k_proj): Linear(in_features=32, out_features=32, bias=False)
            (v_proj): Linear(in_features=32, out_features=32, bias=False)
            (o_proj): Linear(in_features=32, out_features=32, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=32, out_features=64, bias=False)
            (up_proj): Linear(in_features=32, out_features=64, bias=False)
            (down_proj): Linear(in_features=64, out_features=32, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm()
          (post_attention_l

In [36]:
model._transform_fn(germline_df['fv_heavy_aho'].tolist()[:5])[0]['input_ids'].squeeze().shape

torch.Size([512])

In [16]:
germline_df['fv_heavy_aho'].tolist()

['QVQLVQS-GAEVKKPGASVKVSCKASG-YTFTS-----YGISWVRQAPGQGLEWMGWISAY---NGNTNYAQKLQGRVTMTTDTSTSTAYMELRSLRSDDTAVYYCARAEY------------------------FQHWGQGTLVTVSS',
 'QVQLVQS-GAEVKKPGASVKVSCKASG-YTFTS-----YGISWVRQAPGQGLEWMGWISAY---NGNTNYAQKLQGRVTMTTDTSTSTAYMELRSLRSDDTAVYYCARAEY------------------------FQHWGQGTLVTVSS',
 'QVQLVQS-GAEVKKPGASVKVSCKASG-YTFTS-----YGISWVRQAPGQGLEWMGWISAY---NGNTNYAQKLQGRVTMTTDTSTSTAYMELRSLRSDDTAVYYCARAEY------------------------FQHWGQGTLVTVSS',
 'QVQLVQS-GAEVKKPGASVKVSCKASG-YTFTS-----YGISWVRQAPGQGLEWMGWISAY---NGNTNYAQKLQGRVTMTTDTSTSTAYMELRSLRSDDTAVYYCARAEY------------------------FQHWGQGTLVTVSS',
 'QVQLVQS-GAEVKKPGASVKVSCKASG-YTFTS-----YGISWVRQAPGQGLEWMGWISAY---NGNTNYAQKLQGRVTMTTDTSTSTAYMELRSLRSDDTAVYYCARAEY------------------------FQHWGQGTLVTVSS',
 'QVQLVQS-GAEVKKPGASVKVSCKASG-YTFTS-----YGISWVRQAPGQGLEWMGWISAY---NGNTNYAQKLQGRVTMTTDTSTSTAYMELRSLRSDDTAVYYCARAEY------------------------FQHWGQGTLVTVSS',
 'QVQLVQS-GAEVKKPGASVKVSCKASG-YTFTS-----YGISWVRQAPGQGLEWMGWISAY---NGNTNYAQKL

### Prepare seed sequence


In [None]:
seeds_df = pd.read_csv('s3://prescient-data-dev/sequences/accepted_seeds.csv')
to_graft = seeds_df[seeds_df["seed_id"] == "C1S-262"]
fv_heavy_aho, fv_light_aho = to_graft["fv_heavy_aho"].values[0], to_graft["fv_light_aho"].values[0]

### Graft germlines

In [None]:
graft_germlines = GraftGermline(germline_df, tokenization_transform)
grafted_df, _, _ = graft_germlines.graft(fv_heavy_aho, fv_light_aho)

graft_germlines_const_vernier = GraftGermline(germline_df, tokenization_transform, keep_vernier_from_animal="rabbit")
grafted_df_const_vernier, _, _ = graft_germlines_const_vernier.graft(fv_heavy_aho, fv_light_aho)

grafts = pd.concat([grafted_df, grafted_df_const_vernier])

### Score grafts

In [None]:
unique_fv_heavy_grafts, unique_fv_light_grafts = score_grafts(seed_id="C1S-262", df=grafts, seeds_df=seeds_df, out_root_dir="../data/")

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(10, 8))
sns.scatterplot(data=unique_fv_heavy_grafts, x="clm_11M_uniref50_fv_heavy_likelihood", y="clm_11M_uniref50_oas_pdb_fv_heavy_likelihood", hue="ed_seed_fv_heavy", ax=ax[0, 0], palette="Blues_r")
sns.scatterplot(data=unique_fv_light_grafts, x="clm_11M_uniref50_fv_light_likelihood", y="clm_11M_uniref50_oas_pdb_fv_light_likelihood", hue="ed_seed_fv_light", ax=ax[0, 1], palette="Blues_r")
ax[0, 0].legend(loc='upper left', bbox_to_anchor=(1, 1))
ax[0, 1].legend(loc='upper left', bbox_to_anchor=(1, 1))

sns.scatterplot(data=unique_fv_heavy_grafts, x="antiberty_fv_heavy_naturalness", y="esm2_t6_8M_UR50D_fv_heavy_naturalness", hue="ed_seed_fv_heavy", ax=ax[1, 0], palette="Blues_r")
sns.scatterplot(data=unique_fv_light_grafts, x="antiberty_fv_light_naturalness", y="esm2_t6_8M_UR50D_fv_light_naturalness", hue="ed_seed_fv_light", ax=ax[1, 1], palette="Blues_r")
ax[1, 0].legend(loc='upper left', bbox_to_anchor=(1, 1))
ax[1,1 ].legend(loc='upper left', bbox_to_anchor=(1, 1))

fig.subplots_adjust(wspace=0.5)
fig.suptitle("C1S-262", fontsize=16)
