In [None]:
import torch
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel
import torch.nn as nn
from typing import Dict, List, Optional, Tuple, Union
from transformers.modeling_outputs import TokenClassifierOutput

class_weights = torch.tensor([0.5303, 8.7481], device='cuda:0')
LOSS = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])
DROPOUT = 0.25
OUTPUT_SIZE = 1

class EsmForTokenClassificationCustom(EsmPreTrainedModel):
    def __init__(self, config):

        # 
        # config.use_cache = False
        super().__init__(config)
        print(config)
        self.esm = EsmModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(DROPOUT)
        self.classifier = nn.Linear(config.hidden_size, OUTPUT_SIZE)
        self.distance_regressor = nn.Linear(config.hidden_size, OUTPUT_SIZE)
        self.plDDT_regressor = nn.Linear(config.hidden_size, OUTPUT_SIZE)
        self.init_weights()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, TokenClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.esm(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]
        # print(sequence_output.shape)

        sequence_output = self.dropout(sequence_output)
        cbs_logits = self.classifier(sequence_output)
        distance_logits = self.distance_regressor(sequence_output)
        plddt_logits = self.plDDT_regressor(sequence_output)

        return TokenClassifierOutput(
            logits=cbs_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )# , TokenClassifierOutput(
        #     logits=distance_logits,
        #     hidden_states=outputs.hidden_states,
        #     attentions=outputs.attentions,
        # ), TokenClassifierOutput(
        #     logits=plddt_logits,
        #     hidden_states=outputs.hidden_states,
        #     attentions=outputs.attentions,
        # )

In [2]:
from dataclasses import dataclass
from transformers.data.data_collator import DataCollatorMixin
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
# based on transformers DataCollatorForTokenClassification
@dataclass
class DataCollatorForTokenClassificationESM(DataCollatorMixin):
    """
    Data collator that will dynamically pad the inputs received, as well as the labels.
    Args:
        tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
            The tokenizer used for encoding the data.
        padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
              sequence is provided).
            - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
              acceptable input length for the model if that argument is not provided.
            - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
        max_length (`int`, *optional*):
            Maximum length of the returned list and optionally padding length (see above).
        pad_to_multiple_of (`int`, *optional*):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
        label_pad_token_id (`int`, *optional*, defaults to -100):
            The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
        return_tensors (`str`):
            The type of Tensor to return. Allowable values are "np", "pt" and "tf".
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    label_pad_token_id: int = -100
    return_tensors: str = "pt"

    def torch_call(self, features):
        import torch
        if "distances" in features[0].keys():
            label_names = ['labels', 'distances']
        else:
            label_names = ['labels']

        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = {label_name: [feature[label_name] for feature in features] for label_name in label_names}

        no_labels_features = [{k: v for k, v in feature.items() if k not in label_names} for feature in features]

        batch = self.tokenizer.pad(
            no_labels_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        if labels is None:
            return batch

        sequence_length = batch["input_ids"].shape[1]
        padding_side = self.tokenizer.padding_side

        def to_list(tensor_or_iterable):
            if isinstance(tensor_or_iterable, torch.Tensor):
                return tensor_or_iterable.tolist()
            return list(tensor_or_iterable)

        for label_name in label_names:
            if padding_side == "right":
                batch[label_name] = [
                    # to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
                    # changed to pad the special tokens at the beginning and end of the sequence
                    [self.label_pad_token_id] + to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)-1) for label in labels[label_name]
                ]
            else:
                batch[label_name] = [
                    [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels[label_name]
                ]
    
            batch[label_name] = torch.tensor(batch[label_name], dtype=torch.float)
        return batch


2025-04-02 15:32:08.443140: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-02 15:32:08.481606: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
from transformers import AutoTokenizer

# MODEL_NAME = 'facebook/esm2_t36_3B_UR50D'
# MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
# MODEL_NAME = 'facebook/esm2_t30_150M_UR50D'
MODEL_NAME = "facebook/esm2_t6_8M_UR50D"

# TODO: try torch_dtype=torch.bfloat16
model = EsmForTokenClassificationCustom.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


Some weights of EsmForTokenClassificationCustom were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight', 'distance_regressor.bias', 'distance_regressor.weight', 'plDDT_regressor.bias', 'plDDT_regressor.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


EsmConfig {
  "_name_or_path": "facebook/esm2_t6_8M_UR50D",
  "architectures": [
    "EsmForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.0,
  "classifier_dropout": null,
  "emb_layer_norm_before": false,
  "esmfold_config": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 320,
  "initializer_range": 0.02,
  "intermediate_size": 1280,
  "is_folding_model": false,
  "layer_norm_eps": 1e-05,
  "mask_token_id": 32,
  "max_position_embeddings": 1026,
  "model_type": "esm",
  "num_attention_heads": 20,
  "num_hidden_layers": 6,
  "pad_token_id": 1,
  "position_embedding_type": "rotary",
  "token_dropout": true,
  "torch_dtype": "float32",
  "transformers_version": "4.39.2",
  "use_cache": false,
  "vocab_list": null,
  "vocab_size": 33
}



In [4]:
from datasets import Dataset
import csv
import numpy as np

MAX_LENGTH = 1024
def get_dataset(annotation_path, tokenizer):

    sequences = []
    labels = []

    with open(annotation_path) as f:
        reader = csv.reader(f, delimiter=";")

        for row in reader:
            sequence = row[4]
            # max sequence length of ESM2
            if len(sequence) > MAX_LENGTH: continue 

            indices = [int(residue[1:]) for residue in row[3].split(' ')]
            label = np.zeros(len(sequence))
            label[indices] = 1
            sequences.append(sequence)
            labels.append(label) # np.eye(NUMBER_OF_CLASSES)[label])

            assert len(sequence) == len(label)

    train_tokenized = tokenizer(sequences, max_length=MAX_LENGTH, padding=True, truncation=True)
    
    dataset = Dataset.from_dict(train_tokenized)
    dataset = dataset.add_column("labels", labels)

    return dataset


In [5]:
from torch.utils.data import DataLoader

train_dataset = get_dataset('/home/skrhakv/cryptic-nn/data/cryptobench/train.txt', tokenizer)
val_dataset = get_dataset('/home/skrhakv/cryptic-nn/data/cryptobench/test.txt', tokenizer)

data_collator = DataCollatorForTokenClassificationESM(tokenizer) 
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=data_collator)
val_dataloader = DataLoader(val_dataset, batch_size=val_dataset.num_rows, shuffle=True, collate_fn=data_collator)


In [None]:
from sklearn import metrics
import gc

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
model.to(device)
optimizer = torch.optim.AdamW(params=model.parameters(),
                            lr=0.0001)
EPOCHS = 10

def accuracy_fn(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct / len(y_pred)) * 100
    return acc

def print_used_memory():
    free, total = torch.cuda.mem_get_info(torch.device('cuda:0'))
    mem_used_MB = (total - free) / 1024 ** 2
    mem_total_MB = (total) / 1024 ** 2
    print(f'{mem_used_MB} MB / {mem_total_MB} MB')

class_weights = torch.tensor([0.5303, 8.7481], device='cuda:0')

# TODO: Try multiply the class_weights[1] * 2
# BCEWithLogitsLoss - sigmoid is already built-in!
loss_fn = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])

test_losses = []
train_losses = []

# with torch.autocast(device_type='cuda'):
for epoch in range(EPOCHS):
    model.eval()
    # print('Before test:')
    # print_used_memory()
    # VALIDATION LOOP
    with torch.no_grad():
        for batch in val_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            # print('before prediction:')
            # print_used_memory()

            output = model(input_ids, attention_mask=attention_mask)
            # print('after prediction:')
            # print_used_memory()
        
            logits = output.logits.flatten(1)
            labels = batch['labels'].to(device)

            flattened_labels = labels.flatten()

            valid_flattened_logits = logits.flatten()[flattened_labels != -100].float()
            valid_flattened_labels = labels.flatten()[flattened_labels != -100]

            predictions = torch.round(torch.sigmoid(valid_flattened_logits))

            test_loss = loss_fn(valid_flattened_logits, valid_flattened_labels)

            test_losses.append(test_loss.cpu().detach().numpy())
            # print(valid_flattened_logits)
            # compute metrics on test dataset
            test_acc = accuracy_fn(y_true=valid_flattened_labels,
                                    y_pred=predictions)
            fpr, tpr, thresholds = metrics.roc_curve(valid_flattened_labels.cpu().numpy(), torch.sigmoid(valid_flattened_logits).cpu().numpy())
            roc_auc = metrics.auc(fpr, tpr)

            mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().numpy(), predictions.cpu().numpy())
            del input_ids, attention_mask, labels, logits, valid_flattened_logits, valid_flattened_labels
            gc.collect()
            torch.cuda.empty_cache()
    
    # print('after test')
    # print_used_memory()

    model.train()

    batch_losses = []

    # TRAIN

    # TODO: the following row causes the memory explosion
    # with torch.inference_mode():

    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        # Padded labels from the data collator
        # padded_labels += batch['labels'].tolist()

        # print('before train prediction')
        # print_used_memory()

        output = model(input_ids, attention_mask=attention_mask)
        logits = output.logits.flatten(1)
        labels = batch['labels'].to(device)
        flattened_labels = labels.flatten()
        # print('after train prediction')
        # print_used_memory()

        valid_flattened_logits = logits.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]

        loss = loss_fn(valid_flattened_logits, valid_flattened_labels)
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        batch_losses.append(loss.cpu().detach().numpy())
        
        torch.cuda.empty_cache()

    train_losses.append(sum(batch_losses) / len(batch_losses))
    print(f"Epoch: {epoch} | Loss: {loss:.5f}, Accuracy: {test_acc:.2f}% | Test loss: {test_loss:.5f}, AUC: {roc_auc}, MCC: {mcc}, sum: {sum(predictions)}")

# 35M:
# Epoch: 3 | Loss: 0.26539, Accuracy: 87.91% | Test loss: 0.71823, AUC: 0.8429591457364806, MCC: 0.33372585153627116, sum: 7430.0

# 150M:
# Epoch: 2 | Loss: 0.47842, Accuracy: 89.44% | Test loss: 0.63896, AUC: 0.8584368023837566, MCC: 0.36990818473863774, sum: 6639.0



Epoch: 0 | Loss: 0.27190, Accuracy: 92.94% | Test loss: 1.34428, AUC: 0.7471983781967955, MCC: 0.2704698714087362, sum: 2488.0
Epoch: 1 | Loss: 0.27539, Accuracy: 89.50% | Test loss: 1.23362, AUC: 0.7643900761301369, MCC: 0.2460163020180396, sum: 5081.0
Epoch: 2 | Loss: 0.19173, Accuracy: 88.98% | Test loss: 1.31812, AUC: 0.7681208232947736, MCC: 0.2472843175379563, sum: 5488.0
Epoch: 3 | Loss: 0.10782, Accuracy: 91.13% | Test loss: 1.52325, AUC: 0.7705293194415301, MCC: 0.2629112046245141, sum: 3965.0
Epoch: 4 | Loss: 0.08166, Accuracy: 91.15% | Test loss: 1.69592, AUC: 0.7622030622471703, MCC: 0.2549547849067751, sum: 3871.0


In [None]:
from sklearn import metrics
import gc

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
model.to(device)
optimizer = torch.optim.AdamW(params=model.parameters(),
                            lr=0.0001)
EPOCHS = 10

def accuracy_fn(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct / len(y_pred)) * 100
    return acc

def print_used_memory():
    free, total = torch.cuda.mem_get_info(torch.device('cuda:0'))
    mem_used_MB = (total - free) / 1024 ** 2
    mem_total_MB = (total) / 1024 ** 2
    print(f'{mem_used_MB} MB / {mem_total_MB} MB')

class_weights = torch.tensor([0.5303, 8.7481], device='cuda:0')

# TODO: Try multiply the class_weights[1] * 2
# BCEWithLogitsLoss - sigmoid is already built-in!
loss_fn = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])

test_losses = []
train_losses = []

for epoch in range(EPOCHS):
    model.eval()
    print('Before test:')
    print_used_memory()
    # VALIDATION LOOP
    with torch.inference_mode():
        for batch in val_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            print('before prediction:')

            print_used_memory()

            output = model(input_ids, attention_mask=attention_mask)
            print('after prediction:')
            print_used_memory()
        
            logits = output.logits.flatten(1)
            labels = batch['labels'].to(device)

            flattened_labels = labels.flatten()

            valid_flattened_logits = logits.flatten()[flattened_labels != -100]
            valid_flattened_labels = labels.flatten()[flattened_labels != -100]

            predictions = torch.round(torch.sigmoid(valid_flattened_logits))

            test_loss =  loss_fn(valid_flattened_logits, valid_flattened_labels)

            test_losses.append(test_loss.cpu().detach().numpy())
            # print(valid_flattened_logits)
            # compute metrics on test dataset
            test_acc = accuracy_fn(y_true=valid_flattened_labels,
                                   y_pred=predictions)
            fpr, tpr, thresholds = metrics.roc_curve(valid_flattened_labels.cpu().numpy(), torch.sigmoid(valid_flattened_logits).cpu().numpy())
            roc_auc = metrics.auc(fpr, tpr)

            mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().numpy(), predictions.cpu().numpy())
            del input_ids, attention_mask, labels, logits, valid_flattened_logits, valid_flattened_labels
            gc.collect()
            torch.cuda.empty_cache()
    
    print('after test')
    print_used_memory()

    model.train()

    batch_losses = []

    # TRAIN
    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        # Padded labels from the data collator
        # padded_labels += batch['labels'].tolist()

        print('before train prediction')
        print_used_memory()

        output = model(input_ids, attention_mask=attention_mask)
        logits = output.logits.flatten(1)
        labels = batch['labels'].to(device)
        flattened_labels = labels.flatten()

        valid_flattened_logits = logits.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]

        loss = loss_fn(valid_flattened_logits, valid_flattened_labels)
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        batch_losses.append(loss.cpu().detach().numpy())
    torch.cuda.empty_cache()

    train_losses.append(sum(batch_losses) / len(batch_losses))
    print(f"Epoch: {epoch} | Loss: {loss:.5f}, Accuracy: {test_acc:.2f}% | Test loss: {test_loss:.5f}, AUC: {roc_auc}, MCC: {mcc}, sum: {sum(predictions)}")

# 35M:
# Epoch: 3 | Loss: 0.26539, Accuracy: 87.91% | Test loss: 0.71823, AUC: 0.8429591457364806, MCC: 0.33372585153627116, sum: 7430.0

# 150M:
# Epoch: 2 | Loss: 0.47842, Accuracy: 89.44% | Test loss: 0.63896, AUC: 0.8584368023837566, MCC: 0.36990818473863774, sum: 6639.0



Before test:
3043.0 MB / 81153.75 MB
before prediction:
3043.0 MB / 81153.75 MB
after prediction:
31835.0 MB / 81153.75 MB
after test
31837.0 MB / 81153.75 MB


8M:

Epoch: 0 | Loss: 0.78285, Accuracy: 18.43% | Test loss: 1.03734, AUC: 0.42720166912931357, MCC: -0.054964089541398004, sum: 47960.0
Epoch: 1 | Loss: 1.00518, Accuracy: 89.54% | Test loss: 0.79164, AUC: 0.7681007298652949, MCC: 0.23497681155777314, sum: 5025.0
Epoch: 2 | Loss: 0.67339, Accuracy: 78.65% | Test loss: 0.75463, AUC: 0.7986967979922242, MCC: 0.23954103427478463, sum: 12999.0
Epoch: 3 | Loss: 0.61600, Accuracy: 86.49% | Test loss: 0.74335, AUC: 0.7968872367921364, MCC: 0.26001800489415483, sum: 7627.0
Epoch: 4 | Loss: 0.47289, Accuracy: 85.79% | Test loss: 0.75486, AUC: 0.7990373875320378, MCC: 0.2609895855410456, sum: 8165.0
Epoch: 5 | Loss: 0.43352, Accuracy: 80.06% | Test loss: 0.78084, AUC: 0.8002741989833781, MCC: 0.24479102586531848, sum: 12085.0
Epoch: 6 | Loss: 0.43338, Accuracy: 85.32% | Test loss: 0.84007, AUC: 0.7894465876826939, MCC: 0.2561339791448311, sum: 8452.0
Epoch: 7 | Loss: 0.28171, Accuracy: 84.58% | Test loss: 0.87809, AUC: 0.7905991596640114, MCC: 0.24990835735850425, sum: 8914.0
Epoch: 8 | Loss: 0.20200, Accuracy: 87.29% | Test loss: 0.98652, AUC: 0.7812768510780681, MCC: 0.2568167667372486, sum: 6988.0
Epoch: 9 | Loss: 0.23710, Accuracy: 88.13% | Test loss: 1.04308, AUC: 0.783673662642753, MCC: 0.2642501885152083, sum: 6431.0

8M transfer learning:
Epoch: 39 | Loss: 0.84134, Accuracy: 85.23% | Test loss: 0.72810, AUC: 0.7989206203533358, MCC: 0.25630946545449773, sum: 8860.0

# NaN error:
It started giving NaNs. I restarted a few times, rolled back, tried other things. Not sure what actually helped

# 650M, batch size = 8:
Epoch: 0 | Loss: 0.73659, Accuracy: 17.84% | Test loss: 1.04761, AUC: 0.4395135999006268, MCC: -0.053494933235651576, sum: 48365.0
Epoch: 1 | Loss: 0.37189, Accuracy: 86.20% | Test loss: 0.60297, AUC: 0.8741129887987007, MCC: 0.3633713411500189, sum: 9149.0
Epoch: 2 | Loss: 0.18138, Accuracy: 89.32% | Test loss: 0.61765, AUC: 0.8726657087875005, MCC: 0.3940537278837092, sum: 7042.0
Epoch: 3 | Loss: 0.24919, Accuracy: 90.79% | Test loss: 0.79896, AUC: 0.8621024539783995, MCC: 0.3896582797271626, sum: 5727.0
Epoch: 4 | Loss: 0.07897, Accuracy: 90.17% | Test loss: 0.81526, AUC: 0.8614779733187512, MCC: 0.38747794123876006, sum: 6235.0

(to compare with 650M transfer learning:)
(Epoch: 19 | Loss: 0.57333, Accuracy: 87.17% | Test loss: 0.61883, AUC: 0.8642494368067841, MCC: 0.3501214427034594, sum: 8540.0)
so 0.01 AUC improvement, not great not terrible

# 650M MULTITASK learning:
Epoch: 0 | Loss: 0.64907, Accuracy: 27.87% | Test loss: 1.51112 - CBS: 1.3719383478164673, distance: 0.13918635249137878, AUC: 0.43274620440958667, MCC: -0.0542650498061712, sum: 111.0
Epoch: 1 | Loss: 0.48987, Accuracy: 83.13% | Test loss: 0.87216 - CBS: 0.8713946342468262, distance: 0.0007648273603990674, AUC: 0.8760669807589696, MCC: 0.3457642510325305, sum: 80.0
Epoch: 2 | Loss: 0.10602, Accuracy: 83.81% | Test loss: 0.99556 - CBS: 0.9947194457054138, distance: 0.0008437958895228803, AUC: 0.8668396447010817, MCC: 0.33656286186003864, sum: 36.0
Epoch: 3 | Loss: 0.45065, Accuracy: 85.19% | Test loss: 1.37749 - CBS: 1.376583456993103, distance: 0.0009044440812431276, AUC: 0.855343432123152, MCC: 0.3286448894311547, sum: 88.0
Epoch: 4 | Loss: 0.09049, Accuracy: 86.44% | Test loss: 1.36612 - CBS: 1.3654472827911377, distance: 0.0006769609753973782, AUC: 0.8566140648513804, MCC: 0.3439374629449201, sum: 26.0
Epoch: 5 | Loss: 0.11011, Accuracy: 88.46% | Test loss: 1.60821 - CBS: 1.6076503992080688, distance: 0.0005574710085056722, AUC: 0.8539008246932507, MCC: 0.3623427913775662, sum: 22.0
Epoch: 6 | Loss: 0.12407, Accuracy: 88.35% | Test loss: 1.73687 - CBS: 1.73642098903656, distance: 0.000453823187854141, AUC: 0.8575289308249205, MCC: 0.3587223828788271, sum: 73.0
Epoch: 7 | Loss: 0.00377, Accuracy: 90.80% | Test loss: 2.10933 - CBS: 2.1088972091674805, distance: 0.00043249232112430036, AUC: 0.8599045856585643, MCC: 0.3764254979736856, sum: 19.0
Epoch: 8 | Loss: 0.09617, Accuracy: 90.60% | Test loss: 2.00999 - CBS: 2.009589195251465, distance: 0.00039868077146820724, AUC: 0.8536332037559458, MCC: 0.37197599850317287, sum: 53.0
Epoch: 9 | Loss: 0.03469, Accuracy: 89.00% | Test loss: 1.89704 - CBS: 1.8966468572616577, distance: 0.0003907561185769737, AUC: 0.8506434369981348, MCC: 0.3449904319459605, sum: 24.0
# 650M boosted cbs-loss:
Epoch: 0 | Loss: 2.98668, Accuracy: 50.06% | Test loss: 1.27300 - CBS: 1.0052306652069092, distance: 0.2677696645259857, AUC: 0.4612774687124019, MCC: -0.023694198792789625, sum: 27406.0
Epoch: 1 | Loss: 1.02846, Accuracy: 89.12% | Test loss: 0.58897 - CBS: 0.5878161787986755, distance: 0.0011529671028256416, AUC: 0.8777632042771085, MCC: 0.4059588496168788, sum: 7218.0
Epoch: 2 | Loss: 0.55701, Accuracy: 87.81% | Test loss: 0.64764 - CBS: 0.6457691788673401, distance: 0.0018734950572252274, AUC: 0.8685130445474384, MCC: 0.3790404140190229, sum: 7928.0

In [None]:
from datasets import Dataset
import csv
import numpy as np

MAX_LENGTH = 1024

def get_dataset_with_distances(annotation_path, tokenizer, scaler, distances_path='/home/skrhakv/cryptic-nn/data/cryptobench/residue-distances', uniprot_ids=False):

    sequences = []
    labels = []
    distances = []
    with open(annotation_path) as f:
        reader = csv.reader(f, delimiter=";")

        for row in reader:
            if not uniprot_ids:
                protein_id = row[0].lower() + row[1]
            else:
                protein_id = row[0]
            sequence = row[4]
            # max sequence length of ESM2
            if len(sequence) > MAX_LENGTH: continue 
            
            indices = [int(residue[1:]) for residue in row[3].split(' ')]
            label = np.zeros(len(sequence))
            label[indices] = 1
            distance = np.load(f'{distances_path}/{protein_id}.npy')
            distance[distance == -1] = 0.5
            distance = np.clip(distance, 0, 10)

            if len(distance) != len(sequence): 
                print(f'{protein_id} doesn\'t match. Skipping ...')
                break

            # scale the distance
            distance = scaler.transform(distance.reshape(-1, 1)).reshape(1, -1)[0]

            sequences.append(sequence)
            labels.append(label) # np.eye(NUMBER_OF_CLASSES)[label])
            distances.append(distance)
    train_tokenized = tokenizer(sequences) #, padding='max_length', truncation=True, max_length=MAX_LENGTH)# , max_length=MAX_LENGTH, padding=True, truncation=True)
    
    dataset = Dataset.from_dict(train_tokenized)
    dataset = dataset.add_column("labels", labels)
    dataset = dataset.add_column("distances", distances)
    
    return dataset

from sklearn.preprocessing import StandardScaler, MinMaxScaler

def train_scaler(annotation_path, distances_path='/home/skrhakv/cryptic-nn/data/cryptobench/residue-distances', uniprot_ids=False):
    distances = []

    with open(annotation_path) as f:
        reader = csv.reader(f, delimiter=";")

        for row in reader:
            if not uniprot_ids:
                protein_id = row[0].lower() + row[1]
            else:
                protein_id = row[0]
            distance = np.load(f'{distances_path}/{protein_id}.npy')
            distance[distance == -1] = 0.5
            distances.append(distance)

    scaler = MinMaxScaler()
    scaler.fit(np.concatenate(distances).reshape(-1, 1))
    return scaler


In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding=True,truncation=True)

scaler = train_scaler('/home/skrhakv/cryptic-nn/src/fine-tuning/train.txt')
train_dataset = get_dataset_with_distances('/home/skrhakv/cryptic-nn/src/fine-tuning/train.txt', tokenizer, scaler)
val_dataset = get_dataset_with_distances('/home/skrhakv/cryptic-nn/src/fine-tuning/val.txt', tokenizer, scaler)

# scaler = train_scaler('/home/skrhakv/cryptic-nn/data/cryptobench/train.txt')
# train_dataset = get_dataset_with_distances('/home/skrhakv/cryptic-nn/data/cryptobench/train.txt', tokenizer, scaler)
# val_dataset = get_dataset_with_distances('/home/skrhakv/cryptic-nn/data/cryptobench/test.txt', tokenizer, scaler)

from torch.utils.data import DataLoader
data_collator = DataCollatorForTokenClassificationESM(tokenizer) 
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=data_collator)
val_dataloader = DataLoader(val_dataset, batch_size=val_dataset.num_rows, collate_fn=data_collator)

# val_iterator = iter(val_dataloader)
# x = next(val_iterator)
# print(x['input_ids'].shape, x['attention_mask'].shape, x['labels'].shape, x['distances'].shape)


In [None]:
class RMSELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        
    def forward(self, pred, actual):
        return torch.sqrt(self.mse(pred + 1, actual + 1))

class MSLELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        
    def forward(self, pred, actual):
        return self.mse(torch.log(pred + 1), torch.log(actual + 1))

class RMSLELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        
    def forward(self, pred, actual):
        return torch.sqrt(self.mse(torch.log(pred + 1), torch.log(actual + 1)))


In [None]:
from sklearn import metrics
from torch import nn

model = EsmForTokenClassificationCustom.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

torch.manual_seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model.to(device)
optimizer = torch.optim.AdamW(params=model.parameters(),
                            lr=0.0001)
EPOCHS = 10

def accuracy_fn(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct / len(y_pred)) * 100
    return acc

class_weights = torch.tensor([0.5303, 8.7481], device='cuda:0')

# TODO: Try multiply the class_weights[1] * 2
# BCEWithLogitsLoss - sigmoid is already built-in!
cbs_loss_fn = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])
distances_loss_fn = nn.MSELoss() 

test_losses = []
train_losses = []

for epoch in range(EPOCHS):
    model.eval()
    
    # VALIDATION LOOP
    with torch.inference_mode():
        for batch in val_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            output1, output2 = model(input_ids, attention_mask=attention_mask)
            
            cbs_logits = output1.logits.flatten(1)
            distance_logits = output2.logits.flatten(1)

            labels = batch['labels'].to(device)
            distances = batch['distances'].to(device)

            flattened_labels = labels.flatten()

            valid_flattened_cbs_logits = cbs_logits.flatten()[flattened_labels != -100]
            valid_flattened_distance_logits = distance_logits.flatten()[flattened_labels != -100]
            valid_flattened_labels = labels.flatten()[flattened_labels != -100]
            valid_flattened_distances = distances.flatten()[flattened_labels != -100]

            predictions = torch.round(torch.sigmoid(valid_flattened_cbs_logits))

            cbs_test_loss =  cbs_loss_fn(valid_flattened_cbs_logits, valid_flattened_labels)
            distances_test_loss =  distances_loss_fn(torch.sigmoid(valid_flattened_distance_logits), valid_flattened_distances)

            test_loss = cbs_test_loss + distances_test_loss
            test_losses.append(test_loss.cpu().detach().numpy())

            # compute metrics on test dataset
            test_acc = accuracy_fn(y_true=valid_flattened_labels,
                                   y_pred=predictions)

            # print(torch.sum(torch.isnan(torch.sigmoid(valid_flattened_cbs_logits))))
            fpr, tpr, thresholds = metrics.roc_curve(valid_flattened_labels.cpu().numpy(), torch.sigmoid(valid_flattened_cbs_logits).cpu().numpy())
            roc_auc = metrics.auc(fpr, tpr)

            mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().numpy(), predictions.cpu().numpy())
    
    torch.cuda.empty_cache()

    model.train()

    batch_losses = []

    # TRAIN
    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        # Padded labels from the data collator
        # padded_labels += batch['labels'].tolist()

        output1, output2 = model(input_ids, attention_mask=attention_mask)
        
        cbs_logits = output1.logits.flatten(1)
        distance_logits = output2.logits.flatten(1)

        labels = batch['labels'].to(device)
        distances = batch['distances'].to(device)

        flattened_labels = labels.flatten()

        valid_flattened_cbs_logits = cbs_logits.flatten()[flattened_labels != -100]
        valid_flattened_distance_logits = distance_logits.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]
        valid_flattened_distances = distances.flatten()[flattened_labels != -100]

        cbs_loss =  cbs_loss_fn(valid_flattened_cbs_logits, valid_flattened_labels)
        distances_loss =  distances_loss_fn(torch.sigmoid(valid_flattened_distance_logits), valid_flattened_distances)
# different loss, sigmoid
        loss = 4 * cbs_loss +  distances_loss
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        batch_losses.append(loss.cpu().detach().numpy())
        
    torch.cuda.empty_cache()

    train_losses.append(sum(batch_losses) / len(batch_losses))
    print(f"Epoch: {epoch} | Loss: {loss:.5f}, Accuracy: {test_acc:.2f}% | Test loss: {test_loss:.5f} - CBS: {cbs_test_loss}, distance: {distances_test_loss}, AUC: {roc_auc}, MCC: {mcc}, sum: {sum(predictions)}")

Some weights of EsmForTokenClassificationCustom were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight', 'distance_regressor.bias', 'distance_regressor.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch: 0 | Loss: 2.98668, Accuracy: 50.06% | Test loss: 1.27300 - CBS: 1.0052306652069092, distance: 0.2677696645259857, AUC: 0.4612774687124019, MCC: -0.023694198792789625, sum: 27406.0
Epoch: 1 | Loss: 1.02846, Accuracy: 89.12% | Test loss: 0.58897 - CBS: 0.5878161787986755, distance: 0.0011529671028256416, AUC: 0.8777632042771085, MCC: 0.4059588496168788, sum: 7218.0
Epoch: 2 | Loss: 0.55701, Accuracy: 87.81% | Test loss: 0.64764 - CBS: 0.6457691788673401, distance: 0.0018734950572252274, AUC: 0.8685130445474384, MCC: 0.3790404140190229, sum: 7928.0


KeyboardInterrupt: 

Tryout: 
1. focal loss and a contrastive triplet center loss (https://academic.oup.com/bib/article/25/1/bbad488/7505238)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding=True,truncation=True)

scaler = train_scaler('/home/skrhakv/cryptic-nn/data/ligysis/train.txt', distances_path='/home/skrhakv/cryptic-nn/data/ligysis/plDDT', uniprot_ids=True)
train_dataset = get_dataset_with_distances('/home/skrhakv/cryptic-nn/data/ligysis/train.txt', tokenizer, scaler, distances_path='/home/skrhakv/cryptic-nn/data/ligysis/plDDT', uniprot_ids=True)
val_dataset = get_dataset_with_distances('/home/skrhakv/cryptic-nn/data/cryptobench/test.txt', tokenizer, scaler)

from torch.utils.data import DataLoader
data_collator = DataCollatorForTokenClassificationESM(tokenizer) 
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=data_collator) 
val_dataloader = DataLoader(val_dataset, batch_size=val_dataset.num_rows, collate_fn=data_collator)



In [None]:
from sklearn import metrics
from torch import nn

model = EsmForTokenClassificationCustom.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

torch.manual_seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model.to(device)
optimizer = torch.optim.AdamW(params=model.parameters(),
                            lr=0.0001)
EPOCHS = 3

def accuracy_fn(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct / len(y_pred)) * 100
    return acc

class_weights = torch.tensor([0.5303, 8.7481], device='cuda:0')

# TODO: Try multiply the class_weights[1] * 2
# BCEWithLogitsLoss - sigmoid is already built-in!
cbs_loss_fn = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])
plDDT_loss_fn = nn.MSELoss() 

test_losses = []
train_losses = []

for epoch in range(EPOCHS):
    model.eval()
    
    # VALIDATION LOOP
    with torch.inference_mode():
        for batch in val_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            output1, _, output3 = model(input_ids, attention_mask=attention_mask)
            
            cbs_logits = output1.logits.flatten(1)
            distance_logits = output3.logits.flatten(1)

            labels = batch['labels'].to(device)
            distances = batch['distances'].to(device)

            flattened_labels = labels.flatten()

            valid_flattened_cbs_logits = cbs_logits.flatten()[flattened_labels != -100]
            valid_flattened_distance_logits = distance_logits.flatten()[flattened_labels != -100]
            valid_flattened_labels = labels.flatten()[flattened_labels != -100]
            valid_flattened_distances = distances.flatten()[flattened_labels != -100]

            predictions = torch.round(torch.sigmoid(valid_flattened_cbs_logits))

            cbs_test_loss =  cbs_loss_fn(valid_flattened_cbs_logits, valid_flattened_labels)
            # distances_test_loss =  distances_loss_fn(torch.sigmoid(valid_flattened_distance_logits), valid_flattened_distances)

            test_loss = cbs_test_loss # + distances_test_loss
            test_losses.append(test_loss.cpu().detach().numpy())

            # compute metrics on test dataset
            test_acc = accuracy_fn(y_true=valid_flattened_labels,
                                   y_pred=predictions)

            # print(torch.sum(torch.isnan(torch.sigmoid(valid_flattened_cbs_logits))))
            fpr, tpr, thresholds = metrics.roc_curve(valid_flattened_labels.cpu().numpy(), torch.sigmoid(valid_flattened_cbs_logits).cpu().numpy())
            roc_auc = metrics.auc(fpr, tpr)

            mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().numpy(), predictions.cpu().numpy())
    
    torch.cuda.empty_cache()

    model.train()

    batch_losses = []

    # TRAIN
    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        # Padded labels from the data collator
        # padded_labels += batch['labels'].tolist()

        output1, _, output3 = model(input_ids, attention_mask=attention_mask)
        
        cbs_logits = output1.logits.flatten(1)
        distance_logits = output3.logits.flatten(1)

        labels = batch['labels'].to(device)
        distances = batch['distances'].to(device)

        flattened_labels = labels.flatten()

        valid_flattened_cbs_logits = cbs_logits.flatten()[flattened_labels != -100]
        valid_flattened_distance_logits = distance_logits.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]
        valid_flattened_distances = distances.flatten()[flattened_labels != -100]

        cbs_loss =  cbs_loss_fn(valid_flattened_cbs_logits, valid_flattened_labels)
        distances_loss =  plDDT_loss_fn(torch.sigmoid(valid_flattened_distance_logits), valid_flattened_distances)
        
        # different loss, sigmoid
        loss = cbs_loss +  distances_loss
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        batch_losses.append(loss.cpu().detach().numpy())
        
    torch.cuda.empty_cache()

    train_losses.append(sum(batch_losses) / len(batch_losses))
    print(f"Epoch: {epoch} | Loss: {loss:.5f}, Accuracy: {test_acc:.2f}% | Test loss: {test_loss:.5f} - CBS: {cbs_test_loss}, AUC: {roc_auc}, MCC: {mcc}, sum: {sum(predictions)}")

Some weights of EsmForTokenClassificationCustom were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight', 'distance_regressor.bias', 'distance_regressor.weight', 'plDDT_regressor.bias', 'plDDT_regressor.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch: 0 | Loss: 0.83850, Accuracy: 7.99% | Test loss: 1.10899 - CBS: 1.1089880466461182, AUC: 0.40072266332934214, MCC: -0.06774546446956191, sum: 53778.0
Epoch: 1 | Loss: 0.37627, Accuracy: 84.18% | Test loss: 0.65449 - CBS: 0.654488742351532, AUC: 0.8773803128224543, MCC: 0.35904512468287153, sum: 10427.0
Epoch: 2 | Loss: 0.22312, Accuracy: 80.66% | Test loss: 0.67400 - CBS: 0.6740042567253113, AUC: 0.8757899214885686, MCC: 0.33932737993656564, sum: 12702.0


In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding=True,truncation=True)

scaler = train_scaler('/home/skrhakv/cryptic-nn/data/cryptobench/train.txt')
train_dataset = get_dataset_with_distances('/home/skrhakv/cryptic-nn/data/cryptobench/train.txt', tokenizer, scaler)
val_dataset = get_dataset_with_distances('/home/skrhakv/cryptic-nn/data/cryptobench/test.txt', tokenizer, scaler)

from torch.utils.data import DataLoader
data_collator = DataCollatorForTokenClassificationESM(tokenizer) 
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=data_collator)
val_dataloader = DataLoader(val_dataset, batch_size=val_dataset.num_rows, collate_fn=data_collator)



In [None]:

EPOCHS = 10

# TODO: Try multiply the class_weights[1] * 2
# BCEWithLogitsLoss - sigmoid is already built-in!
cbs_loss_fn = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])
distances_loss_fn = nn.MSELoss() 

test_losses = []
train_losses = []

for epoch in range(EPOCHS):
    model.eval()
    
    # VALIDATION LOOP
    with torch.inference_mode():
        for batch in val_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            output1, output2, _ = model(input_ids, attention_mask=attention_mask)
            
            cbs_logits = output1.logits.flatten(1)
            distance_logits = output2.logits.flatten(1)

            labels = batch['labels'].to(device)
            distances = batch['distances'].to(device)

            flattened_labels = labels.flatten()

            valid_flattened_cbs_logits = cbs_logits.flatten()[flattened_labels != -100]
            valid_flattened_distance_logits = distance_logits.flatten()[flattened_labels != -100]
            valid_flattened_labels = labels.flatten()[flattened_labels != -100]
            valid_flattened_distances = distances.flatten()[flattened_labels != -100]

            predictions = torch.round(torch.sigmoid(valid_flattened_cbs_logits))

            cbs_test_loss =  cbs_loss_fn(valid_flattened_cbs_logits, valid_flattened_labels)
            distances_test_loss =  distances_loss_fn(torch.sigmoid(valid_flattened_distance_logits), valid_flattened_distances)

            test_loss = cbs_test_loss + distances_test_loss
            test_losses.append(test_loss.cpu().detach().numpy())

            # compute metrics on test dataset
            test_acc = accuracy_fn(y_true=valid_flattened_labels,
                                   y_pred=predictions)

            # print(torch.sum(torch.isnan(torch.sigmoid(valid_flattened_cbs_logits))))
            fpr, tpr, thresholds = metrics.roc_curve(valid_flattened_labels.cpu().numpy(), torch.sigmoid(valid_flattened_cbs_logits).cpu().numpy())
            roc_auc = metrics.auc(fpr, tpr)

            mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().numpy(), predictions.cpu().numpy())
    
    torch.cuda.empty_cache()

    model.train()

    batch_losses = []

    # TRAIN
    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        # Padded labels from the data collator
        # padded_labels += batch['labels'].tolist()

        output1, output2, _ = model(input_ids, attention_mask=attention_mask)
        
        cbs_logits = output1.logits.flatten(1)
        distance_logits = output2.logits.flatten(1)

        labels = batch['labels'].to(device)
        distances = batch['distances'].to(device)

        flattened_labels = labels.flatten()

        valid_flattened_cbs_logits = cbs_logits.flatten()[flattened_labels != -100]
        valid_flattened_distance_logits = distance_logits.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]
        valid_flattened_distances = distances.flatten()[flattened_labels != -100]

        cbs_loss =  cbs_loss_fn(valid_flattened_cbs_logits, valid_flattened_labels)
        # distances_loss =  distances_loss_fn(torch.sigmoid(valid_flattened_distance_logits), valid_flattened_distances)
# different loss, sigmoid
        loss = cbs_loss # +  distances_loss
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        batch_losses.append(loss.cpu().detach().numpy())
        
    torch.cuda.empty_cache()

    train_losses.append(sum(batch_losses) / len(batch_losses))
    print(f"Epoch: {epoch} | Loss: {loss:.5f}, Accuracy: {test_acc:.2f}% | Test loss: {test_loss:.5f} - CBS: {cbs_test_loss}, distance: {distances_test_loss}, AUC: {roc_auc}, MCC: {mcc}, sum: {sum(predictions)}")

Epoch: 0 | Loss: 0.51899, Accuracy: 86.57% | Test loss: 0.94554 - CBS: 0.7093595862388611, distance: 0.23618388175964355, AUC: 0.8648569544953557, MCC: 0.3733514022556253, sum: 8817.0
Epoch: 1 | Loss: 0.47167, Accuracy: 89.55% | Test loss: 0.80418 - CBS: 0.5722964406013489, distance: 0.23188810050487518, AUC: 0.8841012865297871, MCC: 0.4222520065360723, sum: 7076.0
Epoch: 2 | Loss: 0.05917, Accuracy: 88.98% | Test loss: 0.93403 - CBS: 0.6853230595588684, distance: 0.24871017038822174, AUC: 0.8669081084432939, MCC: 0.39911661201568155, sum: 7247.0
Epoch: 3 | Loss: 0.31031, Accuracy: 91.95% | Test loss: 1.11387 - CBS: 0.8524396419525146, distance: 0.2614312171936035, AUC: 0.8680197789784566, MCC: 0.42306381062307064, sum: 5002.0
Epoch: 4 | Loss: 0.04782, Accuracy: 92.16% | Test loss: 1.25636 - CBS: 0.9804365038871765, distance: 0.2759261131286621, AUC: 0.8705917051782379, MCC: 0.43595482054907025, sum: 4965.0
Epoch: 5 | Loss: 0.09336, Accuracy: 91.67% | Test loss: 1.37343 - CBS: 1.087983

KeyboardInterrupt: 