In [31]:
from datasets import load_dataset
from transformers import AutoModelForTokenClassification, AutoTokenizer, PreTrainedTokenizerFast


ss_dataset = load_dataset("lamm-mit/protein_secondary_structure_from_PDB", split="train", streaming=True)
tokenizer = AutoTokenizer.from_pretrained("example_8m_checkpoint")

tokenizer_args = {
    "max_length": 128,
    "truncation": True,
    "stride": 16,  # TODO: figure this out later
    "return_overflowing_tokens": True,
    "return_offsets_mapping": True,
}

In [33]:
example = next(iter(ss_dataset))

In [36]:
tokenized_example = tokenizer(example["Sequence"], **tokenizer_args)

In [49]:
len(example["Sequence"])

817

In [59]:
example["Secondary_structure"]

'~~HHHHHHHHT~~B~SS~TTTTTTTT~EEEEEEETT~~GGGSS~EETTEES~HHHHHHTTTSEEESSEE~~~~SSHHHHHHHHHHHSS~~~SSS~HHHHTTTB~~~~HHHHHHTTT~EEEEE~SS~TTGGGHHHHHHHTT~SEEE~TTTS~SSSEETTEE~HHHHHHHHHHHHHHHHHTT~~EEEEEE~~TT~TT~~~~GGG~~~~~~TTTTTSHHHHHHHHHHHHHHHHHHHHHHHHHHTSGGGEEEEEEES~~~~~GGGS~GGGGGGS~~TT~~SSTTHHHHB~~EEEE~TT~~~~EE~~S~EEGGGHHHHHHHHTT~~~TT~~~SS~~GGGBSS~~EEE~SSS~TTEEE~SSEEEE~SSSSTTSEEEETTS~EE~~~HHHHHHHHHHHHHHHHHHHHHHT~~B~~~~~HHHHHHHHT~~B~SS~TTTTTTTT~EEEEEEETT~~GGGSS~EETTEES~HHHHHHTTTSEEESSEE~~~~SSHHHHHHHHHHHSS~~~SSS~HHHHTTTB~~~~HHHHHHTTT~EEEEE~SS~TTGGGHHHHHHHTT~SEEE~HHHH~SSSEETTEE~HHHHHHHHHHHHHHHHHTT~~EEEEEE~~TTSTT~~~~GGG~~~~~~TTTTTSHHHHHHHHHHHHHHHHHHHHHHHHHHTSGGGEEEEEEE~~~~~~GGGS~GGGGGGS~~TT~~SSTTHHHHB~~EEEE~TT~~~~EE~~S~EEGGGHHHHHHHHTT~~~TT~~~SS~~GGGBSS~~EEE~SSS~TTEEE~SSEEEE~SSSSTTSEEEETTS~EE~~~HHHHHHHHHHHHHHHHHHHHHHT~~B~~'

In [54]:
tokenized_example["input_ids"][0]

[0,
 12,
 11,
 5,
 15,
 17,
 7,
 17,
 13,
 12,
 15,
 6,
 7,
 8,
 7,
 15,
 8,
 17,
 14,
 13,
 19,
 18,
 6,
 5,
 5,
 15,
 6,
 15,
 17,
 4,
 12,
 12,
 7,
 16,
 4,
 9,
 8,
 18,
 16,
 10,
 17,
 4,
 11,
 17,
 7,
 15,
 12,
 17,
 6,
 16,
 8,
 12,
 11,
 14,
 11,
 4,
 13,
 6,
 4,
 16,
 17,
 9,
 11,
 20,
 19,
 8,
 17,
 16,
 18,
 18,
 16,
 11,
 7,
 8,
 15,
 8,
 17,
 11,
 5,
 13,
 5,
 9,
 22,
 8,
 7,
 19,
 11,
 8,
 11,
 18,
 14,
 8,
 6,
 19,
 19,
 11,
 17,
 11,
 16,
 11,
 19,
 6,
 13,
 10,
 7,
 12,
 14,
 8,
 20,
 14,
 10,
 4,
 4,
 6,
 15,
 17,
 13,
 19,
 15,
 11,
 5,
 11,
 18,
 21,
 11,
 17,
 13,
 2]

In [63]:
example["Secondary_structure"][slice(*tokenized_example["offset_mapping"][0][3])]

'H'

In [None]:
tokenized_example["input_ids"][2]

In [None]:
list(tokenized_example.keys())

['input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping']

In [2]:
ss_dataset.features

{'PDB_ID': Value('string'),
 'Sequence': Value('string'),
 'Secondary_structure': Value('string'),
 'AH': Value('float64'),
 'BS': Value('float64'),
 'T': Value('float64'),
 'UNSTRUCTURED': Value('float64'),
 'BETABRIDGE': Value('float64'),
 '310HELIX': Value('float64'),
 'PIHELIX': Value('float64'),
 'BEND': Value('float64'),
 'Sequence_length': Value('int64'),
 'Sequence_spaced': Value('string'),
 'Primary_SS_Type': Value('string'),
 'Secondary_SS_Type': Value('string')}

In [3]:
model = AutoModelForTokenClassification.from_pretrained(
    "example_8m_checkpoint", num_labels=8, trust_remote_code=True, dtype="bfloat16"
)

Some weights of NVEsmForTokenClassification were not initialized from the model checkpoint at example_8m_checkpoint and are newly initialized: ['classifier._extra_state', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
entry = next(iter(ss_dataset))

In [5]:
ss_tokenizer = PreTrainedTokenizerFast(tokenizer_file="ss_tokenizer.json")
aa_tokenizer = AutoTokenizer.from_pretrained("example_8m_checkpoint")

tokenizer_args = {
    "max_length": 1024,
    "truncation": True,
    # "stride": 100,  # figure this out later
    # "return_overflowing_tokens": True,
}

In [6]:
def tokenize(example):
    result = {}
    result["input_ids"] = aa_tokenizer(example["Sequence"], **tokenizer_args)["input_ids"]
    tokenized_labels = ss_tokenizer(example["Secondary_structure"], **tokenizer_args)["input_ids"]
    result["labels"] = [[ii if ii != 8 else -100 for ii in item] for item in tokenized_labels]
    return result

In [8]:
for item in ss_dataset:
    break

In [12]:
max_length = max(item["Sequence_length"] for item in ss_dataset)

In [13]:
max_length

19350

In [14]:
tokenized_dataset = ss_dataset.map(
    tokenize, batched=True, remove_columns=[col for col in ss_dataset.features if col not in ["input_ids", "labels"]]
)

In [15]:
from transformers import DataCollatorForTokenClassification

In [16]:
collator = DataCollatorForTokenClassification(tokenizer=aa_tokenizer, padding="max_length", max_length=1024)

In [17]:
import torch

In [19]:
dataloader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=16, collate_fn=collator)
batch = next(iter(dataloader))

In [23]:
model.to("cuda")

NVEsmForTokenClassification(
  (esm): NVEsmModel(
    (embeddings): NVEsmEmbeddings(
      (word_embeddings): Embedding(64, 320, padding_idx=1)
    )
    (encoder): NVEsmEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerLayer(
          (self_attention): MultiheadAttention(
            (layernorm_qkv): LayerNormLinear()
            (core_attention): DotProductAttention(
              (flash_attention): FlashAttention()
              (fused_attention): FusedAttention()
              (unfused_attention): UnfusedDotProductAttention(
                (scale_mask_softmax): FusedScaleMaskSoftmax()
                (attention_dropout): Dropout(p=0.0, inplace=False)
              )
            )
            (proj): Linear()
          )
          (layernorm_mlp): LayerNormMLP()
        )
      )
      (emb_layer_norm_after): LayerNorm()
      (rotary_embeddings): RotaryPositionEmbedding()
    )
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (classifier): Linear()
)

In [24]:
batch = {k: v.to("cuda") for k, v in batch.items()}

In [26]:
output = model(**batch)

In [28]:
output.loss

tensor(2.2031, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<NllLossBackward0>)