# Multitask \*BERT\* training 

Fork of a wonderful notebook https://colab.research.google.com/github/zphang/zphang.github.io/blob/master/files/notebooks/Multi_task_Training_with_Transformers_NLP.ipynb#scrollTo=U4YUxdIZz3_i

Ideas on question answering train implementation from https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/question_answering.ipynb#scrollTo=a9a5qSsEqiWp

- Updated to be used with fresh versions of `transformers` and `datasets`
- Training Russian models on 6 tasks from Russian SuperGLUE, XNLI, sberquad(in progress)
- Evaluation

In [1]:
%%capture
! pip install transformers datasets natasha

In [5]:
import numpy as np
import torch
import torch.nn as nn
import transformers
import datasets
import pprint
import pandas as pd
from torch.utils.data.dataloader import DataLoader
from transformers.data.data_collator import DefaultDataCollator, InputDataClass
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler
from typing import List, Union, Dict, Iterable, Callable, Any, Optional
import os
import os.path
# from google.colab import drive
from natasha import Segmenter, NewsEmbedding, Doc
from tqdm.auto import tqdm
import pymorphy2
from multiprocessing import Pool, cpu_count
import gensim

In [17]:
k = np.random.rand(512, 768)
q = np.random.rand(256, 768)
q.shape, k.transpose().shape

((256, 768), (768, 512))

In [11]:
segmenter = Segmenter()
morph = pymorphy2.MorphAnalyzer()
emb = NewsEmbedding()
glove = emb.as_gensim

pp = pprint.PrettyPrinter()
# Optional, if want to save the model, checkpoints, logs etc. to your google drive
# drive.mount('/content/drive')
model_save_path = 'drive/MyDrive/models/multitask_model'

AttributeError: 'KeyedVectors' object has no attribute 'add'

In [140]:
def load_mokoron(positive_url: str = 'https://www.dropbox.com/s/fnpq3z4bcnoktiv/positive.csv', negative_url: str = 'https://www.dropbox.com/s/r6u59ljhhjdg6j0/negative.csv', part_test: float = 0.2, part_val: float = 0.5) -> datasets.DatasetDict:
    positive_path = positive_url.split('/')[-1]
    negative_path = negative_url.split('/')[-1]
    if not os.path.isfile(positive_path):
        print('Downloading [mokoron / positive]', end=' => ')
        os.system(f'wget {positive_url}')
    else:
        print(f'Reusing [mokoron / positive] @ {positive_path}', end=' => ')
    if not os.path.isfile(negative_path):
        print('Downloading [mokoron / negative]',  end=' =>\n')
        os.system(f'wget {negative_url}')
    else:
        print(f'Reusing [mokoron / negative] @ {negative_path}', end=' =>\n')
    print('Preparing [mokoron / positive]', end=' => ')
    pos_df = pd.read_csv(positive_path, sep = ';', header = None)
    pos_df['label'] = 1
    pos_df.rename(columns={3: 'text'}, inplace=True)
    pos_df = pos_df[['text', 'label']]
    print('Preparing [mokoron / negative]', end=' =>\n')
    neg_df = pd.read_csv(negative_path, sep = ';', header = None)
    neg_df['label'] = 0
    neg_df.rename(columns={3: 'text'}, inplace=True)
    neg_df = neg_df[['text', 'label']]
    print('Merging, shuffling & splitting [mokoron / positive <-> mokoron / negative]', end=' => ')
    dataset = datasets.Dataset.from_pandas(pd.concat([pos_df, neg_df], ignore_index=True))
    train_testvalid = dataset.train_test_split(test_size=part_test)
    test_valid = train_testvalid['test'].train_test_split(test_size=part_val)
    print('Done')
    return datasets.DatasetDict({
        'train': train_testvalid['train'],
        'test': test_valid['test'],
        'validation': test_valid['train']
    })

In [136]:
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)
def sample_word_synonyms(word: str, min_num: int = 1, max_num: int = 3, 
                         variation: int = 10, repeats: bool = True, 
                         positive_bonus: float = 0.5, return_str: bool = True, 
                         include_source: bool = True,
                         include_source_normal_form: bool=True) -> Union[str, Iterable[str]]:
    parsed = morph.parse(word)[0]
    if parsed.normal_form != '' and parsed.normal_form is not None:
        nf = parsed.normal_form
    else:
        nf = word
    try:
        syns_probs = glove.most_similar(nf, topn=variation)
        if min_num == max_num:
            n = max_num
        else:
            n = np.random.choice(np.arange(min_num, max_num))
        syn_variants = [sp[0] for sp in syns_probs]
        scores = [sp[1] for sp in syns_probs]
        m = np.mean(scores)
        if include_source:
            syn_variants.append(word)
            scores.append(m)
        if include_source_normal_form:
            syn_variants.append(nf)
            scores.append(m)
        scores_shifted = scores - m
        scores_shifted[scores_shifted > 0] += positive_bonus
        probs = softmax(scores_shifted)
        syns = np.random.choice(syn_variants, n, p=probs, replace=repeats)
        # Add inflection
        if isinstance(syns, str):
            syns = [syns, ]
        return ' '.join(syns) if return_str else syns
    except KeyError:
        syn = np.random.choice([word, nf])
        return syn if return_str else [syn, ]

In [137]:
def aug_syn(text: str, include_source: bool = True, num_samples: int = 10, 
            num_mods: Optional[int] = None, percent_num_mods: float = 0.1) -> Iterable[str]:
    doc = Doc(text)
    doc.segment(segmenter)
    n = len(doc.tokens)
    if include_source:
        yield text
    if num_mods is None:
        if 0 < percent_num_mods <= 1:
            num_mods = int(n * percent_num_mods)
        else:
            raise ValueError(f'aug_syn :: Expected 0 < percent_num_mods: float <= 1, got percent_num_mods: {type(percent_num_mods)} = {percent_num_mods}')
    if num_mods > n:
        raise ValueError("Can't replace to synonyms more words than in the original text")
    for _ in range(num_samples):
        repls = set(np.random.choice(np.arange(0, n-1), num_mods, replace=False))
        tokens = []
        for i, token in enumerate(doc.tokens):
            if i in repls:
                tokens.append(sample_word_synonyms(token.text, repeats=np.random.choice([False, True])))
            else:
                tokens.append(token.text)
        yield ' '.join(tokens)

In [160]:
def mmap(datadict: Union[datasets.DatasetDict, Dict[str, datasets.Dataset]], func, 
         verbose: bool = True, dataset_name: Optional[str] = None,
         parallel: bool = True) -> datasets.DatasetDict:
    data = dict(datadict) if isinstance(datadict, datasets.DatasetDict) else datadict
    result = dict()
    parts = data.keys()
    if dataset_name is None:
        dataset_name = 'unnamed_dataset'
    partsv = '[' + ', '.join([f'{dataset_name}[{p}]' for p in parts]) + ']'
    if verbose:
        print(f'map({func.__name__}, {partsv})' + '{')
    for i, (part, dataset) in enumerate(data.items()):
        if verbose:
            print(f'    map({func.__name__}, {dataset_name}[{part}]):')
        cols = dict()
        i = 0
        if parallel:
            with Pool(processes=cpu_count()) as pool:
                samples = list(tqdm(pool.imap(func, dataset), total=len(dataset)))
            for sample in samples:
                if isinstance(sample, dict):
                    for col, value in sample.items():
                        if col not in cols:
                            cols[col] = [None, ] * i
                        else:
                            cols[col].append(value)
                    i += 1
                elif isinstance(sample, Iterable):
                    for subsample in sample:
                        for col, value in subsample.items():
                            if col not in cols:
                                cols[col] = [None, ] * i
                            else:
                                cols[col].append(value)
                        i += 1
        else:
            for sample in tqdm(dataset):
                new_sample = func(sample)
                if isinstance(new_sample, dict):
                    for col, value in new_sample.items():
                        if col not in cols:
                            cols[col] = [None, ] * i
                        else:
                            cols[col].append(value)
                    i += 1
                elif isinstance(new_sample, Iterable):
                    for subsample in new_sample:
                        for col, value in subsample.items():
                            if col not in cols:
                                cols[col] = [None, ] * i
                            else:
                                cols[col].append(value)
                        i += 1
        result[part] = datasets.Dataset.from_dict(cols)
    print('}')
    return datasets.DatasetDict(result)

def fmap(datadict, funcs) -> datasets.DatasetDict:
    f = lambda data: {col: (data[col] if f is None else f(data)) for col, f in funcs.items()}
    return mmap(datadict, f, parallel=False)

In [None]:
parus_questions = {
    'cause': {
        'По какой причине?', 
        'Почему это произошло?', 
        'Почему?', 
        'Чем это обосновано?',
        'Причина этому', 
        'Это произошло, потому что',
        'Вследствие чего это произошло?',
    },
    'effect': {
        'Что произошло в результате?',
        'В итоге получилось, что', 
        'Из-за этого',
        'Результатом этого всего стало',
        'Вследствие чего',
        'Повлияло на результат',
        'Следствие'
    }
}

def parus_preproc_aug(sample: Dict[str, Any]) -> Iterable[Dict[str, Any]]:
    for question in parus_questions[sample['question']]:
        yield {'label': sample['label'], 
               'input': tuple(f'{sample["premise"]} {question} {sample[f"choice{j}"]}' for j in (1, 2))}

def parus_preproc_aug_list(*args, **kwargs) -> List[Dict[str, Any]]:
    return list(parus_preproc_aug(*args, **kwargs))

def preproc_danetqa(sample) -> List[Dict[str, Any]]:
    return [{
        'input': f'{passage} {sample["question"]}',
        'label': sample['label']
    } for passage in aug_syn(sample['passage'])]

# def preprocess_sberquad(example):
#     example['start'] = example['answers']['answer_start'][0]
#     example['span'] = example['answers']['text'][0]
#     example['end'] = example['start'] + len(example['span'])
#     return example

data = {name: datasets.load_dataset("russian_super_glue", name) for name in {'lidirus', 'rcb', 'parus', 'muserc', 'terra', 'russe', 'rwsd', 'danetqa', 'rucos'}}
data.update({'xnli': datasets.load_dataset("xnli", 'ru'), 'sberquad': datasets.load_dataset("sberquad"), 'mokoron': load_mokoron()})
data['parus'] = mmap(data['parus'], parus_preproc_aug_list)
# data['danetqa'] = mmap(data['danetqa'], lambda x: [{'passage': v, 'question': x['question'], 'label': x['label']} for v in aug_syn(x['passage'])])
# data['danetqa'] = fmap(data['danetqa'], {'input': lambda x: f"{x['passage']} {x['question']}", 'label': None})
data['danetqa'] = mmap(data['danetqa'], preproc_danetqa)
# data['sberquad'] = data['sberquad'].map(preprocess_sberquad)

# TODO: print also validation / test sizes
for task in data:
  if 'train' in data[task]:
    print('Task id:', task)
    print('Train size:', data[task]['train'].num_rows)
    print('Train example:')
    pp.pprint(data[task]['train'][0])
    print()

# Tasks of interest in the data dictionary
# task_names = {'danetqa', 'xnli', 'parus', 'terra', 'sberquad', 'mokoron'}
task_names = {'danetqa', 'xnli', 'parus', 'terra', 'mokoron'}

Reusing dataset russian_super_glue (/root/.cache/huggingface/datasets/russian_super_glue/rcb/0.0.1/6fcadbfc1d8f0298b2f01ff277093772efe9e1b98f3c0df8ab5f511b3b9e13c9)


  0%|          | 0/3 [00:00<?, ?it/s]

Reusing dataset russian_super_glue (/root/.cache/huggingface/datasets/russian_super_glue/lidirus/0.0.1/6fcadbfc1d8f0298b2f01ff277093772efe9e1b98f3c0df8ab5f511b3b9e13c9)


  0%|          | 0/1 [00:00<?, ?it/s]

Reusing dataset russian_super_glue (/root/.cache/huggingface/datasets/russian_super_glue/terra/0.0.1/6fcadbfc1d8f0298b2f01ff277093772efe9e1b98f3c0df8ab5f511b3b9e13c9)


  0%|          | 0/3 [00:00<?, ?it/s]

Reusing dataset russian_super_glue (/root/.cache/huggingface/datasets/russian_super_glue/muserc/0.0.1/6fcadbfc1d8f0298b2f01ff277093772efe9e1b98f3c0df8ab5f511b3b9e13c9)


  0%|          | 0/3 [00:00<?, ?it/s]

Reusing dataset russian_super_glue (/root/.cache/huggingface/datasets/russian_super_glue/rwsd/0.0.1/6fcadbfc1d8f0298b2f01ff277093772efe9e1b98f3c0df8ab5f511b3b9e13c9)


  0%|          | 0/3 [00:00<?, ?it/s]

Reusing dataset russian_super_glue (/root/.cache/huggingface/datasets/russian_super_glue/russe/0.0.1/6fcadbfc1d8f0298b2f01ff277093772efe9e1b98f3c0df8ab5f511b3b9e13c9)


  0%|          | 0/3 [00:00<?, ?it/s]

Reusing dataset russian_super_glue (/root/.cache/huggingface/datasets/russian_super_glue/rucos/0.0.1/6fcadbfc1d8f0298b2f01ff277093772efe9e1b98f3c0df8ab5f511b3b9e13c9)


  0%|          | 0/3 [00:00<?, ?it/s]

Reusing dataset russian_super_glue (/root/.cache/huggingface/datasets/russian_super_glue/danetqa/0.0.1/6fcadbfc1d8f0298b2f01ff277093772efe9e1b98f3c0df8ab5f511b3b9e13c9)


  0%|          | 0/3 [00:00<?, ?it/s]

Reusing dataset russian_super_glue (/root/.cache/huggingface/datasets/russian_super_glue/parus/0.0.1/6fcadbfc1d8f0298b2f01ff277093772efe9e1b98f3c0df8ab5f511b3b9e13c9)


  0%|          | 0/3 [00:00<?, ?it/s]

Reusing dataset xnli (/root/.cache/huggingface/datasets/xnli/ru/1.1.0/243f155ecab4d4f6e82e4eeab62b8c6b1f7abfcb8ed7fcc1661be8e25b117404)


  0%|          | 0/3 [00:00<?, ?it/s]

Reusing dataset sberquad (/root/.cache/huggingface/datasets/sberquad/sberquad/1.0.0/62115d937acf2634cfacbfee10c13a7ee39df3ce345bb45af7088676f9811e77)


  0%|          | 0/3 [00:00<?, ?it/s]

Preparing [mokoron / positive] => Preparing [mokoron / negative] => Merging, shuffling & splitting [mokoron / positive <-> mokoron / negative] => Done
map(parus_preproc_aug_list, [unnamed_dataset[train], unnamed_dataset[validation], unnamed_dataset[test]]){
    map(parus_preproc_aug_list, unnamed_dataset[train]):


  0%|          | 0/400 [00:00<?, ?it/s]

    map(parus_preproc_aug_list, unnamed_dataset[validation]):


  0%|          | 0/100 [00:00<?, ?it/s]

    map(parus_preproc_aug_list, unnamed_dataset[test]):


  0%|          | 0/500 [00:00<?, ?it/s]

}
map(preproc_danetqa, [unnamed_dataset[train], unnamed_dataset[validation], unnamed_dataset[test]]){
    map(preproc_danetqa, unnamed_dataset[train]):


  0%|          | 0/1749 [00:00<?, ?it/s]

In [None]:
class MultitaskModel(transformers.PreTrainedModel):
    def __init__(self, encoder, task_models):
        super().__init__(transformers.PretrainedConfig())
        self.encoder = encoder
        self.task_models = nn.ModuleDict(task_models)

    @classmethod
    def create(cls, model_name, model_type_dict, model_config_dict, verbose: bool = False):
        """
        Creating a MultitaskModel using the model class and config objects
        from single-task models. Creating each single-task model, and 
        having them share the same encoder transformer.
        """
        cls.verbose = verbose
        shared_encoder = None
        task_models = dict()
        for task_name, model_type in model_type_dict.items():
            model = model_type.from_pretrained(
                model_name, 
                config=model_config_dict[task_name],
            )
            if shared_encoder is None:
                shared_encoder = getattr(model, cls.get_encoder_attr_name(model))
            else:
                setattr(model, cls.get_encoder_attr_name(model), shared_encoder)
            task_models[task_name] = model
        return cls(encoder=shared_encoder, task_models=task_models)

    @classmethod
    def get_encoder_attr_name(cls, model):
        """
        The encoder transformer is named differently in each model "architecture".
        This method lets us get the name of the encoder attribute
        """
        if model.__class__.__name__.startswith("Bert"):
            return "bert"
        elif model.__class__.__name__.startswith("Roberta"):
            return "roberta"
        elif model.__class__.__name__.startswith("Albert"):
            return "albert"
        else:
            raise NotImplementedError(f"Add support for new model {model._class_.__name__}")

    def forward(self, task_name, **kwargs):
        if self.verbose:
            print(f'data -> encoder -> {task_name}')
        return self.task_models[task_name](**kwargs)

In [None]:
# model_name = "sberbank-ai/ruRoberta-large" # Won't work in Colab Pro with a reasonable input sequence len and batch size
model_name = 'DeepPavlov/rubert-base-cased'
multitask_model = MultitaskModel.create(
    model_name=model_name,
    verbose=False,
    model_type_dict={
        "danetqa": transformers.AutoModelForSequenceClassification,
        "xnli": transformers.AutoModelForSequenceClassification,
        "parus": transformers.AutoModelForSequenceClassification,
        'terra': transformers.AutoModelForSequenceClassification,
        # 'sberquad': transformers.AutoModelForQuestionAnswering,
        'mokoron': transformers.AutoModelForSequenceClassification,
    },
    model_config_dict={
        "danetqa": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        "xnli": transformers.AutoConfig.from_pretrained(model_name, num_labels=3),
        "parus": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        'terra': transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        # 'sberquad': transformers.AutoConfig.from_pretrained(model_name),
        'mokoron': transformers.AutoConfig.from_pretrained(model_name, num_labels = 2)
    },
)

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

Some weights of the model checkpoint at DeepPavlov/rubert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were n

In [None]:
max_len = 512

def convert_nli(batch):
    inputs = list(zip(batch['hypothesis'], batch['premise']))
    features = tokenizer.batch_encode_plus(inputs, max_length=max_len, pad_to_max_length=True)
    features["labels"] = batch["label"]
    return features

def convert_simple(batch):
    features = tokenizer.batch_encode_plus(batch['input'], max_length=max_len, pad_to_max_length=True)
    features["labels"] = batch["label"]
    return features

def convert_mokoron(batch):
    features = tokenizer.batch_encode_plus(batch['text'], max_length=max_len, pad_to_max_length=True)
    features["labels"] = batch["label"]
    return features

# def convert_sberquad(batch):
#     tokenized_examples = tokenizer.batch_encode_plus(
#         batch['question'], batch['context'], 
#         max_length=max_len, pad_to_max_length=True, 
#         return_offsets_mapping=True
#     )
#     # Since one example might give us several features if it has a long context, we need a map from a feature to
#     # its corresponding example. This key gives us just that.
#     sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
#     # The offset mappings will give us a map from token to character position in the original context. This will
#     # help us compute the start_positions and end_positions.
#     offset_mapping = tokenized_examples.pop("offset_mapping")

#     # Let's label those examples!
#     tokenized_examples["start_positions"] = []
#     tokenized_examples["end_positions"] = []

#     for i, offsets in enumerate(offset_mapping):
#         # We will label impossible answers with the index of the CLS token.
#         input_ids = tokenized_examples["input_ids"][i]
#         cls_index = input_ids.index(tokenizer.cls_token_id)

#         # Grab the sequence corresponding to that example (to know what is the context and what is the question).
#         sequence_ids = tokenized_examples.sequence_ids(i)

#         # One example can give several spans, this is the index of the example containing this span of text.
#         sample_index = sample_mapping[i]
#         answers = examples["answers"][sample_index]
#         # If no answers are given, set the cls_index as answer.
#         if len(answers["answer_start"]) == 0:
#             tokenized_examples["start_positions"].append(cls_index)
#             tokenized_examples["end_positions"].append(cls_index)
#         else:
#             # Start/end character index of the answer in the text.
#             start_char = answers["answer_start"][0]
#             end_char = start_char + len(answers["text"][0])

#             # Start token index of the current span in the text.
#             token_start_index = 0
#             while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
#                 token_start_index += 1

#             # End token index of the current span in the text.
#             token_end_index = len(input_ids) - 1
#             while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
#                 token_end_index -= 1

#             # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
#             if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
#                 tokenized_examples["start_positions"].append(cls_index)
#                 tokenized_examples["end_positions"].append(cls_index)
#             else:
#                 # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
#                 # Note: we could go after the last offset if the answer is the last word (edge case).
#                 while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
#                     token_start_index += 1
#                 tokenized_examples["start_positions"].append(token_start_index - 1)
#                 while offsets[token_end_index][1] >= end_char:
#                     token_end_index -= 1
#                 tokenized_examples["end_positions"].append(token_end_index + 1)

#     return tokenized_examples

conversions = {
    'xnli': convert_nli,
    'terra': convert_nli,
    'parus': convert_simple,
    'danetqa': convert_simple,
    'mokoron': convert_mokoron,
    # 'sberquad': convert_sberquad
}

In [None]:
columns = ['input_ids', 'attention_mask', 'labels']
features = dict()
for task_name in task_names:
    print('>', task_name)
    features[task_name] = dict()
    for phase, phase_dataset in data[task_name].items():
        features[task_name][phase] = phase_dataset.map(
            conversions[task_name],
            batched=True,
            load_from_cache_file=False,
            num_proc = 2,
        )
        print(task_name, phase, len(phase_dataset), len(features[task_name][phase]))
        features[task_name][phase].set_format(type="torch", columns=columns)
        print(task_name, phase, len(phase_dataset), len(features[task_name][phase]))

> danetqa


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


danetqa train 1749 1749
danetqa train 1749 1749


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


danetqa validation 821 821
danetqa validation 821 821


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


danetqa test 805 805
danetqa test 805 805
> parus


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


parus train 400 400
parus train 400 400


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


parus validation 100 100
parus validation 100 100


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


parus test 500 500
parus test 500 500
> mokoron


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


mokoron train 181467 181467
mokoron train 181467 181467


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


mokoron test 22684 22684
mokoron test 22684 22684


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


mokoron validation 22683 22683
mokoron validation 22683 22683
> xnli


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


xnli train 392702 392702
xnli train 392702 392702


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


xnli test 5010 5010
xnli test 5010 5010


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


xnli validation 2490 2490
xnli validation 2490 2490
> terra


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


terra train 2616 2616
terra train 2616 2616


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


terra validation 307 307
terra validation 307 307


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


terra test 3198 3198
terra test 3198 3198


In [None]:
class NLPDataCollator(DefaultDataCollator):
    """
    Extending the existing DataCollator to work with NLP dataset batches
    """
    def collate_batch(self, features: List[Union[InputDataClass, Dict]]) -> Dict[str, torch.Tensor]:
        first = features[0]
        if isinstance(first, dict):
            # NLP data sets current works presents features as lists of dictionary
            # (one per example), so we  will adapt the collate_batch logic for that
            if "labels" in first and first["labels"] is not None:
                if first["labels"].dtype == torch.int64:
                    labels = torch.tensor([f["labels"] for f in features], dtype=torch.long)
                else:
                    labels = torch.tensor([f["labels"] for f in features], dtype=torch.float)
                batch = {"labels": labels}
            for k, v in first.items():
                if k != "labels" and v is not None and not isinstance(v, str):
                    batch[k] = torch.stack([f[k] for f in features])
            return batch
        else:
            # otherwise, revert to using the default collate_batch
            return DefaultDataCollator().collate_batch(features)


class StrIgnoreDevice(str):
    """
    This is a hack. The Trainer is going call .to(device) on every input
    value, but we need to pass in an additional `task_name` string.
    This prevents it from throwing an error
    """
    def to(self, device):
        return self


class DataLoaderWithTaskName:
    """
    Wrapper around a DataLoader to also yield a task name
    """
    def __init__(self, task_name, data_loader):
        self.task_name = task_name
        self.data_loader = data_loader

        self.batch_size = data_loader.batch_size
        self.dataset = data_loader.dataset

    def __len__(self):
        return len(self.data_loader)
    
    def __iter__(self):
        for batch in self.data_loader:
            batch["task_name"] = StrIgnoreDevice(self.task_name)
            yield batch


class MultitaskDataloader:
    """
    Data loader that combines and samples from multiple single-task
    data loaders.
    """
    def __init__(self, dataloader_dict):
        self.dataloader_dict = dataloader_dict
        self.num_batches_dict = {
            task_name: len(dataloader) 
            for task_name, dataloader in self.dataloader_dict.items()
        }
        self.task_name_list = list(self.dataloader_dict)
        self.dataset = [None] * sum(
            len(dataloader.dataset) 
            for dataloader in self.dataloader_dict.values()
        )

    def __len__(self):
        return sum(self.num_batches_dict.values())

    def __iter__(self):
        """
        For each batch, sample a task, and yield a batch from the respective
        task Dataloader.

        We use size-proportional sampling, but you could easily modify this
        to sample from some-other distribution.
        """
        task_choice_list = []
        for i, task_name in enumerate(self.task_name_list):
            task_choice_list += [i] * self.num_batches_dict[task_name]
        task_choice_list = np.array(task_choice_list)
        np.random.shuffle(task_choice_list)
        dataloader_iter_dict = {
            task_name: iter(dataloader) 
            for task_name, dataloader in self.dataloader_dict.items()
        }
        for task_choice in task_choice_list:
            task_name = self.task_name_list[task_choice]
            yield next(dataloader_iter_dict[task_name])    

class MultitaskTrainer(transformers.Trainer):
    def get_single_train_dataloader(self, task_name: str, train_dataset) -> DataLoaderWithTaskName:
        """
        Create a single-task data loader that also yields task names
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        train_sampler = (
            RandomSampler(train_dataset) 
            if self.args.local_rank == -1
            else DistributedSampler(train_dataset)
        )
        return DataLoaderWithTaskName(
            task_name=task_name,
            data_loader=DataLoader(
              train_dataset,
              batch_size=self.args.train_batch_size,
              sampler=train_sampler,
              collate_fn=self.data_collator.collate_batch,
            ),
        )

    def get_single_eval_dataloader(self, task_name: str, eval_dataset) -> DataLoaderWithTaskName:
        """
        Create a single-task data loader for evaluation that also yields task names
        """
        if self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires a eval_dataset.")
        eval_sampler = (
            RandomSampler(eval_dataset) 
            if self.args.local_rank == -1
            else DistributedSampler(eval_dataset)
        )
        return DataLoaderWithTaskName(
            task_name=task_name,
            data_loader=DataLoader(
              eval_dataset,
              batch_size=self.args.eval_batch_size,
              sampler=eval_sampler,
              collate_fn=self.data_collator.collate_batch,
            ),
        )


    def get_train_dataloader(self):
        """
        Returns a MultitaskDataloader, which is not actually a Dataloader
        but an iterable that returns a generator that samples from each 
        task Dataloader
        """
        return MultitaskDataloader({
            task_name: self.get_single_train_dataloader(task_name, task_dataset)
            for task_name, task_dataset in self.train_dataset.items()
        })

    def get_eval_dataloader(self, eval_dataset) -> MultitaskDataloader:
        """
        Returns a MultitaskDataloader, which is not actually a Dataloader
        but an iterable that returns a generator that samples from each 
        task Dataloader
        """
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        mt_dataloader = MultitaskDataloader({
            task_name: self.get_single_eval_dataloader(task_name, task_dataset)
            for task_name, task_dataset in eval_dataset.items()
        })
        # TODO Fix this hell
        # Idk why but I always catch an error
        # AttributeError: 'MultitaskDataloader' object has no attribute 'batch_size'
        mt_dataloader.batch_size = 128
        return mt_dataloader

In [None]:
train_dataset = {
    task_name: dataset["train"] 
    for task_name, dataset in features.items()
}
eval_dataset = {
    task_name: dataset["validation"] 
    for task_name, dataset in features.items()
}

In [None]:
trainer = MultitaskTrainer(
    model=multitask_model,
    args=transformers.TrainingArguments(
        output_dir=model_save_path,
        overwrite_output_dir=True,
        learning_rate=1e-5,
        # evaluation_strategy='steps',
        # logging_strategy='steps',
        logging_steps=100,
        eval_steps=1000,
        do_train=True,
        num_train_epochs=1,
        per_device_train_batch_size=15, 
        per_device_eval_batch_size=128, # TODO fix, this value isn't read where it's needed!
        save_steps=5000,
    ),
    data_collator=NLPDataCollator(),
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
trainer.train()

***** Running training *****
  Num examples = 578934
  Num Epochs = 1
  Instantaneous batch size per device = 15
  Total train batch size (w. parallel, distributed & accumulation) = 15
  Gradient Accumulation steps = 1
  Total optimization steps = 38598


Step,Training Loss,Validation Loss
1000,0.6386,0.132026
2000,0.5642,0.134654


***** Running Evaluation *****
  Num examples = 26401
  Batch size = 256
***** Running Evaluation *****
  Num examples = 26401
  Batch size = 256
***** Running Evaluation *****
  Num examples = 26401
  Batch size = 256


RuntimeError: ignored

In [None]:
# TODO: Evaluation on test dataset