# Multi-task Training with Hugging Face Transformers and NLP

### Or: A recipe for multi-task training with Transformers' Trainer and NLP datasets



Hugging Face has been building a lot of exciting new NLP functionality lately. The newly released [NLP](https://github.com/huggingface/nlp) provides a wide coverage of task data sets and metrics, as well as a simple interface for processing and caching the inputs extremely efficiently. They have also recently introduced a [Trainer](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py) class to the Transformers library that handles all of the training and validation logic.

However, one feature that is not currently supported in Hugging Face's current offerings is *multi-task training*. While there has been some discussion about the best way to support multi-task training ([1](https://github.com/huggingface/transformers/issues/4340), [2](https://github.com/huggingface/nlp/issues/217)), the community has not yet settled on a convention for doing so. Multi-task training has been shown to improve task performance ([1](https://www.aclweb.org/anthology/P19-1441/), [2](https://arxiv.org/abs/1910.10683)) and is a common experimental setting for NLP researchers.

In this Colab notebook, we will show how to use both the new NLP library as well as the Trainer for a **multi-task** training scheme.

So let's get started!

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Library setup

First up, we will install the *NLP* and *Transformers* libraries. 

<font color='red'>**Note: After running the following cell, you will need to restart your runtime for the installation to work properly.**</font>

In [None]:
#!pip install git+https://github.com/huggingface/nlp
!pip install -q transformers==2.11.0
!pip install -q nlp==0.2.0
!pip install -q datasets

[K     |████████████████████████████████| 674 kB 25.9 MB/s 
[K     |████████████████████████████████| 1.3 MB 60.1 MB/s 
[K     |████████████████████████████████| 880 kB 55.6 MB/s 
[K     |████████████████████████████████| 5.6 MB 60.8 MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 857 kB 26.4 MB/s 
[K     |████████████████████████████████| 365 kB 33.5 MB/s 
[K     |████████████████████████████████| 120 kB 60.1 MB/s 
[K     |████████████████████████████████| 212 kB 75.2 MB/s 
[K     |████████████████████████████████| 115 kB 62.9 MB/s 
[K     |████████████████████████████████| 127 kB 70.0 MB/s 
[?25h

In [None]:
import numpy as np
import torch
import torch.nn as nn
import transformers
import nlp
import logging
logging.basicConfig(level=logging.INFO)

from datasets import load_dataset

## Fetching our data

To showcase our multi-task functionality, we will choose tasks of different formats:

* STS-B: A two-sentece textual similarity scoring task. (Prediction is a real number between 1 and 5)
* RTE: A two-sentence natural language entailment task. (Prediction is one of two classes)
* Commonsense QA: A multiple-choice question-answering task. (Each example consists of 5 seperate text inputs, prediction is which one of the 5 choices is correct)

In particular, notice that unlike STS-B and RTE, Commonsense QA consists of feeding *multiple* inputs into the transformer model. Many other tasks have weirder formats too, so our setup needs to be flexible enough to accomodate very different kinds of tasks.

Now, actually getting the task data is super simple. We can simply call the `nlp.load_dataset` method, which automatically downloads the data and prepares it for use.

In [None]:
# dataset_dict = {
#     "mrpc": load_dataset('glue', 'mrpc', split="train+test+validation"),
#     #"stsb": load_dataset('glue', 'stsb'),

#     # "rte": load_dataset('glue', 'rte'),
#     "scitail": load_dataset('scitail', 'tsv_format', split="train[:10%]+test+validation"),

#     #"commonsense_qa": nlp.load_dataset('commonsense_qa'),
#     "US_Airline": load_dataset("Shayanvsf/US_Airline_Sentiment", split="train[:10%]+test+validation"),
# }

In [None]:
dataset_dict = {
    # "mrpc": load_dataset('glue', 'mrpc'),
    # #"stsb": load_dataset('glue', 'stsb'),

    # "rte": load_dataset('glue', 'rte'),
    "scitail": load_dataset('scitail', 'tsv_format'),

    #"commonsense_qa": nlp.load_dataset('commonsense_qa'),
    "US_Airline": load_dataset("Shayanvsf/US_Airline_Sentiment"),
}

Downloading builder script:   0%|          | 0.00/2.42k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

Downloading and preparing dataset scitail/tsv_format (download: 13.52 MiB, generated: 5.05 MiB, post-processed: Unknown size, total: 18.56 MiB) to /root/.cache/huggingface/datasets/scitail/tsv_format/1.1.0/0f221f9167f070d3e492b35010923cdf411c288f9ed5b8dc51f0cb011e773ee5...


Downloading data:   0%|          | 0.00/14.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/23097 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2126 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1304 [00:00<?, ? examples/s]

Dataset scitail downloaded and prepared to /root/.cache/huggingface/datasets/scitail/tsv_format/1.1.0/0f221f9167f070d3e492b35010923cdf411c288f9ed5b8dc51f0cb011e773ee5. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/1.03k [00:00<?, ?B/s]



Downloading and preparing dataset None/None (download: 265.13 KiB, generated: 1.50 MiB, post-processed: Unknown size, total: 1.76 MiB) to /root/.cache/huggingface/datasets/Shayanvsf___parquet/Shayanvsf--US_Airline_Sentiment-a5a7209e33aa0ee7/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/605k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/179k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/92.6k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/Shayanvsf___parquet/Shayanvsf--US_Airline_Sentiment-a5a7209e33aa0ee7/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
import pandas as pd
df_train = pd.DataFrame(dataset_dict["scitail"]['train'])
df_test = pd.DataFrame(dataset_dict["scitail"]['test'])
df_validation =  pd.DataFrame(dataset_dict["scitail"]['validation'])

In [None]:
df_train['label'] = (df_train['label'] == 'entails').astype(int)
df_test['label'] = (df_test['label'] == 'entails').astype(int)
df_validation['label'] = (df_validation['label'] == 'entails').astype(int)

In [None]:
from datasets import Dataset
dataset_dict["scitail"]['train'] = Dataset.from_pandas(df_train)
dataset_dict["scitail"]['test'] = Dataset.from_pandas(df_test)
dataset_dict["scitail"]['validation'] = Dataset.from_pandas(df_validation)

In [None]:
# a = Subset(dataset_dict['US_Airline']["train"], list(range(123)))
# pd.DataFrame(a.dataset)

In [None]:
# from torch.utils.data import Subset

# dataset_dict['US_Airline']['train'] = Subset(dataset_dict['US_Airline']["train"], list(range(123)))
# dataset_dict['mrpc']['train'] = Subset(dataset_dict['mrpc']["train"], list(range(123)))
# dataset_dict['scitail']['train'] = Subset(dataset_dict['scitail']["train"], list(range(123)))

# dataset_dict

We can show one example from each task.

In [None]:
for task_name, dataset in dataset_dict.items():
    print(task_name)
    print(dataset_dict[task_name]["train"][50])
    print()

scitail
{'premise': 'In this food chain, energy flows from the grass (producer) to the deer (primary consumer) to the tiger (secondary consumer).', 'hypothesis': 'A consumer and producer in a food chain are involved when a deer eats a leaf.', 'label': 1}

US_Airline
{'airline_sentiment': 1, 'airline_sentiment_confidence': 0.6744, 'negativereason_confidence': 0.0, 'text': '@SouthwestAir The Fact That U See Black History Month 12 Months A Year Is Honorable! We WILL BE An Economic Base For Corp. Like U In Future!'}



## Creating a Multi-task Model

Next up, we are going to create a multi-task model. 

Typically, a multi-task model in the age of BERT works by having a shared BERT-style encoder transformer, and different task heads for each task.

![Multi-Task 1](https://drive.google.com/uc?id=1TCdyyoHInbiZtSOUmyJN1miCj1iysygU)

We could try to implement this directly in code, but there are two downsides to this approach:

1. Hugging Face's Transformers has implementations for single-task models, but not modular task heads. This means we will need to do a lot of our own leg work to write our own task heads.
2. This format assumes that the input is processed the same way in the encoder for every task. Already, Commonsense QA is problematic for this approach, since it requires the encoder to process *multiple* input sequences for a single example. Other tasks may similarly break this abstraction.

Instead, we are going to do something **radically different**. We are going to create separate models for each task, but we are going make them share the same encoder. 

![Multi-Task 2](https://drive.google.com/uc?id=1xmghPPO5RC-TnpYP4_PpZ-TRfJF33S6p)

This will serve the same goal as having the encoder be jointly trained across multiple tasks, but still retain the independent implementations of each model. As such, we can use the existing task-model implementations in Transformers, such as `RobertaForSequenceClassification` and `RobertaForMultipleChoice`.

Importantly, the shared encoder ensures that during training, all updates will update the same encoder weighs, and also **does not consume any additional GPU memory**.

First, we define our `MultitaskModel` class:

In [None]:
class MultitaskModel(transformers.PreTrainedModel):
    def __init__(self, encoder, taskmodels_dict):
        """
        Setting MultitaskModel up as a PretrainedModel allows us
        to take better advantage of Trainer features
        """
        super().__init__(transformers.PretrainedConfig())

        self.encoder = encoder
        self.taskmodels_dict = nn.ModuleDict(taskmodels_dict)

    @classmethod
    def create(cls, model_name, model_type_dict, model_config_dict):
        """
        This creates a MultitaskModel using the model class and config objects
        from single-task models. 

        We do this by creating each single-task model, and having them share
        the same encoder transformer.
        """
        shared_encoder = None
        taskmodels_dict = {}
        for task_name, model_type in model_type_dict.items():
            model = model_type.from_pretrained(
                model_name, 
                config=model_config_dict[task_name],
            )
            if shared_encoder is None:
                shared_encoder = getattr(model, cls.get_encoder_attr_name(model))
            else:
                setattr(model, cls.get_encoder_attr_name(model), shared_encoder)
            taskmodels_dict[task_name] = model
        return cls(encoder=shared_encoder, taskmodels_dict=taskmodels_dict)

    @classmethod
    def get_encoder_attr_name(cls, model):
        """
        The encoder transformer is named differently in each model "architecture".
        This method lets us get the name of the encoder attribute
        """
        model_class_name = model.__class__.__name__
        if model_class_name.startswith("Bert"):
            return "bert"
        elif model_class_name.startswith("Roberta"):
            return "roberta"
        elif model_class_name.startswith("Albert"):
            return "albert"
        else:
            raise KeyError(f"Add support for new model {model_class_name}")

    def forward(self, task_name, **kwargs):
        return self.taskmodels_dict[task_name](**kwargs)

As described above, the `MultitaskModel` class consists of only two components - the shared "encoder", a dictionary to the individual task models. Now, we can simply create the corresponding task models by supplying the invidual model classes and model configs. We will use Transformers' AutoModels to further automate the choice of model class given a model architecture (in our case, let's use `roberta-base`).

In [None]:
model_name = "bert-base-uncased"
multitask_model = MultitaskModel.create(
    model_name=model_name,
    model_type_dict={
        # "mrpc": transformers.AutoModelForSequenceClassification,
        "scitail": transformers.AutoModelForSequenceClassification,
        "US_Airline": transformers.AutoModelForSequenceClassification,
        #"commonsense_qa": transformers.AutoModelForMultipleChoice,
    },
    model_config_dict={
        # "mrpc": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        "scitail": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        "US_Airline": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        #"commonsense_qa": transformers.AutoConfig.from_pretrained(model_name),
    },
)

Downloading:   0%|          | 0.00/433 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

To confirm that all three task-models use the same encoder, we can check the data pointers of the respective encoders. In this case, we'll check that the word embeddings in each model all point to the same memory location.

In [None]:
if model_name.startswith("bert-"):
    print(multitask_model.encoder.embeddings.word_embeddings.weight.data_ptr())
    # print(multitask_model.taskmodels_dict["mrpc"].roberta.embeddings.word_embeddings.weight.data_ptr())
    print(multitask_model.taskmodels_dict["scitail"].bert.embeddings.word_embeddings.weight.data_ptr())
    print(multitask_model.taskmodels_dict["US_Airline"].bert.embeddings.word_embeddings.weight.data_ptr())
else:
    print("Exercise for the reader: add a check for other model architectures =)")

193724416
193724416
193724416


## Processing our task data

We have created a dictionary of NLP datasets above, but we need to do a little more work to convert the respective task data into model inputs.

We'll start by first getting the tokenizer corresponding to our model.

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Next, we'll write some short functions to convert from raw text to tokenized text inputs. 

* Both STS-B and RTE and two-sentence input tasks, so we will concatenate them with the corresponding special tokens. (The tokenizer's `batch_encode_plus` method handles this for us.) So, the input might look like: 

```
['<s>', 'This', 'is', 'my', 'premise', '.', '</s>', '</s>', 'This', 'is', 'my', 'hypothesis', '.', '</s>']
```

* CommonsenseQA, is a multiple choice task. A single example consists of a question, a five possible answer choices. We will feed the model inputs concatenated like `QUESTION + CHOICE_1`, `QUESTION + CHOICE_2` and so on. 

In [None]:
max_length = 128


def convert_to_scitail_features(example_batch):
    inputs = list(zip(example_batch['premise'], example_batch['hypothesis']))
    features = tokenizer.batch_encode_plus(
        inputs, max_length=max_length, pad_to_max_length=True
    )
    features["labels"] = example_batch["label"]
    return features

def convert_to_US_Airline_features(example_batch):
    inputs = list(example_batch['text'])
    features = tokenizer.batch_encode_plus(
        inputs, max_length=max_length, pad_to_max_length=True
    )
    features["labels"] = example_batch["airline_sentiment"]
    return features

convert_func_dict = {
    # "mrpc": convert_to_mrpc_features,
    "scitail": convert_to_scitail_features,
    "US_Airline": convert_to_US_Airline_features,
}

Now that we have defined the above functions, we can use `dataset.map` method available in the NLP library to apply the functions over our entire datasets. The NLP library that handles the mapping efficiently and caches the features.

In [None]:
columns_dict = {
    # "mrpc": ['input_ids', 'attention_mask', 'labels'],
    "scitail": ['input_ids', 'attention_mask', 'labels'],
    "US_Airline": ['input_ids', 'attention_mask', 'labels'],
}

features_dict = {}
for task_name, dataset in dataset_dict.items():
    features_dict[task_name] = {}
    for phase, phase_dataset in dataset.items():
        features_dict[task_name][phase] = phase_dataset.map(
            convert_func_dict[task_name],
            batched=True,
            load_from_cache_file=False,
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))
        features_dict[task_name][phase].set_format(
            type="torch", 
            columns=columns_dict[task_name],
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))

  0%|          | 0/24 [00:00<?, ?ba/s]

scitail train 23097 23097
scitail train 23097 23097


  0%|          | 0/3 [00:00<?, ?ba/s]

scitail test 2126 2126
scitail test 2126 2126


  0%|          | 0/2 [00:00<?, ?ba/s]

scitail validation 1304 1304
scitail validation 1304 1304


  0%|          | 0/9 [00:00<?, ?ba/s]

US_Airline train 8078 8078
US_Airline train 8078 8078


  0%|          | 0/3 [00:00<?, ?ba/s]

US_Airline validation 2308 2308
US_Airline validation 2308 2308


  0%|          | 0/2 [00:00<?, ?ba/s]

US_Airline test 1155 1155
US_Airline test 1155 1155


As a recap:

* We have created our multi-task model by fusing several single-task Transformer models
* We have created a (cached) dictionary of featurized inputs for each of our tasks, using NLP dataset

Next up, we need to 

1. Set up our data loading
2. Set up our Trainer 
3. Start training!

## Preparing a multi-task data loader and Trainer

Setting up a multi-task data loader should be simple in principle - we simply need to sample from multiple single-task data loaders with some probability, and feed each batch to the multi-task model above. Of course, along with each batch, we also need to tell the model what task it is for, so `MultitaskModel` knows to use the right corresponding task-model.

However, because we want to use the built-in `Trainer` class in Transformers, this gets a little tricky, since the `Trainer` expects a single data loader, and expects a very specific format of per-batch data. This slice of code is somewhat of a hack around that constraint. (This can become a lot more streamlined with some tweaks to the Trainer code from the Hugging Face folks =))

We need to define a `MultitaskDataloader` that combines several data loaders into a single "data loader" - not so different from our multi-task model above! This `MultitaskDataloader` should do what we described: sample from different single-task data loaders, and yield a task batch and the corresponding task name (we're going to add the `task_name` to the batch data).

We will also need to override the `get_train_dataloader` method of the `Trainer` to play well with our `MultitaskDataloader`. We do this with a `MultitaskTrainer`.

In [None]:
import dataclasses
from torch.utils.data.dataloader import DataLoader
from transformers.training_args import is_tpu_available
from transformers.trainer import get_tpu_sampler
from transformers.data.data_collator import DataCollator, InputDataClass
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler
from typing import List, Union, Dict


class NLPDataCollator(DataCollator):
    """
    Extending the existing DataCollator to work with NLP dataset batches
    """
    def collate_batch(self, features: List[Union[InputDataClass, Dict]]) -> Dict[str, torch.Tensor]:
        first = features[0]
        if isinstance(first, dict):
          # NLP data sets current works presents features as lists of dictionary
          # (one per example), so we  will adapt the collate_batch logic for that
          if "labels" in first and first["labels"] is not None:
              if first["labels"].dtype == torch.int64:
                  labels = torch.tensor([f["labels"] for f in features], dtype=torch.long)
              else:
                  labels = torch.tensor([f["labels"] for f in features], dtype=torch.float)
              batch = {"labels": labels}
          for k, v in first.items():
              if k != "labels" and v is not None and not isinstance(v, str):
                  batch[k] = torch.stack([f[k] for f in features])
          return batch
        else:
          # otherwise, revert to using the default collate_batch
          return DefaultDataCollator().collate_batch(features)


class StrIgnoreDevice(str):
    """
    This is a hack. The Trainer is going call .to(device) on every input
    value, but we need to pass in an additional `task_name` string.
    This prevents it from throwing an error
    """
    def to(self, device):
        return self


class DataLoaderWithTaskname:
    """
    Wrapper around a DataLoader to also yield a task name
    """
    def __init__(self, task_name, data_loader):
        self.task_name = task_name
        self.data_loader = data_loader

        self.batch_size = data_loader.batch_size
        self.dataset = data_loader.dataset

    def __len__(self):
        return len(self.data_loader)
    
    def __iter__(self):
        for batch in self.data_loader:
            batch["task_name"] = StrIgnoreDevice(self.task_name)
            yield batch


class MultitaskDataloader:
    """
    Data loader that combines and samples from multiple single-task
    data loaders.
    """
    def __init__(self, dataloader_dict):
        self.dataloader_dict = dataloader_dict
        self.num_batches_dict = {
            task_name: len(dataloader) 
            for task_name, dataloader in self.dataloader_dict.items()
        }
        self.task_name_list = list(self.dataloader_dict)
        self.dataset = [None] * sum(
            len(dataloader.dataset) 
            for dataloader in self.dataloader_dict.values()
        )

    def __len__(self):
        return sum(self.num_batches_dict.values())

    def __iter__(self):
        """
        For each batch, sample a task, and yield a batch from the respective
        task Dataloader.

        We use size-proportional sampling, but you could easily modify this
        to sample from some-other distribution.
        """
        task_choice_list = []
        for i, task_name in enumerate(self.task_name_list):
            task_choice_list += [i] * self.num_batches_dict[task_name]
        task_choice_list = np.array(task_choice_list)
        np.random.shuffle(task_choice_list)
        dataloader_iter_dict = {
            task_name: iter(dataloader) 
            for task_name, dataloader in self.dataloader_dict.items()
        }
        for task_choice in task_choice_list:
            task_name = self.task_name_list[task_choice]
            yield next(dataloader_iter_dict[task_name])    


import json
import logging
import math
import os
import random
import re
import shutil
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple

import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from tqdm.auto import tqdm, trange

from transformers.data.data_collator import DataCollator, DefaultDataCollator
from transformers.modeling_utils import PreTrainedModel
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput
from transformers.training_args import TrainingArguments, is_tpu_available


try:
    from apex import amp

    _has_apex = True
except ImportError:
    _has_apex = False


def is_apex_available():
    return _has_apex


if is_tpu_available():
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

try:
    from torch.utils.tensorboard import SummaryWriter

    _has_tensorboard = True
except ImportError:
    try:
        from tensorboardX import SummaryWriter

        _has_tensorboard = True
    except ImportError:
        _has_tensorboard = False


def is_tensorboard_available():
    return _has_tensorboard


try:
    import wandb

    wandb.ensure_configured()
    if wandb.api.api_key is None:
        _has_wandb = False
        wandb.termwarn("W&B installed but not logged in.  Run `wandb login` or set the WANDB_API_KEY env variable.")
    else:
        _has_wandb = False if os.getenv("WANDB_DISABLED") else True
except ImportError:
    _has_wandb = False


def is_wandb_available():
    return _has_wandb


logger = logging.getLogger(__name__)


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # ^^ safe to call this function even if cuda is not available


@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """
    Decorator to make all processes in distributed training wait for each local_master to do something.
    """
    if local_rank not in [-1, 0]:
        torch.distributed.barrier()
    yield
    if local_rank == 0:
        torch.distributed.barrier()


class SequentialDistributedSampler(Sampler):
    """
    Distributed Sampler that subsamples indicies sequentially,
    making it easier to collate all results at the end.

    Even though we only use this sampler for eval and predict (no training),
    which means that the model params won't have to be synced (i.e. will not hang
    for synchronization even if varied number of forward passes), we still add extra
    samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
    to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
    """

    def __init__(self, dataset, num_replicas=None, rank=None):
        if num_replicas is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = torch.distributed.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        indices = list(range(len(self.dataset)))

        # add extra samples to make it evenly divisible
        indices += indices[: (self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self):
        return self.num_samples


def get_tpu_sampler(dataset: Dataset):
    if xm.xrt_world_size() <= 1:
        return RandomSampler(dataset)
    return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())


class MultitaskTrainer(transformers.Trainer):


    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """

        OUTPUT_DIM = 2
        NUM_EPOCHS = 10
        Prob_per_epoch_Sentiment = np.zeros((NUM_EPOCHS,len(dataset_dict["US_Airline"]['train']),OUTPUT_DIM)) #2 labels
        Prob_per_epoch_Entailment = np.zeros((NUM_EPOCHS,len(dataset_dict["scitail"]['train']),OUTPUT_DIM)) #2 labels
        # Prob_per_epoch_Paraphrase = np.zeros((NUM_EPOCHS,len(dataset_dict["mrpc"]['train']),OUTPUT_DIM)) #2 labels

        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (
                self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
            )
        else:
            t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
            scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16:
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})

        # Train!
        if is_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
            )
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
                logger.info("  Continuing training from global step %d", self.global_step)
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
        train_iterator = trange(
            epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()
        )
        for epoch in train_iterator:
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if is_tpu_available():
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
                epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master())
            else:
                epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())

            train_preds_dict = {'scitail': None, "US_Airline": None} #TODO: automatic generation with for ...
            # stacked_train_preds_Sentiment = None
            # stacked_train_preds_Entailment = None
            # stacked_train_preds_Paraphrase = None
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue
                
                (temp_loss, label_logits) = self._training_step(model, inputs, optimizer)
                tr_loss = temp_loss
                # tr_loss += self._training_step(model, inputs, optimizer)

                 ############## Extracting samples'probability ##############
                # # Get the preds
                preds = nn.Softmax(dim=1)(label_logits) # convert to probability

                # # Move preds to the CPU
                train_preds = preds.detach().cpu().numpy()
                
                
                if train_preds_dict[inputs["task_name"]] is None:  # first batch
                     train_preds_dict[inputs["task_name"]] = train_preds
                else:
                   train_preds_dict[inputs["task_name"]] = np.vstack((train_preds_dict[inputs["task_name"]], train_preds))
                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
                    if self.args.fp16:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                    if is_tpu_available():
                        xm.optimizer_step(optimizer)
                    else:
                        optimizer.step()

                    scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
                        self.global_step == 1 and self.args.logging_first_step
                    ):
                        logs: Dict[str, float] = {}
                        logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >= version.parse("1.4")
                            else scheduler.get_lr()[0]
                        )
                        logging_loss = tr_loss

                        self._log(logs)

                        if self.args.evaluate_during_training:
                            # self.evaluate()
                            
                            # Better?
                        
                            for task_name in list(train_preds_dict.keys()):
                              eval_dataset = self.eval_dataset #features_dict[task_name]["validation"]
                              eval_dataset = eval_dataset[task_name]['train'] #features_dict[task_name]["validation"]
                              eval_dataloader = DataLoaderWithTaskname(
                                  task_name,
                                  self.get_eval_dataloader(eval_dataset=eval_dataset)
                              )
                              print(eval_dataloader.data_loader.collate_fn)
                              output = self._prediction_loop(
                                  eval_dataloader, 
                                  description=f"Train: {task_name}",
                              )
                              self._log(output.metrics)
                              print("acc:", np.mean(
                                  np.argmax(output.predictions, axis=1)
                                  == output.label_ids
                              ))

         

                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert model.module is self.model
                        else:
                            assert model is self.model
                        # Save model checkpoint
                        output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")

                        self.save_model(output_dir)

                        if self.is_world_master():
                            self._rotate_checkpoints()

                        if is_tpu_available():
                            xm.rendezvous("saving_optimizer_states")
                            xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_master():
                            torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            
            
            #TODO
            Prob_per_epoch_Sentiment [epoch][:,:] = train_preds_dict['US_Airline']
            Prob_per_epoch_Entailment [epoch][:,:] = train_preds_dict['scitail']
            # Prob_per_epoch_Paraphrase [epoch][:,:] = train_preds_dict['mrpc']

            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                train_iterator.close()
                break
            if self.args.tpu_metrics_debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())
            

        if self.tb_writer:
            self.tb_writer.close()

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        return TrainOutput(self.global_step, tr_loss / self.global_step), Prob_per_epoch_Sentiment, Prob_per_epoch_Entailment



    def _training_step(
        self, model: nn.Module, inputs: Dict[str, torch.Tensor], optimizer: torch.optim.Optimizer
    ): #-> float:
        model.train()
        for k, v in inputs.items():
            inputs[k] = v.to(self.args.device)

        outputs = model(**inputs)
        loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
        logits = outputs[1]
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if self.args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        return loss.item(), logits

    
    def get_single_train_dataloader(self, task_name, train_dataset):
        """
        Create a single-task data loader that also yields task names
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        if is_tpu_available():
            train_sampler = get_tpu_sampler(train_dataset)
        else:
            # train_sampler = (
            #     RandomSampler(train_dataset)
            #     if self.args.local_rank == -1
            #     else DistributedSampler(train_dataset)
            # )
            train_sampler = SequentialSampler(train_dataset)

        data_loader = DataLoaderWithTaskname(
            task_name=task_name,
            data_loader=DataLoader(
              train_dataset,
              batch_size=self.args.train_batch_size,
              sampler=train_sampler,
              collate_fn=self.data_collator.collate_batch,
            ),
        )

        if is_tpu_available():
            data_loader = pl.ParallelLoader(
                data_loader, [self.args.device]
            ).per_device_loader(self.args.device)
        return data_loader

    def get_train_dataloader(self):
        """
        Returns a MultitaskDataloader, which is not actually a Dataloader
        but an iterable that returns a generator that samples from each 
        task Dataloader
        """
        return MultitaskDataloader({
            task_name: self.get_single_train_dataloader(task_name, task_dataset)
            for task_name, task_dataset in self.train_dataset.items()
        })

## Time to train!

Okay, we have done all the hard work, now it is time for it to pay off. We can now simply create our `MultitaskTrainer`, and start training! 

(This takes about ~45 minutes for me on Colab, but it will depend on the GPU you are allocated.)

In [None]:
!pip install -q wandb

[K     |████████████████████████████████| 1.8 MB 19.8 MB/s 
[K     |████████████████████████████████| 158 kB 66.6 MB/s 
[K     |████████████████████████████████| 181 kB 65.1 MB/s 
[K     |████████████████████████████████| 63 kB 1.6 MB/s 
[K     |████████████████████████████████| 157 kB 73.9 MB/s 
[K     |████████████████████████████████| 157 kB 74.7 MB/s 
[K     |████████████████████████████████| 157 kB 58.6 MB/s 
[K     |████████████████████████████████| 157 kB 80.8 MB/s 
[K     |████████████████████████████████| 157 kB 75.7 MB/s 
[K     |████████████████████████████████| 157 kB 74.1 MB/s 
[K     |████████████████████████████████| 157 kB 81.3 MB/s 
[K     |████████████████████████████████| 156 kB 67.6 MB/s 
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
torch.cuda.memory_summary(device=None, abbreviated=False)



In [None]:
train_dataset = {
    task_name: dataset["train"] 
    for task_name, dataset in features_dict.items()
}

eval_dataset = features_dict


trainer = MultitaskTrainer(
    model=multitask_model,
    args=transformers.TrainingArguments(
        output_dir="/content/drive/MyDrive/NLP Bachelors' Project/checkpoint/MTL(Sentiment,Entailment)/",
        overwrite_output_dir=True,
        learning_rate=1e-5,
        do_train=True,
        num_train_epochs=10,
        evaluate_during_training=True,
        # Adjust batch size if this doesn't fit on the Colab GPU
        per_device_train_batch_size=64,  
        save_steps=488*2,
        logging_steps = 488
    ),
    data_collator=NLPDataCollator(),
    train_dataset=train_dataset,
    eval_dataset=eval_dataset 
)
an = trainer.train('checkpoint-2928')

Epoch:   0%|          | 0/4 [00:00<?, ?it/s]

Iteration:   0%|          | 0/488 [00:00<?, ?it/s]

{"loss": 0.0004969353436446581, "learning_rate": 9e-06, "epoch": 7.0, "step": 3416}
<bound method NLPDataCollator.collate_batch of <__main__.NLPDataCollator object at 0x7f5c3989da90>>


Train: scitail:   0%|          | 0/2888 [00:00<?, ?it/s]

{"eval_loss": 0.24908141630000005, "epoch": 7.0, "step": 3416}
acc: 0.8989479153136771
<bound method NLPDataCollator.collate_batch of <__main__.NLPDataCollator object at 0x7f5c3989da90>>


Train: US_Airline:   0%|          | 0/1010 [00:00<?, ?it/s]

{"eval_loss": 0.19733879812393743, "epoch": 7.0, "step": 3416}
acc: 0.9196583312701163


Iteration:   0%|          | 0/488 [00:00<?, ?it/s]

{"loss": -0.00030474625833210397, "learning_rate": 8.000000000000001e-06, "epoch": 8.0, "step": 3904}
<bound method NLPDataCollator.collate_batch of <__main__.NLPDataCollator object at 0x7f5c3989da90>>


Train: scitail:   0%|          | 0/2888 [00:00<?, ?it/s]

{"eval_loss": 0.15638995563615904, "epoch": 8.0, "step": 3904}
acc: 0.942503355414123
<bound method NLPDataCollator.collate_batch of <__main__.NLPDataCollator object at 0x7f5c3989da90>>


Train: US_Airline:   0%|          | 0/1010 [00:00<?, ?it/s]

{"eval_loss": 0.07791496693424069, "epoch": 8.0, "step": 3904}
acc: 0.9733845011141372


Iteration:   0%|          | 0/488 [00:00<?, ?it/s]

{"loss": -1.8276243669087768e-05, "learning_rate": 7e-06, "epoch": 9.0, "step": 4392}
<bound method NLPDataCollator.collate_batch of <__main__.NLPDataCollator object at 0x7f5c3989da90>>


Train: scitail:   0%|          | 0/2888 [00:00<?, ?it/s]

{"eval_loss": 0.1298135291335787, "epoch": 9.0, "step": 4392}
acc: 0.9509027146382647
<bound method NLPDataCollator.collate_batch of <__main__.NLPDataCollator object at 0x7f5c3989da90>>


Train: US_Airline:   0%|          | 0/1010 [00:00<?, ?it/s]

{"eval_loss": 0.04091287649439482, "epoch": 9.0, "step": 4392}
acc: 0.9891062144095073


Iteration:   0%|          | 0/488 [00:00<?, ?it/s]

{"loss": -8.719893111312976e-05, "learning_rate": 6e-06, "epoch": 10.0, "step": 4880}
<bound method NLPDataCollator.collate_batch of <__main__.NLPDataCollator object at 0x7f5c3989da90>>


Train: scitail:   0%|          | 0/2888 [00:00<?, ?it/s]

{"eval_loss": 0.134148374070069, "epoch": 10.0, "step": 4880}
acc: 0.9530242022773521
<bound method NLPDataCollator.collate_batch of <__main__.NLPDataCollator object at 0x7f5c3989da90>>


Train: US_Airline:   0%|          | 0/1010 [00:00<?, ?it/s]

{"eval_loss": 0.027687410784235207, "epoch": 10.0, "step": 4880}
acc: 0.9929437979697945


In [None]:
[Prob_per_epoch_Sentiment_MTL, Prob_per_epoch_Entailment_MTL] = an[1:3]

In [None]:
print(Prob_per_epoch_Sentiment_MTL.shape)
print(Prob_per_epoch_Entailment_MTL.shape)
# print(Prob_per_epoch_Paraphrase_MTL.shape)

(10, 8078, 2)
(10, 23097, 2)


In [None]:
%cd 
import pickle
NUM_EPOCHS = 10

#with open(f"/content/drive/MyDrive/NLP Bachelors' Project/Sentiment_MTL_SNT_ENT_prob_per_{NUM_EPOCHS}epochs.pkl", 'wb') as f:
with open(f"/content/drive/MyDrive/NLP Bachelors' Project/FRG_Info(pkl files)/Multi Task/Sentiment_MTL_prob_per_10epochs_ES.pkl", 'wb') as f:
    pickle.dump(Prob_per_epoch_Sentiment_MTL, f)

#with open(f"/content/drive/MyDrive/NLP Bachelors' Project/Entailment_MTL_SNT_ENT_prob_per_{NUM_EPOCHS}epochs.pkl", 'wb') as f:
with open(f"/content/drive/MyDrive/NLP Bachelors' Project/FRG_Info(pkl files)/Multi Task/Entailment_MTL_prob_per_10epochs_ES.pkl", 'wb') as f:
    pickle.dump(Prob_per_epoch_Entailment_MTL, f)



/root


All done! Now, we can evaluate our multi-task model on all three tasks. In this case, we can simply use single-task data loaders, since we are evaluating each task individually.

We will use the (private) `_prediction_loop` method from the Trainer.

# **Evaluate** on test data

In [None]:
preds_dict = {}
for task_name in ["scitail","US_Airline"]:
    eval_dataloader = DataLoaderWithTaskname(
        task_name,
        trainer.get_eval_dataloader(eval_dataset=features_dict[task_name]["test"])
    )
    print(eval_dataloader.data_loader.collate_fn)
    preds_dict[task_name] = trainer._prediction_loop(
        eval_dataloader, 
        description=f"Test: {task_name}",
    )

<bound method NLPDataCollator.collate_batch of <__main__.NLPDataCollator object at 0x7f5c3989da90>>


Test: scitail:   0%|          | 0/266 [00:00<?, ?it/s]

<bound method NLPDataCollator.collate_batch of <__main__.NLPDataCollator object at 0x7f5c3989da90>>


Test: US_Airline:   0%|          | 0/145 [00:00<?, ?it/s]

Now that we have all the predictions, let's go ahead and score them. The NLP library also has built-in metrics for the GLUE tasks (which includes STS-B and RTE), but not for Commonsense QA. Thankfully, Commonsense QA's evaluation metric is simple accuracy, which we can compute easily.

In [None]:
# Evalute sentiment
np.mean(
    np.argmax(preds_dict["US_Airline"].predictions, axis=1)
    == preds_dict["US_Airline"].label_ids
)

0.9515151515151515

In [None]:
# Evalute scitail
np.mean(
    np.argmax(preds_dict["scitail"].predictions, axis=1)
    == preds_dict["scitail"].label_ids
)

0.8842897460018815

You should expect scores of approximately:

* RTE: ~0.74
* STS-B: ~0.89/0.89
* Commonsense QA: ~0.60

These aren't award winning scores, nor are our tasks chosen for multi-task training synergy, but hopefully we have demonstrated how to do multi-task training with some of Hugging Face's latest offerings!

# An advertisement: Come check out jiant!

While the above recipe works, we saw what some of the frictions were: handling multi-task data loading, coercing the Trainer to work with multi-task inputs, and handling the featurization for each of the tasks.

If you are interested in more streamlined multi-task (or even single-task) fine-tuning work, we are building [jiant](https://jiant.info/), an NLP research-oriented library, built directly on the Transformers, where multi-task training is a first-class feature. `jiant` aims to facilitate cutting-edge NLP transfer learning research through broad task coverage and modular components, and we highly recommend using `jiant` for streamlined multi-task training workflows.

(If you've previously worked with `jiant`, we are currently undertaking [a complete rewrite](https://github.com/jiant-dev/jiant) to better support current research needs and engineering workflows.)

Click [here](https://jiant.info/) to learn more, or attend our system demo presentation at [ACL 2020](https://acl2020.org/program/accepted/#system-demonstrations).