In [3]:
WARMUP_STEPS = 6000
EPOCHS = 5
BATCH_SIZE = 64
LR = 2e-4
CACHE_KEY = "dna_key"

import wandb
wandb.login()
wandb.init(
    project="nuclprot",
    name="testing nuclprot CROSSATTN_FIN",
    config={
        "WARMUP_STEPS": WARMUP_STEPS,
        "EPOCHS": EPOCHS,
        "BATCH_SIZE": BATCH_SIZE,
        "LR": LR,
    },
)

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mnikolamilicevic[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [1]:
from datasets import load_dataset

ds = load_dataset("vladak/anthem_hla_seq")

In [3]:
# subsample
# from datasets import DatasetDict
# ds = DatasetDict({
#     split: ds[split].shuffle(seed=42).select(range(int(0.02 * len(ds[split]))))
#     for split in ds
# })
# ds

DatasetDict({
    train: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence'],
        num_rows: 10780
    })
    test: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence'],
        num_rows: 3451
    })
})

In [2]:
ds

DatasetDict({
    train: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence'],
        num_rows: 539019
    })
    test: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence'],
        num_rows: 172580
    })
})

In [4]:
import hashlib
def get_sequence_id(example):
    example[CACHE_KEY] = int(hashlib.sha256(example['Sequence'].encode()).hexdigest(), 16) % (10**12)
    return example

ds = ds.map(get_sequence_id)
ds

Map:   0%|          | 0/539019 [00:00<?, ? examples/s]

Map:   0%|          | 0/172580 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence', 'dna_key'],
        num_rows: 539019
    })
    test: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence', 'dna_key'],
        num_rows: 172580
    })
})

In [5]:
# Create model registry as singleton.
# Use this class to avoid creating the same model in multiple places
import torch
from transformers import AutoModel, AutoTokenizer

class ModelRegistry:
    _models = {}
    _tokenizers = {}
    _instance = None
    _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    @classmethod
    def get_model(cls, name):
        if name not in cls._models:
            model = AutoModel.from_pretrained(name).to(cls._device)
            cls._models[name] = model
        return cls._models[name]
    
    @classmethod
    def get_tokenizer(cls, name):
        if name not in cls._tokenizers:
            tokenizer = AutoTokenizer.from_pretrained(name)
            cls._tokenizers[name] = tokenizer
        return cls._tokenizers[name]

In [6]:
from tqdm import tqdm
from transformers import DataCollatorWithPadding
import torch.nn.functional as F

class EmbeddingCache(torch.utils.data.Dataset):
    """This class will precompute embeddings for the data and cache
    them for future reuse."""
    def __init__(self, data, key, input_ids_name, attention_mask_name, emb_model_name, device) -> None:
        """
        Args:
            data: Huggingface dataset that will be cached.
            key: Column with unique values for each sample (e.g. sequence hash).
            value: The value that we are computing embedding for and caching (input_ids + attn_mask).
            emb_model: Model used for computing embeddings.
            device: Device on which embeddings will reside.
        """
        self.device = device
        self.key = key
        self.emb_model_name = emb_model_name
        self.emb_model = ModelRegistry.get_model(emb_model_name)
        self.data = self._filter_duplicates(data, key)
        self.cached_embeddings = self._cache_embeddings(input_ids_name, attention_mask_name)

    def _filter_duplicates(self, data, key):
        seen = set()
        filtered_dataset =  data.filter(lambda example: not (example[key] in seen or seen.add(example[key])))
        return filtered_dataset

    def _cache_embeddings(self, input_ids_name, attention_mask_name):
        embeddings_cache = {}
        for sample in tqdm(self.data):
            id = sample[self.key]
            inputs = {
                "input_ids": torch.tensor(sample[input_ids_name]).unsqueeze(0).to(self.device),
                "attention_mask": torch.tensor(sample[attention_mask_name]).unsqueeze(0).to(self.device)
            }
            embeddings_cache[id] = self._compute_embedding(inputs)
        return embeddings_cache

    def _compute_embedding(self, inputs):
        with torch.cuda.amp.autocast():
            with torch.no_grad():
                embedding = self.emb_model(**inputs).last_hidden_state.detach().cpu()
        return embedding

    def __getitem__(self, keys):
        embeddings = []
        for key in keys:
            embeddings.append(self.cached_embeddings[key.item()])

        # pad with 0s to the end
        max_len = max(e.shape[1] for e in embeddings)
        for i in range(len(embeddings)):
            padding_size = max_len - embeddings[i].shape[1]
            if padding_size > 0:
                embeddings[i] = F.pad(embeddings[i], (0, 0, 0, padding_size))
            embeddings[i] = embeddings[i].squeeze(0)
        
        stacked_embeddings = torch.stack(embeddings, dim=0)
        return stacked_embeddings

    def __len__(self):
        return len(self.data)

In [7]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, DataCollatorWithPadding
import torch.nn.functional as F

class CrossAttentionLayer(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8, dropout=0.1):
        super(CrossAttentionLayer, self).__init__()
        self.modality1_to_modality2_attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.modality2_to_modality1_attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        
        ffn_hidden_dim = embed_dim * 3
        self.ffn_modality1 = nn.Sequential(
            nn.Linear(embed_dim, ffn_hidden_dim),
            nn.ReLU(),
            nn.Linear(ffn_hidden_dim, embed_dim),
        )
        self.ffn_modality2 = nn.Sequential(
            nn.Linear(embed_dim, ffn_hidden_dim),
            nn.ReLU(),
            nn.Linear(ffn_hidden_dim, embed_dim),
        )

        self.modality1_norm = nn.LayerNorm(embed_dim)
        self.modality2_norm = nn.LayerNorm(embed_dim)
        self.ffn_modality1_norm = nn.LayerNorm(embed_dim)
        self.ffn_modality2_norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, modality1_embedding, modality2_embedding, key_pad_mask_modality1, key_pad_mask_modality2):
        # Modality1 attending to Modality2
        attended_modality1, _ = self.modality1_to_modality2_attention(
            query=modality1_embedding, 
            key=modality2_embedding,
            value=modality2_embedding,
            key_padding_mask=key_pad_mask_modality2
        )
        attended_modality1 = self.modality1_norm(modality1_embedding + attended_modality1)
        x_modality1 = self.ffn_modality1(attended_modality1)
        x_modality1 = self.ffn_modality1_norm(attended_modality1 + self.dropout(x_modality1))

        # Modality2 attending to Modality1
        attended_modality2, _ = self.modality2_to_modality1_attention(
            query=modality2_embedding, 
            key=modality1_embedding, 
            value=modality1_embedding,
            key_padding_mask=key_pad_mask_modality1
        )
        attended_modality2 = self.modality2_norm(modality2_embedding + attended_modality2)
        x_modality2 = self.ffn_modality2(attended_modality2)
        x_modality2 = self.ffn_modality2_norm(attended_modality2 + self.dropout(x_modality2))

        return x_modality1, x_modality2

class BindingAffinityModelWithMultiHeadCrossAttention(nn.Module):
    def __init__(self, modality1_model_name, modality2_model_name, num_layers=3, hidden_dim=1024, modality1_cache=None, modality2_cache=None):
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Load pretrained ESM2 model for proteins
        self.modality1_model = ModelRegistry.get_model(modality1_model_name)
        self.modality1_tokenizer = ModelRegistry.get_tokenizer(modality1_model_name)
        self.modality1_model.eval()
        
        # Load pretrained ChemLLM for SMILES (ligands)
        self.modality2_model = ModelRegistry.get_model(modality2_model_name)
        self.modality2_tokenizer = ModelRegistry.get_tokenizer(modality2_model_name)
        self.modality2_model.eval()

        self.modality1_cache = modality1_cache
        self.modality2_cache = modality2_cache

        # Disable gradient computation for both base models
        for param in self.modality1_model.parameters():
            param.requires_grad = False

        for param in self.modality2_model.parameters():
            param.requires_grad = False
        
        self.modality1_embedding_dim = self.modality1_model.config.hidden_size
        self.modality2_embedding_dim = self.modality2_model.config.hidden_size

        # Projecting to the size of Modality1 model 
        self.project_to_common = nn.Linear(self.modality2_embedding_dim, self.modality1_embedding_dim)

        self.layers = nn.ModuleList([
            CrossAttentionLayer(embed_dim=self.modality1_embedding_dim) for _ in range(num_layers)
        ])

        self.ffn_class_head = nn.Sequential(
            nn.Linear(2 * self.modality1_embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        
    def forward(
            self,
            modality1_input_ids,
            modality1_attention_mask,
            modality2_input_ids,
            modality2_attention_mask,
            modality1_cache_keys=None,
            modality2_cache_keys=None,
        ):
        # Modality 1
        modality1_inputs = {
            "input_ids": modality1_input_ids,
            "attention_mask": modality1_attention_mask
        }
        
        if self.modality1_cache:
            modality1_embedding = self.modality1_cache[modality1_cache_keys]
            modality1_embedding = modality1_embedding.to(self.device)
        else:
            # perform in FP16 for lower memory usage (matmuls)
            with torch.cuda.amp.autocast():
                with torch.no_grad():
                    modality1_outputs = self.modality1_model(**modality1_inputs)
                    modality1_embedding = modality1_outputs.last_hidden_state
        
        special_tokens_mask_modality1 = (modality1_inputs['input_ids'] == self.modality1_tokenizer.cls_token_id)\
        | (modality1_inputs['input_ids'] == self.modality1_tokenizer.eos_token_id)\
        | (modality1_inputs['input_ids'] == self.modality1_tokenizer.pad_token_id)
        
        # Modality 2
        modality2_inputs = {
            "input_ids": modality2_input_ids,
            "attention_mask": modality2_attention_mask
        }
    
        if self.modality2_cache:
            modality2_embedding = self.modality2_cache[modality2_cache_keys]
            modality2_embedding = modality2_embedding.to(self.device)
        else:
            # perform in FP16 for lower memory usage (matmuls)
            with torch.cuda.amp.autocast():
                with torch.no_grad():
                    modality2_outputs = self.modality2_model(**modality2_inputs)
                    modality2_embedding = modality2_outputs.last_hidden_state

        special_tokens_mask_modality2 = (modality2_inputs['input_ids'] == self.modality2_tokenizer.cls_token_id)\
        | (modality2_inputs['input_ids'] == self.modality2_tokenizer.eos_token_id)\
        | (modality2_inputs['input_ids'] == self.modality2_tokenizer.pad_token_id)

        # project embeddings to same dimension
        modality2_embedding = self.project_to_common(modality2_embedding)
        
        for layer in self.layers:
            modality1_embedding, modality2_embedding = layer(modality1_embedding, modality2_embedding, special_tokens_mask_modality1, special_tokens_mask_modality2)

        # Perform mean pooling
        modality2_embedding = (modality2_embedding * ~special_tokens_mask_modality2.unsqueeze(dim=-1)).mean(dim=1)
        modality1_embedding = (modality1_embedding * ~special_tokens_mask_modality1.unsqueeze(dim=-1)).mean(dim=1)
        # Combine embeddings
        combined = torch.cat([modality1_embedding, modality2_embedding], dim=1)
        logits = self.ffn_class_head(combined)
        return logits


esm_model_name = "facebook/esm2_t33_650M_UR50D"  # Replace with the correct ESM2 model name
dna_model_name = "InstaDeepAI/nucleotide-transformer-2.5b-multi-species" # Replace with the correct ChemLLM model name


In [8]:
# Data split 
from datasets import DatasetDict

dataset_test = ds['test']
dataset_test_val = dataset_test.train_test_split(test_size=0.5)

dataset_dict = {
    "train": ds['train'],
    "test": dataset_test_val["train"],
    "validation": dataset_test_val['test']
}
dataset = DatasetDict(dataset_dict)
dataset

DatasetDict({
    train: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence', 'dna_key'],
        num_rows: 10780
    })
    test: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence', 'dna_key'],
        num_rows: 1725
    })
    validation: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence', 'dna_key'],
        num_rows: 1726
    })
})

In [9]:
# Tokenization of DNA and protein sequences
dna_tokenizer = AutoTokenizer.from_pretrained(dna_model_name)
esm_tokenizer = AutoTokenizer.from_pretrained(esm_model_name)
print(f"dna tokenizer is fast: {dna_tokenizer.is_fast}")
print(f"esm tokenizer is fast: {esm_tokenizer.is_fast}")

def tokenize_dna(examples):
    toks = dna_tokenizer(examples["Sequence"])
    return {
        "dna_input_ids": toks["input_ids"],
        "dna_attention_mask": toks["attention_mask"]
    }

def tokenize_proteins(examples):
    toks =  esm_tokenizer(examples["peptide"], truncation=True)
    return {
        "protein_input_ids": toks["input_ids"],
        "protein_attention_mask": toks["attention_mask"]
    }

tokenized_dataset = dataset.map(tokenize_proteins, batched=True)
tokenized_dataset = tokenized_dataset.map(tokenize_dna, batched=True)
tokenized_dataset

dna tokenizer is fast: False
esm tokenizer is fast: False


Map:   0%|          | 0/1725 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/1726 [00:00<?, ? examples/s]

Map:   0%|          | 0/1725 [00:00<?, ? examples/s]

Map:   0%|          | 0/1726 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence', 'dna_key', 'protein_input_ids', 'protein_attention_mask', 'dna_input_ids', 'dna_attention_mask'],
        num_rows: 10780
    })
    test: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence', 'dna_key', 'protein_input_ids', 'protein_attention_mask', 'dna_input_ids', 'dna_attention_mask'],
        num_rows: 1725
    })
    validation: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence', 'dna_key', 'protein_input_ids', 'protein_attention_mask', 'dna_input_ids', 'dna_attention_mask'],
        num_rows: 1726
    })
})

In [10]:
# Custom data collator
from transformers import DataCollatorWithPadding

class CustomDataCollator:
    def __init__(self, dna_collator, esm_collator):
            self.dna_collator = dna_collator
            self.esm_collator = esm_collator

    def __call__(self, batch):
        batch_dna = [{"input_ids": b["dna_input_ids"], "attention_mask": b["dna_attention_mask"], CACHE_KEY: b[CACHE_KEY]} for b in batch]
        # "sequence": b["Sequence"]
        batch_protein = [{"input_ids": b["protein_input_ids"], "attention_mask": b["protein_attention_mask"]} for b in batch]
        # "peptide": b["peptide"]

        collated_dna = self.dna_collator(batch_dna)
        collated_esm = self.esm_collator(batch_protein)

        return {
            CACHE_KEY: collated_dna[CACHE_KEY],
            "dna_input_ids": collated_dna["input_ids"],
            "dna_attention_mask": collated_dna["attention_mask"],
            "protein_input_ids": collated_esm["input_ids"],
            "protein_attention_mask": collated_esm["attention_mask"],
            "label": torch.tensor([x['Label'] for x in batch])
        }

In [11]:
from torch.utils.data import DataLoader
dna_collator = DataCollatorWithPadding(tokenizer=dna_tokenizer)
esm_collator = DataCollatorWithPadding(tokenizer=esm_tokenizer)
collator = CustomDataCollator(dna_collator=dna_collator, esm_collator=esm_collator)
train_dataloader = DataLoader(tokenized_dataset["train"], batch_size=BATCH_SIZE, collate_fn=collator)
test_dataloader = DataLoader(tokenized_dataset["test"], batch_size=BATCH_SIZE, collate_fn=collator)
val_dataloader = DataLoader(tokenized_dataset["validation"], batch_size=BATCH_SIZE, collate_fn=collator)

In [14]:
keys = torch.tensor(tokenized_dataset["train"][CACHE_KEY][:16])

In [16]:
# Training loop 
import torch.nn as nn
from tqdm import tqdm

    
def lr_lambda(step):
        if step < WARMUP_STEPS:
            # Linear warmup
            return step / WARMUP_STEPS
        else:
            remaining_steps = total_steps - WARMUP_STEPS
            decay_step = step - WARMUP_STEPS
            return max(0.5 * LR, 1.0 - 0.5 * (decay_step / remaining_steps))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")
dna_cache = EmbeddingCache(data=tokenized_dataset["train"], key=CACHE_KEY, input_ids_name="dna_input_ids", attention_mask_name="dna_attention_mask", emb_model_name=dna_model_name, device=device)
model = BindingAffinityModelWithMultiHeadCrossAttention(modality1_model_name=esm_model_name, modality2_model_name=dna_model_name, modality2_cache=dna_cache).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=0)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
total_steps = EPOCHS * len(train_dataloader)

Using device cuda


100%|██████████| 110/110 [00:08<00:00, 12.39it/s]
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
def train_model(model, train_dataloader, val_dataloader):
    step = 0
    ACCUMULATION_STEPS = 2

    for epoch in range(EPOCHS):
        print(f"Epoch: {epoch + 1}/{EPOCHS}")
        model.train()
        train_loss = 0.0
        train_progress = tqdm(train_dataloader, desc="Training")

        for batch in train_progress:
            dna_input_ids = batch["dna_input_ids"].to(device)
            dna_attention_mask = batch["dna_attention_mask"].to(device)
            protein_input_ids = batch["protein_input_ids"].to(device)
            protein_attention_mask = batch["protein_attention_mask"].to(device)
            dna_cache_keys=batch[CACHE_KEY]
            targets = batch["label"].unsqueeze(dim=-1).to(device)
            preds = model(
                modality1_input_ids=protein_input_ids, 
                modality1_attention_mask=protein_attention_mask, 
                modality2_input_ids=dna_input_ids, 
                modality2_attention_mask=dna_attention_mask,
                modality2_cache_keys=dna_cache_keys
                )
            loss = criterion(preds, targets.float())
            loss.backward()
            train_loss += loss.item()
            step += 1
            if step % ACCUMULATION_STEPS == 0:
                optimizer.step()
                optimizer.zero_grad()
            scheduler.step()

            if step % 100 == 0:
                wandb.log({"train_loss": loss.item()})
                wandb.log({"lr": optimizer.param_groups[0]["lr"]})

        train_loss /= len(train_dataloader)
        print(f"Epoch: {epoch} Train loss: {train_loss}")

        model.eval()
        val_loss = 0.0
        val_progress = tqdm(val_dataloader, desc="Validation")
        with torch.no_grad():
            for batch in val_progress:
                dna_input_ids = batch["dna_input_ids"].to(device)
                dna_attention_mask = batch["dna_attention_mask"].to(device)
                protein_input_ids = batch["protein_input_ids"].to(device)
                protein_attention_mask = batch["protein_attention_mask"].to(device)
                dna_cache_keys=batch[CACHE_KEY]
                targets = batch["label"].unsqueeze(dim=-1).to(device)
                preds = model(
                    modality1_input_ids=protein_input_ids, 
                    modality1_attention_mask=protein_attention_mask, 
                    modality2_input_ids=dna_input_ids, 
                    modality2_attention_mask=dna_attention_mask,
                    modality2_cache_keys=dna_cache_keys
                )
                loss = criterion(preds, targets.float())
                val_loss += loss.item()
                
        val_loss /= len(val_dataloader)
        print(f"Epoch: {epoch} Val loss: {val_loss}")

In [18]:
train_model(model, train_dataloader, val_dataloader)

Epoch: 1/2


Training:   0%|          | 0/169 [00:00<?, ?it/s]

Training:  32%|███▏      | 54/169 [19:38<01:16,  1.51it/s]   

In [12]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

total, trainable = count_parameters(model)
print(f"Total parameters: {total:,}")
print(f"Trainable parameters: {trainable:,}")

Total parameters: 3,293,934,871
Trainable parameters: 104,297,729


In [13]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, classification_report


# def evaluate_model(model, test_loader):
model.eval()
all_predictions = []
all_targets = []
test_progress = tqdm(test_dataloader, desc="Test set")
with torch.no_grad():
    for batch in test_progress:
        ligand_input_ids = batch["dna_input_ids"].to(device)
        ligand_attention_mask = batch["dna_attention_mask"].to(device)
        protein_input_ids = batch["protein_input_ids"].to(device)
        protein_attention_mask = batch["protein_attention_mask"].to(device)
        targets = batch["label"].unsqueeze(dim=-1).to(device)
        preds = model(
            ligand_input_ids,
            ligand_attention_mask,
            protein_input_ids,
            protein_attention_mask,
        )
        # transform preds to 0 - 1 
        # do sigmoid or sth
        probs = torch.sigmoid(preds)
        preds = (probs > 0.5).float()
        all_targets.append(targets)
        all_predictions.append(preds)

all_predictions = torch.cat(all_predictions).cpu()
all_targets = torch.cat(all_targets).cpu()

accuracy = accuracy_score(all_targets, all_predictions)
precision = precision_score(all_targets, all_predictions)
recall = recall_score(all_targets, all_predictions)
f1 = f1_score(all_targets, all_predictions)
auc = roc_auc_score(all_targets, all_predictions)
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-score: {f1:.4f}")
print(classification_report(all_targets, all_predictions))

# evaluate_model(model, test_loader=test_dataloader)

Test set: 100%|██████████| 68/68 [01:27<00:00,  1.28s/it]

Accuracy: 0.8445
Precision: 0.7971
Recall: 0.9278
F1-score: 0.8575
              precision    recall  f1-score   support

           0       0.91      0.76      0.83      2138
           1       0.80      0.93      0.86      2176

    accuracy                           0.84      4314
   macro avg       0.85      0.84      0.84      4314
weighted avg       0.85      0.84      0.84      4314






In [13]:
dataset['train'].features

{'HLA': Value(dtype='string', id=None),
 'peptide': Value(dtype='string', id=None),
 'Label': Value(dtype='int64', id=None),
 'Length': Value(dtype='int64', id=None),
 'Sequence': Value(dtype='string', id=None),
 '__index_level_0__': Value(dtype='int64', id=None)}