In [1]:
#model_checkpoint = "facebook/esm2_t12_35M_UR50D"
model_checkpoint = "facebook/esm2_t33_650M_UR50D"

import requests, pandas, os
from io import BytesIO

query_url = "https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Csequence%2Ccc_subcellular_location&format=tsv&query=%28%28organism_id%3A9606%29%20AND%20%28reviewed%3Atrue%29%20AND%20%28length%3A%5B80%20TO%20500%5D%29%29"
tmp_file = "/tmp/uniprot.parquet.gz"

if not os.path.exists(tmp_file):
    # Get data
    uniprot_request = requests.get(query_url)
    # Store data as binary object
    bio = BytesIO(uniprot_request.content)
    # Reach bionary object as compressed csv
    df = pandas.read_csv(bio, compression='gzip', sep='\t')
    # Write out to local location for faster reloads
    df.to_parquet(tmp_file, compression="gzip")
else:
    # Load from checkpoint
    df = pandas.read_parquet(tmp_file)
df

Unnamed: 0,Entry,Sequence,Subcellular location [CC]
0,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...,SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E...
1,A0AVI4,MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA...,SUBCELLULAR LOCATION: Endoplasmic reticulum me...
2,A0JLT2,MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...,SUBCELLULAR LOCATION: Nucleus {ECO:0000305}.
3,A0M8Q6,GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...,SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu...
4,A0PJY2,MDSSCHNATTKMLATAPARGNMMSTSKPLAFSIERIMARTPEPKAL...,SUBCELLULAR LOCATION: Nucleus {ECO:0000269|Pub...
...,...,...,...
11972,Q9H8W2,MRPGSSPRAPECGAPALPRPQLDRLPARPAPSRGRGAPSLRWPAKE...,
11973,Q9HAA7,MLFGIRILVNTPSPLVTGLHHYNPSIHRDQGECANQWRKGPGSAHL...,
11974,Q9NZ38,MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG...,
11975,Q9UFV3,MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV...,


In [2]:
df = df.dropna()  # Drop proteins with missing columns
cytosolic = df['Subcellular location [CC]'].str.contains("Cytosol") | df['Subcellular location [CC]'].str.contains("Cytoplasm")
membrane = df['Subcellular location [CC]'].str.contains("Membrane") | df['Subcellular location [CC]'].str.contains("Cell membrane")
cytosolic_df = df[cytosolic & ~membrane]
cytosolic_df

Unnamed: 0,Entry,Sequence,Subcellular location [CC]
9,A1E959,MKIIILLGFLGATLSAPLIPQRLMSASNSNELLLNLNNGQLLPLQL...,SUBCELLULAR LOCATION: Secreted {ECO:0000250|Un...
14,A1XBS5,MMRRTLENRNAQTKQLQTAVSNVEKHFGELCQIFAAYVRKTARLRD...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P...
18,A2RU49,MSSGNYQQSEALSKPTFSEEQASALVESVFGLKVSKVRPLPSYDDQ...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000305}.
20,A2RUH7,MEAATAPEVAAGSKLKVKEASPADAEPPQASPGQGAGSPTPQLLPP...,"SUBCELLULAR LOCATION: Cytoplasm, myofibril, sa..."
21,A4D126,MEAGPPGSARPAEPGPCLSGQRGADHTASASLQSVAGTEPGRHPQA...,"SUBCELLULAR LOCATION: Cytoplasm, cytosol {ECO:..."
...,...,...,...
11495,Q8NBC4,MFPRPVLNSRAQAILLPQPPNMLDHRQWPPRLASFPFTKTGMLSRA...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P...
11535,Q8TDY3,MFNPHALDSPAVIFDNGSGFCKAGLSGEFGPRHMVSSIVGHLKFQA...,"SUBCELLULAR LOCATION: Cytoplasm, cytoskeleton ..."
11547,Q8WWF8,MAGTARHDREMAIQAKKKLTTATDPIERLRLQCLARGSAGIKGLGR...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000305}.
11671,Q9NUJ7,MGGQVSASNSFSRLHCRNANEDWMSALCPRLWDVPLHHLSIPGSHD...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P...


In [3]:
membrane_df = df[membrane & ~cytosolic]
membrane_df

Unnamed: 0,Entry,Sequence,Subcellular location [CC]
0,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...,SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E...
3,A0M8Q6,GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...,SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu...
17,A2RU14,MAGTVLGVGAGVFILALLWVAVLLLCVLLSRASGAARFSVIFLFFG...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
33,A5X5Y0,MEGSWFHRKRFSFYLLLGFLLQGRGVTFTINCSGFGQHGADPTALN...,SUBCELLULAR LOCATION: Postsynaptic cell membra...
36,A6ND01,MACWWPLLLELWTVMPTWAGDELLNICMNAKHHKRVPSPEDKLYEE...,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...
...,...,...,...
11895,Q86UQ5,MQSDIYHPGHSFPSWVLCWVHSCGHEGHLRETAEIRKTHQNGDLQI...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11918,Q8N8V8,MLLKVRRASLKPPATPHQGAFRAGNVIGQLIYLLTWSLFTAWLRPP...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11955,Q96N68,MQGQGALKESHIHLPTEQPEASLVLQGQLAESSALGPKGALRPQAQ...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11963,Q9H0A3,MMNNTDFLMLNNPWNKLCLVSMDFCFPLDFVSNLFWIFASKFIIVT...,SUBCELLULAR LOCATION: Membrane {ECO:0000255}; ...


In [4]:
cytosolic_sequences = cytosolic_df["Sequence"].tolist()
cytosolic_labels = [0 for protein in cytosolic_sequences]
membrane_sequences = membrane_df["Sequence"].tolist()
membrane_labels = [1 for protein in membrane_sequences]

sequences = cytosolic_sequences + membrane_sequences
labels = cytosolic_labels + membrane_labels

# Quick check to make sure we got it right
assert(len(sequences) == len(labels))

In [5]:
from sklearn.model_selection import train_test_split

train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.25, shuffle=True)

In [6]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
tokenizer(train_sequences[0])

{'input_ids': [0, 20, 5, 9, 7, 16, 7, 4, 7, 4, 13, 6, 10, 6, 21, 4, 4, 6, 10, 4, 5, 5, 12, 7, 5, 15, 16, 7, 4, 4, 6, 10, 15, 7, 7, 7, 7, 10, 23, 9, 6, 12, 17, 12, 8, 6, 17, 18, 19, 10, 17, 15, 4, 15, 19, 4, 5, 18, 4, 10, 15, 10, 20, 17, 11, 17, 14, 8, 10, 6, 14, 19, 21, 18, 10, 5, 14, 8, 10, 12, 18, 22, 10, 11, 7, 10, 6, 20, 4, 14, 21, 15, 11, 15, 10, 6, 16, 5, 5, 4, 13, 10, 4, 15, 7, 18, 13, 6, 12, 14, 14, 14, 19, 13, 15, 15, 15, 10, 20, 7, 7, 14, 5, 5, 4, 15, 7, 7, 10, 4, 15, 14, 11, 10, 15, 18, 5, 19, 4, 6, 10, 4, 5, 21, 9, 7, 6, 22, 15, 19, 16, 5, 7, 11, 5, 11, 4, 9, 9, 15, 10, 15, 9, 15, 5, 15, 12, 21, 19, 10, 15, 15, 15, 16, 4, 20, 10, 4, 10, 15, 16, 5, 9, 15, 17, 7, 9, 15, 15, 12, 13, 15, 19, 11, 9, 7, 4, 15, 11, 21, 6, 4, 4, 7, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [8]:
train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)

In [9]:
from datasets import Dataset
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)
test_dataset

train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)

# Print the shape and columns of the datasets after adding labels
print(train_dataset.shape, list(train_dataset[0].keys()))
print(test_dataset.shape, list(test_dataset[0].keys()))

(3886, 3) ['input_ids', 'attention_mask', 'labels']
(1296, 3) ['input_ids', 'attention_mask', 'labels']


  block_group = [InMemoryTable(cls._concat_blocks(list(block_group), axis=axis))]
  table = cls._concat_blocks(blocks, axis=0)


In [10]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

num_labels = max(train_labels + test_labels) + 1  # 2: {0: cytosolic, 1: membrane}
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [34]:
model_name = model_checkpoint.split("/")[-1]
batch_size = 8
strat = "steps" # "epoch"

args = TrainingArguments(
    f"{model_name}-finetuned-localization",
    eval_strategy = strat,
    save_strategy = strat,
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    max_steps=100,
    eval_steps=20,
    #num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
    report_to="none",
    include_tokens_per_second=True,
)

In [35]:
from evaluate import load
import numpy as np

metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

In [36]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

  trainer = Trainer(


In [37]:
trainer.train()

Step,Training Loss,Validation Loss,Accuracy
20,No log,0.317503,0.94213
40,No log,0.33607,0.939043
60,No log,0.344598,0.940586
80,No log,0.346892,0.943673
100,No log,0.346528,0.942901


Could not locate the best model at esm2_t33_650M_UR50D-finetuned-localization/checkpoint-80/pytorch_model.bin, if you are running a distributed training on multiple nodes, you should activate `--save_on_each_node`.


TrainOutput(global_step=100, training_loss=0.013548849821090699, metrics={'train_runtime': 179.6321, 'train_samples_per_second': 4.454, 'train_steps_per_second': 0.557, 'train_tokens_per_second': 2097.62, 'total_flos': 1403037318561600.0, 'train_loss': 0.013548849821090699, 'epoch': 0.205761316872428})

In [32]:
import tempfile
from pathlib import Path
from typing import Sequence, Tuple

import lightning.pytorch as pl
from lightning.pytorch.callbacks import Callback, RichModelSummary
from lightning.pytorch.loggers import TensorBoardLogger
from megatron.core.optimizer.optimizer_config import OptimizerConfig
from nemo import lightning as nl
from nemo.collections import llm as nllm
from nemo.lightning import resume
from nemo.lightning.nemo_logger import NeMoLogger
from nemo.lightning.pytorch import callbacks as nl_callbacks
from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform
from nemo.lightning.pytorch.callbacks.peft import PEFT
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule

from bionemo.core.data.load import load
from bionemo.esm2.api import ESM2GenericConfig
from bionemo.esm2.data.tokenizer import BioNeMoESMTokenizer, get_tokenizer
from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule
from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig, InMemorySingleValueDataset
from bionemo.llm.model.biobert.lightning import biobert_lightning_module

__all__: Sequence[str] = ("train_model",)

def train_model(
    experiment_name: str,
    experiment_dir: Path,
    config: ESM2GenericConfig,
    data_module: pl.LightningDataModule,
    n_steps_train: int,
    metric_tracker: Callback | None = None,
    tokenizer: BioNeMoESMTokenizer = get_tokenizer(),
    peft: PEFT | None = None,
    _use_rich_model_summary: bool = True,
) -> Tuple[Path, Callback | None, nl.Trainer]:
    """Trains a BioNeMo ESM2 model using PyTorch Lightning.

    Parameters:
        experiment_name: The name of the experiment.
        experiment_dir: The directory where the experiment will be saved.
        config: The configuration for the ESM2 model.
        data_module: The data module for training and validation.
        n_steps_train: The number of training steps.
        metric_tracker: Optional callback to track metrics
        tokenizer: The tokenizer to use. Defaults to `get_tokenizer()`.
        peft: The PEFT (Parameter-Efficient Fine-Tuning) module. Defaults to None.
        _use_rich_model_summary: Whether to use the RichModelSummary callback, omitted in our test suite until
            https://nvbugspro.nvidia.com/bug/4959776 is resolved. Defaults to True.

    Returns:
        A tuple containing the path to the saved checkpoint, a MetricTracker
        object, and the PyTorch Lightning Trainer object.
    """
    checkpoint_callback = nl_callbacks.ModelCheckpoint(
        save_last=True,
        save_on_train_epoch_end=True,
        monitor="reduced_train_loss",  # TODO find out how to get val_loss logged and use "val_loss",
        every_n_train_steps=n_steps_train // 2,
        always_save_context=True,  # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
    )

    # Setup the logger and train the model
    nemo_logger = NeMoLogger(
        log_dir=str(experiment_dir),
        name=experiment_name,
        tensorboard=TensorBoardLogger(save_dir=experiment_dir, name=experiment_name),
        ckpt=checkpoint_callback,
    )
        # Needed so that the trainer can find an output directory for the profiler
    # ckpt_path needs to be a string for SerDe
    optimizer = MegatronOptimizerModule(
        config=OptimizerConfig(
            lr=5e-4,
            optimizer="adam",
            use_distributed_optimizer=True,
            fp16=config.fp16,
            bf16=config.bf16,
        )
    )
    module = biobert_lightning_module(config=config, tokenizer=tokenizer, optimizer=optimizer, model_transform=peft)

    strategy = nl.MegatronStrategy(
        tensor_model_parallel_size=1,
        pipeline_model_parallel_size=1,
        ddp="megatron",
        find_unused_parameters=True,
        enable_nemo_ckpt_io=True,
    )
    strategy = 'ddp_notebook'

    if _use_rich_model_summary:
        # RichModelSummary is not used in the test suite until https://nvbugspro.nvidia.com/bug/4959776 is resolved due
        # to errors with serialization / deserialization.
        callbacks: list[Callback] = [RichModelSummary(max_depth=4)]
    else:
        callbacks = []

    if metric_tracker is not None:
        callbacks.append(metric_tracker)
    if peft is not None:
        callbacks.append(
            ModelTransform()
        )  # Callback needed for PEFT fine-tuning using NeMo2, i.e. biobert_lightning_module(model_transform=peft).
    trainer = nl.Trainer(
        accelerator="gpu",
        devices=1,
        strategy=strategy,
        limit_val_batches=2,
        val_check_interval=n_steps_train // 2,
        max_steps=n_steps_train,
        num_nodes=1,
        log_every_n_steps=n_steps_train // 2,
        callbacks=callbacks,
        plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
    )
    nllm.train(
        model=module,
        data=data_module,
        trainer=trainer,
        log=nemo_logger,
        resume=resume.AutoResume(
            resume_if_exists=True,  # Looks for the -last checkpoint to continue training.
            resume_ignore_no_checkpoint=True,  # When false this will throw an error with no existing checkpoint.
        ),
    )
    ckpt_path = Path(checkpoint_callback.last_model_path.replace(".ckpt", ""))
    return ckpt_path, metric_tracker, trainer


In [33]:
# set the results directory
experiment_results_dir = tempfile.TemporaryDirectory().name

# create a List[Tuple] with (sequence, target) values
artificial_sequence_data = [
    "TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
    "SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT",
]
data = [(seq, len(seq) / 100.0) for seq in artificial_sequence_data]

# we are training and validating on the same dataset for simplicity
dataset = InMemorySingleValueDataset(data)
data_module = ESM2FineTuneDataModule(train_dataset=train_dataset, valid_dataset=test_dataset)

experiment_name = "finetune_regressor"
n_steps_train = 50
seed = 42

# To download a 650M pre-trained ESM2 model
pretrain_ckpt_path = load("esm2/650m:2.0")

config = ESM2FineTuneSeqConfig(initial_ckpt_path=str(pretrain_ckpt_path))

checkpoint, metrics, trainer = train_model(
    experiment_name=experiment_name,
    experiment_dir=Path(experiment_results_dir),  # new checkpoint will land in a subdir of this
    config=config,  # same config as before since we are just continuing training
    data_module=data_module,
    n_steps_train=n_steps_train,
)
print(f"Experiment completed with checkpoint stored at {checkpoint}")

MisconfigurationException: `Trainer(strategy='ddp_spawn')` is not compatible with an interactive environment. Run your code as a script, or choose a notebook-compatible strategy: `Trainer(strategy='ddp_notebook')`. In case you are spawning processes yourself, make sure to include the Trainer creation inside the worker function.

In [31]:
help(nl.Trainer)


Help on class Trainer in module nemo.lightning.pytorch.trainer:

class Trainer(lightning.pytorch.trainer.trainer.Trainer, nemo.lightning.io.mixin.IOMixin)
 |  Trainer(*, accelerator: Union[str, lightning.pytorch.accelerators.accelerator.Accelerator] = 'auto', strategy: Union[str, lightning.pytorch.strategies.strategy.Strategy] = 'auto', devices: Union[List[int], str, int] = 'auto', num_nodes: int = 1, precision: Union[Literal[64, 32, 16], Literal['transformer-engine', 'transformer-engine-float16', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true'], Literal['64', '32', '16', 'bf16'], NoneType] = None, logger: Union[lightning.pytorch.loggers.logger.Logger, Iterable[lightning.pytorch.loggers.logger.Logger], bool, NoneType] = None, callbacks: Union[List[lightning.pytorch.callbacks.callback.Callback], lightning.pytorch.callbacks.callback.Callback, NoneType] = None, fast_dev_run: Union[int, bool] = False, max_epochs: Optional[int] = None, min_epochs: Optional[int] = None