In [None]:
import torch
import yaml
from torch.cuda.amp import autocast
from model.mutaplm import MutaPLM

# load model
device = torch.device("cuda:2")
model_config_path = "./configs/mutaplm_inference.yaml"
model_cfg = yaml.load(open(model_config_path, "r"), Loader=yaml.Loader)
model_cfg["device"] = device
model = MutaPLM(**model_cfg).to(device)
new_ckpt = torch.load(open("./ckpts/mutaplm.pth", "rb"), map_location="cpu")["model"]
model.load_state_dict(new_ckpt, strict=False)
model.eval()

  from .autonotebook import tqdm as notebook_tqdm


*** loading protein model...


Some weights of EsmForMutationDesign were not initialized from the model checkpoint at /data3/niezk/model/esm/esm2_t33_650M_UR50D and are newly initialized: ['esm.encoder.layer.32.crossattention_adapter.self.query.bias', 'esm.encoder.layer.32.crossattention_adapter.self.value.bias', 'esm.encoder.layer.32.crossattention_adapter.self.rotary_embeddings.inv_freq', 'esm.encoder.layer.32.crossattention_adapter.self.key.weight', 'esm.encoder.layer.32.crossattention_adapter.self.key.bias', 'mutation_classifier.bias', 'esm.encoder.layer.32.crossattention_adapter.self.value.weight', 'esm.encoder.layer.32.crossattention_adapter.output.dense.weight', 'esm.encoder.layer.32.intermediate_adapter.dense.weight', 'esm.encoder.layer.32.intermediate_adapter.dense.bias', 'esm.encoder.layer.32.crossattention_adapter.LayerNorm.weight', 'esm.encoder.layer.32.output_adapter.dense.bias', 'esm.encoder.layer.32.crossattention_adapter.self.query.weight', 'mutation_classifier.weight', 'esm.encoder.layer.32.LayerNor

*** freezing protein model...
*** loading llm tokenizer...


normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


*** loading llm from /data3/niezk/model/biomedgpt-lm...
*** adding LoRA...
trainable params: 0 || all params: 6,774,206,464 || trainable%: 0.0
*** building delta encoder...
*** model built successfully.


MutaPLM(
  (protein_model): EsmForMutationDesign(
    (esm): EsmModel(
      (embeddings): EsmEmbeddings(
        (word_embeddings): Embedding(33, 1280, padding_idx=1)
        (dropout): Dropout(p=0.0, inplace=False)
        (position_embeddings): Embedding(1026, 1280, padding_idx=1)
      )
      (encoder): EsmEncoder(
        (layer): ModuleList(
          (0-31): 32 x EsmLayer(
            (attention): EsmAttention(
              (self): EsmSelfAttention(
                (query): Linear(in_features=1280, out_features=1280, bias=True)
                (key): Linear(in_features=1280, out_features=1280, bias=True)
                (value): Linear(in_features=1280, out_features=1280, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
                (rotary_embeddings): RotaryEmbedding()
              )
              (output): EsmSelfOutput(
                (dense): Linear(in_features=1280, out_features=1280, bias=True)
                (dropout): Dropout(p=0.0, inplace=Fa

In [4]:
# Explanation: given wildtype protein and mutation site, predict its original function and mutational effect.
wildtype_protein = "MASDAAAEPSSGVTHPPRYVIGYALAPKKQQSFIQPSLVAQAASRGMDLVPVDASQPLAEQGPFHLLIHALYGDDWRAQLVAFAARHPAVPIVDPPHAIDRLHNRISMLQVVSELDHAADQDSTFGIPSQVVVYDAAALADFGLLAALRFPLIAKPLVADGTAKSHKMSLVYHREGLGKLRPPLVLQEFVNHGGVIFKVYVVGGHVTCVKRRSLPDVSPEDDASAQGSVSFSQVSNLPTERTAEEYYGEKSLEDAVVPPAAFINQIAGGLRRALGLQLFNFDMIRDVRAGDRYLVIDINYFPGYAKMPGYETVLTDFFWEMVHKDGVGNQQEEKGANHVVVK"
site = "A70K"
mutated_protein = wildtype_protein[:int(site[1:-1])-1] + site[-1] + wildtype_protein[int(site[1:-1]):]
muta_prompt = f"Next is a feature of the mutation {site[0]} to {site[-1]} at position {site[1:-1]}. Please generate a brief summary text to describe it."

with torch.no_grad():
    with autocast(dtype=torch.bfloat16):
        pred_func, pred_mut = model.generate([wildtype_protein], [mutated_protein], [muta_prompt])

print("Predicted function:", pred_func[0])
print("Predicted effect:", pred_mut[0])

Predicted function: ase that can recognize specific palindromic sequences and target them to the proteasome for degradation. Can recognize the palindromic IN box and upstream of a 5'-AAA-3' motif in different proteins such as cyclins B1/2 (CCNB1 and CCNB2), histone H4 (H4), and histone H2B (H2B). Exhibits an endogenous activity in HEK293 cells (human embryonic kidney cells) and can induce the degradation of CCNB1 in these cells. Can drive the degradation of CCNB1 in HEK293 cells even if CCNB1 does not contain the IN box in its C-terminus. Also involved in the cell cycle regulation of G1/S phase by controlling KIP1/CHFR-mediated destabilization of CDK4 and phosphorylation of histone H3, histone H4 and histone H2B. Involved in DNA damage response by triggering protein degradation through the 26S proteasome in a p53/TP53-dependent
Predicted effect: Decrease of IN box-dependent E3 ubiquitin ligase activity.


In [5]:
# Engineering: given wildtype protein and mutational effect, predict mutated position and new amino acid.
wildtype_protein = "MASDAAAEPSSGVTHPPRYVIGYALAPKKQQSFIQPSLVAQAASRGMDLVPVDASQPLAEQGPFHLLIHALYGDDWRAQLVAFAARHPAVPIVDPPHAIDRLHNRISMLQVVSELDHAADQDSTFGIPSQVVVYDAAALADFGLLAALRFPLIAKPLVADGTAKSHKMSLVYHREGLGKLRPPLVLQEFVNHGGVIFKVYVVGGHVTCVKRRSLPDVSPEDDASAQGSVSFSQVSNLPTERTAEEYYGEKSLEDAVVPPAAFINQIAGGLRRALGLQLFNFDMIRDVRAGDRYLVIDINYFPGYAKMPGYETVLTDFFWEMVHKDGVGNQQEEKGANHVVVK"
effect_text = "Strongly enhanced InsP6 kinase activity. The mutation in the ITPK protein causes a change in its catalytic activity."
muta_prompt = "What is the mutated position and new amino acid?"

with torch.no_grad():
    with autocast(dtype=torch.bfloat16):
        preds = model.lm_design([wildtype_protein], [effect_text], muta_prompt=[muta_prompt])

top50 = preds[0].flatten().topk(50).indices
top50_pos = top50 // len(model.protein_tokenizer)
top50_aa = top50 % len(model.protein_tokenizer)
top50_aa = model.protein_tokenizer.batch_decode(top50_aa)
print("mutated position:", top50_pos[0].item())
print("new amino acid:", top50_aa[0])

mutated position: 70
new amino acid: K
