In [None]:
import esm
import torch
import json
import logging
import numpy as np
import matplotlib.pyplot as plt
logging.getLogger().setLevel(logging.WARNING)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the ESM model
# model, alphabet = esm.pretrained.load_model_and_alphabet("/root/models/esm2_t6_8M_UR50D.pt")
# model, alphabet = esm.pretrained.load_model_and_alphabet("/root/models/esm2_t48_15B_UR50D.pt")
# model, alphabet = esm.pretrained.load_model_and_alphabet("/root/models/esm2_t12_35M_UR50D.pt")
model, alphabet = esm.pretrained.load_model_and_alphabet("/root/models/esm2_t30_150M_UR50D.pt")
for param in model.parameters():
    param.requires_grad = False
model.eval()
model = model.to(device)
converter = alphabet.get_batch_converter()

In [None]:
# sequence = "MRNPTLLQCFHWYYPEGGKLWPELAERADGFNDIGINMVWLPPAYKGASGGYSVGYDSYDLFDLGEFDQKGSIPTKYGDKAQLLAAIDALKRNDIAVLLDVVVNHKMGADEKEAIRVQRVNADDRTQIDEEIIECEGWTRYTFPARAGQYSQFIWDFKCFSGIDHIENPDEDGIFKIVNDYTGEGWNDQVDDELGNFDYLMGENIDFRNHAVTEEIKYWARWVMEQTQCDGFRLDAVKHIPAWFYKEWIEHVQEVAPKPLFIVAEYWSHEVDKLQTYIDQVEGKTMLFDAPLQMKFHEASRMGRDYDMTQIFTGTLVEADPFHAVTLVANHDTQPLQALEAPVEPWFKPLAYALILLRENGVPSVFYPDLYGAHYEDVGGDGQTYPIDMPIIEQLDELILARQRFAHGVQTLFFDHPNCIAFSRSGTDEFPGCVVVMSNGDDGEKTIHLGENYGNKTWRDFLGNRQERVVTDENGEATFFCNGGSVSVWVIEEVI"
sequence = "FFSPSPARKRHAPSPEPAVQGTGVAGVPEESGDAAAIPAKKAPAGQEEPGTPPSSPLSAEQLDRIQRNKAAALLRLAARNVPVGFGESWKKHLSGEFGKPYFIKLMGFVAEERKHYTVYPPPHQVFTWTQMCDIKDVKVVILGQDPAHGPNQAHGLCFSVQRPVPPPPSLENIYKELSTDIEDFVHPGHGDLSGWAKQGVLLLNAVLTVRAHQANSHKERGWEQFTDAVVSWLNQNSNGLVFLLWGSYAQKKGSAIDRKKHHVLQTAHPSPLSVYRGFFGCRHFSKTNELLQKSGKKPIDWKEL"
_, _, tokens = converter([("amyA", sequence)])



In [None]:
from esm.model.esm2 import ESM2

class MyESM2(ESM2):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def forward(self, x, repr_layers=[], need_head_weights=False, return_contacts=False, padding_mask=None, token_mask=None):
        return self.forward_without_embedding(x, padding_mask, token_mask, repr_layers, need_head_weights, return_contacts)

    def forward_without_embedding(self, x_raw, padding_mask, token_mask, repr_layers=[], need_head_weights=False, return_contacts=False):
        if return_contacts:
            need_head_weights = True

        x = self.embed_scale * x_raw

        if self.token_dropout:
            x.masked_fill_(token_mask.unsqueeze(-1), 0.0)
            # x: B x T x C
            mask_ratio_train = 0.15 * 0.8
            src_lengths = (~padding_mask).sum(-1)
            mask_ratio_observed = token_mask.sum(-1).to(x.dtype) / src_lengths
            x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]

        if padding_mask is not None:
            x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))

        repr_layers = set(repr_layers)
        hidden_representations = {}
        if 0 in repr_layers:
            hidden_representations[0] = x

        if need_head_weights:
            attn_weights = []

        # (B, T, E) => (T, B, E)
        x = x.transpose(0, 1)

        if not padding_mask.any():
            padding_mask = None

        for layer_idx, layer in enumerate(self.layers):
            # print(x)
            x, attn = layer(
                x,
                self_attn_padding_mask=padding_mask,
                need_head_weights=need_head_weights,
            )
            if (layer_idx + 1) in repr_layers:
                hidden_representations[layer_idx + 1] = x.transpose(0, 1)
            if need_head_weights:
                # (H, B, T, T) => (B, H, T, T)
                attn_weights.append(attn.transpose(1, 0))

        x = self.emb_layer_norm_after(x)
        x = x.transpose(0, 1)  # (T, B, E) => (B, T, E)

        # last hidden representation should have layer norm applied
        if (layer_idx + 1) in repr_layers:
            hidden_representations[layer_idx + 1] = x
        x = self.lm_head(x)

        result = {"logits": x, "representations": hidden_representations}
        if need_head_weights:
            # attentions: B x L x H x T x T
            attentions = torch.stack(attn_weights, 1)
            if padding_mask is not None:
                attention_mask = 1 - padding_mask.type_as(attentions)
                attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
                attentions = attentions * attention_mask[:, None, None, :, :]
            result["attentions"] = attentions
            # if return_contacts:
            #     contacts = self.contact_head(tokens, attentions)
            #     result["contacts"] = contacts

        return result


mymodel = MyESM2(num_layers=model.num_layers, embed_dim=model.embed_dim, attention_heads=model.attention_heads, alphabet=model.alphabet, token_dropout=model.token_dropout)
mymodel.load_state_dict(model.state_dict())
mymodel = mymodel.to(device)
for param in mymodel.parameters():
    param.requires_grad = False
mymodel.eval()

words = ["<bos>"] + list(sequence) + ["<eos>"]
x = mymodel.embed_tokens(tokens.to(device))
padding_mask = tokens.eq(mymodel.padding_idx).to(device)  # B, T
token_mask = (tokens == mymodel.mask_idx).to(device)
dataset = [x]

In [None]:
def Phi(x):
    global model, padding_mask, token_mask
    # x = x.squeeze(0)
    last_layer = len(mymodel.layers)
    results = mymodel(x, padding_mask=padding_mask, token_mask=token_mask, repr_layers=[last_layer], return_contacts=False)
    # print(results)
    return results["representations"][last_layer]

In [None]:
from Interpreter import Interpreter, calculate_regularization
regularization = calculate_regularization(dataset, Phi, device=device)

In [None]:
interpreter = Interpreter(x=x, Phi=Phi, words=words).to(device)

In [None]:
interpreter.optimize(iteration=5000, lr=0.01, show_progress=True)


In [None]:
sigma_ = interpreter.get_sigma()
sigma_np = np.concatenate([sigma_, np.zeros((sigma_.shape[-1]//20 + 1)*20-sigma_.shape[-1])])
words_new = words + [""] * ((sigma_.shape[-1]//20 + 1)*20-sigma_.shape[-1])
# print(sigma_np[:20])
# print(len(words_new))
sigma_np = sigma_np.reshape(20, -1)
# print(sigma_np[0, :])
# raise Exception
# set figure size
plt.figure(figsize=(20, 20))
plt.imshow(sigma_np, cmap="GnBu_r")
# show color value
for i in range(sigma_np.shape[0]): # 20
    # print(i)
    for j in range(sigma_np.shape[1]): # -1
        plt.text(j, i, words_new[i*sigma_np.shape[1]+j]+str(i*sigma_np.shape[1]+j)+"\n"+f"{sigma_np[i][j]:.4f}", ha="center", va="center", color="k")
# hide axis
plt.xticks([])
plt.yticks([])
# show color bar
plt.colorbar()
plt.show()

In [None]:
interpreter.visualize()


In [None]:
import matplotlib.pyplot as plt
sigma_ = interpreter.get_sigma()
sigma_ = sigma_.reshape(1, -1)
plt.imshow([sigma_])
