In [1]:
# default_exp models.pretrained.transformer

In [2]:
# all_func


In [2]:
import pandas as pd

from peptide.basics import *
from peptide.preprocessing.data import (
    ProteinDataset,
    ACPDataset,
    AMPDataset,
    DNABindDataset,
)

## ESM

In [3]:
import torch
model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")

Using cache found in /home/vinod/.cache/torch/hub/facebookresearch_esm_main


In [4]:
import torch
import esm

# Load ESM-1b model
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, (_, seq) in enumerate(data):
    sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0))

# Look at the unsupervised self-attention map contact predictions
# import matplotlib.pyplot as plt
# for (_, seq), attention_contacts in zip(data, results["contacts"]):
#     plt.matshow(attention_contacts[: len(seq), : len(seq)])
#     plt.title(seq)
#     plt.show()

In [22]:
token_representations[0, 1:4].mean(0).numpy()

array([ 0.26060754, -0.09024662,  0.09138174, ..., -0.04813125,
       -0.16790868, -0.02575146], dtype=float32)

In [9]:
sequence_representations[0].shape

torch.Size([1280])

In [5]:
acp_data = ACPDataset(DATA_STORE)

In [6]:
df = acp_data.train

In [7]:
sequences = [
    (str(i), "".join(df.loc[i, "sequence"]))
        for i in range(len(df))
]

In [8]:
sequences[:5]

[('0', 'RRWWRRWRRW'),
 ('1', 'GWKSVFRKAKKVGKTVGGLALDHYLG'),
 ('2', 'ALWKTMLKKLGTMALHAGKAALGAAADTISQGTQ'),
 ('3', 'GLFDVIKKVAAVIGGL'),
 ('4', 'VAKLLAKLAKKVL')]

In [9]:
len(sequences)

1378

In [8]:
acp_data_sample = [
('0', 'RRWWRRWRRW'),
('1', 'GWKSVFRKAKKVGKTVGGLALDHYLG'),
('2', 'ALWKTMLKKLGTMALHAGKAALGAAADTISQGTQ'),
('3', 'GLFDVIKKVAAVIGGL'),
('4', 'VAKLLAKLAKKVL')
]

In [10]:
batch_labels, batch_strs, batch_tokens = batch_converter(sequences[:100])

In [11]:
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

In [12]:
token_representations.shape

torch.Size([100, 50, 1280])

In [13]:
sequence_representations = []
for i, (_, seq) in enumerate(data):
    sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0))

In [14]:
sequence_representations[0].shape

torch.Size([1280])

### Bulk from fasta

**ACP**
```
python scripts/extract.py esm1b_t33_650M_UR50S ~/.peptide/datasets/fasta/ACPDataset_train.fasta  ~/.peptide/datasets/transformer/mean/acp/train/ \    
    --repr_layers 33 --include mean
Transferred model to GPU
Read /home/vinod/.peptide/datasets/fasta/ACPDataset_train.fasta with 1378 sequences
Processing 1 of 10 batches (292 sequences)
Processing 2 of 10 batches (215 sequences)
Processing 3 of 10 batches (178 sequences)
Processing 4 of 10 batches (157 sequences)
Processing 5 of 10 batches (132 sequences)
Processing 6 of 10 batches (117 sequences)
Processing 7 of 10 batches (105 sequences)
Processing 8 of 10 batches (91 sequences)
Processing 9 of 10 batches (80 sequences)
Processing 10 of 10 batches (11 sequences)
```

**AMP**
```
python scripts/extract.py esm1b_t33_650M_UR50S ~/.peptide/datasets/fasta/AMPDataset_test.fasta  ~/.peptide/datasets/transformer/mean/amp/test \
    --repr_layers 33 --include mean
Transferred model to GPU
Read /home/vinod/.peptide/datasets/fasta/AMPDataset_test.fasta with 808 sequences
Processing 1 of 9 batches (204 sequences)
Processing 2 of 9 batches (157 sequences)
Processing 3 of 9 batches (124 sequences)
Processing 4 of 9 batches (102 sequences)
Processing 5 of 9 batches (85 sequences)
Processing 6 of 9 batches (63 sequences)
Processing 7 of 9 batches (44 sequences)
Processing 8 of 9 batches (26 sequences)
Processing 9 of 9 batches (3 sequences)
```