In [2]:
from datasets import load_dataset
from transformers import AutoModelForTokenClassification, AutoTokenizer, PreTrainedTokenizerFast


ss_dataset = load_dataset("lamm-mit/protein_secondary_structure_from_PDB", split="train", streaming=True)

  from .autonotebook import tqdm as notebook_tqdm
  import pynvml  # type: ignore[import]


In [3]:
model = AutoModelForTokenClassification.from_pretrained(
    "example_8m_checkpoint", num_labels=8, trust_remote_code=True, dtype="bfloat16"
)

Some weights of NVEsmForTokenClassification were not initialized from the model checkpoint at example_8m_checkpoint and are newly initialized: ['classifier._extra_state', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
entry = next(iter(ss_dataset))

In [5]:
ss_tokenizer = PreTrainedTokenizerFast(tokenizer_file="ss_tokenizer.json")
aa_tokenizer = AutoTokenizer.from_pretrained("example_8m_checkpoint")

tokenizer_args = {
    "max_length": 1024,
    "truncation": True,
    # "stride": 100,  # figure this out later
    # "return_overflowing_tokens": True,
}

In [6]:
def tokenize(example):
    result = {}
    result["input_ids"] = aa_tokenizer(example["Sequence"], **tokenizer_args)["input_ids"]
    tokenized_labels = ss_tokenizer(example["Secondary_structure"], **tokenizer_args)["input_ids"]
    result["labels"] = [[ii if ii != 8 else -100 for ii in item] for item in tokenized_labels]
    return result

In [8]:
for item in ss_dataset:
    break

In [12]:
max_length = max(item["Sequence_length"] for item in ss_dataset)

In [13]:
max_length

19350

In [14]:
tokenized_dataset = ss_dataset.map(
    tokenize, batched=True, remove_columns=[col for col in ss_dataset.features if col not in ["input_ids", "labels"]]
)

In [15]:
from transformers import DataCollatorForTokenClassification

In [16]:
collator = DataCollatorForTokenClassification(tokenizer=aa_tokenizer, padding="max_length", max_length=1024)

In [17]:
import torch

In [19]:
dataloader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=16, collate_fn=collator)
batch = next(iter(dataloader))

In [23]:
model.to("cuda")

NVEsmForTokenClassification(
  (esm): NVEsmModel(
    (embeddings): NVEsmEmbeddings(
      (word_embeddings): Embedding(64, 320, padding_idx=1)
    )
    (encoder): NVEsmEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerLayer(
          (self_attention): MultiheadAttention(
            (layernorm_qkv): LayerNormLinear()
            (core_attention): DotProductAttention(
              (flash_attention): FlashAttention()
              (fused_attention): FusedAttention()
              (unfused_attention): UnfusedDotProductAttention(
                (scale_mask_softmax): FusedScaleMaskSoftmax()
                (attention_dropout): Dropout(p=0.0, inplace=False)
              )
            )
            (proj): Linear()
          )
          (layernorm_mlp): LayerNormMLP()
        )
      )
      (emb_layer_norm_after): LayerNorm()
      (rotary_embeddings): RotaryPositionEmbedding()
    )
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (classifier): Linear()
)

In [24]:
batch = {k: v.to("cuda") for k, v in batch.items()}

In [26]:
output = model(**batch)

In [28]:
output.loss

tensor(2.2031, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<NllLossBackward0>)