In [1]:
import torch
import yaml
from torch.cuda.amp import autocast
from model.mutaplm import MutaPLM


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%pwd

'/Users/yvesgreatti/github/MutaPLM'

In [6]:

# load model
device = torch.device("cpu")
model_config_path = "./configs/mutaplm_inference.yaml"
model_cfg = yaml.load(open(model_config_path, "r"), Loader=yaml.Loader)
model_cfg["device"] = device
model = MutaPLM(**model_cfg).to(device)

*** loading protein model...


Some weights of EsmForMutationDesign were not initialized from the model checkpoint at ./ckpts/esm2-650m and are newly initialized: ['esm.encoder.layer.32.LayerNorm_adapter.bias', 'esm.encoder.layer.32.LayerNorm_adapter.weight', 'esm.encoder.layer.32.crossattention_adapter.LayerNorm.bias', 'esm.encoder.layer.32.crossattention_adapter.LayerNorm.weight', 'esm.encoder.layer.32.crossattention_adapter.output.dense.bias', 'esm.encoder.layer.32.crossattention_adapter.output.dense.weight', 'esm.encoder.layer.32.crossattention_adapter.self.key.bias', 'esm.encoder.layer.32.crossattention_adapter.self.key.weight', 'esm.encoder.layer.32.crossattention_adapter.self.query.bias', 'esm.encoder.layer.32.crossattention_adapter.self.query.weight', 'esm.encoder.layer.32.crossattention_adapter.self.rotary_embeddings.inv_freq', 'esm.encoder.layer.32.crossattention_adapter.self.value.bias', 'esm.encoder.layer.32.crossattention_adapter.self.value.weight', 'esm.encoder.layer.32.intermediate_adapter.dense.bias'

*** freezing protein model...
*** loading llm tokenizer...
*** loading llm from ./ckpts/biomedgpt-lm-7b...


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


*** adding LoRA...
trainable params: 0 || all params: 6,774,206,464 || trainable%: 0.0000
*** building delta encoder...
*** model built successfully.


In [7]:
new_ckpt = torch.load(open("./ckpts/mutaplm.pth", "rb"), map_location="cpu")["model"]
model.load_state_dict(new_ckpt, strict=False)
model.eval()

MutaPLM(
  (protein_model): EsmForMutationDesign(
    (esm): EsmModel(
      (embeddings): EsmEmbeddings(
        (word_embeddings): Embedding(33, 1280, padding_idx=1)
        (dropout): Dropout(p=0.0, inplace=False)
        (position_embeddings): Embedding(1026, 1280, padding_idx=1)
      )
      (encoder): EsmEncoder(
        (layer): ModuleList(
          (0-31): 32 x EsmLayer(
            (attention): EsmAttention(
              (self): EsmSelfAttention(
                (query): Linear(in_features=1280, out_features=1280, bias=True)
                (key): Linear(in_features=1280, out_features=1280, bias=True)
                (value): Linear(in_features=1280, out_features=1280, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
                (rotary_embeddings): RotaryEmbedding()
              )
              (output): EsmSelfOutput(
                (dense): Linear(in_features=1280, out_features=1280, bias=True)
                (dropout): Dropout(p=0.0, inplace=Fa

In [None]:
# Explanation: given wildtype protein and mutation site, predict its original function and mutational effect.
wildtype_protein = "MASDAAAEPSSGVTHPPRYVIGYALAPKKQQSFIQPSLVAQAASRGMDLVPVDASQPLAEQGPFHLLIHALYGDDWRAQLVAFAARHPAVPIVDPPHAIDRLHNRISMLQVVSELDHAADQDSTFGIPSQVVVYDAAALADFGLLAALRFPLIAKPLVADGTAKSHKMSLVYHREGLGKLRPPLVLQEFVNHGGVIFKVYVVGGHVTCVKRRSLPDVSPEDDASAQGSVSFSQVSNLPTERTAEEYYGEKSLEDAVVPPAAFINQIAGGLRRALGLQLFNFDMIRDVRAGDRYLVIDINYFPGYAKMPGYETVLTDFFWEMVHKDGVGNQQEEKGANHVVVK"
site = "A70K"
mutated_protein = wildtype_protein[:int(site[1:-1])-1] + site[-1] + wildtype_protein[int(site[1:-1]):]
muta_prompt = f"Next is a feature of the mutation {site[0]} to {site[-1]} at position {site[1:-1]}. Please generate a brief summary text to describe it."

# with torch.no_grad():
#     with autocast(dtype=torch.bfloat16):
#         pred_func, pred_mut = model.generate([wildtype_protein], [mutated_protein], [muta_prompt])

# print("Predicted function:", pred_func[0])
# print("Predicted effect:", pred_mut[0])



RuntimeError: expected m1 and m2 to have the same dtype, but got: float != c10::BFloat16

In [None]:
from contextlib import nullcontext

# 1) Make sure model is float32 if you're on CPU
model = model.float()  # casts all params/buffers to float32

# 2) Remove CUDA autocast on CPU
amp_ctx = nullcontext()

with torch.no_grad(), amp_ctx:
    pred_func, pred_mut = model.generate(
        [wildtype_protein], [mutated_protein], [muta_prompt]
    )

print("Predicted function:", pred_func[0])
print("Predicted effect:", pred_mut[0])


In [None]:
with torch.no_grad():
    emb_wt, emb_mut = model._encode_protein([wildtype_protein], [muta_prompt])  # each: [1, num_query_tokens, llm_hidden]
    # Sequence-level vectors
    vec_wt  = emb_wt.mean(dim=1)                          # [1, llm_hidden]
    vec_mut = emb_mut.mean(dim=1)                         # [1, llm_hidden]
    delta   = (vec_mut - vec_wt)                          # [1, llm_hidden]


In [5]:
# Engineering: given wildtype protein and mutational effect, predict mutated position and new amino acid.
wildtype_protein = "MASDAAAEPSSGVTHPPRYVIGYALAPKKQQSFIQPSLVAQAASRGMDLVPVDASQPLAEQGPFHLLIHALYGDDWRAQLVAFAARHPAVPIVDPPHAIDRLHNRISMLQVVSELDHAADQDSTFGIPSQVVVYDAAALADFGLLAALRFPLIAKPLVADGTAKSHKMSLVYHREGLGKLRPPLVLQEFVNHGGVIFKVYVVGGHVTCVKRRSLPDVSPEDDASAQGSVSFSQVSNLPTERTAEEYYGEKSLEDAVVPPAAFINQIAGGLRRALGLQLFNFDMIRDVRAGDRYLVIDINYFPGYAKMPGYETVLTDFFWEMVHKDGVGNQQEEKGANHVVVK"
effect_text = "Strongly enhanced InsP6 kinase activity. The mutation in the ITPK protein causes a change in its catalytic activity."
muta_prompt = "What is the mutated position and new amino acid?"

with torch.no_grad():
    with autocast(dtype=torch.bfloat16):
        preds = model.lm_design([wildtype_protein], [effect_text], muta_prompt=[muta_prompt])

top50 = preds[0].flatten().topk(50).indices
top50_pos = top50 // len(model.protein_tokenizer)
top50_aa = top50 % len(model.protein_tokenizer)
top50_aa = model.protein_tokenizer.batch_decode(top50_aa)
print("mutated position:", top50_pos[0].item())
print("new amino acid:", top50_aa[0])

mutated position: 70
new amino acid: K
