In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
import torch

from magneton.config import DataConfig
from magneton.core_types import SubstructType
from magneton.data import MagnetonDataModule
from magneton.models.base_models.esmc import ESMCBaseModel, ESMCConfig, ESMC_300M
from magneton.utils import get_data_dir, get_model_dir

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
data_path = get_data_dir()
model_path = get_model_dir()

This notebook provides an example of how to generate ESM-C embeddings using an existing protein dataset. Note that we need to specify both the location of the protein dataset directory as well as the path to the FASTA file containing the protein sequences.

In [None]:
interpro_path = (
    data_path /
    "interpro_103.0"
)

label_path = (
    interpro_path /
    "labels" /
    "selected_subset"
)

dataset_path = interpro_path / "debug_subset"
fasta_path = data_path / "sequences" / "uniprot_sprot.fasta.gz"


data_config = DataConfig(
    data_dir=dataset_path,
    prefix="swissprot.with_ss.train",
    fasta_path=fasta_path,
    labels_path=label_path,
    substruct_types=[SubstructType.DOMAIN],
)

In [None]:
data_module = MagnetonDataModule(
    data_config=data_config,
    model_type="esmc",
)

In [None]:
loader = data_module.train_dataloader()

100%|██████████████████████████████████████████████████████| 1/1 [00:13<00:00, 13.30s/it]


In [None]:
it = iter(loader)

In [None]:
batch = next(it)
batch

ESMCBatch(protein_ids=['A1ABS6', 'A1AWD0', 'A0KU61', 'A0JXU0', 'A1AU61', 'A0BD73', 'A0ALA8', 'A1ADB6', 'A0A2H4HHY6', 'A0AEM3', 'A0AFC3', 'A0A455LLX4', 'A0T0M9', 'A0RJ81', 'A0KYA2', 'A0A0H3NBY9', 'A0RV25', 'A0KEH8', 'A0B9K1', 'A0RCM7', 'A0T0L8', 'A0QL16', 'A0QSG3', 'A1JLK6', 'A0Q3I1', 'A0R1W8', 'A0T0H8', 'A0KZ22', 'A0PW28', 'A0A0H2URG7', 'A0RPF9', 'A1JNH0'], seqs=None, substructures=None, structure_list=None, labels=None, tokenized_seq=tensor([[ 0, 20, 17,  ...,  1,  1,  1],
        [ 0, 20,  5,  ...,  1,  1,  1],
        [ 0, 20, 21,  ...,  1,  1,  1],
        ...,
        [ 0, 20, 11,  ...,  1,  1,  1],
        [ 0, 20, 12,  ...,  1,  1,  1],
        [ 0, 20, 20,  ...,  1,  1,  1]]))

In [9]:
use_flash_attn = True

esmc_config = ESMCConfig(
    model_size=ESMC_300M,
    weights_path=model_path / "esmc-300m-2024-12",
    # Note that this is also the default, the hidden states from
    # the final transformer layer.
    rep_layer=29,
    use_flash_attn=use_flash_attn,
)

In [10]:
embedder = ESMCBaseModel(config=esmc_config)
if use_flash_attn:
    embedder = embedder.to(device=device, dtype=torch.bfloat16)
else:
    embedder = embedder.to(device=device)

We can now embed a single amino acid sequence.

In [11]:
# Embed a single sequence
print(batch.seqs[0])
seq_embed = embedder.embed_single_protein(batch.seqs[0])

print(seq_embed.shape)
seq_embed

MPSTTQTTVQSIDSIDSIPTTIKRRQNDKTKTPKTKPVSKIPICPKNSSIPRLDQPSQHKFILLQSLLPITVHQLTTLVLSISRYDDYVHPFLLRLCVIIGYGYAFRFLLRREGLAIRTLGKKLGYLDGDHHPRDKVPRDSTRLNWSLPLTVGSRTVMCVLVAYDPSQQPINYLASLKWWAWLAVYLSLYPIILDFYYYCVHRAWHEVPCLWRFHRRHHTIKRPSILFTAYADSEQELFDIVGTPLLTFFTLKALHLPMDFYTWWICIQYIAYTEVMGHSGLRIYTTPPISCSWLLQRFGVELVIEDHDLHHRQGYRQARNYGKQTRIWDRLFGTCADRIETNPVNIQKGRRVMMHSINIPSLGN
torch.Size([365, 960])


tensor([[-2.1240e-02,  5.1270e-02,  3.3447e-02,  ...,  2.3193e-02,
          7.0801e-03, -1.3062e-02],
        [-4.7852e-02,  5.1270e-02, -2.5269e-02,  ...,  4.7363e-02,
          1.9043e-02,  2.2461e-02],
        [-5.7617e-02, -1.8768e-03, -5.1758e-02,  ..., -2.8839e-03,
         -6.9580e-03, -1.7334e-02],
        ...,
        [-3.2227e-02,  4.4678e-02, -8.6594e-04,  ...,  1.8555e-02,
          1.3123e-02,  1.3184e-02],
        [-3.8574e-02,  4.4189e-02,  2.2095e-02,  ...,  3.5156e-02,
          3.8086e-02,  1.2144e-06],
        [-2.6978e-02,  7.6660e-02, -2.2949e-02,  ...,  3.4180e-03,
          4.2725e-03,  3.6011e-03]], device='cuda:0', dtype=torch.bfloat16)

Or embed a whole batch, returning a list of `torch.Tensor`.

In [12]:
batch_embeds = embedder.embed_sequences(batch.seqs)
print(len(batch_embeds))
for embed in batch_embeds:
    print(embed.shape)
batch_embeds[0]

Processing sequences: 100%|█████████████████████████████████████████████████| 4/4 [00:00<00:00, 35.19it/s]

4
torch.Size([365, 960])
torch.Size([1567, 960])
torch.Size([1033, 960])
torch.Size([118, 960])





tensor([[-2.1240e-02,  5.1270e-02,  3.3447e-02,  ...,  2.3193e-02,
          7.0801e-03, -1.3062e-02],
        [-4.7852e-02,  5.1270e-02, -2.5269e-02,  ...,  4.7363e-02,
          1.9043e-02,  2.2461e-02],
        [-5.7617e-02, -1.8768e-03, -5.1758e-02,  ..., -2.8839e-03,
         -6.9580e-03, -1.7334e-02],
        ...,
        [-3.2227e-02,  4.4678e-02, -8.6594e-04,  ...,  1.8555e-02,
          1.3123e-02,  1.3184e-02],
        [-3.8574e-02,  4.4189e-02,  2.2095e-02,  ...,  3.5156e-02,
          3.8086e-02,  1.2144e-06],
        [-2.6978e-02,  7.6660e-02, -2.2949e-02,  ...,  3.4180e-03,
          4.2725e-03,  3.6011e-03]], device='cuda:0', dtype=torch.bfloat16)

We can also handle sequences longer than the max context length without any modifications.

In [13]:
batch.seqs[1] = batch.seqs[1]*20
[len(seq) for seq in batch.seqs]

[365, 31340, 1033, 118]

In [14]:
# Note this is embedding sequences serially, batch-level tokenization has been pushed into the data loader
batch_embeds = embedder.embed_sequences(batch.seqs)
print(len(batch_embeds))
for embed in batch_embeds:
    print(embed.shape)

Processing sequences:   0%|                                                         | 0/4 [00:00<?, ?it/s]

Processing sequences: 100%|█████████████████████████████████████████████████| 4/4 [00:00<00:00,  7.51it/s]

4
torch.Size([365, 960])
torch.Size([31340, 960])
torch.Size([1033, 960])
torch.Size([118, 960])



