In [1]:
import torch
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel
import torch.nn as nn
from typing import Dict, List, Optional, Tuple, Union
from transformers.modeling_outputs import TokenClassifierOutput

class_weights = torch.tensor([0.5303, 8.7481], device='cuda:0')
LOSS = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])
DROPOUT = 0.5
OUTPUT_SIZE = 1

class EsmForTokenClassificationCustom(EsmPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.esm = EsmModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(DROPOUT)
        self.classifier = nn.Linear(config.hidden_size, OUTPUT_SIZE)
        self.distance_regressor = nn.Linear(config.hidden_size, OUTPUT_SIZE)
        self.init_weights()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, TokenClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.esm(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        cbs_logits = self.classifier(sequence_output)
        distance_logits = self.distance_regressor(sequence_output)

        loss = None
        # changed to ignore special tokens at the seq start and end 
        # as well as invalid positions (labels -100)
        if labels is not None:
            loss_fct = LOSS

            active_loss = attention_mask.view(-1) == 1
            active_logits = logits.view(-1, self.num_labels)

            active_labels = torch.where(
              active_loss, labels.view(-1), torch.tensor(-100).type_as(labels)
            )

            valid_logits=active_logits[active_labels!=-100]
            valid_labels=active_labels[active_labels!=-100]
            
            
            valid_labels=valid_labels.type(torch.LongTensor).to('cuda:0')
            
            valid_logits = torch.softmax(valid_logits, dim=1)
            loss = loss_fct(valid_logits, valid_labels)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=cbs_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )# , TokenClassifierOutput(
        #     loss=loss,
        #     logits=distance_logits,
        #     hidden_states=outputs.hidden_states,
        #     attentions=outputs.attentions,
        # )

In [2]:
from dataclasses import dataclass
from transformers.data.data_collator import DataCollatorMixin
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
# based on transformers DataCollatorForTokenClassification
@dataclass
class DataCollatorForTokenClassificationESM(DataCollatorMixin):
    """
    Data collator that will dynamically pad the inputs received, as well as the labels.
    Args:
        tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
            The tokenizer used for encoding the data.
        padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
              sequence is provided).
            - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
              acceptable input length for the model if that argument is not provided.
            - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
        max_length (`int`, *optional*):
            Maximum length of the returned list and optionally padding length (see above).
        pad_to_multiple_of (`int`, *optional*):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
        label_pad_token_id (`int`, *optional*, defaults to -100):
            The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
        return_tensors (`str`):
            The type of Tensor to return. Allowable values are "np", "pt" and "tf".
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    label_pad_token_id: int = -100
    return_tensors: str = "pt"

    def torch_call(self, features):
        import torch
        if "distances" in features[0].keys():
            label_names = ['labels', 'distances']
        else:
            label_names = ['labels']

        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = {label_name: [feature[label_name] for feature in features] for label_name in label_names}

        no_labels_features = [{k: v for k, v in feature.items() if k not in label_names} for feature in features]

        batch = self.tokenizer.pad(
            no_labels_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        if labels is None:
            return batch

        sequence_length = batch["input_ids"].shape[1]
        padding_side = self.tokenizer.padding_side

        def to_list(tensor_or_iterable):
            if isinstance(tensor_or_iterable, torch.Tensor):
                return tensor_or_iterable.tolist()
            return list(tensor_or_iterable)

        for label_name in label_names:
            if padding_side == "right":
                batch[label_name] = [
                    # to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
                    # changed to pad the special tokens at the beginning and end of the sequence
                    [self.label_pad_token_id] + to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)-1) for label in labels[label_name]
                ]
            else:
                batch[label_name] = [
                    [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels[label_name]
                ]
    
            batch[label_name] = torch.tensor(batch[label_name], dtype=torch.float)
        return batch

def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
    """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
    import torch

    # Tensorize if necessary.
    if isinstance(examples[0], (list, tuple, np.ndarray)):
        examples = [torch.tensor(e, dtype=torch.long) for e in examples]

    length_of_first = examples[0].size(0)

    # Check if padding is necessary.

    are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
    if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
        return torch.stack(examples, dim=0)

    # If yes, check if we have a `pad_token`.
    if tokenizer._pad_token is None:
        raise ValueError(
            "You are attempting to pad samples but the tokenizer you are using"
            f" ({tokenizer.__class__.__name__}) does not have a pad token."
        )

    # Creating the full tensor and filling it with our data.
    max_length = max(x.size(0) for x in examples)
    if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
        max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
    result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
    for i, example in enumerate(examples):
        if tokenizer.padding_side == "right":
            result[i, : example.shape[0]] = example
        else:
            result[i, -example.shape[0] :] = example
    return result

def tolist(x):
    if isinstance(x, list):
        return x
    elif hasattr(x, "numpy"):  # Checks for TF tensors without needing the import
        x = x.numpy()
    return x.tolist()


2025-03-20 12:10:16.498165: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-20 12:10:16.533590: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
from transformers import AutoTokenizer

MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
# MODEL_NAME = 'facebook/esm2_t30_150M_UR50D'
# MODEL_NAME = "facebook/esm2_t6_8M_UR50D"

model = EsmForTokenClassificationCustom.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


Some weights of EsmForTokenClassificationCustom were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight', 'distance_regressor.bias', 'distance_regressor.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
from datasets import Dataset
import csv
import numpy as np

MAX_LENGTH = 1024
def get_dataset(annotation_path, tokenizer):

    sequences = []
    labels = []

    with open(annotation_path) as f:
        reader = csv.reader(f, delimiter=";")

        for row in reader:
            sequence = row[4]
            # max sequence length of ESM2
            if len(sequence) > MAX_LENGTH: continue 

            indices = [int(residue[1:]) for residue in row[3].split(' ')]
            label = np.zeros(len(sequence))
            label[indices] = 1
            sequences.append(sequence)
            labels.append(label) # np.eye(NUMBER_OF_CLASSES)[label])

            assert len(sequence) == len(label)

    train_tokenized = tokenizer(sequences, max_length=MAX_LENGTH, padding=True, truncation=True)
    
    dataset = Dataset.from_dict(train_tokenized)
    dataset = dataset.add_column("labels", labels)

    return dataset


In [5]:
from torch.utils.data import DataLoader

train_dataset = get_dataset('/home/skrhakv/cryptic-nn/src/fine-tuning/train.txt', tokenizer)
val_dataset = get_dataset('/home/skrhakv/cryptic-nn/src/fine-tuning/val.txt', tokenizer)

data_collator = DataCollatorForTokenClassificationESM(tokenizer) 
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=data_collator)
val_dataloader = DataLoader(val_dataset, batch_size=val_dataset.num_rows, shuffle=True, collate_fn=data_collator)


In [6]:
from sklearn import metrics
import gc

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
model.to(device)
optimizer = torch.optim.AdamW(params=model.parameters(),
                            lr=0.0001)
EPOCHS = 10

def accuracy_fn(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct / len(y_pred)) * 100
    return acc

def print_used_memory():
    free, total = torch.cuda.mem_get_info(torch.device('cuda:0'))
    mem_used_MB = (total - free) / 1024 ** 2
    mem_total_MB = (total) / 1024 ** 2
    print(f'{mem_used_MB} MB / {mem_total_MB} MB')

class_weights = torch.tensor([0.5303, 8.7481], device='cuda:0')

# TODO: Try multiply the class_weights[1] * 2
# BCEWithLogitsLoss - sigmoid is already built-in!
loss_fn = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])

test_losses = []
train_losses = []

for epoch in range(EPOCHS):
    model.eval()
    # print('Before test:')
    # print_used_memory()
    # VALIDATION LOOP
    with torch.inference_mode():
        for batch in val_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            # print('before prediction:')
            # print_used_memory()

            output = model(input_ids, attention_mask=attention_mask)
            # print('after prediction:')
            # print_used_memory()
        
            logits = output.logits.flatten(1)
            labels = batch['labels'].to(device)

            flattened_labels = labels.flatten()

            valid_flattened_logits = logits.flatten()[flattened_labels != -100]
            valid_flattened_labels = labels.flatten()[flattened_labels != -100]

            predictions = torch.round(torch.sigmoid(valid_flattened_logits))

            test_loss = loss_fn(valid_flattened_logits, valid_flattened_labels)

            test_losses.append(test_loss.cpu().detach().numpy())
            # print(valid_flattened_logits)
            # compute metrics on test dataset
            test_acc = accuracy_fn(y_true=valid_flattened_labels,
                                   y_pred=predictions)
            fpr, tpr, thresholds = metrics.roc_curve(valid_flattened_labels.cpu().numpy(), torch.sigmoid(valid_flattened_logits).cpu().numpy())
            roc_auc = metrics.auc(fpr, tpr)

            mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().numpy(), predictions.cpu().numpy())
            del input_ids, attention_mask, labels, logits, valid_flattened_logits, valid_flattened_labels
            gc.collect()
            torch.cuda.empty_cache()
    
    # print('after test')
    # print_used_memory()

    model.train()

    batch_losses = []

    # TRAIN

    # TODO: the following row causes the memory explosion
    # with torch.inference_mode():

    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        # Padded labels from the data collator
        # padded_labels += batch['labels'].tolist()

        # print('before train prediction')
        # print_used_memory()

        output = model(input_ids, attention_mask=attention_mask)
        logits = output.logits.flatten(1)
        labels = batch['labels'].to(device)
        flattened_labels = labels.flatten()
        # print('after train prediction')
        # print_used_memory()

        valid_flattened_logits = logits.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]

        loss = loss_fn(valid_flattened_logits, valid_flattened_labels)
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        batch_losses.append(loss.cpu().detach().numpy())
    torch.cuda.empty_cache()

    train_losses.append(sum(batch_losses) / len(batch_losses))
    print(f"Epoch: {epoch} | Loss: {loss:.5f}, Accuracy: {test_acc:.2f}% | Test loss: {test_loss:.5f}, AUC: {roc_auc}, MCC: {mcc}, sum: {sum(predictions)}")

# 35M:
# Epoch: 3 | Loss: 0.26539, Accuracy: 87.91% | Test loss: 0.71823, AUC: 0.8429591457364806, MCC: 0.33372585153627116, sum: 7430.0

# 150M:
# Epoch: 2 | Loss: 0.47842, Accuracy: 89.44% | Test loss: 0.63896, AUC: 0.8584368023837566, MCC: 0.36990818473863774, sum: 6639.0



Epoch: 0 | Loss: 0.73659, Accuracy: 17.84% | Test loss: 1.04761, AUC: 0.4395135999006268, MCC: -0.053494933235651576, sum: 48365.0
Epoch: 1 | Loss: 0.37189, Accuracy: 86.20% | Test loss: 0.60297, AUC: 0.8741129887987007, MCC: 0.3633713411500189, sum: 9149.0
Epoch: 2 | Loss: 0.18138, Accuracy: 89.32% | Test loss: 0.61765, AUC: 0.8726657087875005, MCC: 0.3940537278837092, sum: 7042.0
Epoch: 3 | Loss: 0.24919, Accuracy: 90.79% | Test loss: 0.79896, AUC: 0.8621024539783995, MCC: 0.3896582797271626, sum: 5727.0
Epoch: 4 | Loss: 0.07897, Accuracy: 90.17% | Test loss: 0.81526, AUC: 0.8614779733187512, MCC: 0.38747794123876006, sum: 6235.0


KeyboardInterrupt: 

In [None]:
from sklearn import metrics
import gc

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
model.to(device)
optimizer = torch.optim.AdamW(params=model.parameters(),
                            lr=0.0001)
EPOCHS = 10

def accuracy_fn(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct / len(y_pred)) * 100
    return acc

def print_used_memory():
    free, total = torch.cuda.mem_get_info(torch.device('cuda:0'))
    mem_used_MB = (total - free) / 1024 ** 2
    mem_total_MB = (total) / 1024 ** 2
    print(f'{mem_used_MB} MB / {mem_total_MB} MB')

class_weights = torch.tensor([0.5303, 8.7481], device='cuda:0')

# TODO: Try multiply the class_weights[1] * 2
# BCEWithLogitsLoss - sigmoid is already built-in!
loss_fn = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])

test_losses = []
train_losses = []

for epoch in range(EPOCHS):
    model.eval()
    print('Before test:')
    print_used_memory()
    # VALIDATION LOOP
    with torch.inference_mode():
        for batch in val_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            print('before prediction:')

            print_used_memory()

            output = model(input_ids, attention_mask=attention_mask)
            print('after prediction:')
            print_used_memory()
        
            logits = output.logits.flatten(1)
            labels = batch['labels'].to(device)

            flattened_labels = labels.flatten()

            valid_flattened_logits = logits.flatten()[flattened_labels != -100]
            valid_flattened_labels = labels.flatten()[flattened_labels != -100]

            predictions = torch.round(torch.sigmoid(valid_flattened_logits))

            test_loss =  loss_fn(valid_flattened_logits, valid_flattened_labels)

            test_losses.append(test_loss.cpu().detach().numpy())
            # print(valid_flattened_logits)
            # compute metrics on test dataset
            test_acc = accuracy_fn(y_true=valid_flattened_labels,
                                   y_pred=predictions)
            fpr, tpr, thresholds = metrics.roc_curve(valid_flattened_labels.cpu().numpy(), torch.sigmoid(valid_flattened_logits).cpu().numpy())
            roc_auc = metrics.auc(fpr, tpr)

            mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().numpy(), predictions.cpu().numpy())
            del input_ids, attention_mask, labels, logits, valid_flattened_logits, valid_flattened_labels
            gc.collect()
            torch.cuda.empty_cache()
    
    print('after test')
    print_used_memory()

    model.train()

    batch_losses = []

    # TRAIN
    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        # Padded labels from the data collator
        # padded_labels += batch['labels'].tolist()

        print('before train prediction')
        print_used_memory()

        output = model(input_ids, attention_mask=attention_mask)
        logits = output.logits.flatten(1)
        labels = batch['labels'].to(device)
        flattened_labels = labels.flatten()

        valid_flattened_logits = logits.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]

        loss = loss_fn(valid_flattened_logits, valid_flattened_labels)
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        batch_losses.append(loss.cpu().detach().numpy())
    torch.cuda.empty_cache()

    train_losses.append(sum(batch_losses) / len(batch_losses))
    print(f"Epoch: {epoch} | Loss: {loss:.5f}, Accuracy: {test_acc:.2f}% | Test loss: {test_loss:.5f}, AUC: {roc_auc}, MCC: {mcc}, sum: {sum(predictions)}")

# 35M:
# Epoch: 3 | Loss: 0.26539, Accuracy: 87.91% | Test loss: 0.71823, AUC: 0.8429591457364806, MCC: 0.33372585153627116, sum: 7430.0

# 150M:
# Epoch: 2 | Loss: 0.47842, Accuracy: 89.44% | Test loss: 0.63896, AUC: 0.8584368023837566, MCC: 0.36990818473863774, sum: 6639.0



Before test:
3043.0 MB / 81153.75 MB
before prediction:
3043.0 MB / 81153.75 MB
after prediction:
31835.0 MB / 81153.75 MB
after test
31837.0 MB / 81153.75 MB


8M:

Epoch: 0 | Loss: 0.78285, Accuracy: 18.43% | Test loss: 1.03734, AUC: 0.42720166912931357, MCC: -0.054964089541398004, sum: 47960.0
Epoch: 1 | Loss: 1.00518, Accuracy: 89.54% | Test loss: 0.79164, AUC: 0.7681007298652949, MCC: 0.23497681155777314, sum: 5025.0
Epoch: 2 | Loss: 0.67339, Accuracy: 78.65% | Test loss: 0.75463, AUC: 0.7986967979922242, MCC: 0.23954103427478463, sum: 12999.0
Epoch: 3 | Loss: 0.61600, Accuracy: 86.49% | Test loss: 0.74335, AUC: 0.7968872367921364, MCC: 0.26001800489415483, sum: 7627.0
Epoch: 4 | Loss: 0.47289, Accuracy: 85.79% | Test loss: 0.75486, AUC: 0.7990373875320378, MCC: 0.2609895855410456, sum: 8165.0
Epoch: 5 | Loss: 0.43352, Accuracy: 80.06% | Test loss: 0.78084, AUC: 0.8002741989833781, MCC: 0.24479102586531848, sum: 12085.0
Epoch: 6 | Loss: 0.43338, Accuracy: 85.32% | Test loss: 0.84007, AUC: 0.7894465876826939, MCC: 0.2561339791448311, sum: 8452.0
Epoch: 7 | Loss: 0.28171, Accuracy: 84.58% | Test loss: 0.87809, AUC: 0.7905991596640114, MCC: 0.24990835735850425, sum: 8914.0
Epoch: 8 | Loss: 0.20200, Accuracy: 87.29% | Test loss: 0.98652, AUC: 0.7812768510780681, MCC: 0.2568167667372486, sum: 6988.0
Epoch: 9 | Loss: 0.23710, Accuracy: 88.13% | Test loss: 1.04308, AUC: 0.783673662642753, MCC: 0.2642501885152083, sum: 6431.0

8M transfer learning:
Epoch: 39 | Loss: 0.84134, Accuracy: 85.23% | Test loss: 0.72810, AUC: 0.7989206203533358, MCC: 0.25630946545449773, sum: 8860.0

# NaN error:
It started giving NaNs. I restarted a few times, rolled back, tried other things. Not sure what actually helped

# 650M, batch size = 8:
Epoch: 0 | Loss: 0.73659, Accuracy: 17.84% | Test loss: 1.04761, AUC: 0.4395135999006268, MCC: -0.053494933235651576, sum: 48365.0
Epoch: 1 | Loss: 0.37189, Accuracy: 86.20% | Test loss: 0.60297, AUC: 0.8741129887987007, MCC: 0.3633713411500189, sum: 9149.0
Epoch: 2 | Loss: 0.18138, Accuracy: 89.32% | Test loss: 0.61765, AUC: 0.8726657087875005, MCC: 0.3940537278837092, sum: 7042.0
Epoch: 3 | Loss: 0.24919, Accuracy: 90.79% | Test loss: 0.79896, AUC: 0.8621024539783995, MCC: 0.3896582797271626, sum: 5727.0
Epoch: 4 | Loss: 0.07897, Accuracy: 90.17% | Test loss: 0.81526, AUC: 0.8614779733187512, MCC: 0.38747794123876006, sum: 6235.0

In [4]:
from datasets import Dataset
import csv
import numpy as np

MAX_LENGTH = 1024

def get_dataset_with_distances(annotation_path, tokenizer, scaler, distances_path='/home/skrhakv/cryptic-nn/data/cryptobench/residue-distances'):

    sequences = []
    labels = []
    distances = []
    with open(annotation_path) as f:
        reader = csv.reader(f, delimiter=";")

        for row in reader:
            protein_id = row[0].lower() + row[1]
            sequence = row[4]
            # max sequence length of ESM2
            if len(sequence) > MAX_LENGTH: continue 

            indices = [int(residue[1:]) for residue in row[3].split(' ')]
            label = np.zeros(len(sequence))
            label[indices] = 1
            distance = np.load(f'{distances_path}/{protein_id}.npy')
            distance[distance == -1] = 0.5

            if len(distance) != len(sequence): 
                print(f'{protein_id} doesn\'t match. Skipping ...')
                break

            # scale the distance
            distance = scaler.transform(distance.reshape(-1, 1)).reshape(1, -1)[0]

            sequences.append(sequence)
            labels.append(label) # np.eye(NUMBER_OF_CLASSES)[label])
            distances.append(distance)
    train_tokenized = tokenizer(sequences) #, padding='max_length', truncation=True, max_length=MAX_LENGTH)# , max_length=MAX_LENGTH, padding=True, truncation=True)
    
    dataset = Dataset.from_dict(train_tokenized)
    dataset = dataset.add_column("labels", labels)
    dataset = dataset.add_column("distances", distances)
    
    return dataset

from sklearn.preprocessing import StandardScaler

def train_scaler(annotation_path, distances_path='/home/skrhakv/cryptic-nn/data/cryptobench/residue-distances'):
    distances = []

    with open(annotation_path) as f:
        reader = csv.reader(f, delimiter=";")

        for row in reader:
            protein_id = row[0].lower() + row[1]
            distance = np.load(f'{distances_path}/{protein_id}.npy')
            distance[distance == -1] = 0.5
            distances.append(distance)

    scaler = StandardScaler()
    scaler.fit(np.concatenate(distances).reshape(-1, 1))
    return scaler


In [5]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding=True,truncation=True)

scaler = train_scaler('/home/skrhakv/cryptic-nn/src/fine-tuning/train.txt')
train_dataset = get_dataset_with_distances('/home/skrhakv/cryptic-nn/src/fine-tuning/train.txt', tokenizer, scaler)
val_dataset = get_dataset_with_distances('/home/skrhakv/cryptic-nn/src/fine-tuning/val.txt', tokenizer, scaler)
from torch.utils.data import DataLoader
data_collator = DataCollatorForTokenClassificationESM(tokenizer) 
train_dataloader = DataLoader(train_dataset, batch_size=8, collate_fn=data_collator) # shuffle=True, 
val_dataloader = DataLoader(val_dataset, batch_size=val_dataset.num_rows, collate_fn=data_collator)

val_iterator = iter(val_dataloader)
x = next(val_iterator)
print(x['input_ids'].shape, x['attention_mask'].shape, x['labels'].shape, x['distances'].shape)

torch.Size([193, 833]) torch.Size([193, 833]) torch.Size([193, 833]) torch.Size([193, 833])


In [6]:
from sklearn import metrics
from torch import nn

model = EsmForTokenClassificationCustom.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

class RMSLELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        
    def forward(self, pred, actual):
        return torch.sqrt(self.mse(torch.log(pred + 1), torch.log(actual + 1)))


device = 'cuda' if torch.cuda.is_available() else 'cpu'

model.to(device)
optimizer = torch.optim.AdamW(params=model.parameters(),
                            lr=0.0001)
EPOCHS = 10

def accuracy_fn(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct / len(y_pred)) * 100
    return acc

class_weights = torch.tensor([0.5303, 8.7481], device='cuda:0')

# TODO: Try multiply the class_weights[1] * 2
# BCEWithLogitsLoss - sigmoid is already built-in!
cbs_loss_fn = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])
distances_loss_fn = RMSLELoss() 

test_losses = []
train_losses = []

for epoch in range(EPOCHS):
    model.eval()
    
    # VALIDATION LOOP
    with torch.inference_mode():
        for batch in val_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            output1, output2 = model(input_ids, attention_mask=attention_mask)
            
            cbs_logits = output1.logits.flatten(1)
            distance_logits = output2.logits.flatten(1)

            labels = batch['labels'].to(device)
            distances = batch['distances'].to(device)

            flattened_labels = labels.flatten()

            valid_flattened_cbs_logits = cbs_logits.flatten()[flattened_labels != -100]
            valid_flattened_distance_logits = distance_logits.flatten()[flattened_labels != -100]
            valid_flattened_labels = labels.flatten()[flattened_labels != -100]
            valid_flattened_distances = distances.flatten()[flattened_labels != -100]

            predictions = torch.round(torch.sigmoid(valid_flattened_cbs_logits))

            cbs_test_loss =  cbs_loss_fn(valid_flattened_cbs_logits, valid_flattened_labels)
            distances_test_loss =  distances_loss_fn(valid_flattened_distance_logits, valid_flattened_distances)

            test_loss = cbs_test_loss + distances_test_loss
            test_losses.append(test_loss.cpu().detach().numpy())

            # compute metrics on test dataset
            test_acc = accuracy_fn(y_true=valid_flattened_labels,
                                   y_pred=predictions)

            fpr, tpr, thresholds = metrics.roc_curve(valid_flattened_labels.cpu().numpy(), torch.sigmoid(valid_flattened_cbs_logits).cpu().numpy())
            roc_auc = metrics.auc(fpr, tpr)

            mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().numpy(), predictions.cpu().numpy())
    
    torch.cuda.empty_cache()

    model.train()

    batch_losses = []

    # TRAIN
    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        # Padded labels from the data collator
        # padded_labels += batch['labels'].tolist()

        output1, output2 = model(input_ids, attention_mask=attention_mask)
        
        cbs_logits = output1.logits.flatten(1)
        distance_logits = output2.logits.flatten(1)

        labels = batch['labels'].to(device)
        distances = batch['distances'].to(device)

        flattened_labels = labels.flatten()

        valid_flattened_cbs_logits = cbs_logits.flatten()[flattened_labels != -100]
        valid_flattened_distance_logits = distance_logits.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]
        valid_flattened_distances = distances.flatten()[flattened_labels != -100]

        predictions = torch.round(torch.sigmoid(valid_flattened_cbs_logits))

        cbs_loss =  cbs_loss_fn(valid_flattened_cbs_logits, valid_flattened_labels)
        distances_loss =  distances_loss_fn(valid_flattened_distance_logits, valid_flattened_distances)

        loss = cbs_loss + 4 * distances_loss
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        batch_losses.append(loss.cpu().detach().numpy())
        
    torch.cuda.empty_cache()

    train_losses.append(sum(batch_losses) / len(batch_losses))
    print(f"Epoch: {epoch} | Loss: {loss:.5f}, Accuracy: {test_acc:.2f}% | Test loss: {test_loss:.5f} - CBS: {cbs_test_loss}, distance: {distances_test_loss}, AUC: {roc_auc}, MCC: {mcc}, sum: {sum(predictions)}")

Some weights of EsmForTokenClassificationCustom were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight', 'distance_regressor.bias', 'distance_regressor.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch: 0 | Loss: 2.62910, Accuracy: 22.17% | Test loss: 1.48222 - CBS: 1.028655767440796, distance: 0.4535630941390991, AUC: 0.45355163887608463, MCC: -0.04503063367232409, sum: 23.0
Epoch: 1 | Loss: 2.75044, Accuracy: 93.80% | Test loss: 1.30920 - CBS: 0.901236891746521, distance: 0.4079667627811432, AUC: 0.6626167473062828, MCC: 0.1263499090393085, sum: 104.0
Epoch: 2 | Loss: 2.84905, Accuracy: 90.45% | Test loss: 1.22923 - CBS: 0.8445392847061157, distance: 0.38468629121780396, AUC: 0.729914440872734, MCC: 0.2096951501950845, sum: 149.0
Epoch: 3 | Loss: 2.70347, Accuracy: 88.89% | Test loss: 1.17809 - CBS: 0.7982466220855713, distance: 0.37984153628349304, AUC: 0.7637647883872939, MCC: 0.21954990130813484, sum: 156.0
Epoch: 4 | Loss: 2.60240, Accuracy: 87.04% | Test loss: 1.15696 - CBS: 0.7784027457237244, distance: 0.3785569965839386, AUC: 0.7701274244147808, MCC: 0.22228374455587616, sum: 134.0
Epoch: 5 | Loss: 2.72189, Accuracy: 87.56% | Test loss: 1.14809 - CBS: 0.77476602792739

Tryout: 
1. focal loss and a contrastive triplet center loss (https://academic.oup.com/bib/article/25/1/bbad488/7505238)