In [2]:
import transformers
import numpy as np
import evaluate, torch, os
from accelerate import Accelerator
from torch.utils.data import DataLoader
from transformers import (DataCollatorWithPadding, 
                          DataCollatorForTokenClassification,
                          AutoTokenizer,
                          default_data_collator,
                          AutoConfig,
                          AutoModelForSequenceClassification,
                          AutoModelForTokenClassification,
                          AutoModelForQuestionAnswering,
                          PreTrainedTokenizerFast,
                          AdamW,
                          get_scheduler)

from datasets import load_dataset, DatasetDict, load_metric, concatenate_datasets
from torch import nn
from tqdm.auto import tqdm

### Simulataneous NER and NLI using TaskSampler

In [3]:
class TaskSampler():
    """ 
    Class for sampling batches from a dictionary of dataloaders according to a weighted sampling scheme.

    Dynamic task weights can be externally computed and set using the set_task_weights method,
    or, this class can be extended with methods and state state to implement a more complex sampling scheme.

    You probably/shouldn't need to use this with multiple GPUs, but if you do, you'll may need
    to extend/debug it yourself since the current implementation is not distributed-aware.
    
    Args:
        dataloader_dict (dict[str, DataLoader]): Dictionary of dataloaders to sample from.
        task_weights (list[float], optional): List of weights for each task. If None, uniform weights are used. Defaults to None.
        max_iters (int, optional): Maximum number of iterations. If None, infinite. Defaults to None.
    """
    def __init__(self, 
                *,
                dataloader_dict: dict[str, DataLoader],
                task_weights=None,
                max_iters=None):
        
        assert dataloader_dict is not None, "Dataloader dictionary must be provided."

        self.dataloader_dict = dataloader_dict
        self.task_names = list(dataloader_dict.keys())
        self.dataloader_iterators = self._initialize_iterators()
        self.task_weights = task_weights if task_weights is not None else self._get_uniform_weights()
        self.max_iters = max_iters if max_iters is not None else float("inf")
    
    # Initialization methods
    def _get_uniform_weights(self):
        return [1/len(self.task_names) for _ in self.task_names]
    
    def _initialize_iterators(self):
        return {name:iter(dataloader) for name, dataloader in self.dataloader_dict.items()}
    
    # Weight getter and setter methods (NOTE can use these to dynamically set weights)
    def set_task_weights(self, task_weights):
        assert sum(self.task_weights) == 1, "Task weights must sum to 1."
        self.task_weights = task_weights
    
    def get_task_weights(self):
        return self.task_weights

    # Sampling logic
    def _sample_task(self):
        return np.random.choice(self.task_names, p=self.task_weights)
    
    def _sample_batch(self, task):
        try:
            return self.dataloader_iterators[task].__next__()
        except StopIteration:
            print(f"Restarting iterator for {task}")
            self.dataloader_iterators[task] = iter(self.dataloader_dict[task])
            return self.dataloader_iterators[task].__next__()
        except KeyError as e:
            print(e)
            raise KeyError("Task not in dataset dictionary.")
    
    # Iterable interface
    def __iter__(self):
        self.current_iter = 0
        return self
    
    def __next__(self):
        if self.current_iter >= self.max_iters:
            raise StopIteration
        else:
            self.current_iter += 1
        task = self._sample_task()
        batch = self._sample_batch(task)
        return task, batch

In [4]:
model_checkpoint = "distilbert-base-uncased"
batch_size = 16

In [5]:
ner_datasets = load_dataset("Babelscape/wikineural")
nli_datasets = load_dataset("multi_nli")
squad_datasets = load_dataset("squad_v2")

Found cached dataset parquet (/home/jhdavis/.cache/huggingface/datasets/Babelscape___parquet/Babelscape--wikineural-579d1dc98d2a6b93/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/27 [00:00<?, ?it/s]

Found cached dataset multi_nli (/home/jhdavis/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset squad_v2 (/home/jhdavis/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
ner_label_list = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
ner_labels_vocab = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
ner_labels_vocab_reverse = {v:k for k,v in ner_labels_vocab.items()}
#labels_vocab_reverse

In [7]:
ner_train_dataset = concatenate_datasets([ner_datasets["train_en"]])
ner_val_dataset = concatenate_datasets([ner_datasets["val_en"]]) 
ner_test_dataset = concatenate_datasets([ner_datasets["test_en"]])

nli_train_dataset = nli_datasets["train"]
nli_val_dataset = nli_datasets["validation_matched"]

squad_train_dataset = squad_datasets["train"].select(range(round(0.1*len(squad_datasets["train"]))))
squad_val_dataset = squad_datasets["validation"]

### Preprocessing the data

In [8]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
label_all_tokens = False

In [9]:
max_length = 384 # The maximum length of a feature (question and context)
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.
pad_on_right = tokenizer.padding_side == "right"

def prepare_features(examples):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [10]:
train_features = prepare_features(squad_train_dataset[:5])
squad_train_tokenized = squad_train_dataset.map(prepare_features, batched=True, num_proc=4, remove_columns=squad_train_dataset.column_names)

Loading cached processed dataset at /home/jhdavis/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d/cache-4e230f77dbb0bd71_*_of_00004.arrow


In [11]:
# This is also only for Squad...
def prepare_validation_features(examples):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

    # We keep the example_id that gave us this feature and we will store the offset mappings.
    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1 if pad_on_right else 0

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])

        # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
        # position is part of the context or not.
        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
        ]

    return tokenized_examples

In [12]:
squad_val_features = squad_val_dataset.map(
    prepare_validation_features,
    batched=True, num_proc=4,
    remove_columns=squad_val_dataset.column_names
)

Loading cached processed dataset at /home/jhdavis/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d/cache-f6dc69689d034286_*_of_00004.arrow


In [13]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [14]:
ner_train_tokenized = ner_train_dataset.map(tokenize_and_align_labels, batched=True, num_proc=4, remove_columns=ner_train_dataset.column_names)
ner_val_tokenized = ner_val_dataset.map(tokenize_and_align_labels, batched=True, num_proc=4, remove_columns=ner_val_dataset.column_names)
ner_test_tokenized = ner_test_dataset.map(tokenize_and_align_labels, batched=True, num_proc=4, remove_columns=ner_test_dataset.column_names)

Loading cached processed dataset at /home/jhdavis/.cache/huggingface/datasets/Babelscape___parquet/Babelscape--wikineural-579d1dc98d2a6b93/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-f82d460afdcc6927_*_of_00004.arrow
Loading cached processed dataset at /home/jhdavis/.cache/huggingface/datasets/Babelscape___parquet/Babelscape--wikineural-579d1dc98d2a6b93/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-c8bf1c020e3caab9_*_of_00004.arrow
Loading cached processed dataset at /home/jhdavis/.cache/huggingface/datasets/Babelscape___parquet/Babelscape--wikineural-579d1dc98d2a6b93/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-1b57291c63d37a7e_*_of_00004.arrow


In [15]:
sentence1_key = "premise"
sentence2_key = "hypothesis"
padding = "max_length"
max_seq_length = 128

def preprocess_function(examples):
    # Tokenize the texts
    texts = (
        (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
    )
    result = tokenizer(*texts, padding=padding, max_length=max_seq_length, truncation=True)

    if "label" in examples:
        result["labels"] = examples["label"]
    return result

In [16]:
nli_train_tokenized = nli_train_dataset.map(preprocess_function, batched=True, num_proc=4, remove_columns=nli_train_dataset.column_names)
nli_val_tokenized = nli_val_dataset.map(preprocess_function, batched=True, num_proc=4, remove_columns=nli_val_dataset.column_names)

Loading cached processed dataset at /home/jhdavis/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-0a27a65efa67b7c4_*_of_00004.arrow
Loading cached processed dataset at /home/jhdavis/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-8076d56d45c61aae_*_of_00004.arrow


In [17]:
ner_model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, 
                                                            num_labels=len(ner_label_list), 
                                                            label2id=ner_labels_vocab, 
                                                            id2label=ner_labels_vocab_reverse)
nli_config = AutoConfig.from_pretrained(model_checkpoint, num_labels=3, finetuning_task="mnli")
nli_model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=nli_config)
squad_model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

model_dict = {"nli": nli_model, "ner": ner_model, "squad": squad_model}

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForTokenClassification: ['vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN t

In [18]:
accelerator = Accelerator()
device = accelerator.device

for model in model_dict.values():
    model.distilbert = nli_model.distilbert
    model.to(device)

In [19]:
nli_dataloader = DataLoader(nli_train_tokenized, collate_fn=default_data_collator, batch_size=16)
ner_dataloader = DataLoader(ner_train_tokenized, collate_fn=DataCollatorForTokenClassification(tokenizer=tokenizer), batch_size=16)
squad_dataloader = DataLoader(squad_train_tokenized, collate_fn=default_data_collator, batch_size=8)
dataloader_dict = {"nli": nli_dataloader, "ner": ner_dataloader, "squad": squad_dataloader}

nli_val_dataloader = DataLoader(nli_val_tokenized, collate_fn=default_data_collator, batch_size=16)
ner_val_dataloader = DataLoader(ner_val_tokenized, collate_fn=DataCollatorForTokenClassification(tokenizer=tokenizer), batch_size=16)
val_dataloader_dict = {"nli": nli_val_dataloader, "ner": ner_val_dataloader}

In [20]:
optimizers = {
    "nli": AdamW(nli_model.parameters(), lr=2e-5),
    "ner": AdamW(ner_model.parameters(), lr=2e-5),
    "squad": AdamW(squad_model.parameters(), lr=2e-5)
}



In [21]:
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizers["ner"],
    num_warmup_steps=0,
    num_training_steps=min(len(nli_dataloader), len(ner_dataloader)) + len(squad_dataloader)
)

In [22]:
task_sampler = TaskSampler(dataloader_dict=dataloader_dict, task_weights=[0.5, 0.5, 0.0])

In [23]:
(   nli_model, ner_model, squad_model, 
    optimizers["nli"], optimizers["ner"], optimizers["squad"], 
    nli_dataloader, ner_dataloader, squad_dataloader,
    nli_val_dataloader, ner_val_dataloader,
    lr_scheduler, task_sampler
) = accelerator.prepare(
    nli_model, ner_model, squad_model,
    optimizers["nli"], optimizers["ner"], optimizers["squad"],
    nli_dataloader, ner_dataloader, squad_dataloader,
    nli_val_dataloader, ner_val_dataloader,
    lr_scheduler, task_sampler
)

In [24]:
ner_metric = evaluate.load("seqeval")
nli_metric = evaluate.load("accuracy")
squad_metric = evaluate.load("squad_v2")

# This function is for NER only
def ner_compute_metrics():
    results = ner_metric.compute()
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

metric_dict = {"ner": ner_metric, "nli": nli_metric, "squad": squad_metric}

In [25]:
# This is for NER only
def get_labels(predictions, references):
    y_pred = predictions.detach().cpu().clone().numpy()
    y_true = references.detach().cpu().clone().numpy()
    true_predictions = [
        [ner_label_list[p] for (p, l) in zip(pred, gold_label) if l != -100]
        for pred, gold_label in zip(y_pred, y_true)
    ]
    true_labels = [
        [ner_label_list[l] for (p, l) in zip(pred, gold_label) if l != -100]
        for pred, gold_label in zip(y_pred, y_true)
    ]
    return true_predictions, true_labels

In [26]:
# This is for Squad only
def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30):
    all_start_logits, all_end_logits = raw_predictions
    # Build a map example to its corresponding features.
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    # The dictionaries we have to fill.
    predictions = collections.OrderedDict()

    # Logging.
    print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")

    # Let's loop over all the examples!
    for example_index, example in enumerate(tqdm(examples)):
        # Those are the indices of the features associated to the current example.
        feature_indices = features_per_example[example_index]

        min_null_score = None # Only used if squad_v2 is True.
        valid_answers = []
        
        context = example["context"]
        # Looping through all the features associated to the current example.
        for feature_index in feature_indices:
            # We grab the predictions of the model for this feature.
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # This is what will allow us to map some the positions in our logits to span of texts in the original
            # context.
            offset_mapping = features[feature_index]["offset_mapping"]

            # Update minimum null prediction.
            cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
            feature_null_score = start_logits[cls_index] + end_logits[cls_index]
            if min_null_score is None or min_null_score < feature_null_score:
                min_null_score = feature_null_score

            # Go through all possibilities for the `n_best_size` greater start and end logits.
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                    ):
                        continue
                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue

                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    valid_answers.append(
                        {
                            "score": start_logits[start_index] + end_logits[end_index],
                            "text": context[start_char: end_char]
                        }
                    )
        
        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
        else:
            # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
            # failure.
            best_answer = {"text": "", "score": 0.0}
        
        # Let's pick our final answer: the best one or the null answer (only for squad_v2)
        if not squad_v2:
            predictions[example["id"]] = best_answer["text"]
        else:
            answer = best_answer["text"] if best_answer["score"] > min_null_score else ""
            predictions[example["id"]] = answer

    return predictions

In [28]:
num_epochs = 1
num_updates_per_epoch = min(len(nli_dataloader), len(ner_dataloader)) + len(squad_dataloader)
completed_steps = 0
completed_epochs = 0
progress_bar = tqdm(range(num_updates_per_epoch*num_epochs))

resume_from_checkpoint = ""
resume_step = None
starting_epoch = 0
checkpointing_steps = 2480

if resume_from_checkpoint != False:
    if resume_from_checkpoint == "":
        # Get the most recent checkpoint
        dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
        dirs.sort(key=os.path.getctime)
        path = dirs[-1]  # Sorts folders by date modified, most recent checkpoint is the last
        accelerator.print(f"Resumed from checkpoint: {path}")
        accelerator.load_state(path)
    else:
        accelerator.print(f"Resumed from checkpoint: {resume_from_checkpoint}")
        accelerator.load_state(resume_from_checkpoint)
        path = os.path.basename(resume_from_checkpoint)
    
    # Extract `epoch_{i}` or `step_{i}`
    training_difference = os.path.splitext(path)[0]

    if "epoch" in training_difference:
        starting_epoch = int(training_difference.replace("epoch_", "")) + 1
        resume_step = None
    else:
        resume_step = int(training_difference.replace("step_", ""))
        starting_epoch = resume_step // num_updates_per_epoch*num_epochs
        resume_step -= starting_epoch * num_updates_per_epoch

for epoch in range(starting_epoch, num_epochs):
    for model in model_dict.values(): model.train()
    # TRAINING LOOP
    for step, (task, batch) in enumerate(task_sampler):
        # skip to resume point if applicable
        if resume_from_checkpoint != False and epoch == starting_epoch and resume_step is not None and step < resume_step:
            completed_steps += 1
            if step % 10 == 0 and step > 0: progress_bar.update(10)
            if step == resume_step: progress_bar.update(step % 10)
            continue
        #Set this to do dynamic weights
        #task_sampler.set_task_weights([0.0, 0.5])
        if step >= min(len(nli_dataloader), len(ner_dataloader)):
            # Switch to only training the MRC head
            task_sampler.set_task_weights([0.0, 0.0, 1.0])
        optimizers[task].zero_grad()
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model_dict[task](**batch)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizers[task].step()
        lr_scheduler.step()
        progress_bar.update(1)
        completed_steps += 1
        
        if isinstance(checkpointing_steps, int):
            if completed_steps % checkpointing_steps == 0:
                output_dir = f"step_{completed_steps}"
                accelerator.save_state(output_dir)
        
        if completed_steps >= num_updates_per_epoch: break
    
    for task in ("ner", "nli", "squad"):
        progress_bar_eval = tqdm(range(len(val_dataloader_dict[task])))
        model_dict[task].eval()
        samples_seen = 0
        all_start_logits = []
        all_end_logits = []
        # EVAL LOOP
        for step, batch in enumerate(val_dataloader_dict[task]):
            batch = {k: v.to(device) for k, v in batch.items()}
            if task == "squad":
                with torch.no_grad():
                    outputs = model(**batch)
                    start_logits = outputs.start_logits
                    end_logits = outputs.end_logits
                    if not args.pad_to_max_length:  # necessary to pad predictions and labels for being gathered
                        start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
                        end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)
                    all_start_logits.append(accelerator.gather_for_metrics(start_logits).cpu().numpy())
                    all_end_logits.append(accelerator.gather_for_metrics(end_logits).cpu().numpy())
            else:
                with torch.no_grad():
                    outputs = model_dict[task](**batch)
                preds = outputs.logits.argmax(dim=-1)
                labels = batch["labels"]
                preds, labels = accelerator.gather((preds, labels))
                # If we are in a multiprocess environment, the last batch has duplicates
                if accelerator.num_processes > 1:
                    if step == len(val_dataloader_dict[task]) - 1:
                        preds = preds[: len(val_dataloader_dict[task].dataset) - samples_seen]
                        labels = labels[: len(val_dataloader_dict[task].dataset) - samples_seen]
                    else:
                        samples_seen += labels_gathered.shape[0]
                if task == "ner": 
                    preds, labels = get_labels(preds, labels)
                if len(preds) != len(labels):
                    print("Skipping a bad test")
                    continue
                metric_dict[task].add_batch(
                    predictions=preds,
                    references=labels
                )
                progress_bar_eval.update(1)
        if task == "squad":
            max_len = max([x.shape[1] for x in all_start_logits])  # Get the max_length of the tensor
            
            # concatenate the numpy array
            start_logits_concat = create_and_fill_np_array(all_start_logits, predict_dataset, max_len)
            end_logits_concat = create_and_fill_np_array(all_end_logits, predict_dataset, max_len)

            # delete the list of numpy arrays
            del all_start_logits
            del all_end_logits

            outputs_numpy = (start_logits_concat, end_logits_concat)
            prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)
            predict_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
            print(f"Predict metrics: {predict_metric}")
        else:
            eval_metric = ner_compute_metrics() if task == "ner" else metric_dict[task].compute()
            print(f"{task} epoch {epoch}:", eval_metric)

    if checkpointing_steps == "epoch":
        output_dir = f"epoch_{epoch}"
        accelerator.save_state(output_dir)

  0%|          | 0/7444 [00:00<?, ?it/s]

Resumed from checkpoint: step_5790
Restarting iterator for squad
Restarting iterator for squad
Restarting iterator for squad
Restarting iterator for squad
Restarting iterator for squad


  0%|          | 0/725 [00:00<?, ?it/s]

ner epoch 0: {'precision': 0.5589257317863243, 'recall': 0.48295165394402034, 'f1': 0.5181686641731962, 'accuracy': 0.937776369262891}


  0%|          | 0/614 [00:00<?, ?it/s]

nli epoch 0: {'accuracy': 0.5762608252674478}


KeyError: 'squad'