In [1]:
import torch 
from transformers import AutoTokenizer, AutoModelForMaskedLM

device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "facebook/esm2_t12_35M_UR50D"

In [2]:
tokenizer = AutoTokenizer.from_pretrained(model_id, do_lower_case=False)
model = AutoModelForMaskedLM.from_pretrained(model_id).to(device)
model.eval() 

tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/778 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/136M [00:00<?, ?B/s]

EsmForMaskedLM(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 480, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-11): 12 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=480, out_features=480, bias=True)
              (key): Linear(in_features=480, out_features=480, bias=True)
              (value): Linear(in_features=480, out_features=480, bias=True)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=480, out_features=480, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((480,), eps=1e-05, elementwise_affine=True)
          )
          (intermediate): EsmIntermediate(
            (dense): Linear(in_features=480, out_fe

In [3]:
seq = "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQ"

In [4]:
enc = tokenizer(seq, return_tensors="pt", add_special_tokens=True)
input_ids = enc["input_ids"].to(device)
attention_mask = enc["attention_mask"].to(device)

In [19]:
print(enc)

{'input_ids': tensor([[ 0, 20, 15, 11,  5, 19, 12,  5, 15, 16, 10, 16, 12,  8, 18,  7, 15,  8,
         21, 18,  8, 10, 16,  4,  9,  9, 10,  4,  6,  4, 12,  9,  7, 16,  2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


In [5]:
with torch.no_grad():
    out = model(input_ids = input_ids, attention_mask=attention_mask)
    logits = out.logits
    log_probs = torch.log_softmax(logits, dim=-1)

In [6]:
print("input_ids:", input_ids.shape)
print("log_probs:", log_probs.shape)
print("vocab size:", log_probs.shape[-1])

input_ids: torch.Size([1, 35])
log_probs: torch.Size([1, 35, 33])
vocab size: 33


In [7]:
MASK = tokenizer.mask_token_id

In [8]:
def mask_position(input_ids, i_bio):
    ids = input_ids.clone()
    model_idx = i_bio
    ids[0, model_idx] = MASK
    return ids

i_bio = 5
ids_masked = mask_position(input_ids, i_bio)

with torch.no_grad():
    logits_masked = model(input_ids=ids_masked, attention_mask=attention_mask).logits
    logp_masked = torch.log_softmax(logits_masked, dim=-1)

logp_pos= logp_masked[0, i_bio]

wt_token= input_ids[0, i_bio].item()
wt_logp = logp_pos[wt_token].item()

In [9]:
print("WT token id:", wt_token, "WT log-prob at pos 5:", wt_logp)

WT token id: 19 WT log-prob at pos 5: -3.7084124088287354


In [10]:
CLS  = tokenizer.cls_token_id
EOS  = tokenizer.eos_token_id

In [11]:
def per_site_wt_logprobs(input_ids, attention_mask):

    ids = input_ids.clone()
    lm = ids.shape[1]

    mask_positions = torch.arange(1, lm-1, device=ids.device)
    batch = ids.repeat(mask_positions.numel(), 1)
    batch[torch.arange(mask_positions.numel()), mask_positions] = MASK
    amask = attention_mask.repeat(mask_positions.numel(), 1)

    with torch.no_grad():
        logits = model(input_ids=batch, attention_mask=amask).logits
        logp = torch.log_softmax(logits, dim=-1)

    wt_tokens = ids[0, mask_positions]
    wt_logps = logp[torch.arange(mask_positions.numel()), mask_positions, wt_tokens]

    return wt_logps
    

In [12]:
wt_logps = per_site_wt_logprobs(input_ids, attention_mask)

print("WT per-site log-probs shape:", wt_logps.shape)
print("PLL (sum of site log-probs):", float(wt_logps.sum()))
print("Perplexity (pseudo):", float(( -wt_logps.mean() ).exp()))


WT per-site log-probs shape: torch.Size([33])
PLL (sum of site log-probs): -91.56507110595703
Perplexity (pseudo): 16.033803939819336


In [22]:
aa2id = {tok: tokenizer.convert_tokens_to_ids(tok) for tok in tokenizer.get_vocab()}

In [24]:
def delta_logprob_single_mutation(input_ids, attention_mask, i_bio, mutant_aa):
    ids_masked_1 = input_ids.clone()
    ids_masked_1[0, i_bio] = MASK

    with torch.no_grad():
        logp = torch.log_softmax(model(ids_masked_1, attention_mask=attention_mask).logits, dim=-1)[0, i_bio]

    wt_token = input_ids[0, i_bio].item()
    mut_token = tokenizer.convert_tokens_to_ids(mutant_aa)
    return float(logp[mut_token] - logp[wt_token])

In [33]:
print(r' {$\delta$} log p at position 5 for V: ', delta_logprob_single_mutation(input_ids, attention_mask, 5, "V"))

 {$\delta$} log p at position 5 for V:  0.8428685665130615


In [29]:
seqs = [
    "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
    "MKTAYIAKQRYISFVKSHFSRQLDEKLG",  
]


enc = tokenizer(
                seqs,
                return_tensors='pt',
                add_special_tokens=True,
                padding=True,
                truncation=False
               )

ids = enc["input_ids"].to(device)
amask = enc["attention_mask"].to(device)

with torch.no_grad():
    logp = torch.log_softmax(model(ids, attention_mask=amask).logits, dim=-1)