In [2]:
import torch

from magneton.config import (
    BaseModelConfig,
    DataConfig,
    ModelConfig,
    PipelineConfig,
    TrainingConfig,
)
from magneton.core_types import SubstructType
from magneton.data import MagnetonDataModule
from magneton.data.core import get_substructure_parser
from magneton.models.substructure_classifier import SubstructureClassifier

from magneton.utils import get_data_dir, get_model_dir

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

This notebook provides an example of how to generate ESM-C embeddings using an existing protein dataset. Note that we need to specify both the location of the protein dataset directory as well as the path to the FASTA file containing the protein sequences.

In [4]:
interpro_path = get_data_dir() / "interpro_103.0"

dataset_path = interpro_path / "debug_subset"
labels_path = interpro_path / "labels" / "selected_subset"
splits_path = interpro_path / "dataset_splits" / "seq_splits.tsv"
fasta_path = get_data_dir() / "sequences" / "uniprot_sprot.fasta.gz"

In [5]:
data_config = DataConfig(
    data_dir=dataset_path,
    fasta_path=fasta_path,
    labels_path=labels_path,
    splits=splits_path,
    substruct_types=[SubstructType.DOMAIN, SubstructType.ACT_SITE],
    batch_size=4,
)

In [6]:
data_module = MagnetonDataModule(
    data_config=data_config,
    # model_type specifies which model-specific transforms to use,
    # e.g. tokenization
    model_type="esmc",
)

In [7]:
loader = data_module.train_dataloader()

processing proteins: 100%|████████████████████████████████████| 1/1 [00:04<00:00,  4.56s/it]
INFO:magneton.data.core.unified_dataset:split train: got 3010 proteins
INFO:magneton.data.data_modules:remaining proteins after length filter: 2957 / 3010


In [8]:
example_batch = next(iter(loader))
example_batch

ESMCBatch(protein_ids=['A1AGH4', 'A0RFP3', 'A1ADB6', 'A1A8Z8'], lengths=[190, 215, 556, 250], seqs=None, substructures=[[LabeledSubstructure(ranges=[tensor([  7, 190])], label=379, element_type=<SubstructType.DOMAIN: 'Domain'>)], [LabeledSubstructure(ranges=[tensor([69, 85])], label=80, element_type=<SubstructType.ACT_SITE: 'Active_site'>), LabeledSubstructure(ranges=[tensor([130, 144])], label=81, element_type=<SubstructType.ACT_SITE: 'Active_site'>)], [LabeledSubstructure(ranges=[tensor([184, 388])], label=755, element_type=<SubstructType.DOMAIN: 'Domain'>)], [LabeledSubstructure(ranges=[tensor([ 8, 17])], label=3, element_type=<SubstructType.ACT_SITE: 'Active_site'>)]], structure_list=None, labels=None, tokenized_seq=tensor([[ 0, 20, 17,  ...,  1,  1,  1],
        [ 0, 20, 15,  ...,  1,  1,  1],
        [ 0, 20,  8,  ..., 21,  4,  2],
        [ 0, 20,  5,  ...,  1,  1,  1]]))

Now that we have a batch of tokenized sequences, we'll set up the substructure classification model using ESM-C 300M as the base model.

In [9]:
model_config = ModelConfig(
    frozen_base_model=True,
    pooling_mechanism="mean",
    # Parameters for the substructure classification heads.
    # 'embed' is shorthand for the base model's embedding
    # dimensionality
    model_params={
        "hidden_dims": ["embed"],
        "dropout_rate": 0.1,
    },
)

train_config = TrainingConfig(
    loss_strategy="standard",
)

base_model_config = BaseModelConfig(
    model="esmc",
    model_params={
        "model_size": "300m",
        "weights_path": get_model_dir() / "esmc-300m-2024-12",
        "rep_layer": 29,
        "use_flash_attn": False,
    }
)

config = PipelineConfig(
    training=train_config,
    model=model_config,
    base_model=base_model_config,
)

In [10]:
substruct_parser = get_substructure_parser(data_config)

model = SubstructureClassifier(
    config=config,
    num_classes=substruct_parser.num_labels()
)

INFO:magneton.models.substructure_classifier:head model: ModuleDict(
  (Active_site): Sequential(
    (0): Linear(in_features=960, out_features=960, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=960, out_features=82, bias=True)
  )
  (Domain): Sequential(
    (0): Linear(in_features=960, out_features=960, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=960, out_features=917, bias=True)
  )
)


We can now get substructure embeddings, which are returned as a dict where the key is the type of substructure and the value is the tensor of embeddings for all substructures of that type within the batch. For example, if the batch contained two proteins with 3 domains each, rows 0-2 would contain the embeddings of the domains in the first protein and rows 3-5 would contain the embeddings of the domains in the second protein.

In [11]:
model = model.to(device)
example_batch = example_batch.to(device)

In [12]:
with torch.inference_mode():
    substruct_embeds = model(example_batch)
substruct_embeds

{<SubstructType.DOMAIN: 'Domain'>: tensor([[ 0.0269, -0.0068, -0.0080,  ...,  0.0192,  0.0055,  0.0119],
         [ 0.0284,  0.0014, -0.0060,  ...,  0.0221,  0.0089,  0.0020]],
        device='cuda:0'),
 <SubstructType.ACT_SITE: 'Active_site'>: tensor([[-0.0140,  0.0091,  0.0111, -0.0251, -0.0091, -0.0035, -0.0112,  0.0125,
          -0.0079,  0.0201,  0.0069, -0.0143, -0.0636,  0.0019,  0.0099,  0.0214,
          -0.0027,  0.0070,  0.0019,  0.0533,  0.0232, -0.0098, -0.0416,  0.0333,
           0.0015,  0.0190, -0.0105,  0.0167,  0.0181, -0.0032,  0.0259,  0.0222,
          -0.0283,  0.0168, -0.0105,  0.0007, -0.0137, -0.0252,  0.0022, -0.0214,
          -0.0005, -0.0048, -0.0305, -0.0022,  0.0016, -0.0065, -0.0163, -0.0094,
          -0.0002,  0.0033, -0.0212, -0.0178, -0.0090, -0.0198, -0.0196, -0.0311,
          -0.0056,  0.0115,  0.0082,  0.0012, -0.0065,  0.0207,  0.0277,  0.0040,
          -0.0102, -0.0328,  0.0147, -0.0080,  0.0240,  0.0485, -0.0223, -0.0033,
           0.0112,

If we want residue-level or protein-level embeddings from the base model (e.g. for downstream evaluation tasks after substructure-tuning), we can directly call the base model.

Note that the method for computing protein-level embeddings varies across base models (e.g. mean pooling vs CLS token).

In [13]:
with torch.inference_mode():
    residue_embeds = model.base_model.embed_batch(example_batch)
residue_embeds.shape

torch.Size([4, 557, 960])

In [14]:
with torch.inference_mode():
    protein_embeds = model.base_model.embed_batch(example_batch, protein_level=True)
protein_embeds.shape

torch.Size([4, 960])