In [None]:
!pip install biomed-multi-alignment

In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split
from fuse.data.ops.ops_read import OpReadDataframe
from fuse.data.datasets.dataset_default import DatasetDefault
from fuse.data.pipelines.pipeline_default import PipelineDefault

def load_datasets(data_path: str) -> dict[str, DatasetDefault]:
    # Load CSV
    df = pd.read_csv(data_path, index_col="index").reset_index(drop=True)
    df.columns = ["data.protein", "data.label"]

    # Split the data
    train_df, temp_df = train_test_split(df, test_size=0.4, random_state=42, stratify=df["data.label"])
    val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df["data.label"])

    # read files
    df_dict = {}
    for set_name, set_df in zip(["train", "val", "test"],[train_df,val_df,test_df]):
        input_df = set_df[["data.protein"]].reset_index(drop=True)
        labels_df = set_df[["data.label"]].reset_index(drop=True)
        df_dict[set_name] = (input_df, labels_df)

    ds_dict = {}
    for set_name in ["train", "val", "test"]:
        input_df, labels_df = df_dict[set_name]
        size = len(labels_df)
        print(f"{set_name} set size is {size}")
        dynamic_pipeline = PipelineDefault(
            "binding",
            [
                (OpReadDataframe(input_df, key_column=None), dict()),
                (OpReadDataframe(labels_df, key_column=None), dict()),
            ],
        )

        ds = DatasetDefault(sample_ids=size, dynamic_pipeline=dynamic_pipeline)
        ds.create()
        ds_dict[set_name] = ds

    return ds_dict

In [4]:
# Load dataset - train, validation and test splits
ds_dict = load_datasets("non_binders.csv")

# Retrieve and visualize a single sample
sample_dict = ds_dict["train"][0]

print("Visualize sample in a tree-fashion")
sample_dict.print_tree(print_values=True)

print("Visualize sample as a raw flat dictionary")
sample_dict

train set size is 606
val set size is 202
test set size is 202
Visualize sample in a tree-fashion
--- data
------ protein -> MAEVQLQASGGGFVQPGGSLRLSCAASGSYSSREVMGWFRQAPGKEREFVSAISSDSNHFRYYADSVKGRFTISRDNSKNTVYLQMNSLRAEDTATYYCALGSNEYHHMSNGYWGQGTQVTVSSA
------ label -> 0
------ initial_sample_id -> 0
------ sample_id -> 0
Visualize sample as a raw flat dictionary


{'data.initial_sample_id': 0, 'data.sample_id': 0, 'data.protein': 'MAEVQLQASGGGFVQPGGSLRLSCAASGSYSSREVMGWFRQAPGKEREFVSAISSDSNHFRYYADSVKGRFTISRDNSKNTVYLQMNSLRAEDTATYYCALGSNEYHHMSNGYWGQGTQVTVSSA', 'data.label': 0}

In [5]:
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader
from fuse.data.utils.collates import CollateDefault
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp


class ProteinBindingDataModule(pl.LightningDataModule):
    def __init__(
        self,
        *,
        tokenizer_op: ModularTokenizerOp,
        data_preprocessing: callable,
        # Default values for simplicity
        batch_size: int = 1,
        data_path: str = "./non_binders.csv",
        protein_max_seq_length: int = 600, # was 1250
        encoder_input_max_seq_len: int = 610, # was 1260
        labels_max_seq_len: int = 4,
    ) -> None:
        """_summary_
        Args:
            data_path (str): path to the raw data, if not exist, will download the data to the given path.
            batch_size (int): batch size
            tokenizer_op (ModularTokenizerOp): tokenizer op
            encoder_inputs_max_seq_len: max tokenizer sequence length for the encoder inputs,
            labels_max_seq_len: max tokenizer sequence length for the labels,
            seed (int): random seed
        """
        super().__init__()
        self.data_path = data_path
        self.tokenizer_op = tokenizer_op
        self.protein_max_seq_length = protein_max_seq_length
        self.encoder_input_max_seq_len = encoder_input_max_seq_len
        self.labels_max_seq_len = labels_max_seq_len
        self.batch_size = batch_size
        self.data_preprocessing = data_preprocessing

        self.pad_token_id = self.tokenizer_op.get_token_id("<PAD>")

    def setup(self, stage: str) -> None:
        self.ds_dict = load_datasets(self.data_path)

        task_pipeline = [
            (
                # Prepare the input string(s) in modular tokenizer input format
                self.data_preprocessing,
                dict(
                    protein_sequence_key="data.protein",
                    binding_label_key="data.label", # Switch "solubility" with "affinity"?
                    tokenizer_op=self.tokenizer_op,
                    protein_max_seq_length=self.protein_max_seq_length,
                    encoder_input_max_seq_len=self.encoder_input_max_seq_len,
                    labels_max_seq_len=self.labels_max_seq_len,
                ),
            ),
        ]

        for ds in self.ds_dict.values():
            ds.dynamic_pipeline.extend(task_pipeline)

    def train_dataloader(self) -> DataLoader:
        train_loader = DataLoader(
            dataset=self.ds_dict["train"],
            batch_size=self.batch_size,
            collate_fn=CollateDefault(),
            shuffle=True,
        )
        return train_loader

    def val_dataloader(self) -> DataLoader:
        val_loader = DataLoader(
            self.ds_dict["val"],
            batch_size=self.batch_size,
            collate_fn=CollateDefault(),
        )

        return val_loader

    def test_dataloader(self) -> DataLoader:
        test_loader = DataLoader(
            self.ds_dict["test"],
            batch_size=self.batch_size,
            collate_fn=CollateDefault(),
        )

        return test_loader

    def predict_dataloader(self) -> DataLoader:
        return self.test_dataloader()

In [6]:
import torch

from mammal.keys import *  # Import all dictionary static keys

def data_preprocessing(
    sample_dict: dict,
    *,
    protein_sequence_key: str,
    tokenizer_op: ModularTokenizerOp,
    binding_label_key: int | None = None,
    protein_max_seq_length: int = 600, # was 1250
    encoder_input_max_seq_len: int | None = 610, # was 1260
    labels_max_seq_len: int | None = 4,
    device: str | torch.device = "cpu",
) -> dict:
    """
    :param sample_dict: a dictionary with raw data
    :param protein_sequence_key: sample_dict key which points to protein sequence
    :param binding_label_key: sample_dict key which points to label
    :param protein_max_seq_length: max sequence length of a protein. Will be used to truncate the protein
    :param encoder_input_max_seq_len: max sequence length of labels. Will be used to truncate/pad the encoder_input.
    :param labels_max_seq_len: max sequence length of labels. Will be used to truncate/pad the labels.
    :param tokenizer_op: tokenizer op

    """
    protein_sequence = sample_dict[protein_sequence_key]
    binding_label = sample_dict.get(binding_label_key, None)
    tbg_sequence = """MSPFLYLVLLVLGLHATIHCASPEGKVTACHSSQPNATLYKMSSINADFAFNLYRRFTVETPDKNIFFSPVSISAALVMLSFGACCSTQTEIVETLGFNLTDTPMVEIQHGFQHLICSLNFPKKELELQIGNALFIGKHLKPLAKFLNDVKTLYETEVFSTDFSNISAAKQEINSHVEMQTKGKVVGLIQDLKPNTIMVLVNYIHFKAQWANPFDPSKTEDSSSFLIDKTTTVQVPMMHQMEQYYHLVDMELNCTVLQMDYSKNALALFVLPKEGQMESVEAAMSSKTLKKWNRLLQKGWVDLFVPKFSISATYDLGATLLKMGIQHAYSENADFSGLTEDNGLKLSNAAHKAVLHIGEKGTEAAAVPEVELSDQPENTFLHPIIQIDRSFMLLILERSTRSILFLGKVVNPTEA"""

    sample_dict[ENCODER_INPUTS_STR] = (
        f"<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0><@TOKENIZER-TYPE=AA@MAX-LEN={protein_max_seq_length}><MOLECULAR_ENTITY><MOLECULAR_ENTITY_ANTIBODY_HEAVY_CHAIN>{protein_sequence}<@TOKENIZER-TYPE=AA@MAX-LEN={protein_max_seq_length}><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>{tbg_sequence}<EOS>"
    )
    tokenizer_op(
        sample_dict=sample_dict,
        key_in=ENCODER_INPUTS_STR,
        key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
        key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
        max_seq_len=encoder_input_max_seq_len,
    )
    sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
        sample_dict[ENCODER_INPUTS_TOKENS], device=device
    )
    sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
        sample_dict[ENCODER_INPUTS_ATTENTION_MASK], device=device
    )

    if binding_label is not None:
        pad_id = tokenizer_op.get_token_id("<PAD>")
        ignore_token_value = -100
        sample_dict[LABELS_STR] = (
            f"<@TOKENIZER-TYPE=AA><SENTINEL_ID_0><{binding_label}><EOS>"
        )
        tokenizer_op(
            sample_dict=sample_dict,
            key_in=LABELS_STR,
            key_out_tokens_ids=LABELS_TOKENS,
            key_out_attention_mask=LABELS_ATTENTION_MASK,
            max_seq_len=labels_max_seq_len,
        )
        sample_dict[LABELS_TOKENS] = torch.tensor(
            sample_dict[LABELS_TOKENS], device=device
        )
        sample_dict[LABELS_ATTENTION_MASK] = torch.tensor(
            sample_dict[LABELS_ATTENTION_MASK], device=device
        )
        # replace pad_id with -100 to
        pad_id_tns = torch.tensor(pad_id)
        sample_dict[LABELS_TOKENS][
            (sample_dict[LABELS_TOKENS][..., None] == pad_id_tns).any(-1).nonzero()
        ] = ignore_token_value

        sample_dict[DECODER_INPUTS_STR] = (
            f"<@TOKENIZER-TYPE=AA><DECODER_START><SENTINEL_ID_0><{binding_label}><EOS>"
        )
        tokenizer_op(
            sample_dict=sample_dict,
            key_in=DECODER_INPUTS_STR,
            key_out_tokens_ids=DECODER_INPUTS_TOKENS,
            key_out_attention_mask=DECODER_INPUTS_ATTENTION_MASK,
            max_seq_len=labels_max_seq_len,
        )
        sample_dict[DECODER_INPUTS_TOKENS] = torch.tensor(
            sample_dict[DECODER_INPUTS_TOKENS], device=device
        )
        sample_dict[DECODER_INPUTS_ATTENTION_MASK] = torch.tensor(
            sample_dict[DECODER_INPUTS_ATTENTION_MASK], device=device
        )

    return sample_dict

In [7]:
# Loading the tokenizer
tokenizer_op = ModularTokenizerOp.from_pretrained("ibm/biomed.omics.bl.sm.ma-ted-458m")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

tokenizer/gene_tokenizer.json:   0%|          | 0.00/2.76M [00:00<?, ?B/s]

(…)th_aug_4272372_samples_balanced_1_1.json:   0%|          | 0.00/191k [00:00<?, ?B/s]

tokenizer/t5_tokenizer_AA_special.json:   0%|          | 0.00/70.1k [00:00<?, ?B/s]

tokenizer/config.yaml:   0%|          | 0.00/967 [00:00<?, ?B/s]

tokenizer/cell_attributes_tokenizer.json:   0%|          | 0.00/92.7k [00:00<?, ?B/s]

In [None]:
import copy

protein_sequence_key = "data.protein_sequence"

# Define dummy initial sample dict
initial_sample_dict = dict()
initial_sample_dict[protein_sequence_key] = "AAA"

processed_sample_dict = data_preprocessing(
    copy.deepcopy(initial_sample_dict),
    protein_sequence_key=protein_sequence_key,
    tokenizer_op=tokenizer_op,
)

print(
    "Visualize participating sample dicts before and after the data processing for the model:"
)
print(f"{initial_sample_dict=}")
print(f"{processed_sample_dict=}")

In [9]:
import numpy as np


def process_model_output(
    tokenizer_op: ModularTokenizerOp,
    decoder_output: np.ndarray,
    decoder_output_scores: np.ndarray,
) -> dict | None:
    """
    Extract predicted binding class and scores
    expecting decoder output to be <SENTINEL_ID_0><0><EOS> or <SENTINEL_ID_0><1><EOS>
    note - the normalized version will calculate the positive ('<1>') score divided by the sum of the scores for both '<0>' and '<1>'
        BE CAREFUL as both negative and positive absolute scores can be drastically low, and normalized score could be very high.
    outputs a dictionary containing:
        dict(
            predicted_token_str = #... e.g. '<1>'
            not_normalized_score = #the score for the positive token... e.g.  0.01
            normalized_score = #... (positive_token_score) / (positive_token_score+negative_token_score)
        )
        if there is any error in parsing the model output, None is returned.
    """

    negative_token_id = tokenizer_op.get_token_id("<0>")
    positive_token_id = tokenizer_op.get_token_id("<1>")
    label_id_to_int = {
        negative_token_id: 0,
        positive_token_id: 1,
    }
    classification_position = 1

    if decoder_output_scores is not None:
        not_normalized_score = decoder_output_scores[
            classification_position, positive_token_id
        ]
        normalized_score = not_normalized_score / (
            not_normalized_score
            + decoder_output_scores[classification_position, negative_token_id]
            + 1e-10
        )
    ans = dict(
        pred=label_id_to_int.get(int(decoder_output[classification_position]), -1),
        not_normalized_scores=not_normalized_score,
        normalized_scores=normalized_score,
    )

    return ans

In [10]:
from typing import Any

from mammal.metrics import classification_metrics
from mammal.task import (
    MammalTask,
    MetricBase,
)


class ProteinBindingTask(MammalTask):
    def __init__(
        self,
        *,
        name: str,
        tokenizer_op: ModularTokenizerOp,
        logger: Any | None = None,
    ) -> None:
        super().__init__(
            name=name,
            logger=logger,
            tokenizer_op=tokenizer_op,
        )

        self.preds_key = CLS_PRED
        self.scores_key = SCORES
        self.labels_key = LABELS_TOKENS

    def data_module(self, **kwargs) -> pl.LightningDataModule:
        return ProteinBindingDataModule(
            tokenizer_op=self._tokenizer_op,
            data_preprocessing=self.data_preprocessing,
            **kwargs
        )

    def train_metrics(self) -> dict[str, MetricBase]:
        metrics = super().train_metrics()
        metrics.update(
            classification_metrics(
                self.name(),
                class_position=1,
                tokenizer_op=self._tokenizer_op,
                class_tokens=["<0>", "<1>"],
            )
        )

        return metrics

    def validation_metrics(self) -> dict[str, MetricBase]:
        validation_metrics = super().validation_metrics()
        validation_metrics.update(
            classification_metrics(
                self.name(),
                class_position=1,
                tokenizer_op=self._tokenizer_op,
                class_tokens=["<0>", "<1>"],
            )
        )
        return validation_metrics

    @staticmethod
    def data_preprocessing(
        sample_dict: dict,
        *,
        protein_sequence_key: str,
        tokenizer_op: ModularTokenizerOp,
        binding_label_key: int | None = None,
        protein_max_seq_length: int = 1250,
        encoder_input_max_seq_len: int | None = 1260,
        labels_max_seq_len: int | None = 4,
        device: str | torch.device = "cpu",
    ) -> dict:

        # We use the method we defined above, just to make it cleaner
        # Another option is to write the underling logic inside this function definition.
        sample_dict = data_preprocessing(
            sample_dict=sample_dict,
            tokenizer_op=tokenizer_op,
            protein_sequence_key=protein_sequence_key,
            binding_label_key=binding_label_key,
            protein_max_seq_length=protein_max_seq_length,
            encoder_input_max_seq_len=encoder_input_max_seq_len,
            labels_max_seq_len=labels_max_seq_len,
            device=device,
        )
        return sample_dict

    @staticmethod
    def process_model_output(
        tokenizer_op: ModularTokenizerOp,
        decoder_output: np.ndarray,
        decoder_output_scores: np.ndarray,
    ) -> dict | None:

        # We use the method we defined above, just to make it cleaner
        ans = process_model_output(
            tokenizer_op=tokenizer_op,
            decoder_output=decoder_output,
            decoder_output_scores=decoder_output_scores,
        )

        return ans

In [14]:
from torch.optim import AdamW

from fuse.dl.lightning.pl_module import LightningModuleDefault

from mammal.model import Mammal
from mammal.lr_schedulers import cosine_annealing_with_warmup_lr_scheduler

# Seed for reproducibility
pl.seed_everything(42)

# Load pre-trained model from HF hub
model = Mammal.from_pretrained("ibm/biomed.omics.bl.sm.ma-ted-458m")

# Initialized OUR task :-)
task = ProteinBindingTask(name="our_binding_prediction_task", tokenizer_op=tokenizer_op)

# Extract Task's datamodule to pass it to the Lightning module
pl_data_module = task.data_module()

# Define optimizer and lr scheduler for the Lightning module
opt = AdamW(params=model.parameters(), lr=0.00001)
lr_scheduler = cosine_annealing_with_warmup_lr_scheduler(optimizer=opt)
optimizers_and_lr_schs = dict(
        optimizer=opt,
        lr_scheduler={"scheduler": lr_scheduler, "interval": "step"},
)

# Initialize 'LightningModuleDefault' which is a subclass of 'pl.LightningModule' defined in Fuse-Med-ML.
pl_module = LightningModuleDefault(model=model,
        losses=task.losses(),
        validation_metrics=task.validation_metrics(),
        train_metrics=task.train_metrics(),
        optimizers_and_lr_schs=optimizers_and_lr_schs,
        model_dir="mammal_binding_finetune",
)

# Create Lightning's Trainer
pl_trainer = pl.Trainer(max_epochs=3)

# Let the training begin
pl_trainer.fit(model=pl_module, datamodule=pl_data_module)


INFO:lightning_fabric.utilities.seed:Seed set to 42


Path doesn't exist. Will try to download fron hf hub. pretrained_model_name_or_path='ibm/biomed.omics.bl.sm.ma-ted-458m'


Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

Attempting to load model from dir: pretrained_model_name_or_path='/root/.cache/huggingface/hub/models--ibm--biomed.omics.bl.sm.ma-ted-458m/snapshots/6d319d8dcf97f8821635327fc8cda24670553daa'


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint


train set size is 606
val set size is 202
test set size is 202


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name   | Type   | Params | Mode 
------------------------------------------
0 | _model | Mammal | 458 M  | train
------------------------------------------
458 M     Trainable params
0         Non-trainable params
458 M     Total params
1,832.029 Total estimated model params size (MB)
579       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=3` reached.


In [13]:
pl_trainer.logged_metrics

{'validation.losses.our_binding_prediction_task_ce': tensor(0.0176),
 'validation.losses.our_binding_prediction_task_scalars_mse': tensor(0.),
 'validation.losses.total_loss': tensor(0.0176),
 'validation.metrics.our_binding_prediction_task_perplexity': tensor(1.0178),
 'validation.metrics.our_binding_prediction_task_token_acc': tensor(0.9967),
 'validation.metrics.our_binding_prediction_task_aucroc': tensor(0.9975),
 'validation.metrics.our_binding_prediction_task_acc': tensor(0.9901),
 'validation.metrics.our_binding_prediction_task_mcc': tensor(0.),
 'train.losses.our_binding_prediction_task_ce': tensor(0.0216),
 'train.losses.our_binding_prediction_task_scalars_mse': tensor(0.),
 'train.losses.total_loss': tensor(0.0216),
 'train.metrics.our_binding_prediction_task_perplexity': tensor(1.0218),
 'train.metrics.our_binding_prediction_task_token_acc': tensor(0.9967),
 'train.metrics.our_binding_prediction_task_aucroc': tensor(0.5628),
 'train.metrics.our_binding_prediction_task_acc': 