In [None]:
num_epochs = 5
task = "mnli-mm" # which can be one of the tasks
checkpoint = "bert-base-uncased"

In [None]:
actual_task = "mnli" if task == "mnli-mm" else task
dataset = load_dataset("glue", actual_task)
metric = load_metric('glue', actual_task)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=True)

In [None]:
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mnli-mm": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}
sentence1_key, sentence2_key = task_to_keys[task]

In [None]:
def tokenize_function(example):
    if sentence2_key is None:
        return tokenizer(example[sentence1_key], truncation=True)
    return tokenizer(example[sentence1_key], example[sentence2_key], truncation=True)

In [None]:
num_labels = 3 if task.startswith("mnli") else 1 if task=="stsb" else 2
tokenized_datasets = dataset.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

tokenized_datasets = tokenized_datasets.remove_columns(["sentence1", "sentence2", "idx"]) 
# remove columns based on different tasks
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
tokenized_datasets["train"].column_names

In [None]:
train_dataloader = DataLoader(tokenized_datasets["train"],
                              shuffle=True, batch_size=128, collate_fn=data_collator)
eval_dataloader = DataLoader(tokenized_datasets["validation"],
                             batch_size=128, collate_fn=data_collator)

for batch in train_dataloader:
    break
{k: v.shape for k, v in batch.items()}

### Codes borrowed and adjusted from

1.   [NLP lecture codes](https://github.com/amir-jafari/NLP/blob/master/Lecture_09/Lecture%20Code/12-training.py)
2.   [Hugging Face transformers notebooks](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/text_classification.ipynb#scrollTo=7k8ge1L1IrJk)
3.   [Ashwin Geet D'Sa's answer on Stackoverflow](https://stackoverflow.com/questions/65205582/how-can-i-add-a-bi-lstm-layer-on-top-of-bert-model?rq=1)

