# Commonsense QA: Baselines

EECS 595 Final Project, Task 1: Commonsense QA

Credit: Ziqiao Ma

Last update: 2020.12.1

# Setup

## Colab setups

Run this cell load the autoreload extension.

In [1]:
%load_ext autoreload
%autoreload 2

Run the following cell to mount your Google Drive.

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Fill in the Google Drive path where you uploaded the file.

In [3]:
GOOGLE_DRIVE_PATH_AFTER_MYDRIVE = 'Colab Notebooks/eecs595/commonsense_qa'

Test if files are located.

In [4]:
import os
import sys

GOOGLE_DRIVE_PATH = os.path.join('drive', 'My Drive', GOOGLE_DRIVE_PATH_AFTER_MYDRIVE)
sys.path.append(GOOGLE_DRIVE_PATH)
print(os.listdir(GOOGLE_DRIVE_PATH))

['roberta.ipynb', 'csqa-graph-reasoning.ipynb', 'csqa-baseline.ipynb']


## Dependency installation

In [5]:
import json
import codecs
import argparse
from copy import deepcopy
from tqdm import tqdm, trange

import random
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

Install `datasets`

In [6]:
!pip install datasets
from datasets import load_dataset

Collecting datasets
[?25l  Downloading https://files.pythonhosted.org/packages/1a/38/0c24dce24767386123d528d27109024220db0e7a04467b658d587695241a/datasets-1.1.3-py3-none-any.whl (153kB)
[K     |██▏                             | 10kB 18.7MB/s eta 0:00:01[K     |████▎                           | 20kB 22.9MB/s eta 0:00:01[K     |██████▍                         | 30kB 16.4MB/s eta 0:00:01[K     |████████▌                       | 40kB 14.4MB/s eta 0:00:01[K     |██████████▋                     | 51kB 9.0MB/s eta 0:00:01[K     |████████████▉                   | 61kB 9.8MB/s eta 0:00:01[K     |███████████████                 | 71kB 9.8MB/s eta 0:00:01[K     |█████████████████               | 81kB 10.0MB/s eta 0:00:01[K     |███████████████████▏            | 92kB 10.1MB/s eta 0:00:01[K     |█████████████████████▎          | 102kB 8.5MB/s eta 0:00:01[K     |███████████████████████▌        | 112kB 8.5MB/s eta 0:00:01[K     |█████████████████████████▋      | 122kB 8.5MB/s

Install `sentencepiece` for `XLNetTokenizer`

In [7]:
!pip install sentencepiece
import sentencepiece

Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/e5/2d/6d4ca4bef9a67070fa1cac508606328329152b1df10bdf31fb6e4e727894/sentencepiece-0.1.94-cp36-cp36m-manylinux2014_x86_64.whl (1.1MB)
[K     |▎                               | 10kB 22.1MB/s eta 0:00:01[K     |▋                               | 20kB 17.0MB/s eta 0:00:01[K     |▉                               | 30kB 14.1MB/s eta 0:00:01[K     |█▏                              | 40kB 12.9MB/s eta 0:00:01[K     |█▌                              | 51kB 8.6MB/s eta 0:00:01[K     |█▊                              | 61kB 8.1MB/s eta 0:00:01[K     |██                              | 71kB 9.1MB/s eta 0:00:01[K     |██▍                             | 81kB 10.0MB/s eta 0:00:01[K     |██▋                             | 92kB 9.2MB/s eta 0:00:01[K     |███                             | 102kB 8.4MB/s eta 0:00:01[K     |███▎                            | 112kB 8.4MB/s eta 0:00:01[K     |███▌              

Install `transformers`

In [8]:
!pip install transformers
# !pip install transformers==2.0.0

from transformers import (AdamW, get_linear_schedule_with_warmup, AutoModelForQuestionAnswering,
                          BertConfig, BertForMultipleChoice, BertTokenizer,
                          XLNetConfig, XLNetForMultipleChoice, XLNetTokenizer,
                          RobertaConfig, RobertaForMultipleChoice, RobertaTokenizer,
                          GPT2Config, GPT2ForSequenceClassification, GPT2Tokenizer)

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/ed/db/98c3ea1a78190dac41c0127a063abf92bd01b4b0b6970a6db1c2f5b66fa0/transformers-4.0.1-py3-none-any.whl (1.4MB)
[K     |████████████████████████████████| 1.4MB 8.6MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 29.5MB/s 
Collecting tokenizers==0.9.4
[?25l  Downloading https://files.pythonhosted.org/packages/0f/1c/e789a8b12e28be5bc1ce2156cf87cb522b379be9cadc7ad8091a4cc107c4/tokenizers-0.9.4-cp36-cp36m-manylinux2010_x86_64.whl (2.9MB)
[K     |████████████████████████████████| 2.9MB 40.8MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893261 sha256=9ba74622df8c3

## Helper Functions

In [9]:
SEED = 0

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def load_model(model='all'):
    if model == 'bert':
        return BertConfig, BertForMultipleChoice, BertTokenizer
    elif model == 'xlnet':
        return XLNetConfig, XLNetForMultipleChoice, XLNetTokenizer
    elif model == 'roberta':
        return RobertaConfig, RobertaForMultipleChoice, RobertaTokenizer
    elif model == 'gpt2':
        raise NotImplemented
        # return GPT2Config, AutoModelForQuestionAnswering, GPT2Tokenizer
    raise NotImplemented


def load_optimizer(args, model, train_size):
    num_training_steps = train_size // args.num_train_epochs
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0}
    ]

    optimizer = AdamW(
        optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=num_training_steps)

    return model, optimizer, scheduler

# Benchmark

## Dataset

As a question answering benchmark, [Commonsense QA](https://arxiv.org/abs/1811.00937) presents a natural language question $Q$ of $m$ tokens $\{q_1,q_2,\cdots,q_m\}$ and 5 choices $\{a_1,a_2,\cdots,a_5\}$ labeled with $\{A,B,\cdots,E\}$ regarding each question [1]. Notably, the questions do not entail a inference basis in themselves, so the lack of evidence requires the model to hold a comprehensive understanding on common sense knowledge and a strong reasoning ability to make the right choice.

In [10]:
def load_data(dataset='commonsense_qa', preview=-1):

    assert dataset in {'commonsense_qa', 'conv_entail', 'eat'}

    if dataset == 'commonsense_qa':
        ds = load_dataset('commonsense_qa')

        if preview > 0:
            print('\nLoading an example...')
            data_tr = ds.data['train']
            question = data_tr['question']
            choices = data_tr['choices']
            answerKey = data_tr['answerKey']
            print(question[preview])
            for label, text in zip(choices[preview]['label'], choices[preview]['text']):
                print(label, text)
            print('Ans:', answerKey[preview])

    elif dataset == 'conv_entail':
        dev_file = '/content/drive/Shareddrives/EECS595-Fall2020/Final_Project_Common/Conversational_Entailment/dev_set.json'
        act_file = '/content/drive/Shareddrives/EECS595-Fall2020/Final_Project_Common/Conversational_Entailment/act_tag.json'
        dev_set = codecs.open(dev_file, 'r', encoding='utf-8').read()
        act_tag = codecs.open(act_file, 'r', encoding='utf-8').read()
        ds = json.loads(dev_set), json.loads(act_tag)

        if preview > 0:
            print('Preview not yet implemented for this dataset.')

    else:
        file_name = '/content/drive/Shareddrives/EECS595-Fall2020/Final_Project_Common/EAT/eat_train.json'
        eat = codecs.open(file_name, 'r', encoding='utf-8').read()
        ds = json.loads(eat)

        if preview > 0:
            print('\nLoading an example...')
            story = ds[preview]['story']
            label = ds[preview]['label']
            bp = ds[preview]['breakpoint']
            for line in story:
                print(line)
            print(label)
            print(bp)

    return ds

Run the following code to preview the dataset:

In [11]:
ds = load_data(dataset='commonsense_qa', preview=5)
print('\nDataset statistics:')
print(ds)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1586.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1055.0, style=ProgressStyle(description…

Using custom data configuration default



Downloading and preparing dataset commonsense_qa/default (download: 4.46 MiB, generated: 2.08 MiB, post-processed: Unknown size, total: 6.54 MiB) to /root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3785890.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=423148.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=471653.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset commonsense_qa downloaded and prepared to /root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29. Subsequent calls will reuse this data.

Loading an example...
What home entertainment equipment requires cable?
A radio shack
B substation
C cabinet
D television
E desk
Ans: D

Dataset statistics:
DatasetDict({
    train: Dataset({
        features: ['answerKey', 'question', 'choices'],
        num_rows: 9741
    })
    validation: Dataset({
        features: ['answerKey', 'question', 'choices'],
        num_rows: 1221
    })
    test: Dataset({
        features: ['answerKey', 'question', 'choices'],
        num_rows: 1140
    })
})


## Data Preprocessing

In [12]:
class InputExample(object):
    """
    A single multiple choice question.
    """

    def __init__(self, example_id, question, answers, label):
        self.example_id = example_id
        self.question = question
        self.answers = answers
        self.label = label


class InputFeatures(object):
    """
    A single feature converted from an example.
    """

    def __init__(self, example_id, choices_features, label):
        self.example_id = example_id
        self.label = label
        self.choices_features = [
            {'input_ids': input_ids, 'input_mask': input_mask, 'segment_ids': segment_ids}
            for _, input_ids, input_mask, segment_ids in choices_features
        ]


class CommonsenseQAProcessor:
    """
    A Commonsense QA Data Processor
    """

    def __init__(self):
        self.dataset = None
        self.labels = [0, 1, 2, 3, 4]
        self.LABELS = ['A', 'B', 'C', 'D', 'E']

    def get_split(self, split='train'):
        if self.dataset is None:
            self.dataset = load_data(dataset='commonsense_qa', preview=-1)
        return self.dataset[split]

    def create_examples(self, split='train'):
        examples = []
        data_tr = self.get_split(split)
        example_id = 0

        for question, choices, answerKey in zip(data_tr['question'], data_tr['choices'], data_tr['answerKey']):
            answers = np.array(choices['text'])
            label = self.LABELS.index(answerKey)
            examples.append(InputExample(
                example_id=example_id, question=question,
                answers=answers, label=label
            ))
            example_id += 1

        return examples


def truncate_seq_pair(tokens_a, tokens_b, max_length):
    """
    Truncates a sequence pair in place to the maximum length.

    This is a simple heuristic which will always truncate the longer sequence one token at a time.
    This makes more sense than truncating an equal percent of tokens from each,
    since if one sequence is very short then each token that's truncated
    likely contains more information than a longer sequence.

    However, since we'd better not to remove tokens of options and questions,
    you can choose to use a bigger length or only pop from context
    """

    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            warning = 'Attention! you are removing from token_b (swag task is ok). ' \
                      'If you are training ARC and RACE (you are popping question + options), ' \
                      'you need to try to use a bigger max seq length!'
            print(warning)
            tokens_b.pop()


def examples_to_features(examples, label_list, max_seq_length, tokenizer,
                         cls_token_at_end=False,
                         cls_token='[CLS]',
                         cls_token_segment_id=1,
                         sep_token='[SEP]',
                         sequence_a_segment_id=0,
                         sequence_b_segment_id=1,
                         sep_token_extra=False,
                         pad_token_segment_id=0,
                         pad_on_left=False,
                         pad_token=0,
                         mask_padding_with_zero=True):
    """
    Convert Commonsense QA examples to features.

    The convention in BERT is:
    (a) For sequence pairs:
    tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
    type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1

    (b) For single sequences:
    tokens:   [CLS] the dog is hairy . [SEP]
    type_ids:   0   0   0   0  0     0   0

    Where "type_ids" are used to indicate whether this is the first sequence or the second sequence.
    The embedding vectors for `type=0` and `type=1` were learned during pre-training
    and are added to the word piece embedding vector (and position vector).
    This is not *strictly* necessary since the [SEP] token unambiguously separates the sequences,
    but it makes it easier for the model to learn the concept of sequences.

    For classification tasks, the first vector (corresponding to [CLS]) is used as as the "sentence vector".
    Note that this only makes sense because the entire model is fine-tuned.
    """

    label_map = {label: i for i, label in enumerate(label_list)}

    features = []
    for (ex_index, example) in tqdm(enumerate(examples), desc="Converting examples to features", disable=True):

        choices_features = []
        for ending_idx, (question, answers) in enumerate(zip(example.question, example.answers)):

            tokens_a = tokenizer.tokenize(example.question)
            if example.question.find("_") != -1:
                tokens_b = tokenizer.tokenize(example.question.replace("_", answers))
            else:
                tokens_b = tokenizer.tokenize(answers)

            special_tokens_count = 4 if sep_token_extra else 3
            truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count)

            tokens = tokens_a + [sep_token]
            if sep_token_extra:
                tokens += [sep_token]

            segment_ids = [sequence_a_segment_id] * len(tokens)

            if tokens_b:
                tokens += tokens_b + [sep_token]
                segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)

            if cls_token_at_end:
                tokens = tokens + [cls_token]
                segment_ids = segment_ids + [cls_token_segment_id]
            else:
                tokens = [cls_token] + tokens
                segment_ids = [cls_token_segment_id] + segment_ids

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens.
            # Only real tokens are attended to.
            input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

            # Zero-pad up to the sequence length.
            padding_length = max_seq_length - len(input_ids)

            if pad_on_left:
                input_ids = ([pad_token] * padding_length) + input_ids
                input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
                segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids

            else:
                input_ids = input_ids + ([pad_token] * padding_length)
                input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
                segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            choices_features.append((tokens, input_ids, input_mask, segment_ids))

        label = label_map[example.label]

        if ex_index < 0:
            print("*** Example ***")
            print("race_id: {}".format(example.example_id))
            for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
                print("choice: {}".format(choice_idx))
                print("tokens: {}".format(' '.join(tokens)))
                print("input_ids: {}".format(' '.join(map(str, input_ids))))
                print("input_mask: {}".format(' '.join(map(str, input_mask))))
                print("segment_ids: {}".format(' '.join(map(str, segment_ids))))
                print("label: {}".format(label))

        features.append(InputFeatures(
            example_id=example.example_id,
            choices_features=choices_features,
            label=label
        ))

    return features


def load_features(args, tokenizer, mode='train'):
    """
    Load the processed Commonsense QA dataset
    """

    def select_field(feature_list, field_name):
        return [
            [choice[field_name] for choice in feature.choices_features]
            for feature in feature_list
        ]

    assert mode in {'train', 'validation', 'test'}
    print("Creating features from dataset...")

    processor = CommonsenseQAProcessor()
    label_list = processor.labels
    examples = processor.create_examples(split=mode)

    print("Training number:", str(len(examples)))
    features = examples_to_features(examples, label_list, args.max_seq_length, tokenizer,
                                    cls_token_at_end=bool(args.model_type in ['xlnet']),
                                    cls_token=tokenizer.cls_token,
                                    sep_token=tokenizer.sep_token,
                                    sep_token_extra=bool(args.model_type in ['roberta']),
                                    cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
                                    pad_on_left=bool(args.model_type in ['xlnet']),
                                    pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0)

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor(select_field(features, 'input_ids'), dtype=torch.long)
    all_input_mask = torch.tensor(select_field(features, 'input_mask'), dtype=torch.long)
    all_segment_ids = torch.tensor(select_field(features, 'segment_ids'), dtype=torch.long)
    all_label_ids = torch.tensor([f.label for f in features], dtype=torch.long)

    dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
    return dataset

# Experiment

## Pre-trained Model

In [13]:
def train(args, model, tokenizer):

    print('\n Loading training dataset')
    dataset_tr = load_features(args, tokenizer, mode='train')
    sampler_tr = RandomSampler(dataset_tr)
    dataloader_tr = DataLoader(dataset_tr, sampler=sampler_tr, batch_size=args.batch_size)

    print('\n Loading validation dataset')
    dataset_val = load_features(args, tokenizer, 'validation')
    sampler_val = SequentialSampler(dataset_val)
    dataloader_val = DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size)

    model, optimizer, scheduler = load_optimizer(args, model, len(dataloader_tr))

    num_steps = 0
    best_steps = 0
    tr_loss = 0.0
    best_val_acc, best_val_loss = 0.0, 99999999999.0
    best_model = None

    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=False, leave=True, position=1)

    for _ in train_iterator:

        epoch_iterator = tqdm(dataloader_tr, desc="Iteration", disable=False, leave=True, position=1)
        for step, batch in enumerate(epoch_iterator):

            model.train()

            batch = tuple(b.to(args.device) for b in batch)
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None,
                      'labels': batch[3]}
            outputs = model(**inputs)
            loss = outputs[0]

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            tr_loss += loss.item()

            optimizer.step()
            scheduler.step()
            model.zero_grad()
            num_steps += 1

            if args.logging_steps > 0 and num_steps % args.logging_steps == 0:
                results = evaluate(args, model, dataloader_val)
                print("\n val acc: {}, val loss: {}"
                      .format(str(results['val_acc']), str(results['val_loss'])))
                if results["val_acc"] > best_val_acc:
                    best_val_acc, best_val_loss = results["val_acc"], results["val_loss"]
                    best_steps = num_steps
                    best_model = deepcopy(model)

    loss = tr_loss / num_steps

    return best_model


def evaluate(args, model, dataloader):

    val_loss = 0.0
    num_steps = 0
    preds, labels = None, None

    results = {}

    for batch in tqdm(dataloader, desc="Validation", disable=True, leave=True, position=1):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)

        with torch.no_grad():
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None,
                      'labels': batch[3]}
            outputs = model(**inputs)
            loss, logits = outputs[:2]

            val_loss += loss.mean().item()

        num_steps += 1

        if preds is None:
            preds = logits.detach().cpu().numpy()
            labels = inputs['labels'].detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            labels = np.append(labels, inputs['labels'].detach().cpu().numpy(), axis=0)

    loss = val_loss / num_steps
    preds = np.argmax(preds, axis=1)
    acc = (preds == labels).mean()
    result = {"val_acc": acc, "val_loss": loss}
    results.update(result)

    return results


def test(args, tokenizer, model):

    dataset = load_features(args, tokenizer, mode='validation')
    sampler = SequentialSampler(dataset)
    dataloader_test = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size)

    results = evaluate(args, model, dataloader_test)
    print('\nTesting...')
    print("\n final validation acc: {}, final validation loss: {}"
          .format(str(results['val_acc']), str(results['val_loss'])))

In [14]:
def main(args):

    print('Using device', args.device)
    set_seed(args.seed)

    processor = CommonsenseQAProcessor()
    num_labels = len(processor.labels)

    config_class, model_class, tokenizer_class = load_model(args.model_type)
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name,
        num_labels=num_labels, finetuning_task=args.task_name)
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name,
        do_lower_case=True)
    
    model = model_class.from_pretrained(
        args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config)
    model.to(args.device)

    best_model = train(args, model, tokenizer)
    test(args, tokenizer, best_model)

## Runtime

The default model is based on `bert`, and
```
parser.add_argument("--model_name", type=str, default='bert-base-uncased',
                    help="Path to pre-trained model or shortcut name. See https://huggingface.co/models")
```

This would leads to

```
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMultipleChoice: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMultipleChoice were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using custom data configuration default
```

Should check https://huggingface.co/models for other models.

In [16]:
def run(model_type='bert',
        model_name='bert-base-uncased',
        task_name=None,
        batch_size=16,
        logging_steps=50,
        lr=1e-5,
        epochs=1):
  
    parser = argparse.ArgumentParser(description="Common sense question answering")
    parser.add_argument("--model_type", type=str, default=model_type,
                        help="Model: <str> [ bert | xlnet | roberta | gpt2 ]")
    parser.add_argument("--task_name", default=task_name, type=str, required=False,
                        help="The name of the task to train: <str> [ commonqa ]")
    parser.add_argument("--model_name", type=str,
                        default=model_name,
                        help="Path to pre-trained model or shortcut name."
                              "See https://huggingface.co/models")
    parser.add_argument("--config_name", type=str,
                        default=model_name,
                        help="Pre-trained config name or path")
    parser.add_argument("--tokenizer_name", default=model_name, type=str,
                        help="Pre-trained tokenizer name or path if not the same as model_name")

    parser.add_argument("--max_seq_length", default=128, type=int,
                        help="The maximum total input sequence length after tokenization. "
                                "Sequences longer than this will be truncated, sequences shorter will be padded.")
    parser.add_argument("--batch_size", default=batch_size, type=int,
                        help="Batch size for training.")

    parser.add_argument("--learning_rate", default=lr, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-6, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")

    parser.add_argument("--num_train_epochs", default=epochs, type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_steps", default=10, type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument('--logging_steps', type=int, default=logging_steps,
                        help="Log every n updates steps.")

    parser.add_argument('--fp16', type=bool, default=True,
                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
    parser.add_argument('--opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")

    parser.add_argument("--seed", type=int, default=SEED, help="Random seed: <int>")
    parser.add_argument("--device", default=torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    args, unknown = parser.parse_known_args()
    main(args)

## Finetuning

### Bert

In [None]:
model_type = 'bert'
model_name = 'bert-base-uncased'
task_name = 'commonqa'
batch_size = 16
logging_steps = 100
lr = 5e-5

run(model_type, model_name, task_name, batch_size, logging_steps, lr)

Using device cuda


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMultipleChoice: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMultipleChoice were not initialized from the model checkpoint at bert-base-uncased and are newly


 Loading training dataset
Creating features from dataset...


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)


Training number: 9741


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)



 Loading validation dataset
Creating features from dataset...
Training number: 1221



Epoch:   0%|          | 0/1 [00:00<?, ?it/s][A
Iteration:   0%|          | 0/609 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/609 [00:01<16:50,  1.66s/it][A
Iteration:   0%|          | 2/609 [00:03<16:47,  1.66s/it][A
Iteration:   0%|          | 3/609 [00:04<16:43,  1.66s/it][A
Iteration:   1%|          | 4/609 [00:06<16:42,  1.66s/it][A
Iteration:   1%|          | 5/609 [00:08<16:42,  1.66s/it][A
Iteration:   1%|          | 6/609 [00:09<16:42,  1.66s/it][A
Iteration:   1%|          | 7/609 [00:11<16:43,  1.67s/it][A
Iteration:   1%|▏         | 8/609 [00:13<16:44,  1.67s/it][A
Iteration:   1%|▏         | 9/609 [00:15<16:47,  1.68s/it][A
Iteration:   2%|▏         | 10/609 [00:16<16:47,  1.68s/it][A
Iteration:   2%|▏         | 11/609 [00:18<16:48,  1.69s/it][A
Iteration:   2%|▏         | 12/609 [00:20<16:49,  1.69s/it][A
Iteration:   2%|▏         | 13/609 [00:21<16:53,  1.70s/it][A
Iteration:   2%|▏         | 14/609 [00:23<16:56,  1.71s/it][A
Iteration:   2%|▏       


 val acc: 0.4742014742014742, val loss: 1.3063861580638143



Iteration:  17%|█▋        | 101/609 [03:40<1:41:40, 12.01s/it][A
Iteration:  17%|█▋        | 102/609 [03:42<1:15:17,  8.91s/it][A
Iteration:  17%|█▋        | 103/609 [03:44<56:53,  6.75s/it]  [A
Iteration:  17%|█▋        | 104/609 [03:45<43:59,  5.23s/it][A
Iteration:  17%|█▋        | 105/609 [03:47<34:59,  4.17s/it][A
Iteration:  17%|█▋        | 106/609 [03:49<28:41,  3.42s/it][A
Iteration:  18%|█▊        | 107/609 [03:50<24:15,  2.90s/it][A
Iteration:  18%|█▊        | 108/609 [03:52<21:09,  2.53s/it][A
Iteration:  18%|█▊        | 109/609 [03:54<18:59,  2.28s/it][A
Iteration:  18%|█▊        | 110/609 [03:55<17:28,  2.10s/it][A
Iteration:  18%|█▊        | 111/609 [03:57<16:23,  1.98s/it][A
Iteration:  18%|█▊        | 112/609 [03:59<15:39,  1.89s/it][A
Iteration:  19%|█▊        | 113/609 [04:00<15:06,  1.83s/it][A
Iteration:  19%|█▊        | 114/609 [04:02<14:44,  1.79s/it][A
Iteration:  19%|█▉        | 115/609 [04:04<14:27,  1.76s/it][A
Iteration:  19%|█▉        | 116/6


 val acc: 0.49713349713349714, val loss: 1.2598319223948888



Iteration:  33%|███▎      | 201/609 [07:18<1:21:18, 11.96s/it][A
Iteration:  33%|███▎      | 202/609 [07:20<1:00:12,  8.88s/it][A
Iteration:  33%|███▎      | 203/609 [07:22<45:29,  6.72s/it]  [A
Iteration:  33%|███▎      | 204/609 [07:23<35:11,  5.21s/it][A
Iteration:  34%|███▎      | 205/609 [07:25<27:59,  4.16s/it][A
Iteration:  34%|███▍      | 206/609 [07:27<22:58,  3.42s/it][A
Iteration:  34%|███▍      | 207/609 [07:28<19:26,  2.90s/it][A
Iteration:  34%|███▍      | 208/609 [07:30<16:58,  2.54s/it][A
Iteration:  34%|███▍      | 209/609 [07:32<15:13,  2.28s/it][A
Iteration:  34%|███▍      | 210/609 [07:33<14:01,  2.11s/it][A
Iteration:  35%|███▍      | 211/609 [07:35<13:09,  1.98s/it][A
Iteration:  35%|███▍      | 212/609 [07:37<12:33,  1.90s/it][A
Iteration:  35%|███▍      | 213/609 [07:38<12:05,  1.83s/it][A
Iteration:  35%|███▌      | 214/609 [07:40<11:48,  1.79s/it][A
Iteration:  35%|███▌      | 215/609 [07:42<11:34,  1.76s/it][A
Iteration:  35%|███▌      | 216/6


 val acc: 0.515970515970516, val loss: 1.2454247072145537



Iteration:  49%|████▉     | 301/609 [10:56<1:01:25, 11.97s/it][A
Iteration:  50%|████▉     | 302/609 [10:58<45:26,  8.88s/it]  [A
Iteration:  50%|████▉     | 303/609 [11:00<34:18,  6.73s/it][A
Iteration:  50%|████▉     | 304/609 [11:01<26:31,  5.22s/it][A
Iteration:  50%|█████     | 305/609 [11:03<21:05,  4.16s/it][A
Iteration:  50%|█████     | 306/609 [11:05<17:16,  3.42s/it][A
Iteration:  50%|█████     | 307/609 [11:07<14:37,  2.90s/it][A
Iteration:  51%|█████     | 308/609 [11:08<12:45,  2.54s/it][A
Iteration:  51%|█████     | 309/609 [11:10<11:25,  2.29s/it][A
Iteration:  51%|█████     | 310/609 [11:12<10:30,  2.11s/it][A
Iteration:  51%|█████     | 311/609 [11:13<09:51,  1.98s/it][A
Iteration:  51%|█████     | 312/609 [11:15<09:23,  1.90s/it][A
Iteration:  51%|█████▏    | 313/609 [11:17<09:03,  1.83s/it][A
Iteration:  52%|█████▏    | 314/609 [11:18<08:48,  1.79s/it][A
Iteration:  52%|█████▏    | 315/609 [11:20<08:37,  1.76s/it][A
Iteration:  52%|█████▏    | 316/609


 val acc: 0.5405405405405406, val loss: 1.1484103017039113



Iteration:  66%|██████▌   | 401/609 [14:35<41:23, 11.94s/it][A
Iteration:  66%|██████▌   | 402/609 [14:36<30:35,  8.87s/it][A
Iteration:  66%|██████▌   | 403/609 [14:38<23:03,  6.72s/it][A
Iteration:  66%|██████▋   | 404/609 [14:40<17:48,  5.21s/it][A
Iteration:  67%|██████▋   | 405/609 [14:42<14:08,  4.16s/it][A
Iteration:  67%|██████▋   | 406/609 [14:43<11:34,  3.42s/it][A
Iteration:  67%|██████▋   | 407/609 [14:45<09:46,  2.90s/it][A
Iteration:  67%|██████▋   | 408/609 [14:47<08:30,  2.54s/it][A
Iteration:  67%|██████▋   | 409/609 [14:48<07:37,  2.29s/it][A
Iteration:  67%|██████▋   | 410/609 [14:50<07:00,  2.11s/it][A
Iteration:  67%|██████▋   | 411/609 [14:52<06:34,  1.99s/it][A
Iteration:  68%|██████▊   | 412/609 [14:53<06:15,  1.90s/it][A
Iteration:  68%|██████▊   | 413/609 [14:55<06:00,  1.84s/it][A
Iteration:  68%|██████▊   | 414/609 [14:57<05:50,  1.80s/it][A
Iteration:  68%|██████▊   | 415/609 [14:59<05:43,  1.77s/it][A
Iteration:  68%|██████▊   | 416/609 [15


 val acc: 0.5659295659295659, val loss: 1.1136866111259955



Iteration:  82%|████████▏ | 501/609 [18:13<21:36, 12.00s/it][A
Iteration:  82%|████████▏ | 502/609 [18:15<15:53,  8.91s/it][A
Iteration:  83%|████████▎ | 503/609 [18:16<11:55,  6.75s/it][A
Iteration:  83%|████████▎ | 504/609 [18:18<09:09,  5.23s/it][A
Iteration:  83%|████████▎ | 505/609 [18:20<07:13,  4.17s/it][A
Iteration:  83%|████████▎ | 506/609 [18:22<05:53,  3.43s/it][A
Iteration:  83%|████████▎ | 507/609 [18:23<04:56,  2.91s/it][A
Iteration:  83%|████████▎ | 508/609 [18:25<04:17,  2.55s/it][A
Iteration:  84%|████████▎ | 509/609 [18:27<03:49,  2.30s/it][A
Iteration:  84%|████████▎ | 510/609 [18:28<03:29,  2.12s/it][A
Iteration:  84%|████████▍ | 511/609 [18:30<03:15,  1.99s/it][A
Iteration:  84%|████████▍ | 512/609 [18:32<03:04,  1.90s/it][A
Iteration:  84%|████████▍ | 513/609 [18:33<02:56,  1.84s/it][A
Iteration:  84%|████████▍ | 514/609 [18:35<02:50,  1.80s/it][A
Iteration:  85%|████████▍ | 515/609 [18:37<02:46,  1.77s/it][A
Iteration:  85%|████████▍ | 516/609 [18


 val acc: 0.5618345618345618, val loss: 1.0950990009617496



Iteration:  99%|█████████▊| 601/609 [21:52<01:35, 11.98s/it][A
Iteration:  99%|█████████▉| 602/609 [21:53<01:02,  8.90s/it][A
Iteration:  99%|█████████▉| 603/609 [21:55<00:40,  6.73s/it][A
Iteration:  99%|█████████▉| 604/609 [21:57<00:26,  5.22s/it][A
Iteration:  99%|█████████▉| 605/609 [21:58<00:16,  4.17s/it][A
Iteration: 100%|█████████▉| 606/609 [22:00<00:10,  3.43s/it][A
Iteration: 100%|█████████▉| 607/609 [22:02<00:05,  2.91s/it][A
Iteration: 100%|█████████▉| 608/609 [22:04<00:02,  2.54s/it][A
Iteration: 100%|██████████| 609/609 [22:05<00:00,  2.18s/it]

Epoch: 100%|██████████| 1/1 [22:05<00:00, 1325.43s/it]


Creating features from dataset...


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)


Training number: 1221

Testing...

 final validation acc: 0.5659295659295659, final validation loss: 1.1136866111259955


In [None]:
model_type = 'bert'
model_name = 'bert-large-uncased'
task_name = 'commonqa'
batch_size = 4
logging_steps = 400
lr = 1e-5

run(model_type, model_name, task_name, batch_size, logging_steps, lr)

Using device cuda


Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForMultipleChoice: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMultipleChoice were not initialized from the model checkpoint at bert-large-uncased and are new


 Loading training dataset
Creating features from dataset...
Training number: 9741


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)



 Loading validation dataset
Creating features from dataset...
Training number: 1221



Epoch:   0%|          | 0/1 [00:00<?, ?it/s][A
Iteration:   0%|          | 0/2436 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/2436 [00:01<57:08,  1.41s/it][A
Iteration:   0%|          | 2/2436 [00:02<56:30,  1.39s/it][A
Iteration:   0%|          | 3/2436 [00:04<56:06,  1.38s/it][A
Iteration:   0%|          | 4/2436 [00:05<55:52,  1.38s/it][A
Iteration:   0%|          | 5/2436 [00:06<55:46,  1.38s/it][A
Iteration:   0%|          | 6/2436 [00:08<55:44,  1.38s/it][A
Iteration:   0%|          | 7/2436 [00:09<55:40,  1.38s/it][A
Iteration:   0%|          | 8/2436 [00:10<55:42,  1.38s/it][A
Iteration:   0%|          | 9/2436 [00:12<55:43,  1.38s/it][A
Iteration:   0%|          | 10/2436 [00:13<55:49,  1.38s/it][A
Iteration:   0%|          | 11/2436 [00:15<56:00,  1.39s/it][A
Iteration:   0%|          | 12/2436 [00:16<56:06,  1.39s/it][A
Iteration:   1%|          | 13/2436 [00:17<56:07,  1.39s/it][A
Iteration:   1%|          | 14/2436 [00:19<56:06,  1.39s/it][A
Iteration


 val acc: 0.5249795249795249, val loss: 1.2298493488551745



Iteration:  16%|█▋        | 401/2436 [12:05<18:22:16, 32.50s/it][A
Iteration:  17%|█▋        | 402/2436 [12:06<13:05:47, 23.18s/it][A
Iteration:  17%|█▋        | 403/2436 [12:08<9:24:25, 16.66s/it] [A
Iteration:  17%|█▋        | 404/2436 [12:09<6:49:30, 12.09s/it][A
Iteration:  17%|█▋        | 405/2436 [12:10<5:01:09,  8.90s/it][A
Iteration:  17%|█▋        | 406/2436 [12:12<3:45:19,  6.66s/it][A
Iteration:  17%|█▋        | 407/2436 [12:13<2:52:27,  5.10s/it][A
Iteration:  17%|█▋        | 408/2436 [12:15<2:15:25,  4.01s/it][A
Iteration:  17%|█▋        | 409/2436 [12:16<1:49:24,  3.24s/it][A
Iteration:  17%|█▋        | 410/2436 [12:18<1:31:14,  2.70s/it][A
Iteration:  17%|█▋        | 411/2436 [12:19<1:18:35,  2.33s/it][A
Iteration:  17%|█▋        | 412/2436 [12:21<1:09:42,  2.07s/it][A
Iteration:  17%|█▋        | 413/2436 [12:22<1:03:19,  1.88s/it][A
Iteration:  17%|█▋        | 414/2436 [12:23<58:53,  1.75s/it]  [A
Iteration:  17%|█▋        | 415/2436 [12:25<55:43,  1.65s/


 val acc: 0.5470925470925471, val loss: 1.1950963718049668



Iteration:  33%|███▎      | 801/2436 [24:09<14:45:37, 32.50s/it][A
Iteration:  33%|███▎      | 802/2436 [24:11<10:31:17, 23.18s/it][A
Iteration:  33%|███▎      | 803/2436 [24:12<7:33:29, 16.66s/it] [A
Iteration:  33%|███▎      | 804/2436 [24:14<5:29:08, 12.10s/it][A
Iteration:  33%|███▎      | 805/2436 [24:15<4:02:06,  8.91s/it][A
Iteration:  33%|███▎      | 806/2436 [24:16<3:01:00,  6.66s/it][A
Iteration:  33%|███▎      | 807/2436 [24:18<2:18:22,  5.10s/it][A
Iteration:  33%|███▎      | 808/2436 [24:19<1:48:29,  4.00s/it][A
Iteration:  33%|███▎      | 809/2436 [24:21<1:27:38,  3.23s/it][A
Iteration:  33%|███▎      | 810/2436 [24:22<1:12:58,  2.69s/it][A
Iteration:  33%|███▎      | 811/2436 [24:24<1:02:44,  2.32s/it][A
Iteration:  33%|███▎      | 812/2436 [24:25<55:35,  2.05s/it]  [A
Iteration:  33%|███▎      | 813/2436 [24:27<50:31,  1.87s/it][A
Iteration:  33%|███▎      | 814/2436 [24:28<46:57,  1.74s/it][A
Iteration:  33%|███▎      | 815/2436 [24:29<44:28,  1.65s/it]


 val acc: 0.5872235872235873, val loss: 1.0778233563783122



Iteration:  49%|████▉     | 1201/2436 [36:13<11:08:59, 32.50s/it][A
Iteration:  49%|████▉     | 1202/2436 [36:14<7:56:48, 23.18s/it] [A
Iteration:  49%|████▉     | 1203/2436 [36:16<5:42:23, 16.66s/it][A
Iteration:  49%|████▉     | 1204/2436 [36:17<4:08:16, 12.09s/it][A
Iteration:  49%|████▉     | 1205/2436 [36:19<3:02:29,  8.89s/it][A
Iteration:  50%|████▉     | 1206/2436 [36:20<2:16:29,  6.66s/it][A
Iteration:  50%|████▉     | 1207/2436 [36:22<1:44:20,  5.09s/it][A
Iteration:  50%|████▉     | 1208/2436 [36:23<1:21:46,  4.00s/it][A
Iteration:  50%|████▉     | 1209/2436 [36:24<1:05:58,  3.23s/it][A
Iteration:  50%|████▉     | 1210/2436 [36:26<55:02,  2.69s/it]  [A
Iteration:  50%|████▉     | 1211/2436 [36:27<47:20,  2.32s/it][A
Iteration:  50%|████▉     | 1212/2436 [36:29<41:59,  2.06s/it][A
Iteration:  50%|████▉     | 1213/2436 [36:30<38:06,  1.87s/it][A
Iteration:  50%|████▉     | 1214/2436 [36:32<35:24,  1.74s/it][A
Iteration:  50%|████▉     | 1215/2436 [36:33<33:33,  


 val acc: 0.5954135954135954, val loss: 1.0424056557658452



Iteration:  66%|██████▌   | 1601/2436 [48:17<7:32:08, 32.49s/it] [A
Iteration:  66%|██████▌   | 1602/2436 [48:18<5:22:07, 23.17s/it][A
Iteration:  66%|██████▌   | 1603/2436 [48:20<3:51:12, 16.65s/it][A
Iteration:  66%|██████▌   | 1604/2436 [48:21<2:47:37, 12.09s/it][A
Iteration:  66%|██████▌   | 1605/2436 [48:22<2:03:10,  8.89s/it][A
Iteration:  66%|██████▌   | 1606/2436 [48:24<1:32:05,  6.66s/it][A
Iteration:  66%|██████▌   | 1607/2436 [48:25<1:10:19,  5.09s/it][A
Iteration:  66%|██████▌   | 1608/2436 [48:27<55:06,  3.99s/it]  [A
Iteration:  66%|██████▌   | 1609/2436 [48:28<44:30,  3.23s/it][A
Iteration:  66%|██████▌   | 1610/2436 [48:30<37:02,  2.69s/it][A
Iteration:  66%|██████▌   | 1611/2436 [48:31<31:49,  2.31s/it][A
Iteration:  66%|██████▌   | 1612/2436 [48:33<28:11,  2.05s/it][A
Iteration:  66%|██████▌   | 1613/2436 [48:34<25:39,  1.87s/it][A
Iteration:  66%|██████▋   | 1614/2436 [48:35<23:52,  1.74s/it][A
Iteration:  66%|██████▋   | 1615/2436 [48:37<22:33,  1.65s


 val acc: 0.6044226044226044, val loss: 1.0187847280716584



Iteration:  82%|████████▏ | 2001/2436 [1:00:20<3:55:42, 32.51s/it][A
Iteration:  82%|████████▏ | 2002/2436 [1:00:22<2:47:43, 23.19s/it][A
Iteration:  82%|████████▏ | 2003/2436 [1:00:23<2:00:15, 16.66s/it][A
Iteration:  82%|████████▏ | 2004/2436 [1:00:25<1:27:04, 12.09s/it][A
Iteration:  82%|████████▏ | 2005/2436 [1:00:26<1:03:54,  8.90s/it][A
Iteration:  82%|████████▏ | 2006/2436 [1:00:27<47:43,  6.66s/it]  [A
Iteration:  82%|████████▏ | 2007/2436 [1:00:29<36:24,  5.09s/it][A
Iteration:  82%|████████▏ | 2008/2436 [1:00:30<28:29,  3.99s/it][A
Iteration:  82%|████████▏ | 2009/2436 [1:00:32<22:58,  3.23s/it][A
Iteration:  83%|████████▎ | 2010/2436 [1:00:33<19:05,  2.69s/it][A
Iteration:  83%|████████▎ | 2011/2436 [1:00:35<16:23,  2.31s/it][A
Iteration:  83%|████████▎ | 2012/2436 [1:00:36<14:29,  2.05s/it][A
Iteration:  83%|████████▎ | 2013/2436 [1:00:37<13:09,  1.87s/it][A
Iteration:  83%|████████▎ | 2014/2436 [1:00:39<12:13,  1.74s/it][A
Iteration:  83%|████████▎ | 2015/24


 val acc: 0.6167076167076168, val loss: 0.9998837094875722



Iteration:  99%|█████████▊| 2401/2436 [1:12:24<18:57, 32.49s/it][A
Iteration:  99%|█████████▊| 2402/2436 [1:12:25<13:08, 23.18s/it][A
Iteration:  99%|█████████▊| 2403/2436 [1:12:27<09:09, 16.66s/it][A
Iteration:  99%|█████████▊| 2404/2436 [1:12:28<06:26, 12.09s/it][A
Iteration:  99%|█████████▊| 2405/2436 [1:12:30<04:35,  8.90s/it][A
Iteration:  99%|█████████▉| 2406/2436 [1:12:31<03:19,  6.67s/it][A
Iteration:  99%|█████████▉| 2407/2436 [1:12:33<02:27,  5.10s/it][A
Iteration:  99%|█████████▉| 2408/2436 [1:12:34<01:52,  4.00s/it][A
Iteration:  99%|█████████▉| 2409/2436 [1:12:36<01:27,  3.23s/it][A
Iteration:  99%|█████████▉| 2410/2436 [1:12:37<01:10,  2.70s/it][A
Iteration:  99%|█████████▉| 2411/2436 [1:12:38<00:57,  2.32s/it][A
Iteration:  99%|█████████▉| 2412/2436 [1:12:40<00:49,  2.06s/it][A
Iteration:  99%|█████████▉| 2413/2436 [1:12:41<00:43,  1.87s/it][A
Iteration:  99%|█████████▉| 2414/2436 [1:12:43<00:38,  1.75s/it][A
Iteration:  99%|█████████▉| 2415/2436 [1:12:44<

Creating features from dataset...


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)


Training number: 1221

Testing...

 final validation acc: 0.6167076167076168, final validation loss: 0.9998837094875722


In [None]:
model_type = 'bert'
model_name = 'deepset/bert-large-uncased-whole-word-masking-squad2'
task_name = 'commonqa'
batch_size = 2
logging_steps = 800
lr = 5e-6

run(model_type, model_name, task_name, batch_size, logging_steps, lr)

Using device cuda


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=540.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=19.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1340669807.0, style=ProgressStyle(descr…




Some weights of the model checkpoint at deepset/bert-large-uncased-whole-word-masking-squad2 were not used when initializing BertForMultipleChoice: ['qa_outputs.weight', 'qa_outputs.bias']
- This IS expected if you are initializing BertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMultipleChoice were not initialized from the model checkpoint at deepset/bert-large-uncased-whole-word-masking-squad2 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



 Loading training dataset
Creating features from dataset...


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)


Training number: 9741


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)



 Loading validation dataset
Creating features from dataset...
Training number: 1221



Epoch:   0%|          | 0/1 [00:00<?, ?it/s][A
Iteration:   0%|          | 0/4871 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/4871 [00:00<1:12:48,  1.11it/s][A
Iteration:   0%|          | 2/4871 [00:01<1:11:46,  1.13it/s][A
Iteration:   0%|          | 3/4871 [00:02<1:10:45,  1.15it/s][A
Iteration:   0%|          | 4/4871 [00:03<1:09:59,  1.16it/s][A
Iteration:   0%|          | 5/4871 [00:04<1:09:24,  1.17it/s][A
Iteration:   0%|          | 6/4871 [00:05<1:09:16,  1.17it/s][A
Iteration:   0%|          | 7/4871 [00:05<1:09:13,  1.17it/s][A
Iteration:   0%|          | 8/4871 [00:06<1:09:04,  1.17it/s][A
Iteration:   0%|          | 9/4871 [00:07<1:09:05,  1.17it/s][A
Iteration:   0%|          | 10/4871 [00:08<1:08:51,  1.18it/s][A
Iteration:   0%|          | 11/4871 [00:09<1:08:55,  1.18it/s][A
Iteration:   0%|          | 12/4871 [00:10<1:08:58,  1.17it/s][A
Iteration:   0%|          | 13/4871 [00:11<1:08:41,  1.18it/s][A
Iteration:   0%|          | 14/4871 [00:11<1:08


 val acc: 0.5102375102375102, val loss: 1.2335076117110721



Iteration:  16%|█▋        | 801/4871 [13:54<39:00:20, 34.50s/it][A
Iteration:  16%|█▋        | 802/4871 [13:55<27:35:01, 24.40s/it][A
Iteration:  16%|█▋        | 803/4871 [13:55<19:35:13, 17.33s/it][A
Iteration:  17%|█▋        | 804/4871 [13:56<13:59:56, 12.39s/it][A
Iteration:  17%|█▋        | 805/4871 [13:57<10:04:45,  8.92s/it][A
Iteration:  17%|█▋        | 806/4871 [13:58<7:20:15,  6.50s/it] [A
Iteration:  17%|█▋        | 807/4871 [13:59<5:25:24,  4.80s/it][A
Iteration:  17%|█▋        | 808/4871 [14:00<4:04:56,  3.62s/it][A
Iteration:  17%|█▋        | 809/4871 [14:00<3:08:34,  2.79s/it][A
Iteration:  17%|█▋        | 810/4871 [14:01<2:28:58,  2.20s/it][A
Iteration:  17%|█▋        | 811/4871 [14:02<2:01:24,  1.79s/it][A
Iteration:  17%|█▋        | 812/4871 [14:03<1:41:53,  1.51s/it][A
Iteration:  17%|█▋        | 813/4871 [14:04<1:28:33,  1.31s/it][A
Iteration:  17%|█▋        | 814/4871 [14:05<1:19:14,  1.17s/it][A
Iteration:  17%|█▋        | 815/4871 [14:06<1:12:36,  1


 val acc: 0.556920556920557, val loss: 1.124694232470714



Iteration:  33%|███▎      | 1601/4871 [27:47<31:22:06, 34.53s/it][A
Iteration:  33%|███▎      | 1602/4871 [27:48<22:10:56, 24.43s/it][A
Iteration:  33%|███▎      | 1603/4871 [27:49<15:45:08, 17.35s/it][A
Iteration:  33%|███▎      | 1604/4871 [27:50<11:15:03, 12.40s/it][A
Iteration:  33%|███▎      | 1605/4871 [27:51<8:06:16,  8.93s/it] [A
Iteration:  33%|███▎      | 1606/4871 [27:52<5:54:13,  6.51s/it][A
Iteration:  33%|███▎      | 1607/4871 [27:52<4:21:39,  4.81s/it][A
Iteration:  33%|███▎      | 1608/4871 [27:53<3:16:51,  3.62s/it][A
Iteration:  33%|███▎      | 1609/4871 [27:54<2:31:21,  2.78s/it][A
Iteration:  33%|███▎      | 1610/4871 [27:55<1:59:38,  2.20s/it][A
Iteration:  33%|███▎      | 1611/4871 [27:56<1:37:44,  1.80s/it][A
Iteration:  33%|███▎      | 1612/4871 [27:57<1:21:58,  1.51s/it][A
Iteration:  33%|███▎      | 1613/4871 [27:58<1:11:04,  1.31s/it][A
Iteration:  33%|███▎      | 1614/4871 [27:58<1:03:22,  1.17s/it][A
Iteration:  33%|███▎      | 1615/4871 [27:


 val acc: 0.5724815724815725, val loss: 1.0850894078429416



Iteration:  49%|████▉     | 2401/4871 [41:41<23:40:53, 34.52s/it][A
Iteration:  49%|████▉     | 2402/4871 [41:42<16:44:33, 24.41s/it][A
Iteration:  49%|████▉     | 2403/4871 [41:43<11:53:22, 17.34s/it][A
Iteration:  49%|████▉     | 2404/4871 [41:43<8:29:40, 12.40s/it] [A
Iteration:  49%|████▉     | 2405/4871 [41:44<6:06:57,  8.93s/it][A
Iteration:  49%|████▉     | 2406/4871 [41:45<4:27:16,  6.51s/it][A
Iteration:  49%|████▉     | 2407/4871 [41:46<3:17:18,  4.80s/it][A
Iteration:  49%|████▉     | 2408/4871 [41:47<2:28:27,  3.62s/it][A
Iteration:  49%|████▉     | 2409/4871 [41:48<1:54:14,  2.78s/it][A
Iteration:  49%|████▉     | 2410/4871 [41:48<1:30:20,  2.20s/it][A
Iteration:  49%|████▉     | 2411/4871 [41:49<1:13:33,  1.79s/it][A
Iteration:  50%|████▉     | 2412/4871 [41:50<1:01:40,  1.50s/it][A
Iteration:  50%|████▉     | 2413/4871 [41:51<53:30,  1.31s/it]  [A
Iteration:  50%|████▉     | 2414/4871 [41:52<47:50,  1.17s/it][A
Iteration:  50%|████▉     | 2415/4871 [41:53<


 val acc: 0.5888615888615889, val loss: 1.077146301704494



Iteration:  66%|██████▌   | 3201/4871 [55:34<15:58:57, 34.45s/it][A
Iteration:  66%|██████▌   | 3202/4871 [55:35<11:17:52, 24.37s/it][A
Iteration:  66%|██████▌   | 3203/4871 [55:36<8:01:14, 17.31s/it] [A
Iteration:  66%|██████▌   | 3204/4871 [55:37<5:43:37, 12.37s/it][A
Iteration:  66%|██████▌   | 3205/4871 [55:37<4:07:21,  8.91s/it][A
Iteration:  66%|██████▌   | 3206/4871 [55:38<3:00:00,  6.49s/it][A
Iteration:  66%|██████▌   | 3207/4871 [55:39<2:12:59,  4.80s/it][A
Iteration:  66%|██████▌   | 3208/4871 [55:40<1:39:58,  3.61s/it][A
Iteration:  66%|██████▌   | 3209/4871 [55:41<1:16:57,  2.78s/it][A
Iteration:  66%|██████▌   | 3210/4871 [55:42<1:00:46,  2.20s/it][A
Iteration:  66%|██████▌   | 3211/4871 [55:42<49:26,  1.79s/it]  [A
Iteration:  66%|██████▌   | 3212/4871 [55:43<41:37,  1.51s/it][A
Iteration:  66%|██████▌   | 3213/4871 [55:44<36:10,  1.31s/it][A
Iteration:  66%|██████▌   | 3214/4871 [55:45<32:14,  1.17s/it][A
Iteration:  66%|██████▌   | 3215/4871 [55:46<29:29


 val acc: 0.5913185913185913, val loss: 1.051629226924562



Iteration:  82%|████████▏ | 4001/4871 [1:09:29<8:20:41, 34.53s/it] [A
Iteration:  82%|████████▏ | 4002/4871 [1:09:29<5:53:44, 24.42s/it][A
Iteration:  82%|████████▏ | 4003/4871 [1:09:30<4:10:57, 17.35s/it][A
Iteration:  82%|████████▏ | 4004/4871 [1:09:31<2:59:05, 12.39s/it][A
Iteration:  82%|████████▏ | 4005/4871 [1:09:32<2:08:54,  8.93s/it][A
Iteration:  82%|████████▏ | 4006/4871 [1:09:33<1:33:46,  6.50s/it][A
Iteration:  82%|████████▏ | 4007/4871 [1:09:34<1:09:09,  4.80s/it][A
Iteration:  82%|████████▏ | 4008/4871 [1:09:34<51:58,  3.61s/it]  [A
Iteration:  82%|████████▏ | 4009/4871 [1:09:35<40:01,  2.79s/it][A
Iteration:  82%|████████▏ | 4010/4871 [1:09:36<31:34,  2.20s/it][A
Iteration:  82%|████████▏ | 4011/4871 [1:09:37<25:42,  1.79s/it][A
Iteration:  82%|████████▏ | 4012/4871 [1:09:38<21:34,  1.51s/it][A
Iteration:  82%|████████▏ | 4013/4871 [1:09:39<18:41,  1.31s/it][A
Iteration:  82%|████████▏ | 4014/4871 [1:09:40<16:39,  1.17s/it][A
Iteration:  82%|████████▏ | 40


 val acc: 0.592956592956593, val loss: 1.0395029854149833



Iteration:  99%|█████████▊| 4801/4871 [1:23:20<40:00, 34.29s/it][A
Iteration:  99%|█████████▊| 4802/4871 [1:23:21<27:53, 24.25s/it][A
Iteration:  99%|█████████▊| 4803/4871 [1:23:22<19:31, 17.23s/it][A
Iteration:  99%|█████████▊| 4804/4871 [1:23:23<13:44, 12.31s/it][A
Iteration:  99%|█████████▊| 4805/4871 [1:23:23<09:45,  8.87s/it][A
Iteration:  99%|█████████▊| 4806/4871 [1:23:24<06:59,  6.46s/it][A
Iteration:  99%|█████████▊| 4807/4871 [1:23:25<05:05,  4.77s/it][A
Iteration:  99%|█████████▊| 4808/4871 [1:23:26<03:46,  3.59s/it][A
Iteration:  99%|█████████▊| 4809/4871 [1:23:27<02:51,  2.77s/it][A
Iteration:  99%|█████████▊| 4810/4871 [1:23:28<02:13,  2.19s/it][A
Iteration:  99%|█████████▉| 4811/4871 [1:23:29<01:47,  1.78s/it][A
Iteration:  99%|█████████▉| 4812/4871 [1:23:29<01:28,  1.50s/it][A
Iteration:  99%|█████████▉| 4813/4871 [1:23:30<01:15,  1.30s/it][A
Iteration:  99%|█████████▉| 4814/4871 [1:23:31<01:06,  1.16s/it][A
Iteration:  99%|█████████▉| 4815/4871 [1:23:32<

Creating features from dataset...


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)


Training number: 1221

Testing...

 final validation acc: 0.592956592956593, final validation loss: 1.0395029854149833


### XLNet

In [None]:
model_type = 'xlnet'
model_name = 'xlnet-base-cased'
task_name = 'commonqa'
batch_size = 4
logging_steps = 400
lr = 1e-6

run(model_type, model_name, task_name, batch_size, logging_steps, lr)

Using device cuda


Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForMultipleChoice: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForMultipleChoice were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.weight', 'sequence_summary.summary.bias', 'logits_proj.weight', 'logits_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using custom data 


 Loading training dataset
Creating features from dataset...
Training number: 9741


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)



 Loading validation dataset
Creating features from dataset...
Training number: 1221



Epoch:   0%|          | 0/1 [00:00<?, ?it/s][A
Iteration:   0%|          | 0/2436 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/2436 [00:00<25:18,  1.60it/s][A
Iteration:   0%|          | 2/2436 [00:01<24:37,  1.65it/s][A
Iteration:   0%|          | 3/2436 [00:01<24:23,  1.66it/s][A
Iteration:   0%|          | 4/2436 [00:02<24:00,  1.69it/s][A
Iteration:   0%|          | 5/2436 [00:02<23:54,  1.69it/s][A
Iteration:   0%|          | 6/2436 [00:03<23:43,  1.71it/s][A
Iteration:   0%|          | 7/2436 [00:04<23:39,  1.71it/s][A
Iteration:   0%|          | 8/2436 [00:04<23:33,  1.72it/s][A
Iteration:   0%|          | 9/2436 [00:05<23:52,  1.69it/s][A
Iteration:   0%|          | 10/2436 [00:05<23:41,  1.71it/s][A
Iteration:   0%|          | 11/2436 [00:06<23:46,  1.70it/s][A
Iteration:   0%|          | 12/2436 [00:07<23:34,  1.71it/s][A
Iteration:   1%|          | 13/2436 [00:07<23:39,  1.71it/s][A
Iteration:   1%|          | 14/2436 [00:08<23:38,  1.71it/s][A
Iteration


 val acc: 0.2596232596232596, val loss: 1.5974556199865404



Iteration:  16%|█▋        | 401/2436 [05:08<8:42:36, 15.41s/it] [A
Iteration:  17%|█▋        | 402/2436 [05:08<6:11:38, 10.96s/it][A
Iteration:  17%|█▋        | 403/2436 [05:09<4:25:58,  7.85s/it][A
Iteration:  17%|█▋        | 404/2436 [05:09<3:12:02,  5.67s/it][A
Iteration:  17%|█▋        | 405/2436 [05:10<2:20:23,  4.15s/it][A
Iteration:  17%|█▋        | 406/2436 [05:11<1:44:10,  3.08s/it][A
Iteration:  17%|█▋        | 407/2436 [05:11<1:18:51,  2.33s/it][A
Iteration:  17%|█▋        | 408/2436 [05:12<1:01:11,  1.81s/it][A
Iteration:  17%|█▋        | 409/2436 [05:12<48:57,  1.45s/it]  [A
Iteration:  17%|█▋        | 410/2436 [05:13<40:13,  1.19s/it][A
Iteration:  17%|█▋        | 411/2436 [05:14<34:12,  1.01s/it][A
Iteration:  17%|█▋        | 412/2436 [05:14<29:52,  1.13it/s][A
Iteration:  17%|█▋        | 413/2436 [05:15<26:55,  1.25it/s][A
Iteration:  17%|█▋        | 414/2436 [05:15<24:50,  1.36it/s][A
Iteration:  17%|█▋        | 415/2436 [05:16<23:23,  1.44it/s][A
Itera


 val acc: 0.26945126945126946, val loss: 1.5953456003681508



Iteration:  33%|███▎      | 801/2436 [10:15<6:59:19, 15.39s/it][A
Iteration:  33%|███▎      | 802/2436 [10:15<4:58:11, 10.95s/it][A
Iteration:  33%|███▎      | 803/2436 [10:16<3:33:26,  7.84s/it][A
Iteration:  33%|███▎      | 804/2436 [10:16<2:34:07,  5.67s/it][A
Iteration:  33%|███▎      | 805/2436 [10:17<1:52:36,  4.14s/it][A
Iteration:  33%|███▎      | 806/2436 [10:18<1:23:33,  3.08s/it][A
Iteration:  33%|███▎      | 807/2436 [10:18<1:03:13,  2.33s/it][A
Iteration:  33%|███▎      | 808/2436 [10:19<49:00,  1.81s/it]  [A
Iteration:  33%|███▎      | 809/2436 [10:19<39:04,  1.44s/it][A
Iteration:  33%|███▎      | 810/2436 [10:20<32:06,  1.18s/it][A
Iteration:  33%|███▎      | 811/2436 [10:20<27:15,  1.01s/it][A
Iteration:  33%|███▎      | 812/2436 [10:21<23:50,  1.14it/s][A
Iteration:  33%|███▎      | 813/2436 [10:22<21:29,  1.26it/s][A
Iteration:  33%|███▎      | 814/2436 [10:22<19:52,  1.36it/s][A
Iteration:  33%|███▎      | 815/2436 [10:23<18:41,  1.45it/s][A
Iteratio


 val acc: 0.19246519246519248, val loss: 1.6092822820532555



Iteration:  49%|████▉     | 1201/2436 [15:21<5:16:48, 15.39s/it][A
Iteration:  49%|████▉     | 1202/2436 [15:22<3:45:12, 10.95s/it][A
Iteration:  49%|████▉     | 1203/2436 [15:23<2:41:08,  7.84s/it][A
Iteration:  49%|████▉     | 1204/2436 [15:23<1:56:20,  5.67s/it][A
Iteration:  49%|████▉     | 1205/2436 [15:24<1:24:59,  4.14s/it][A
Iteration:  50%|████▉     | 1206/2436 [15:24<1:03:06,  3.08s/it][A
Iteration:  50%|████▉     | 1207/2436 [15:25<47:48,  2.33s/it]  [A
Iteration:  50%|████▉     | 1208/2436 [15:26<37:04,  1.81s/it][A
Iteration:  50%|████▉     | 1209/2436 [15:26<29:35,  1.45s/it][A
Iteration:  50%|████▉     | 1210/2436 [15:27<24:21,  1.19s/it][A
Iteration:  50%|████▉     | 1211/2436 [15:27<20:37,  1.01s/it][A
Iteration:  50%|████▉     | 1212/2436 [15:28<18:02,  1.13it/s][A
Iteration:  50%|████▉     | 1213/2436 [15:29<16:12,  1.26it/s][A
Iteration:  50%|████▉     | 1214/2436 [15:29<15:01,  1.36it/s][A
Iteration:  50%|████▉     | 1215/2436 [15:30<14:06,  1.44it/s


 val acc: 0.23013923013923013, val loss: 1.6087489396918053



Iteration:  66%|██████▌   | 1601/2436 [20:28<3:33:53, 15.37s/it][A
Iteration:  66%|██████▌   | 1602/2436 [20:29<2:31:59, 10.93s/it][A
Iteration:  66%|██████▌   | 1603/2436 [20:29<1:48:42,  7.83s/it][A
Iteration:  66%|██████▌   | 1604/2436 [20:30<1:18:27,  5.66s/it][A
Iteration:  66%|██████▌   | 1605/2436 [20:30<57:17,  4.14s/it]  [A
Iteration:  66%|██████▌   | 1606/2436 [20:31<42:29,  3.07s/it][A
Iteration:  66%|██████▌   | 1607/2436 [20:31<32:09,  2.33s/it][A
Iteration:  66%|██████▌   | 1608/2436 [20:32<24:55,  1.81s/it][A
Iteration:  66%|██████▌   | 1609/2436 [20:33<19:53,  1.44s/it][A
Iteration:  66%|██████▌   | 1610/2436 [20:33<16:22,  1.19s/it][A
Iteration:  66%|██████▌   | 1611/2436 [20:34<13:52,  1.01s/it][A
Iteration:  66%|██████▌   | 1612/2436 [20:34<12:08,  1.13it/s][A
Iteration:  66%|██████▌   | 1613/2436 [20:35<10:56,  1.25it/s][A
Iteration:  66%|██████▋   | 1614/2436 [20:36<10:06,  1.36it/s][A
Iteration:  66%|██████▋   | 1615/2436 [20:36<09:28,  1.44it/s][A


 val acc: 0.23423423423423423, val loss: 1.606630280516506



Iteration:  82%|████████▏ | 2001/2436 [25:34<1:51:26, 15.37s/it][A
Iteration:  82%|████████▏ | 2002/2436 [25:35<1:19:06, 10.94s/it][A
Iteration:  82%|████████▏ | 2003/2436 [25:35<56:31,  7.83s/it]  [A
Iteration:  82%|████████▏ | 2004/2436 [25:36<40:44,  5.66s/it][A
Iteration:  82%|████████▏ | 2005/2436 [25:37<29:44,  4.14s/it][A
Iteration:  82%|████████▏ | 2006/2436 [25:37<22:03,  3.08s/it][A
Iteration:  82%|████████▏ | 2007/2436 [25:38<16:40,  2.33s/it][A
Iteration:  82%|████████▏ | 2008/2436 [25:38<12:55,  1.81s/it][A
Iteration:  82%|████████▏ | 2009/2436 [25:39<10:17,  1.45s/it][A
Iteration:  83%|████████▎ | 2010/2436 [25:40<08:26,  1.19s/it][A
Iteration:  83%|████████▎ | 2011/2436 [25:40<07:08,  1.01s/it][A
Iteration:  83%|████████▎ | 2012/2436 [25:41<06:15,  1.13it/s][A
Iteration:  83%|████████▎ | 2013/2436 [25:41<05:37,  1.25it/s][A
Iteration:  83%|████████▎ | 2014/2436 [25:42<05:10,  1.36it/s][A
Iteration:  83%|████████▎ | 2015/2436 [25:43<04:52,  1.44it/s][A
Ite


 val acc: 0.2375102375102375, val loss: 1.606173419874478



Iteration:  99%|█████████▊| 2401/2436 [30:41<08:57, 15.37s/it][A
Iteration:  99%|█████████▊| 2402/2436 [30:41<06:11, 10.93s/it][A
Iteration:  99%|█████████▊| 2403/2436 [30:42<04:18,  7.83s/it][A
Iteration:  99%|█████████▊| 2404/2436 [30:42<03:01,  5.66s/it][A
Iteration:  99%|█████████▊| 2405/2436 [30:43<02:08,  4.14s/it][A
Iteration:  99%|█████████▉| 2406/2436 [30:44<01:32,  3.07s/it][A
Iteration:  99%|█████████▉| 2407/2436 [30:44<01:07,  2.33s/it][A
Iteration:  99%|█████████▉| 2408/2436 [30:45<00:50,  1.80s/it][A
Iteration:  99%|█████████▉| 2409/2436 [30:45<00:38,  1.44s/it][A
Iteration:  99%|█████████▉| 2410/2436 [30:46<00:30,  1.18s/it][A
Iteration:  99%|█████████▉| 2411/2436 [30:46<00:25,  1.00s/it][A
Iteration:  99%|█████████▉| 2412/2436 [30:47<00:21,  1.14it/s][A
Iteration:  99%|█████████▉| 2413/2436 [30:48<00:18,  1.26it/s][A
Iteration:  99%|█████████▉| 2414/2436 [30:48<00:16,  1.37it/s][A
Iteration:  99%|█████████▉| 2415/2436 [30:49<00:14,  1.45it/s][A
Iteration

Creating features from dataset...


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)


Training number: 1221

Testing...

 final validation acc: 0.26945126945126946, final validation loss: 1.5953456003681508


### RoBerta

In [None]:
model_type = 'roberta'
model_name = 'roberta-base'
task_name = 'commonqa'
batch_size = 16
logging_steps = 100
lr = 5e-5

run(model_type, model_name, task_name, batch_size, logging_steps, lr)

Using device cuda


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForMultipleChoice: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForMultipleChoice were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predi


 Loading training dataset
Creating features from dataset...


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)


Training number: 9741


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)



 Loading validation dataset
Creating features from dataset...
Training number: 1221



Epoch:   0%|          | 0/1 [00:00<?, ?it/s][A
Iteration:   0%|          | 0/609 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/609 [00:01<16:50,  1.66s/it][A
Iteration:   0%|          | 2/609 [00:03<16:39,  1.65s/it][A
Iteration:   0%|          | 3/609 [00:04<16:37,  1.65s/it][A
Iteration:   1%|          | 4/609 [00:06<16:38,  1.65s/it][A
Iteration:   1%|          | 5/609 [00:08<16:39,  1.65s/it][A
Iteration:   1%|          | 6/609 [00:09<16:39,  1.66s/it][A
Iteration:   1%|          | 7/609 [00:11<16:41,  1.66s/it][A
Iteration:   1%|▏         | 8/609 [00:13<16:42,  1.67s/it][A
Iteration:   1%|▏         | 9/609 [00:14<16:42,  1.67s/it][A
Iteration:   2%|▏         | 10/609 [00:16<16:44,  1.68s/it][A
Iteration:   2%|▏         | 11/609 [00:18<16:47,  1.68s/it][A
Iteration:   2%|▏         | 12/609 [00:20<16:50,  1.69s/it][A
Iteration:   2%|▏         | 13/609 [00:21<16:52,  1.70s/it][A
Iteration:   2%|▏         | 14/609 [00:23<16:54,  1.71s/it][A
Iteration:   2%|▏       


 val acc: 0.32678132678132676, val loss: 1.5397118500300817



Iteration:  17%|█▋        | 101/609 [03:39<1:39:47, 11.79s/it][A
Iteration:  17%|█▋        | 102/609 [03:41<1:14:00,  8.76s/it][A
Iteration:  17%|█▋        | 103/609 [03:43<55:59,  6.64s/it]  [A
Iteration:  17%|█▋        | 104/609 [03:45<43:24,  5.16s/it][A
Iteration:  17%|█▋        | 105/609 [03:46<34:36,  4.12s/it][A
Iteration:  17%|█▋        | 106/609 [03:48<28:28,  3.40s/it][A
Iteration:  18%|█▊        | 107/609 [03:50<24:10,  2.89s/it][A
Iteration:  18%|█▊        | 108/609 [03:51<21:09,  2.53s/it][A
Iteration:  18%|█▊        | 109/609 [03:53<19:01,  2.28s/it][A
Iteration:  18%|█▊        | 110/609 [03:55<17:31,  2.11s/it][A
Iteration:  18%|█▊        | 111/609 [03:56<16:28,  1.98s/it][A
Iteration:  18%|█▊        | 112/609 [03:58<15:44,  1.90s/it][A
Iteration:  19%|█▊        | 113/609 [04:00<15:11,  1.84s/it][A
Iteration:  19%|█▊        | 114/609 [04:02<14:49,  1.80s/it][A
Iteration:  19%|█▉        | 115/609 [04:03<14:32,  1.77s/it][A
Iteration:  19%|█▉        | 116/6


 val acc: 0.42424242424242425, val loss: 1.3650531861689184



Iteration:  33%|███▎      | 201/609 [07:18<1:20:37, 11.86s/it][A
Iteration:  33%|███▎      | 202/609 [07:19<59:42,  8.80s/it]  [A
Iteration:  33%|███▎      | 203/609 [07:21<45:07,  6.67s/it][A
Iteration:  33%|███▎      | 204/609 [07:23<34:57,  5.18s/it][A
Iteration:  34%|███▎      | 205/609 [07:24<27:48,  4.13s/it][A
Iteration:  34%|███▍      | 206/609 [07:26<22:49,  3.40s/it][A
Iteration:  34%|███▍      | 207/609 [07:28<19:18,  2.88s/it][A
Iteration:  34%|███▍      | 208/609 [07:29<16:52,  2.52s/it][A
Iteration:  34%|███▍      | 209/609 [07:31<15:08,  2.27s/it][A
Iteration:  34%|███▍      | 210/609 [07:33<13:56,  2.10s/it][A
Iteration:  35%|███▍      | 211/609 [07:34<13:06,  1.98s/it][A
Iteration:  35%|███▍      | 212/609 [07:36<12:29,  1.89s/it][A
Iteration:  35%|███▍      | 213/609 [07:38<12:04,  1.83s/it][A
Iteration:  35%|███▌      | 214/609 [07:39<11:45,  1.79s/it][A
Iteration:  35%|███▌      | 215/609 [07:41<11:32,  1.76s/it][A
Iteration:  35%|███▌      | 216/609


 val acc: 0.4881244881244881, val loss: 1.2662005378054333



Iteration:  49%|████▉     | 301/609 [10:55<1:00:49, 11.85s/it][A
Iteration:  50%|████▉     | 302/609 [10:57<45:02,  8.80s/it]  [A
Iteration:  50%|████▉     | 303/609 [10:59<34:01,  6.67s/it][A
Iteration:  50%|████▉     | 304/609 [11:01<26:19,  5.18s/it][A
Iteration:  50%|█████     | 305/609 [11:02<20:57,  4.14s/it][A
Iteration:  50%|█████     | 306/609 [11:04<17:11,  3.40s/it][A
Iteration:  50%|█████     | 307/609 [11:06<14:35,  2.90s/it][A
Iteration:  51%|█████     | 308/609 [11:07<12:43,  2.54s/it][A
Iteration:  51%|█████     | 309/609 [11:09<11:26,  2.29s/it][A
Iteration:  51%|█████     | 310/609 [11:11<10:31,  2.11s/it][A
Iteration:  51%|█████     | 311/609 [11:12<09:53,  1.99s/it][A
Iteration:  51%|█████     | 312/609 [11:14<09:25,  1.90s/it][A
Iteration:  51%|█████▏    | 313/609 [11:16<09:05,  1.84s/it][A
Iteration:  52%|█████▏    | 314/609 [11:18<08:51,  1.80s/it][A
Iteration:  52%|█████▏    | 315/609 [11:19<08:39,  1.77s/it][A
Iteration:  52%|█████▏    | 316/609


 val acc: 0.5192465192465192, val loss: 1.2012166961446984



Iteration:  66%|██████▌   | 401/609 [14:33<40:48, 11.77s/it][A
Iteration:  66%|██████▌   | 402/609 [14:35<30:11,  8.75s/it][A
Iteration:  66%|██████▌   | 403/609 [14:37<22:45,  6.63s/it][A
Iteration:  66%|██████▋   | 404/609 [14:38<17:35,  5.15s/it][A
Iteration:  67%|██████▋   | 405/609 [14:40<13:59,  4.11s/it][A
Iteration:  67%|██████▋   | 406/609 [14:42<11:28,  3.39s/it][A
Iteration:  67%|██████▋   | 407/609 [14:43<09:41,  2.88s/it][A
Iteration:  67%|██████▋   | 408/609 [14:45<08:27,  2.52s/it][A
Iteration:  67%|██████▋   | 409/609 [14:47<07:35,  2.28s/it][A
Iteration:  67%|██████▋   | 410/609 [14:49<06:58,  2.10s/it][A
Iteration:  67%|██████▋   | 411/609 [14:50<06:31,  1.98s/it][A
Iteration:  68%|██████▊   | 412/609 [14:52<06:12,  1.89s/it][A
Iteration:  68%|██████▊   | 413/609 [14:54<05:59,  1.84s/it][A
Iteration:  68%|██████▊   | 414/609 [14:55<05:49,  1.79s/it][A
Iteration:  68%|██████▊   | 415/609 [14:57<05:42,  1.76s/it][A
Iteration:  68%|██████▊   | 416/609 [14


 val acc: 0.5249795249795249, val loss: 1.1803535594568624



Iteration:  82%|████████▏ | 501/609 [18:11<21:15, 11.81s/it][A
Iteration:  82%|████████▏ | 502/609 [18:13<15:39,  8.78s/it][A
Iteration:  83%|████████▎ | 503/609 [18:15<11:45,  6.66s/it][A
Iteration:  83%|████████▎ | 504/609 [18:16<09:02,  5.17s/it][A
Iteration:  83%|████████▎ | 505/609 [18:18<07:09,  4.13s/it][A
Iteration:  83%|████████▎ | 506/609 [18:20<05:50,  3.40s/it][A
Iteration:  83%|████████▎ | 507/609 [18:21<04:55,  2.89s/it][A
Iteration:  83%|████████▎ | 508/609 [18:23<04:16,  2.54s/it][A
Iteration:  84%|████████▎ | 509/609 [18:25<03:48,  2.29s/it][A
Iteration:  84%|████████▎ | 510/609 [18:27<03:29,  2.12s/it][A
Iteration:  84%|████████▍ | 511/609 [18:28<03:15,  1.99s/it][A
Iteration:  84%|████████▍ | 512/609 [18:30<03:04,  1.90s/it][A
Iteration:  84%|████████▍ | 513/609 [18:32<02:56,  1.84s/it][A
Iteration:  84%|████████▍ | 514/609 [18:33<02:50,  1.80s/it][A
Iteration:  85%|████████▍ | 515/609 [18:35<02:46,  1.77s/it][A
Iteration:  85%|████████▍ | 516/609 [18


 val acc: 0.5389025389025389, val loss: 1.1480596057780377



Iteration:  99%|█████████▊| 601/609 [21:49<01:34, 11.77s/it][A
Iteration:  99%|█████████▉| 602/609 [21:50<01:01,  8.75s/it][A
Iteration:  99%|█████████▉| 603/609 [21:52<00:39,  6.63s/it][A
Iteration:  99%|█████████▉| 604/609 [21:54<00:25,  5.15s/it][A
Iteration:  99%|█████████▉| 605/609 [21:56<00:16,  4.12s/it][A
Iteration: 100%|█████████▉| 606/609 [21:57<00:10,  3.39s/it][A
Iteration: 100%|█████████▉| 607/609 [21:59<00:05,  2.88s/it][A
Iteration: 100%|█████████▉| 608/609 [22:01<00:02,  2.52s/it][A
Iteration: 100%|██████████| 609/609 [22:02<00:00,  2.17s/it]

Epoch: 100%|██████████| 1/1 [22:02<00:00, 1322.57s/it]


Creating features from dataset...


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)


Training number: 1221

Testing...

 final validation acc: 0.5389025389025389, final validation loss: 1.1480596057780377


In [None]:
model_type = 'roberta'
model_name = 'roberta-large'
task_name = 'commonqa'
batch_size = 2
logging_steps = 800
lr = 2e-6

run(model_type, model_name, task_name, batch_size, logging_steps, lr)

Using device cuda


Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForMultipleChoice: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForMultipleChoice were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for pre


 Loading training dataset
Creating features from dataset...


Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)


Training number: 9741


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)



 Loading validation dataset
Creating features from dataset...
Training number: 1221



Epoch:   0%|          | 0/1 [00:00<?, ?it/s][A
Iteration:   0%|          | 0/4871 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/4871 [00:00<1:09:47,  1.16it/s][A
Iteration:   0%|          | 2/4871 [00:01<1:08:01,  1.19it/s][A
Iteration:   0%|          | 3/4871 [00:02<1:07:07,  1.21it/s][A
Iteration:   0%|          | 4/4871 [00:03<1:06:31,  1.22it/s][A
Iteration:   0%|          | 5/4871 [00:04<1:06:13,  1.22it/s][A
Iteration:   0%|          | 6/4871 [00:04<1:05:55,  1.23it/s][A
Iteration:   0%|          | 7/4871 [00:05<1:05:48,  1.23it/s][A
Iteration:   0%|          | 8/4871 [00:06<1:05:35,  1.24it/s][A
Iteration:   0%|          | 9/4871 [00:07<1:05:31,  1.24it/s][A
Iteration:   0%|          | 10/4871 [00:08<1:05:28,  1.24it/s][A
Iteration:   0%|          | 11/4871 [00:08<1:05:23,  1.24it/s][A
Iteration:   0%|          | 12/4871 [00:09<1:05:34,  1.24it/s][A
Iteration:   0%|          | 13/4871 [00:10<1:05:43,  1.23it/s][A
Iteration:   0%|          | 14/4871 [00:11<1:05


 val acc: 0.2285012285012285, val loss: 1.6063991359720058



Iteration:  16%|█▋        | 801/4871 [13:09<33:46:22, 29.87s/it][A
Iteration:  16%|█▋        | 802/4871 [13:10<23:54:41, 21.16s/it][A
Iteration:  16%|█▋        | 803/4871 [13:11<17:00:38, 15.05s/it][A
Iteration:  17%|█▋        | 804/4871 [13:11<12:10:50, 10.78s/it][A
Iteration:  17%|█▋        | 805/4871 [13:12<8:48:00,  7.79s/it] [A
Iteration:  17%|█▋        | 806/4871 [13:13<6:26:06,  5.70s/it][A
Iteration:  17%|█▋        | 807/4871 [13:14<4:46:43,  4.23s/it][A
Iteration:  17%|█▋        | 808/4871 [13:15<3:37:13,  3.21s/it][A
Iteration:  17%|█▋        | 809/4871 [13:16<2:48:30,  2.49s/it][A
Iteration:  17%|█▋        | 810/4871 [13:16<2:14:24,  1.99s/it][A
Iteration:  17%|█▋        | 811/4871 [13:17<1:50:29,  1.63s/it][A
Iteration:  17%|█▋        | 812/4871 [13:18<1:33:44,  1.39s/it][A
Iteration:  17%|█▋        | 813/4871 [13:19<1:22:01,  1.21s/it][A
Iteration:  17%|█▋        | 814/4871 [13:20<1:13:50,  1.09s/it][A
Iteration:  17%|█▋        | 815/4871 [13:20<1:08:11,  1.


 val acc: 0.4357084357084357, val loss: 1.5088766754746632



Iteration:  33%|███▎      | 1601/4871 [26:18<27:05:23, 29.82s/it][A
Iteration:  33%|███▎      | 1602/4871 [26:18<19:10:48, 21.12s/it][A
Iteration:  33%|███▎      | 1603/4871 [26:19<13:38:32, 15.03s/it][A
Iteration:  33%|███▎      | 1604/4871 [26:20<9:46:06, 10.76s/it] [A
Iteration:  33%|███▎      | 1605/4871 [26:21<7:03:26,  7.78s/it][A
Iteration:  33%|███▎      | 1606/4871 [26:22<5:09:40,  5.69s/it][A
Iteration:  33%|███▎      | 1607/4871 [26:23<3:49:56,  4.23s/it][A
Iteration:  33%|███▎      | 1608/4871 [26:23<2:54:11,  3.20s/it][A
Iteration:  33%|███▎      | 1609/4871 [26:24<2:15:11,  2.49s/it][A
Iteration:  33%|███▎      | 1610/4871 [26:25<1:47:54,  1.99s/it][A
Iteration:  33%|███▎      | 1611/4871 [26:26<1:28:41,  1.63s/it][A
Iteration:  33%|███▎      | 1612/4871 [26:27<1:15:18,  1.39s/it][A
Iteration:  33%|███▎      | 1613/4871 [26:27<1:06:04,  1.22s/it][A
Iteration:  33%|███▎      | 1614/4871 [26:28<59:26,  1.10s/it]  [A
Iteration:  33%|███▎      | 1615/4871 [26:2


 val acc: 0.49713349713349714, val loss: 1.2734163771772151



Iteration:  49%|████▉     | 2401/4871 [39:26<20:29:31, 29.87s/it][A
Iteration:  49%|████▉     | 2402/4871 [39:27<14:30:23, 21.15s/it][A
Iteration:  49%|████▉     | 2403/4871 [39:28<10:19:09, 15.05s/it][A
Iteration:  49%|████▉     | 2404/4871 [39:29<7:23:13, 10.78s/it] [A
Iteration:  49%|████▉     | 2405/4871 [39:29<5:20:06,  7.79s/it][A
Iteration:  49%|████▉     | 2406/4871 [39:30<3:54:00,  5.70s/it][A
Iteration:  49%|████▉     | 2407/4871 [39:31<2:53:42,  4.23s/it][A
Iteration:  49%|████▉     | 2408/4871 [39:32<2:11:33,  3.20s/it][A
Iteration:  49%|████▉     | 2409/4871 [39:33<1:42:03,  2.49s/it][A
Iteration:  49%|████▉     | 2410/4871 [39:33<1:21:24,  1.98s/it][A
Iteration:  49%|████▉     | 2411/4871 [39:34<1:06:53,  1.63s/it][A
Iteration:  50%|████▉     | 2412/4871 [39:35<56:49,  1.39s/it]  [A
Iteration:  50%|████▉     | 2413/4871 [39:36<49:47,  1.22s/it][A
Iteration:  50%|████▉     | 2414/4871 [39:37<44:45,  1.09s/it][A
Iteration:  50%|████▉     | 2415/4871 [39:37<41


 val acc: 0.5642915642915642, val loss: 1.147228055031522



Iteration:  66%|██████▌   | 3201/4871 [52:34<13:50:50, 29.85s/it][A
Iteration:  66%|██████▌   | 3202/4871 [52:35<9:48:02, 21.14s/it] [A
Iteration:  66%|██████▌   | 3203/4871 [52:36<6:58:12, 15.04s/it][A
Iteration:  66%|██████▌   | 3204/4871 [52:37<4:59:21, 10.77s/it][A
Iteration:  66%|██████▌   | 3205/4871 [52:37<3:36:14,  7.79s/it][A
Iteration:  66%|██████▌   | 3206/4871 [52:38<2:38:04,  5.70s/it][A
Iteration:  66%|██████▌   | 3207/4871 [52:39<1:57:19,  4.23s/it][A
Iteration:  66%|██████▌   | 3208/4871 [52:40<1:28:51,  3.21s/it][A
Iteration:  66%|██████▌   | 3209/4871 [52:41<1:08:53,  2.49s/it][A
Iteration:  66%|██████▌   | 3210/4871 [52:41<54:56,  1.98s/it]  [A
Iteration:  66%|██████▌   | 3211/4871 [52:42<45:08,  1.63s/it][A
Iteration:  66%|██████▌   | 3212/4871 [52:43<38:19,  1.39s/it][A
Iteration:  66%|██████▌   | 3213/4871 [52:44<33:32,  1.21s/it][A
Iteration:  66%|██████▌   | 3214/4871 [52:45<30:13,  1.09s/it][A
Iteration:  66%|██████▌   | 3215/4871 [52:46<27:52,  


 val acc: 0.5904995904995906, val loss: 1.0808421138783768



Iteration:  82%|████████▏ | 4001/4871 [1:05:43<7:12:27, 29.82s/it] [A
Iteration:  82%|████████▏ | 4002/4871 [1:05:43<5:05:55, 21.12s/it][A
Iteration:  82%|████████▏ | 4003/4871 [1:05:44<3:37:26, 15.03s/it][A
Iteration:  82%|████████▏ | 4004/4871 [1:05:45<2:35:33, 10.77s/it][A
Iteration:  82%|████████▏ | 4005/4871 [1:05:46<1:52:18,  7.78s/it][A
Iteration:  82%|████████▏ | 4006/4871 [1:05:47<1:22:02,  5.69s/it][A
Iteration:  82%|████████▏ | 4007/4871 [1:05:48<1:00:53,  4.23s/it][A
Iteration:  82%|████████▏ | 4008/4871 [1:05:48<46:06,  3.21s/it]  [A
Iteration:  82%|████████▏ | 4009/4871 [1:05:49<35:42,  2.49s/it][A
Iteration:  82%|████████▏ | 4010/4871 [1:05:50<28:28,  1.98s/it][A
Iteration:  82%|████████▏ | 4011/4871 [1:05:51<23:24,  1.63s/it][A
Iteration:  82%|████████▏ | 4012/4871 [1:05:52<19:51,  1.39s/it][A
Iteration:  82%|████████▏ | 4013/4871 [1:05:52<17:23,  1.22s/it][A
Iteration:  82%|████████▏ | 4014/4871 [1:05:53<15:37,  1.09s/it][A
Iteration:  82%|████████▏ | 40


 val acc: 0.6003276003276004, val loss: 1.0647238510738564



Iteration:  99%|█████████▊| 4801/4871 [1:18:51<34:48, 29.83s/it][A
Iteration:  99%|█████████▊| 4802/4871 [1:18:52<24:17, 21.13s/it][A
Iteration:  99%|█████████▊| 4803/4871 [1:18:53<17:02, 15.03s/it][A
Iteration:  99%|█████████▊| 4804/4871 [1:18:54<12:01, 10.77s/it][A
Iteration:  99%|█████████▊| 4805/4871 [1:18:55<08:33,  7.78s/it][A
Iteration:  99%|█████████▊| 4806/4871 [1:18:56<06:10,  5.69s/it][A
Iteration:  99%|█████████▊| 4807/4871 [1:18:56<04:30,  4.23s/it][A
Iteration:  99%|█████████▊| 4808/4871 [1:18:57<03:21,  3.20s/it][A
Iteration:  99%|█████████▊| 4809/4871 [1:18:58<02:34,  2.49s/it][A
Iteration:  99%|█████████▊| 4810/4871 [1:18:59<02:01,  1.99s/it][A
Iteration:  99%|█████████▉| 4811/4871 [1:19:00<01:38,  1.64s/it][A
Iteration:  99%|█████████▉| 4812/4871 [1:19:00<01:21,  1.39s/it][A
Iteration:  99%|█████████▉| 4813/4871 [1:19:01<01:10,  1.22s/it][A
Iteration:  99%|█████████▉| 4814/4871 [1:19:02<01:02,  1.10s/it][A
Iteration:  99%|█████████▉| 4815/4871 [1:19:03<

Creating features from dataset...


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)


Training number: 1221

Testing...

 final validation acc: 0.6003276003276004, final validation loss: 1.0647238510738564


In [None]:
model_type = 'roberta'
model_name = 'deepset/roberta-base-squad2'
task_name = 'commonqa'
batch_size = 16
logging_steps = 100
lr = 5e-5

run(model_type, model_name, task_name, batch_size, logging_steps, lr)

In [None]:
model_type = 'roberta'
model_name = 'roberta-large'
task_name = 'commonqa'
batch_size = 2
logging_steps = 800
lr = 1e-5

run(model_type, model_name, task_name, batch_size, logging_steps, lr)

Using device cuda


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=482.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1425941629.0, style=ProgressStyle(descr…




Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForMultipleChoice: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForMultipleChoice were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for pre


 Loading training dataset
Creating features from dataset...


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)


Training number: 9741

 Loading validation dataset
Creating features from dataset...


Using custom data configuration default
Reusing dataset commonsense_qa (/root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29)


Training number: 1221



Epoch:   0%|          | 0/1 [00:00<?, ?it/s][A
Iteration:   0%|          | 0/4871 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/4871 [00:01<1:22:59,  1.02s/it][A
Iteration:   0%|          | 2/4871 [00:01<1:15:25,  1.08it/s][A
Iteration:   0%|          | 3/4871 [00:02<1:10:26,  1.15it/s][A
Iteration:   0%|          | 4/4871 [00:03<1:06:58,  1.21it/s][A
Iteration:   0%|          | 5/4871 [00:03<1:04:46,  1.25it/s][A
Iteration:   0%|          | 6/4871 [00:04<1:02:52,  1.29it/s][A
Iteration:   0%|          | 7/4871 [00:05<1:01:50,  1.31it/s][A
Iteration:   0%|          | 8/4871 [00:06<1:00:48,  1.33it/s][A
Iteration:   0%|          | 9/4871 [00:06<1:00:27,  1.34it/s][A
Iteration:   0%|          | 10/4871 [00:07<1:00:06,  1.35it/s][A
Iteration:   0%|          | 11/4871 [00:08<59:55,  1.35it/s]  [A
Iteration:   0%|          | 12/4871 [00:09<59:41,  1.36it/s][A
Iteration:   0%|          | 13/4871 [00:09<59:40,  1.36it/s][A
Iteration:   0%|          | 14/4871 [00:10<59:29,  


 val acc: 0.18263718263718265, val loss: 1.6094383553475289



Iteration:  16%|█▋        | 801/4871 [12:58<34:31:57, 30.54s/it][A
Iteration:  16%|█▋        | 802/4871 [12:59<24:26:49, 21.63s/it][A
Iteration:  16%|█▋        | 803/4871 [13:00<17:23:09, 15.39s/it][A
Iteration:  17%|█▋        | 804/4871 [13:01<12:26:21, 11.01s/it][A
Iteration:  17%|█▋        | 805/4871 [13:01<8:58:55,  7.95s/it] [A
Iteration:  17%|█▋        | 806/4871 [13:02<6:33:38,  5.81s/it][A
Iteration:  17%|█▋        | 807/4871 [13:03<4:51:55,  4.31s/it][A
Iteration:  17%|█▋        | 808/4871 [13:04<3:40:43,  3.26s/it][A
Iteration:  17%|█▋        | 809/4871 [13:05<2:51:05,  2.53s/it][A
Iteration:  17%|█▋        | 810/4871 [13:05<2:16:18,  2.01s/it][A
Iteration:  17%|█▋        | 811/4871 [13:06<1:51:51,  1.65s/it][A
Iteration:  17%|█▋        | 812/4871 [13:07<1:34:46,  1.40s/it][A
Iteration:  17%|█▋        | 813/4871 [13:08<1:22:41,  1.22s/it][A
Iteration:  17%|█▋        | 814/4871 [13:09<1:14:17,  1.10s/it][A
Iteration:  17%|█▋        | 815/4871 [13:09<1:08:20,  1.


 val acc: 0.20884520884520885, val loss: 1.6094372887072899



Iteration:  33%|███▎      | 1601/4871 [26:08<27:38:53, 30.44s/it][A
Iteration:  33%|███▎      | 1602/4871 [26:09<19:34:14, 21.55s/it][A
Iteration:  33%|███▎      | 1603/4871 [26:10<13:54:46, 15.33s/it][A
Iteration:  33%|███▎      | 1604/4871 [26:11<9:57:20, 10.97s/it] [A
Iteration:  33%|███▎      | 1605/4871 [26:12<7:11:12,  7.92s/it][A
Iteration:  33%|███▎      | 1606/4871 [26:13<5:14:56,  5.79s/it][A
Iteration:  33%|███▎      | 1607/4871 [26:13<3:53:37,  4.29s/it][A
Iteration:  33%|███▎      | 1608/4871 [26:14<2:56:31,  3.25s/it][A
Iteration:  33%|███▎      | 1609/4871 [26:15<2:16:41,  2.51s/it][A
Iteration:  33%|███▎      | 1610/4871 [26:16<1:48:45,  2.00s/it][A
Iteration:  33%|███▎      | 1611/4871 [26:17<1:29:20,  1.64s/it][A
Iteration:  33%|███▎      | 1612/4871 [26:17<1:15:40,  1.39s/it][A
Iteration:  33%|███▎      | 1613/4871 [26:18<1:06:02,  1.22s/it][A
Iteration:  33%|███▎      | 1614/4871 [26:19<59:27,  1.10s/it]  [A
Iteration:  33%|███▎      | 1615/4871 [26:2

# Results

## Validation Accuracy

The validation accuracy without finetuning is listed as follows:
* Bert:  56.6% (base), 61.7% (large)
* XLNet: 26.9%
* RoBerta: 53.9% (base), 60.0% (large)

## Todo

A few things to do in the future:
1. Further tune `xlnet`.
2. Run `roberta-large` with more epochs. Currently `epochs = 1`.
3. Finish all finetuning experiments.

# References

[1] Talmor, A., Herzig, J., Lourie, N., & Berant, J. (2019). CommonsenseQA: A Question Answering Challenge Targeting Commonsense Knowledge. ArXiv, abs/1811.00937.