# HOW TO: Introducing New Task

Welcome and thanks for joining our journey! 


This tutorial will serve as a walkthrough guide to create a custom task on the MAMMAL framework to fine-tune our base model [**biomed.omics.bl.sm.ma-ted-458m**](https://huggingface.co/ibm/biomed.omics.bl.sm.ma-ted-458m). \
As a case study we will use our [protein solubility prediction](https://github.com/BiomedSciAI/biomed-multi-alignment/tree/main?tab=readme-ov-file#protein-solubility-prediction) task. We will break down the main components so you will be able to create / modify them with your own task and data.

## Installation
Firstly, make sure you have mammal package install. Please follow the [installation guide](https://github.com/BiomedSciAI/biomed-multi-alignment/blob/main/README.md#installation) in the main README file.

## The Big Picture
In order to implement a downstream task for the MAMMAL framework, one should implement a `MammalTask` ([source](https://github.com/BiomedSciAI/biomed-multi-alignment/blob/e56a03e0e9f69e42f919a96def739b78e50a47e5/mammal/task.py#L15)). A `MammalTask` consist of three main components:
1. Data Module - A Lightning data module class where we load and process the data for the task.
2. `data_preprocessing()` function. Which responsible for formatting the input prompt.
3. `process_model_output()` function. Which takes the raw output of the model and translate it into a human-level answers.

### 1. Data Module
We need to create a custom Lightning data module in order to load and prepare the input data for the model. \
In the datamodule we will need to:
1. Load the raw dataset from our source (in this example we use TDC's API) - will be done in `load_datasets()` .
2. Process the raw data to match the model's prompt structure - will be done in `data_preprocessing()` which we cover after in the next section.

In [None]:
import os
import wget
import shutil
import pandas as pd

from fuse.data.ops.ops_read import OpReadDataframe
from fuse.data.datasets.dataset_default import DatasetDefault
from fuse.data.pipelines.pipeline_default import PipelineDefault


_SOLUBILITY_URL = "https://zenodo.org/api/records/1162886/files-archive"


def load_datasets(data_path: str) -> dict[str, DatasetDefault]:
    """
    Automatically downloads the data and create dataset iterator for "train", "val" and "test".
    paper: https://academic.oup.com/bioinformatics/article/34/15/2605/4938490
    Data retrieved from: https://zenodo.org/records/1162886
    The benchmark requires classifying protein sequences into binary labels - Soluble or Insoluble (1 or 0).
    :param data_path: path to a directory to store the raw data
    :return: dictionary that maps fold name "train", "val" and "test" to a dataset iterator
    """

    if not os.path.exists(data_path):
        os.makedirs(data_path)

    raw_data_path = os.path.join(data_path, "sameerkhurana10-DSOL_rv0.2-20562ad/data")
    if not os.path.exists(raw_data_path):
        wget.download(_SOLUBILITY_URL, data_path)
        file_path = os.path.join(data_path, "1162886.zip")
        shutil.unpack_archive(file_path, extract_dir=data_path)
        inner_file_path = os.path.join(
            data_path, "sameerkhurana10", "DSOL_rv0.2-v0.3.zip"
        )
        shutil.unpack_archive(inner_file_path, extract_dir=data_path)
        assert os.path.exists(
            raw_data_path
        ), f"Error: download complete but {raw_data_path} doesn't exist"

    # read files
    df_dict = {}
    for set_name in ["train", "val", "test"]:
        input_df = pd.read_csv(
            os.path.join(raw_data_path, f"{set_name}_src"), names=["data.protein"]
        )
        labels_df = pd.read_csv(
            os.path.join(raw_data_path, f"{set_name}_tgt"), names=["data.label"]
        )
        df_dict[set_name] = (input_df, labels_df)

    ds_dict = {}
    for set_name in ["train", "val", "test"]:
        input_df, labels_df = df_dict[set_name]
        size = len(labels_df)
        print(f"{set_name} set size is {size}")
        dynamic_pipeline = PipelineDefault(
            "solubility",
            [
                (OpReadDataframe(input_df, key_column=None), dict()),
                (OpReadDataframe(labels_df, key_column=None), dict()),
            ],
        )

        ds = DatasetDefault(sample_ids=size, dynamic_pipeline=dynamic_pipeline)
        ds.create()
        ds_dict[set_name] = ds

    return ds_dict

In [None]:
# Load dataset - train, validation and test splits
ds_dict = load_datasets("./data")

# Retrieve and visualize a single sample
sample_dict = ds_dict["train"][0]

print("Visualize sample in a tree-fashion")
sample_dict.print_tree(print_values=True)

print("Visualize sample as a raw flat dictionary")
sample_dict

-1 / unknowntrain set size is 62478
val set size is 6942
test set size is 1999
Visualize sample in a tree-fashion
--- data
------ protein -> GMILKTNLFGHTYQFKSITDVLAKANEEKSGDRLAGVAAESAEERVAAKVVLSKMTLGDLRNNPVVPYETDEVTRIIQDQVNDRIHDSIKNWTVEELREWILDHKTTDADIKRVARGLTSEIIAAVTKLMSNLDLIYGAKKIRVIAHANTTIGLPGTFSARLQPNHPTDDPDGILASLMEGLTYGIGDAVIGLNPVDDSTDSVVRLLNKFEEFRSKWDVPTQTCVLAHVKTQMEAMRRGAPTGLVFQSIAGSEKGNTAFGFDGATIEEARQLALQSGAATGPNVMYFETGQGSELSSDAHFGVDQVTMEARCYGFAKKFDPFLVNTVVGFIGPEYLYDSKQVIRAGLEDHFMGKLTGISMGCDVCYTNHMKADQNDVENLSVLLTAAGCNFIMGIPHGDDVMLNYQTTGYHETATLRELFGLKPIKEFDQWMEKMGFSENGKLTSRAGDASIFLK
------ initial_sample_id -> 0
------ label -> 1
------ sample_id -> 0
Visualize sample as a raw flat dictionary


{'data.initial_sample_id': 0, 'data.sample_id': 0, 'data.protein': 'GMILKTNLFGHTYQFKSITDVLAKANEEKSGDRLAGVAAESAEERVAAKVVLSKMTLGDLRNNPVVPYETDEVTRIIQDQVNDRIHDSIKNWTVEELREWILDHKTTDADIKRVARGLTSEIIAAVTKLMSNLDLIYGAKKIRVIAHANTTIGLPGTFSARLQPNHPTDDPDGILASLMEGLTYGIGDAVIGLNPVDDSTDSVVRLLNKFEEFRSKWDVPTQTCVLAHVKTQMEAMRRGAPTGLVFQSIAGSEKGNTAFGFDGATIEEARQLALQSGAATGPNVMYFETGQGSELSSDAHFGVDQVTMEARCYGFAKKFDPFLVNTVVGFIGPEYLYDSKQVIRAGLEDHFMGKLTGISMGCDVCYTNHMKADQNDVENLSVLLTAAGCNFIMGIPHGDDVMLNYQTTGYHETATLRELFGLKPIKEFDQWMEKMGFSENGKLTSRAGDASIFLK', 'data.label': 1}

Now we can complete our data module class.

##### NOTE
The `data_preprocessing` callable function will be implemented and explained in the next section.

In [None]:
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader
from fuse.data.utils.collates import CollateDefault
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp


class ProteinSolubilityDataModule(pl.LightningDataModule):
    def __init__(
        self,
        *,
        data_path: str,
        batch_size: int,
        tokenizer_op: ModularTokenizerOp,
        train_dl_kwargs: dict,
        valid_dl_kwargs: dict,
        seed: int,
        data_preprocessing: callable,
        protein_max_seq_length: int,
        encoder_input_max_seq_len: int,
        labels_max_seq_len: int,
    ) -> None:
        """_summary_
        Args:
            data_path (str): path to the raw data, if not exist, will download the data to the given path.
            batch_size (int): batch size
            tokenizer_op (ModularTokenizerOp): tokenizer op
            encoder_inputs_max_seq_len: max tokenizer sequence length for the encoder inputs,
            labels_max_seq_len: max tokenizer sequence length for the labels,
            train_dl_kwargs (dict): train dataloader constructor parameters
            valid_dl_kwargs (dict): validation dataloader constructor parameters
            seed (int): random seed
        """
        super().__init__()
        self.data_path = data_path
        self.tokenizer_op = tokenizer_op
        self.protein_max_seq_length = protein_max_seq_length
        self.encoder_input_max_seq_len = encoder_input_max_seq_len
        self.labels_max_seq_len = labels_max_seq_len
        self.batch_size = batch_size
        self.train_dl_kwargs = train_dl_kwargs
        self.valid_dl_kwargs = valid_dl_kwargs
        self.seed = seed
        self.data_preprocessing = data_preprocessing

        self.pad_token_id = self.tokenizer_op.get_token_id("<PAD>")

    def setup(self, stage: str) -> None:
        self.ds_dict = load_datasets(self.data_path)

        task_pipeline = [
            (
                # Prepare the input string(s) in modular tokenizer input format
                self.data_preprocessing,
                dict(
                    protein_sequence_key="data.protein",
                    solubility_label_key="data.label",
                    tokenizer_op=self.tokenizer_op,
                    protein_max_seq_length=self.protein_max_seq_length,
                    encoder_input_max_seq_len=self.encoder_input_max_seq_len,
                    labels_max_seq_len=self.labels_max_seq_len,
                ),
            ),
        ]

        for ds in self.ds_dict.values():
            ds.dynamic_pipeline.extend(task_pipeline)

    def train_dataloader(self) -> DataLoader:
        train_loader = DataLoader(
            dataset=self.ds_dict["train"],
            batch_size=self.batch_size,
            collate_fn=CollateDefault(),
            shuffle=True,
            **self.train_dl_kwargs,
        )
        return train_loader

    def val_dataloader(self) -> DataLoader:
        val_loader = DataLoader(
            self.ds_dict["val"],
            batch_size=self.batch_size,
            collate_fn=CollateDefault(),
            **self.valid_dl_kwargs,
        )

        return val_loader

    def test_dataloader(self) -> DataLoader:
        test_loader = DataLoader(
            self.ds_dict["test"],
            batch_size=self.batch_size,
            collate_fn=CollateDefault(),
            **self.valid_dl_kwargs,
        )

        return test_loader

    def predict_dataloader(self) -> DataLoader:
        return self.test_dataloader()

### 2. `data_preprocessing()` function


This method plays a crucial role in the task's workflow. Here we process the raw data the we load into a prompt that fits the model's pretraining prompt distribution (see paper for more details). \
Besides formatting the prompt as a string, we also tokenize it using our custom modular tokenizer operator.


In [None]:
import torch

from mammal.keys import *  # Import all dictionary static keys


def data_preprocessing(
    sample_dict: dict,
    *,
    protein_sequence_key: str,
    tokenizer_op: ModularTokenizerOp,
    solubility_label_key: int | None = None,
    protein_max_seq_length: int = 1250,
    encoder_input_max_seq_len: int | None = 1260,
    labels_max_seq_len: int | None = 4,
    device: str | torch.device = "cpu",
) -> dict:
    """
    :param sample_dict: a dictionary with raw data
    :param protein_sequence_key: sample_dict key which points to protein sequence
    :param solubility_label_key: sample_dict key which points to label
    :param protein_max_seq_length: max sequence length of a protein. Will be used to truncate the protein
    :param encoder_input_max_seq_len: max sequence length of labels. Will be used to truncate/pad the encoder_input.
    :param labels_max_seq_len: max sequence length of labels. Will be used to truncate/pad the labels.
    :param tokenizer_op: tokenizer op

    """
    protein_sequence = sample_dict[protein_sequence_key]
    solubility_label = sample_dict.get(solubility_label_key, None)

    sample_dict[ENCODER_INPUTS_STR] = (
        f"<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><SOLUBILITY><SENTINEL_ID_0><@TOKENIZER-TYPE=AA@MAX-LEN={protein_max_seq_length}><SEQUENCE_NATURAL_START>{protein_sequence}<SEQUENCE_NATURAL_END><EOS>"
    )
    tokenizer_op(
        sample_dict=sample_dict,
        key_in=ENCODER_INPUTS_STR,
        key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
        key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
        max_seq_len=encoder_input_max_seq_len,
    )
    sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
        sample_dict[ENCODER_INPUTS_TOKENS], device=device
    )
    sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
        sample_dict[ENCODER_INPUTS_ATTENTION_MASK], device=device
    )

    if solubility_label is not None:
        pad_id = tokenizer_op.get_token_id("<PAD>")
        ignore_token_value = -100
        sample_dict[LABELS_STR] = (
            f"<@TOKENIZER-TYPE=AA><SENTINEL_ID_0><{solubility_label}><EOS>"
        )
        tokenizer_op(
            sample_dict=sample_dict,
            key_in=LABELS_STR,
            key_out_tokens_ids=LABELS_TOKENS,
            key_out_attention_mask=LABELS_ATTENTION_MASK,
            max_seq_len=labels_max_seq_len,
        )
        sample_dict[LABELS_TOKENS] = torch.tensor(
            sample_dict[LABELS_TOKENS], device=device
        )
        sample_dict[LABELS_ATTENTION_MASK] = torch.tensor(
            sample_dict[LABELS_ATTENTION_MASK], device=device
        )
        # replace pad_id with -100 to
        pad_id_tns = torch.tensor(pad_id)
        sample_dict[LABELS_TOKENS][
            (sample_dict[LABELS_TOKENS][..., None] == pad_id_tns).any(-1).nonzero()
        ] = ignore_token_value

        sample_dict[DECODER_INPUTS_STR] = (
            f"<@TOKENIZER-TYPE=AA><DECODER_START><SENTINEL_ID_0><{solubility_label}><EOS>"
        )
        tokenizer_op(
            sample_dict=sample_dict,
            key_in=DECODER_INPUTS_STR,
            key_out_tokens_ids=DECODER_INPUTS_TOKENS,
            key_out_attention_mask=DECODER_INPUTS_ATTENTION_MASK,
            max_seq_len=labels_max_seq_len,
        )
        sample_dict[DECODER_INPUTS_TOKENS] = torch.tensor(
            sample_dict[DECODER_INPUTS_TOKENS], device=device
        )
        sample_dict[DECODER_INPUTS_ATTENTION_MASK] = torch.tensor(
            sample_dict[DECODER_INPUTS_ATTENTION_MASK], device=device
        )

    return sample_dict

In [None]:
# Loading the tokenizer
tokenizer_op = ModularTokenizerOp.from_pretrained("ibm/biomed.omics.bl.sm.ma-ted-458m")

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]



In [None]:
import copy

protein_sequence_key = "data.protein_sequence"

# Define dummy initial sample dict
initial_sample_dict = dict()
initial_sample_dict[protein_sequence_key] = "AAA"

processed_sample_dict = data_preprocessing(
    copy.deepcopy(initial_sample_dict),
    protein_sequence_key=protein_sequence_key,
    tokenizer_op=tokenizer_op,
)

print(
    "Visualize participating sample dicts before and after the data processing for the model:"
)
print(f"{initial_sample_dict=}")
print(f"{processed_sample_dict=}")

Visualize participating sample dicts before and after the data processing for the model:
initial_sample_dict={'data.protein_sequence': 'AAA'}
processed_sample_dict={'data.protein_sequence': 'AAA', 'data.query.encoder_input': '<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><SOLUBILITY><SENTINEL_ID_0><@TOKENIZER-TYPE=AA@MAX-LEN=1250><SEQUENCE_NATURAL_START>AAA<SEQUENCE_NATURAL_END><EOS>', 'data.query.encoder_input.with_placeholders': '<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><SOLUBILITY><SENTINEL_ID_0><@TOKENIZER-TYPE=AA@MAX-LEN=1250><SEQUENCE_NATURAL_START>AAA<SEQUENCE_NATURAL_END><EOS>', 'data.query.encoder_input.per_meta_part_encoding': [Encoding(num_tokens=4, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]), Encoding(num_tokens=6, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])], 'data.encoder_input_token_ids': tensor([  6, 274,  27,  ..

### 3. `process_model_output()` function

In [None]:
import numpy as np


def process_model_output(
    tokenizer_op: ModularTokenizerOp,
    decoder_output: np.ndarray,
    decoder_output_scores: np.ndarray,
) -> dict | None:
    """
    Extract predicted solubility class and scores
    expecting decoder output to be <SENTINEL_ID_0><0><EOS> or <SENTINEL_ID_0><1><EOS>
    note - the normalized version will calculate the positive ('<1>') score divided by the sum of the scores for both '<0>' and '<1>'
        BE CAREFUL as both negative and positive absolute scores can be drastically low, and normalized score could be very high.
    outputs a dictionary containing:
        dict(
            predicted_token_str = #... e.g. '<1>'
            not_normalized_score = #the score for the positive token... e.g.  0.01
            normalized_score = #... (positive_token_score) / (positive_token_score+negative_token_score)
        )
        if there is any error in parsing the model output, None is returned.
    """

    negative_token_id = tokenizer_op.get_token_id("<0>")
    positive_token_id = tokenizer_op.get_token_id("<1>")
    label_id_to_int = {
        negative_token_id: 0,
        positive_token_id: 1,
    }
    classification_position = 1

    if decoder_output_scores is not None:
        not_normalized_score = decoder_output_scores[
            classification_position, positive_token_id
        ]
        normalized_score = not_normalized_score / (
            not_normalized_score
            + decoder_output_scores[classification_position, negative_token_id]
            + 1e-10
        )
    ans = dict(
        pred=label_id_to_int.get(int(decoder_output[classification_position]), -1),
        not_normalized_scores=not_normalized_score,
        normalized_scores=normalized_score,
    )

    return ans

### Finally - Define Task Object

Now we can combine all of the components together to form our `MammalTask` object that will define the new task.

In [None]:
from typing import Any

from mammal.metrics import classification_metrics
from mammal.task import (
    MammalTask,
    MetricBase,
)


class ProteinSolubilityTask(MammalTask):
    def __init__(
        self,
        *,
        name: str,
        tokenizer_op: ModularTokenizerOp,
        data_module_kwargs: dict,
        seed: int,
        logger: Any | None = None,
    ) -> None:
        super().__init__(
            name=name,
            logger=logger,
            tokenizer_op=tokenizer_op,
        )
        self._data_module_kwargs = data_module_kwargs
        self._seed = seed

        self.preds_key = CLS_PRED
        self.scores_key = SCORES
        self.labels_key = LABELS_TOKENS

    def data_module(self) -> pl.LightningDataModule:
        return ProteinSolubilityDataModule(
            tokenizer_op=self._tokenizer_op,
            seed=self._seed,
            data_preprocessing=self.data_preprocessing,
            **self._data_module_kwargs,
        )

    def train_metrics(self) -> dict[str, MetricBase]:
        metrics = super().train_metrics()
        metrics.update(
            classification_metrics(
                self.name(),
                class_position=1,
                tokenizer_op=self._tokenizer_op,
                class_tokens=["<0>", "<1>"],
            )
        )

        return metrics

    def validation_metrics(self) -> dict[str, MetricBase]:
        validation_metrics = super().validation_metrics()
        validation_metrics.update(
            classification_metrics(
                self.name(),
                class_position=1,
                tokenizer_op=self._tokenizer_op,
                class_tokens=["<0>", "<1>"],
            )
        )
        return validation_metrics

    @staticmethod
    def data_preprocessing(
        sample_dict: dict,
        *,
        protein_sequence_key: str,
        tokenizer_op: ModularTokenizerOp,
        solubility_label_key: int | None = None,
        protein_max_seq_length: int = 1250,
        encoder_input_max_seq_len: int | None = 1260,
        labels_max_seq_len: int | None = 4,
        device: str | torch.device = "cpu",
    ) -> dict:

        # We use the method we defined above, just to make it cleaner
        sample_dict = data_preprocessing(
            sample_dict=sample_dict,
            tokenizer_op=tokenizer_op,
            protein_sequence_key=protein_sequence_key,
            solubility_label_key=solubility_label_key,
            protein_max_seq_length=protein_max_seq_length,
            encoder_input_max_seq_len=encoder_input_max_seq_len,
            labels_max_seq_len=labels_max_seq_len,
            device=device,
        )
        return sample_dict

    @staticmethod
    def process_model_output(
        tokenizer_op: ModularTokenizerOp,
        decoder_output: np.ndarray,
        decoder_output_scores: np.ndarray,
    ) -> dict | None:

        # We use the method we defined above, just to make it cleaner
        ans = process_model_output(
            tokenizer_op=tokenizer_op,
            decoder_output=decoder_output,
            decoder_output_scores=decoder_output_scores,
        )

        return ans

### MORE

Losses, Metrics..

we have impl for reg and cls