In [1]:
import torch
import esm
from Bio import SeqIO

torch.set_num_threads(20)

In [2]:
with open("./Human_ACE_orthologues.fa") as file:
    fasta_seqs = list(SeqIO.parse(file, "fasta"))

In [3]:
data = [(i.id, str(i.seq)) for i in fasta_seqs]

In [7]:
data[0:1]

[('ENSOGAP00000006131_Bushbaby',
  'MGASSGLRASGPPPVLSLLPLLQLLLLPPSPAAQALDPGLLPGNFSADEAGAQLFAQSYNSSAEQVLFQSTAASWAYSTNITEENAQRQEEAVLLNQRFAEAWGQKAKELYGPIWQNFTDPKLRKVIRAVCTLGPANLPLAKQQQYVSLQTNMSRIYSTTKVCFPNKTATCWSLDPELTNILASSRSYARLLFAWEGWHDTVGIPLKALYQDFTTISNEAYRQDGFSDTGAYWRSWYNSATFEEDLEHLYHQLEPLYLNLHAYVRRALHRRYGDRYINLRGPIPAHLLGDMWAQSWDNLYDMVVPFPGKPNLDVTSTMVKQGWNATHMFRVAEEFFTSLGLSPMPPEFWAESMLEKPADGREVVCHASAWDFYNRKDFRIKQCTQVTMDQLSTVHHEMGHVQYYLQYKDQPVSLREGANPGFHEAIGDVLGLSVSTPAHLHKIGLLDHVTNDTESDINYLLKMALDKIAFLPFGYLVDQWRWGVFSGHTPPSRYNSDWWYLRTKYQGICPPVVRNETHFDAGAKFHIPNGTPYIRYFVSFILQFQFHQALCKEAGHQGPLHQCDIYKSTQAGAKLQEVLRAGSSRPWQEVLKNMTGSDALDAQPLLDYFQPVSQWLQEQNQQNNEILGWPEYQWRPPLPTNYPEGIDLITDEAEANKFVEEYDQVSQVVWNEFAEANWNYNTNITTEASQILLQKNLEIANHTLKWGIQARQFDVSTFQNTTTKRVIKKVQDLDRAALPAKELEEYNKILLEMETTYSVATVCHTNGTCLQLEPDLTSMMATSRQYYELLWAWKSWRDKVGRAILPSFPKYVELTNKAARLNGYIDGGDSWRSMYEMPSLEQNLEELFQELQPLYLNLHAYVRRALHRHYGAQHINLEGPIPAHLLGNMWAQTWSNIYDLVVPFPSAPSMDATEAMIKQGWTPRRMFKEADNFFISLGLLPVPPEFWNKSMLEKPTDGREVVCH

In [12]:
# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
#data = [
#    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
#    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
#    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
#    ("protein3",  "K A <mask> I S Q"),
#]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)



In [10]:
import torchvision.models as models
from torch.profiler import profile, record_function, ProfilerActivity

In [13]:
with profile(activities=[ProfilerActivity.CPU],
        profile_memory=True, record_shapes=True) as prof:
    model(batch_tokens)

print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))

STAGE:2024-03-22 20:32:08 522199:522199 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-03-22 20:32:42 522199:522199 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-03-22 20:32:42 522199:522199 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


KeyboardInterrupt: 

In [9]:
# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]



KeyboardInterrupt: 

In [None]:
# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))

# Look at the unsupervised self-attention map contact predictions
import matplotlib.pyplot as plt
for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):
    plt.matshow(attention_contacts[: tokens_len, : tokens_len])
    plt.title(seq)
    plt.show()