In [1]:
import os

from torch.utils.data import DataLoader

from magneton.data.sequence_only import SequenceOnlyDataset, collate_sequence_datasets
from magneton.embedders.esmc_embedder import ESMCEmbedder, ESMCConfig

This notebook provides an example of how to generate ESM-C embeddings using an existing protein dataset. Note that we need to specify both the location of the protein dataset directory as well as the path to the FASTA file containing the protein sequences.

In [2]:
interpro_path = "/weka/scratch/weka/kellislab/rcalef/data/interpro/103.0/"
fasta_path = "/rdma/vast-rdma/vast-home/rcalef/transfer/uniprot_sprot.fasta.gz"

In [3]:
prot_dataset = SequenceOnlyDataset(
    input_path=os.path.join(interpro_path, "swissprot", "sharded_swissprot", "with_ss"),
    fasta_path=fasta_path,
    prefix="swissprot.with_ss",
)
len(prot_dataset)

530601

In [4]:
dataloader = DataLoader(
    prot_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_sequence_datasets,
)

# Batch is a list of tuples, each tuple is a protein sequence and Protein object.
batch = next(iter(dataloader))
print(batch[0])

('MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENARIQSKLSDLQKKKIDIDNKLLKEKQNLIKEEILERKKLEVLTKKQQKDEIEHQKKLKREIDAIKASTQYITDVSISSYNNTIPETEPEYDLFISHASEDKEDFVRPLAETLQQLGVNVWYDEFTLKVGDSLRQKIDSGLRNSKYGTVVLSTDFIKKDWTNYELDGLVAREMNGHKMILPIWHKITKNDVLDYSPNLADKVALNTSVNSIEEIAHQLADVILNR', Protein(uniprot_id='A0A009IHW8', kb_id='sp|A0A009IHW8|ABTIR_ACIB9', name='ABTIR_ACIB9', length=269, parsed_entries=5, total_entries=5, entries=[InterproEntry(id='IPR035897', element_type='Homologous_superfamily', match_id='G3DSA:3.40.50.10140', element_name='Toll/interleukin-1 receptor homology (TIR) domain superfamily', representative=False, positions=[(80, 266)]), InterproEntry(id='IPR000157', element_type='Domain', match_id='PF13676', element_name='Toll/interleukin-1 receptor homology (TIR) domain', representative=False, positions=[(138, 231)]), InterproEntry(id='IPR000157', element_type='Domain', match_id='PS50104', element_name='Toll/interleukin-1 receptor homology (TIR) domain', representative=True, positions=[

In [5]:
esmc_config = ESMCConfig(
    weights_path="/weka/scratch/weka/kellislab/rcalef/model_weights/esmc-600m-2024-12",
    # Note that this is also the default, the hidden states from
    # the final transformer layer.
    rep_layer=35,
    use_flash_attn=False,
    device="cuda",
)

In [6]:
embedder = ESMCEmbedder(config=esmc_config)

  state_dict = torch.load(


We can now embed a single amino acid sequence.

In [7]:
# Embed a single sequence
print(batch[0][0])
seq_embed = embedder.embed_single_protein(batch[0][0])

print(seq_embed.shape)
seq_embed

MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENARIQSKLSDLQKKKIDIDNKLLKEKQNLIKEEILERKKLEVLTKKQQKDEIEHQKKLKREIDAIKASTQYITDVSISSYNNTIPETEPEYDLFISHASEDKEDFVRPLAETLQQLGVNVWYDEFTLKVGDSLRQKIDSGLRNSKYGTVVLSTDFIKKDWTNYELDGLVAREMNGHKMILPIWHKITKNDVLDYSPNLADKVALNTSVNSIEEIAHQLADVILNR
torch.Size([269, 1152])


tensor([[ 0.0288,  0.0265, -0.0140,  ...,  0.0057, -0.0190,  0.0182],
        [-0.0192,  0.0288,  0.0173,  ...,  0.0073,  0.0032, -0.0096],
        [-0.0063,  0.0354,  0.0086,  ...,  0.0073,  0.0022, -0.0060],
        ...,
        [ 0.0159, -0.0169,  0.0096,  ..., -0.0070, -0.0086,  0.0072],
        [-0.0188, -0.0211,  0.0249,  ..., -0.0003,  0.0126,  0.0464],
        [ 0.0239, -0.0315,  0.0257,  ..., -0.0056,  0.0286, -0.0071]])

Or embed a whole batch, returning a list of `torch.Tensor`.

In [9]:
batch_embeds = embedder.embed_batch(batch)
print(len(batch_embeds))
for embed in batch_embeds:
    print(embed.shape)
batch_embeds[0]

4
torch.Size([269, 1152])
torch.Size([118, 1152])
torch.Size([118, 1152])
torch.Size([119, 1152])


tensor([[ 0.0288,  0.0265, -0.0140,  ...,  0.0057, -0.0190,  0.0182],
        [-0.0192,  0.0288,  0.0173,  ...,  0.0073,  0.0032, -0.0096],
        [-0.0063,  0.0354,  0.0086,  ...,  0.0073,  0.0022, -0.0060],
        ...,
        [ 0.0159, -0.0169,  0.0096,  ..., -0.0070, -0.0086,  0.0072],
        [-0.0188, -0.0211,  0.0249,  ..., -0.0003,  0.0126,  0.0464],
        [ 0.0239, -0.0315,  0.0257,  ..., -0.0056,  0.0286, -0.0071]])

We can also handle sequences longer than the max context length without any modifications.

In [10]:
batch[1] = (batch[1][0]*100, batch[1][1])
[len(seq) for seq, _ in batch]

[269, 11800, 118, 119]

In [11]:
batch_embeds = embedder.embed_batch(batch)
print(len(batch_embeds))
for embed in batch_embeds:
    print(embed.shape)

4
torch.Size([269, 1152])
torch.Size([11800, 1152])
torch.Size([118, 1152])
torch.Size([119, 1152])
