# Introducing a new task to MAMMAL

Welcome and thanks for joining our journey! 


This tutorial will serve as a walkthrough guide to create a custom task on the MAMMAL framework to fine-tune our base model [**biomed.omics.bl.sm.ma-ted-458m**](https://huggingface.co/ibm/biomed.omics.bl.sm.ma-ted-458m). \
As a case study we will use our [protein solubility prediction task](https://github.com/BiomedSciAI/biomed-multi-alignment/tree/main?tab=readme-ov-file#protein-solubility-prediction). We will break down the main components so you will be able to create / modify them with your own task and data.

## Installation
Firstly, make sure you have mammal package install. Please follow the [installation guide](https://github.com/BiomedSciAI/biomed-multi-alignment/blob/main/README.md#installation) in the main README file.

As for this notebook, we can simply run the next code block (you'll might need to restart the session and re-run):

In [None]:
!pip install biomed-multi-alignment

## The Big Picture
In order to implement a downstream task for the MAMMAL framework, one should implement a `MammalTask` class instance ([source](https://github.com/BiomedSciAI/biomed-multi-alignment/blob/e56a03e0e9f69e42f919a96def739b78e50a47e5/mammal/task.py#L15)). A `MammalTask` class consists of three main components:
1. Data Module - A Lightning data module (`LightningDataModule`) where we load and process the data of the task.
2. `data_preprocessing()` method. Which responsible for formatting the input prompt to the model's expected format.
3. `process_model_output()` method. Which takes the raw output of the model and translates it into a human-level answers.

### 1. Data Module
We need to create a custom Lightning data module in order to load and prepare the input data for the model. \
In the datamodule we will need to:
1. Load the raw dataset from our source (in this example we use TDC's API) - will be done in `load_datasets()` .
2. Process the raw data to match the model's prompt structure - will be done in `data_preprocessing()` which we cover after in the next section.

In [None]:
import os
import wget
import shutil
import pandas as pd

from fuse.data.ops.ops_read import OpReadDataframe
from fuse.data.datasets.dataset_default import DatasetDefault
from fuse.data.pipelines.pipeline_default import PipelineDefault


_SOLUBILITY_URL = "https://zenodo.org/api/records/1162886/files-archive"


def load_datasets(data_path: str) -> dict[str, DatasetDefault]:
    """
    Automatically downloads the data and create dataset iterator for "train", "val" and "test".
    paper: https://academic.oup.com/bioinformatics/article/34/15/2605/4938490
    Data retrieved from: https://zenodo.org/records/1162886
    The benchmark requires classifying protein sequences into binary labels - Soluble or Insoluble (1 or 0).
    :param data_path: path to a directory to store the raw data
    :return: dictionary that maps fold name "train", "val" and "test" to a dataset iterator
    """

    if not os.path.exists(data_path):
        os.makedirs(data_path)

    raw_data_path = os.path.join(data_path, "sameerkhurana10-DSOL_rv0.2-20562ad/data")
    if not os.path.exists(raw_data_path):
        wget.download(_SOLUBILITY_URL, data_path)
        file_path = os.path.join(data_path, "1162886.zip")
        shutil.unpack_archive(file_path, extract_dir=data_path)
        inner_file_path = os.path.join(
            data_path, "sameerkhurana10", "DSOL_rv0.2-v0.3.zip"
        )
        shutil.unpack_archive(inner_file_path, extract_dir=data_path)
        assert os.path.exists(
            raw_data_path
        ), f"Error: download complete but {raw_data_path} doesn't exist"

    # read files
    df_dict = {}
    for set_name in ["train", "val", "test"]:
        input_df = pd.read_csv(
            os.path.join(raw_data_path, f"{set_name}_src"), names=["data.protein"]
        )
        labels_df = pd.read_csv(
            os.path.join(raw_data_path, f"{set_name}_tgt"), names=["data.label"]
        )
        df_dict[set_name] = (input_df, labels_df)

    ds_dict = {}
    for set_name in ["train", "val", "test"]:
        input_df, labels_df = df_dict[set_name]
        size = len(labels_df)
        print(f"{set_name} set size is {size}")
        dynamic_pipeline = PipelineDefault(
            "solubility",
            [
                (OpReadDataframe(input_df, key_column=None), dict()),
                (OpReadDataframe(labels_df, key_column=None), dict()),
            ],
        )

        ds = DatasetDefault(sample_ids=size, dynamic_pipeline=dynamic_pipeline)
        ds.create()
        ds_dict[set_name] = ds

    return ds_dict

In [None]:
# Load dataset - train, validation and test splits
ds_dict = load_datasets("./data")

# Retrieve and visualize a single sample
sample_dict = ds_dict["train"][0]

print("Visualize sample in a tree-fashion")
sample_dict.print_tree(print_values=True)

print("Visualize sample as a raw flat dictionary")
sample_dict

train set size is 62478
val set size is 6942
test set size is 1999
Visualize sample in a tree-fashion
--- data
------ label -> 1
------ initial_sample_id -> 0
------ sample_id -> 0
------ protein -> GMILKTNLFGHTYQFKSITDVLAKANEEKSGDRLAGVAAESAEERVAAKVVLSKMTLGDLRNNPVVPYETDEVTRIIQDQVNDRIHDSIKNWTVEELREWILDHKTTDADIKRVARGLTSEIIAAVTKLMSNLDLIYGAKKIRVIAHANTTIGLPGTFSARLQPNHPTDDPDGILASLMEGLTYGIGDAVIGLNPVDDSTDSVVRLLNKFEEFRSKWDVPTQTCVLAHVKTQMEAMRRGAPTGLVFQSIAGSEKGNTAFGFDGATIEEARQLALQSGAATGPNVMYFETGQGSELSSDAHFGVDQVTMEARCYGFAKKFDPFLVNTVVGFIGPEYLYDSKQVIRAGLEDHFMGKLTGISMGCDVCYTNHMKADQNDVENLSVLLTAAGCNFIMGIPHGDDVMLNYQTTGYHETATLRELFGLKPIKEFDQWMEKMGFSENGKLTSRAGDASIFLK
Visualize sample as a raw flat dictionary


{'data.initial_sample_id': 0, 'data.sample_id': 0, 'data.protein': 'GMILKTNLFGHTYQFKSITDVLAKANEEKSGDRLAGVAAESAEERVAAKVVLSKMTLGDLRNNPVVPYETDEVTRIIQDQVNDRIHDSIKNWTVEELREWILDHKTTDADIKRVARGLTSEIIAAVTKLMSNLDLIYGAKKIRVIAHANTTIGLPGTFSARLQPNHPTDDPDGILASLMEGLTYGIGDAVIGLNPVDDSTDSVVRLLNKFEEFRSKWDVPTQTCVLAHVKTQMEAMRRGAPTGLVFQSIAGSEKGNTAFGFDGATIEEARQLALQSGAATGPNVMYFETGQGSELSSDAHFGVDQVTMEARCYGFAKKFDPFLVNTVVGFIGPEYLYDSKQVIRAGLEDHFMGKLTGISMGCDVCYTNHMKADQNDVENLSVLLTAAGCNFIMGIPHGDDVMLNYQTTGYHETATLRELFGLKPIKEFDQWMEKMGFSENGKLTSRAGDASIFLK', 'data.label': 1}

Now we can complete our data module class.

##### NOTE
The `data_preprocessing` callable function will be implemented and explained in the next section.

In [None]:
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader
from fuse.data.utils.collates import CollateDefault
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp


class ProteinSolubilityDataModule(pl.LightningDataModule):
    def __init__(
        self,
        *,
        tokenizer_op: ModularTokenizerOp,
        data_preprocessing: callable,
        # Default values for simplicity
        batch_size: int = 1,
        data_path: str = "./example_solubility_data",
        protein_max_seq_length: int = 1250,
        encoder_input_max_seq_len: int = 1260,
        labels_max_seq_len: int = 4,
    ) -> None:
        """_summary_
        Args:
            data_path (str): path to the raw data, if not exist, will download the data to the given path.
            batch_size (int): batch size
            tokenizer_op (ModularTokenizerOp): tokenizer op
            encoder_inputs_max_seq_len: max tokenizer sequence length for the encoder inputs,
            labels_max_seq_len: max tokenizer sequence length for the labels,
            seed (int): random seed
        """
        super().__init__()
        self.data_path = data_path
        self.tokenizer_op = tokenizer_op
        self.protein_max_seq_length = protein_max_seq_length
        self.encoder_input_max_seq_len = encoder_input_max_seq_len
        self.labels_max_seq_len = labels_max_seq_len
        self.batch_size = batch_size
        self.data_preprocessing = data_preprocessing

        self.pad_token_id = self.tokenizer_op.get_token_id("<PAD>")

    def setup(self, stage: str) -> None:
        self.ds_dict = load_datasets(self.data_path)

        task_pipeline = [
            (
                # Prepare the input string(s) in modular tokenizer input format
                self.data_preprocessing,
                dict(
                    protein_sequence_key="data.protein",
                    solubility_label_key="data.label",
                    tokenizer_op=self.tokenizer_op,
                    protein_max_seq_length=self.protein_max_seq_length,
                    encoder_input_max_seq_len=self.encoder_input_max_seq_len,
                    labels_max_seq_len=self.labels_max_seq_len,
                ),
            ),
        ]

        for ds in self.ds_dict.values():
            ds.dynamic_pipeline.extend(task_pipeline)

    def train_dataloader(self) -> DataLoader:
        train_loader = DataLoader(
            dataset=self.ds_dict["train"],
            batch_size=self.batch_size,
            collate_fn=CollateDefault(),
            shuffle=True,
        )
        return train_loader

    def val_dataloader(self) -> DataLoader:
        val_loader = DataLoader(
            self.ds_dict["val"],
            batch_size=self.batch_size,
            collate_fn=CollateDefault(),
        )

        return val_loader

    def test_dataloader(self) -> DataLoader:
        test_loader = DataLoader(
            self.ds_dict["test"],
            batch_size=self.batch_size,
            collate_fn=CollateDefault(),
        )

        return test_loader

    def predict_dataloader(self) -> DataLoader:
        return self.test_dataloader()

### 2. `data_preprocessing()` function


This method plays a crucial role in the task's workflow. Here we process the raw data the we load into a prompt that fits the model's pretraining prompt distribution (see paper for more details). \
Besides formatting the prompt as a string, we also tokenize it using our custom modular tokenizer operator. Thus, it is essential to ensure that the raw data is translated into entities that align with the lexicon the model was trained on.


In [None]:
import torch

from mammal.keys import *  # Import all dictionary static keys


def data_preprocessing(
    sample_dict: dict,
    *,
    protein_sequence_key: str,
    tokenizer_op: ModularTokenizerOp,
    solubility_label_key: int | None = None,
    protein_max_seq_length: int = 1250,
    encoder_input_max_seq_len: int | None = 1260,
    labels_max_seq_len: int | None = 4,
    device: str | torch.device = "cpu",
) -> dict:
    """
    :param sample_dict: a dictionary with raw data
    :param protein_sequence_key: sample_dict key which points to protein sequence
    :param solubility_label_key: sample_dict key which points to label
    :param protein_max_seq_length: max sequence length of a protein. Will be used to truncate the protein
    :param encoder_input_max_seq_len: max sequence length of labels. Will be used to truncate/pad the encoder_input.
    :param labels_max_seq_len: max sequence length of labels. Will be used to truncate/pad the labels.
    :param tokenizer_op: tokenizer op

    """
    protein_sequence = sample_dict[protein_sequence_key]
    solubility_label = sample_dict.get(solubility_label_key, None)

    sample_dict[ENCODER_INPUTS_STR] = (
        f"<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><SOLUBILITY><SENTINEL_ID_0><@TOKENIZER-TYPE=AA@MAX-LEN={protein_max_seq_length}><SEQUENCE_NATURAL_START>{protein_sequence}<SEQUENCE_NATURAL_END><EOS>"
    )
    tokenizer_op(
        sample_dict=sample_dict,
        key_in=ENCODER_INPUTS_STR,
        key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
        key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
        max_seq_len=encoder_input_max_seq_len,
    )
    sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
        sample_dict[ENCODER_INPUTS_TOKENS], device=device
    )
    sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
        sample_dict[ENCODER_INPUTS_ATTENTION_MASK], device=device
    )

    if solubility_label is not None:
        pad_id = tokenizer_op.get_token_id("<PAD>")
        ignore_token_value = -100
        sample_dict[LABELS_STR] = (
            f"<@TOKENIZER-TYPE=AA><SENTINEL_ID_0><{solubility_label}><EOS>"
        )
        tokenizer_op(
            sample_dict=sample_dict,
            key_in=LABELS_STR,
            key_out_tokens_ids=LABELS_TOKENS,
            key_out_attention_mask=LABELS_ATTENTION_MASK,
            max_seq_len=labels_max_seq_len,
        )
        sample_dict[LABELS_TOKENS] = torch.tensor(
            sample_dict[LABELS_TOKENS], device=device
        )
        sample_dict[LABELS_ATTENTION_MASK] = torch.tensor(
            sample_dict[LABELS_ATTENTION_MASK], device=device
        )
        # replace pad_id with -100 to
        pad_id_tns = torch.tensor(pad_id)
        sample_dict[LABELS_TOKENS][
            (sample_dict[LABELS_TOKENS][..., None] == pad_id_tns).any(-1).nonzero()
        ] = ignore_token_value

        sample_dict[DECODER_INPUTS_STR] = (
            f"<@TOKENIZER-TYPE=AA><DECODER_START><SENTINEL_ID_0><{solubility_label}><EOS>"
        )
        tokenizer_op(
            sample_dict=sample_dict,
            key_in=DECODER_INPUTS_STR,
            key_out_tokens_ids=DECODER_INPUTS_TOKENS,
            key_out_attention_mask=DECODER_INPUTS_ATTENTION_MASK,
            max_seq_len=labels_max_seq_len,
        )
        sample_dict[DECODER_INPUTS_TOKENS] = torch.tensor(
            sample_dict[DECODER_INPUTS_TOKENS], device=device
        )
        sample_dict[DECODER_INPUTS_ATTENTION_MASK] = torch.tensor(
            sample_dict[DECODER_INPUTS_ATTENTION_MASK], device=device
        )

    return sample_dict

In [None]:
# Loading the tokenizer
tokenizer_op = ModularTokenizerOp.from_pretrained("ibm/biomed.omics.bl.sm.ma-ted-458m")

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]



In [None]:
import copy

protein_sequence_key = "data.protein_sequence"

# Define dummy initial sample dict
initial_sample_dict = dict()
initial_sample_dict[protein_sequence_key] = "AAA"

processed_sample_dict = data_preprocessing(
    copy.deepcopy(initial_sample_dict),
    protein_sequence_key=protein_sequence_key,
    tokenizer_op=tokenizer_op,
)

print(
    "Visualize participating sample dicts before and after the data processing for the model:"
)
print(f"{initial_sample_dict=}")
print(f"{processed_sample_dict=}")

34, 4435, 4436, 4437, 4438, 4439, 4440, 4441, 4442, 4443, 4444, 4445, 4446, 4447, 4448, 4449, 4450, 4451, 4452, 4453, 4454, 4455, 4456, 4457, 4458, 4459, 4460, 4461, 4462, 4463, 4464, 4465, 4466, 4467, 4468, 4469, 4470, 4471, 4472, 4473, 4474, 4475, 4476, 4477, 4478, 4479, 4480, 4481, 4482, 4483, 4484, 4485, 4486, 4487, 4488, 4489, 4490, 4491, 4492, 4493, 4494, 4495, 4496, 4497, 4498, 4499, 4500, 4501, 4502, 4503, 4504, 4505, 4506, 4507, 4508, 4509, 4510, 4511, 4512, 4513, 4514, 4515, 4516, 4517, 4518, 4519, 4520, 4521, 4522, 4523, 4524, 4525, 4526, 4527, 4528, 4529, 4530, 4531, 4532, 4533, 4534, 4535, 4536, 4537, 4538, 4539, 4540, 4541, 4542, 4543, 4544, 4545, 4546, 4547, 4548, 4549, 4550, 4551, 4552, 4553, 4554, 4555, 4556, 4557, 4558, 4559, 4560, 4561, 4562, 4563, 4564, 4565, 4566, 4567, 4568, 4569, 4570, 4571, 4572, 4573, 4574, 4575, 4576, 4577, 4578, 4579, 4580, 4581, 4582, 4583, 4584, 4585, 4586, 4587, 4588, 4589, 4590, 4591, 4592, 4593, 4594, 4595, 4596, 4597, 4598, 4599, 4600, 

### 3. `process_model_output()` function

In [None]:
import numpy as np


def process_model_output(
    tokenizer_op: ModularTokenizerOp,
    decoder_output: np.ndarray,
    decoder_output_scores: np.ndarray,
) -> dict | None:
    """
    Extract predicted solubility class and scores
    expecting decoder output to be <SENTINEL_ID_0><0><EOS> or <SENTINEL_ID_0><1><EOS>
    note - the normalized version will calculate the positive ('<1>') score divided by the sum of the scores for both '<0>' and '<1>'
        BE CAREFUL as both negative and positive absolute scores can be drastically low, and normalized score could be very high.
    outputs a dictionary containing:
        dict(
            predicted_token_str = #... e.g. '<1>'
            not_normalized_score = #the score for the positive token... e.g.  0.01
            normalized_score = #... (positive_token_score) / (positive_token_score+negative_token_score)
        )
        if there is any error in parsing the model output, None is returned.
    """

    negative_token_id = tokenizer_op.get_token_id("<0>")
    positive_token_id = tokenizer_op.get_token_id("<1>")
    label_id_to_int = {
        negative_token_id: 0,
        positive_token_id: 1,
    }
    classification_position = 1

    if decoder_output_scores is not None:
        not_normalized_score = decoder_output_scores[
            classification_position, positive_token_id
        ]
        normalized_score = not_normalized_score / (
            not_normalized_score
            + decoder_output_scores[classification_position, negative_token_id]
            + 1e-10
        )
    ans = dict(
        pred=label_id_to_int.get(int(decoder_output[classification_position]), -1),
        not_normalized_scores=not_normalized_score,
        normalized_scores=normalized_score,
    )

    return ans

### Metrics
As part of the ML workflow, it is essential to evaluate tasks using standard metrics. This repository includes a minimal implementation for metrics collection, for both classification and regression tasks, provided in `mammal/metrics.py`. The metrics are based on the [fuse-med-ml](https://github.com/BiomedSciAI/fuse-med-ml/tree/master) package.

**Classification metrics**: AUCROC, Accuracy, MCC (Matthews Correlation Coefficient ).
**Regression metrics**: Pearson Correlation, Spearman Correlation, MAE, MSE, RMSE, R2.

Since solubility is a classification problem (soluble or insoluble), we will utilize our classification metrics. Refer to the `classification_metrics` method in the code block below.

### Define Task Object

Now we can combine all of the components together to form our `MammalTask` object that will define the new task.

In [None]:
from typing import Any

from mammal.metrics import classification_metrics
from mammal.task import (
    MammalTask,
    MetricBase,
)


class ProteinSolubilityTask(MammalTask):
    def __init__(
        self,
        *,
        name: str,
        tokenizer_op: ModularTokenizerOp,
        logger: Any | None = None,
    ) -> None:
        super().__init__(
            name=name,
            logger=logger,
            tokenizer_op=tokenizer_op,
        )

        self.preds_key = CLS_PRED
        self.scores_key = SCORES
        self.labels_key = LABELS_TOKENS

    def data_module(self, **kwargs) -> pl.LightningDataModule:
        return ProteinSolubilityDataModule(
            tokenizer_op=self._tokenizer_op,
            data_preprocessing=self.data_preprocessing,
            **kwargs
        )

    def train_metrics(self) -> dict[str, MetricBase]:
        metrics = super().train_metrics()
        metrics.update(
            classification_metrics(
                self.name(),
                class_position=1,
                tokenizer_op=self._tokenizer_op,
                class_tokens=["<0>", "<1>"],
            )
        )

        return metrics

    def validation_metrics(self) -> dict[str, MetricBase]:
        validation_metrics = super().validation_metrics()
        validation_metrics.update(
            classification_metrics(
                self.name(),
                class_position=1,
                tokenizer_op=self._tokenizer_op,
                class_tokens=["<0>", "<1>"],
            )
        )
        return validation_metrics

    @staticmethod
    def data_preprocessing(
        sample_dict: dict,
        *,
        protein_sequence_key: str,
        tokenizer_op: ModularTokenizerOp,
        solubility_label_key: int | None = None,
        protein_max_seq_length: int = 1250,
        encoder_input_max_seq_len: int | None = 1260,
        labels_max_seq_len: int | None = 4,
        device: str | torch.device = "cpu",
    ) -> dict:

        # We use the method we defined above, just to make it cleaner
        # Another option is to write the underling logic inside this function definition.
        sample_dict = data_preprocessing(
            sample_dict=sample_dict,
            tokenizer_op=tokenizer_op,
            protein_sequence_key=protein_sequence_key,
            solubility_label_key=solubility_label_key,
            protein_max_seq_length=protein_max_seq_length,
            encoder_input_max_seq_len=encoder_input_max_seq_len,
            labels_max_seq_len=labels_max_seq_len,
            device=device,
        )
        return sample_dict

    @staticmethod
    def process_model_output(
        tokenizer_op: ModularTokenizerOp,
        decoder_output: np.ndarray,
        decoder_output_scores: np.ndarray,
    ) -> dict | None:

        # We use the method we defined above, just to make it cleaner
        ans = process_model_output(
            tokenizer_op=tokenizer_op,
            decoder_output=decoder_output,
            decoder_output_scores=decoder_output_scores,
        )

        return ans

## Finetune model
Now let's create a simple finetuning block using the task we just created. The following code is based on `main_finetune.py`, we you can run with your own task & configuration.

In [None]:
from torch.optim import AdamW

from fuse.dl.lightning.pl_module import LightningModuleDefault

from mammal.model import Mammal
from mammal.lr_schedulers import cosine_annealing_with_warmup_lr_scheduler

# Seed for reproducibility 
pl.seed_everything(42)

# Load pre-trained model from HF hub
model = Mammal.from_pretrained("ibm/biomed.omics.bl.sm.ma-ted-458m")

# Initialized OUR task :-)
task = ProteinSolubilityTask(name="our_solubility_prediction_task", tokenizer_op=tokenizer_op)

# Extract Task's datamodule to pass it to the Lightning module
pl_data_module = task.data_module()

# Define optimizer and lr scheduler for the Lightning module
opt = AdamW(params=model.parameters(), lr=0.00001)
lr_scheduler = cosine_annealing_with_warmup_lr_scheduler(optimizer=opt)
optimizers_and_lr_schs = dict(
        optimizer=opt,
        lr_scheduler={"scheduler": lr_scheduler, "interval": "step"},
)

# Initialize 'LightningModuleDefault' which is a subclass of 'pl.LightningModule' defined in Fuse-Med-ML.
pl_module = LightningModuleDefault( model=model,
        losses=task.losses(),
        validation_metrics=task.validation_metrics(),
        train_metrics=task.train_metrics(),
        optimizers_and_lr_schs=optimizers_and_lr_schs,
        model_dir="mammal_solubility_finetune",
)

# Create Lightning's Trainer
pl_trainer = pl.Trainer(max_epochs=1)

# Let the training begin
pl_trainer.fit(model=pl_module, datamodule=pl_data_module)


2024-11-23 20:36:13.779182: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-23 20:36:13.803888: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1732386973.823679  515207 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732386973.828829  515207 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-23 20:36:13.850546: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

Path doesn't exist. Will try to download fron hf hub. pretrained_model_name_or_path='ibm/biomed.omics.bl.sm.ma-ted-458m'


Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

Attempting to load model from dir: pretrained_model_name_or_path='/dccstor/mm_hcls/usr/sagi/.cache/models--ibm--biomed.omics.bl.sm.ma-ted-458m/snapshots/421daf3f8eae4ada57ffd3580f7347828b34d69a'


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint


train set size is 62478
val set size is 6942
test set size is 1999


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type   | Params | Mode 
------------------------------------------
0 | _model | Mammal | 458 M  | train
------------------------------------------
458 M     Trainable params
0         Non-trainable params
458 M     Total params
1,832.029 Total estimated model params size (MB)
579       Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

/dccstor/mm_hcls/usr/sagi/envs/biomed_multi_alignment/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
/dccstor/mm_hcls/usr/sagi/envs/biomed_multi_alignment/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |                                                                                                   …



Validation: |                                                                                                 …



`Trainer.fit` stopped: `max_epochs=1` reached.
