Folded from https://www.kaggle.com/yihdarshieh/more-nli-datasets
Changes:
1. Using xmlr large model instead of base model;
2. Set the random seeds
3. Using more mnli datas(60000)

<center><img src="https://raw.githubusercontent.com/chiapas/kaggle/master/competitions/contradictory-my-dear-watson/header.png" width="1000"></center>
<br>
<center><h1>Detecting contradiction and entailment in multilingual text using TPUs</h1></center>
<br>

#### Natural Language Inferencing (NLI) is a classic NLP (Natural Language Processing) problem that involves taking two sentences (the _premise_ and the _hypothesis_ ), and deciding how they are related- if the premise entails the hypothesis, contradicts it, or neither.

#### In this notebook, we will use more NLI datasets, including

* [The Stanford Natural Language Inference Corpus (SNLI)](https://nlp.stanford.edu/projects/snli/)
* [The Multi-Genre NLI Corpus (MultiNLI, MNLI)](https://cims.nyu.edu/~sbowman/multinli/)
* [Cross-lingual NLI Corpus (XNLI)](https://cims.nyu.edu/~sbowman/xnli/)

#### We will also use Hugging Face recent library [nlp](https://huggingface.co/nlp/) to work with these datasets.

#### Import

In [None]:
import os
os.environ["WANDB_API_KEY"] = "0" ## to silence warning

import numpy as np
import random
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf
import plotly.express as px

!pip uninstall -y transformers
!pip install transformers

import transformers
import tokenizers

# Hugging Face new library for datasets (https://huggingface.co/nlp/)
!pip install nlp
import nlp

import datetime

strategy = None

def seed_all(seed=2020):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'
seed_all(2020)

In [None]:
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# Datasets

## Competition dataset

In [None]:
original_train = pd.read_csv("../input/contradictory-my-dear-watson/train.csv")

original_train = shuffle(original_train)
original_valid = original_train[:len(original_train) // 5]
original_train = original_train[len(original_train) // 5:]

In [None]:
print(f"original - training: {len(original_train)} examples")
original_train.head(10)

In [None]:
print(f"original - validation: {len(original_valid)} examples")
original_valid.head(10)

In [None]:
original_test = pd.read_csv("../input/contradictory-my-dear-watson/test.csv")
print(f"original - test: {len(original_test)} examples")
original_test.head(10)

## Extra datasets

### Let's use Hugging Face new library [nlp](https://huggingface.co/nlp/), to get more NLI datasets.

#### Load a dataset - The Multi-Genre NLI Corpus (MNLI)
First, let's load the [The Multi-Genre NLI Corpus (MultiNLI, MNLI)](https://cims.nyu.edu/~sbowman/multinli/). It contains $433000$ sentence pairs annotated with textual entailment information.

In [None]:
mnli = nlp.load_dataset(path='glue', name='mnli')

#### check the loaded dataset

Let's look some information about the MNLI dataset. The (default) return value of [nlp.load_dataset](https://huggingface.co/nlp/package_reference/loading_methods.html#nlp.load_dataset) is a dictionary with split names as keys, usually they are `train`, `validation` and `test`, but not always. The values are [nlp.arrow_dataset.Dataset](https://huggingface.co/nlp/master/package_reference/main_classes.html#nlp.Dataset).



In [None]:
print(mnli, '\n')

print('The split names in MNLI dataset:')
for k in mnli:
    print('   ', k)
    
# Get the datasets
print("\nmnli['train'] is ", type(mnli['train']))

mnli['train']

#### look inside 'nlp.arrow_dataset.Dataset'

In order to get the number of examples in a dataset, for example, `mnli['train']`, you can do
```
    mnli['train'].num_rows
```


You can iterate a [nlp.arrow_dataset.Dataset](https://huggingface.co/nlp/master/package_reference/main_classes.html#nlp.Dataset) object like:
```
    for elt in mnli['train']:
        ...
```
Each step, you get an example (which is a dictionary containing features - in a general sense).

You can also access the content of a [nlp.arrow_dataset.Dataset](https://huggingface.co/nlp/master/package_reference/main_classes.html#nlp.Dataset) object by specifying a feature name . For example, the training dataset in `mnli` has `premise`, `hypothesis`, `label` and `idx` as features.

You can either specify a feature name first (you get a list) followed by a slice, like
```
    # You get a `list` first, then slice it
    mnli['train']['premise'][:3]
```
or use slice notation first to get a dictionary (which represents a sliced dataset) followed by a feature name.
```
    # You get a `dictionary` (of lists) first, then a list
    mnli['train'][:3]['premise']
```

The results will be the same.

In order to get the name of the classes, you can do

```
mnli['train'].features['label'].names
```

Let's use what we learned to check some training examples

In [None]:
print('The number of training examples in mnli dataset:', mnli['train'].num_rows)
print('The number of validation examples in mnli dataset - part 1:', mnli['validation_matched'].num_rows)
print('The number of validation examples in mnli dataset - part 2:', mnli['validation_mismatched'].num_rows, '\n')

print('The class names in mnli dataset:', mnli['train'].features['label'].names)
print('The feature names in mnli dataset:', list(mnli['train'].features.keys()), '\n')

for elt in mnli['train']:
    
    print('premise:', elt['premise'])
    print('hypothesis:', elt['hypothesis'])
    print('label:', elt['label'])
    print('label name:', mnli['train'].features['label'].names[elt['label']])
    print('idx', elt['idx'])
    print('-' * 80)
    
    if elt['idx'] >= 10:
        break

Note that the class names are
```
    ['entailment', 'neutral', 'contradiction'] 
```
which corresponds to the original competition dataset, described in [this competition data page](https://www.kaggle.com/c/contradictory-my-dear-watson/data):

> label: the classification of the relationship between the premise and hypothesis (0 for entailment, 1 for neutral, 2 for contradiction)

### Load more extra datasets

#### The Stanford Natural Language Inference Corpus (SNLI)

First, let's load the [The Stanford Natural Language Inference Corpus (SNLI)](https://nlp.stanford.edu/projects/snli/). It contains $570000$ sentence pairs annotated with textual entailment information.

In [None]:
snli = nlp.load_dataset(path='snli')

print('The number of training examples in snli dataset:', snli['train'].num_rows)
print('The number of validation examples in snli dataset:', snli['validation'].num_rows, '\n')

print('The class names in snli dataset:', snli['train'].features['label'].names)
print('The feature names in snli dataset:', list(snli['train'].features.keys()), '\n')

for idx, elt in enumerate(snli['train']):
    
    print('premise:', elt['premise'])
    print('hypothesis:', elt['hypothesis'])
    print('label:', elt['label'])
    print('label name:', snli['train'].features['label'].names[elt['label']])
    print('-' * 80)
    
    if idx >= 10:
        break

Again, the class names are
```
    ['entailment', 'neutral', 'contradiction'] 
```
which corresponds to the original competition dataset.

In [SNLI](https://nlp.stanford.edu/projects/snli/), we have the same premise with different hypotheses/labels. With a first try, I got `nan` as the training loss value. So I won't use this dataset in the current notebook.

#### The Cross-Lingual NLI Corpus (XNLI)

The [MNLI](https://cims.nyu.edu/~sbowman/multinli/) and [SNLI](https://nlp.stanford.edu/projects/snli/) contain only english sentences. Let's load the [Cross-lingual NLI Corpus (XNLI)](https://cims.nyu.edu/~sbowman/xnli/) dataset. It contains only validation and test dataset, not training examples.

In [None]:
xnli = nlp.load_dataset(path='xnli')

print('The number of validation examples in xnli dataset:', xnli['validation'].num_rows, '\n')

print('The class names in xnli dataset:', xnli['validation'].features['label'].names)
print('The feature names in xnli dataset:', list(xnli['validation'].features.keys()), '\n')

for idx, elt in enumerate(xnli['validation']):
    
    print('premise:', elt['premise'])
    print('hypothesis:', elt['hypothesis'])
    print('label:', elt['label'])
    print('label name:', xnli['validation'].features['label'].names[elt['label']])
    print('-' * 80)
    
    if idx >= 3:
        break

The class names are still
```
    ['entailment', 'neutral', 'contradiction'],
```
however, the features `premise` and `hypothesis` are no longer `string` but `dictionary` which contain sentences in different language! 

## Unified dataset format

Since the 4 datasets have different formats, we are going to create an unified interface, which uses [tf.data.Dataset](https://www.tensorflow.org/api_docs/python/tf/data/Dataset), to make working with them easier.

### Make a unified format of raw datasets

In [None]:
def _get_features(elt):
    '''
    Args:
        elt: elements of a `nlp.arrow_dataset.Dataset` that we have seen above
    
    Yields: tuples of 3 elements: (premise, hypothesis, language)
    '''

    if type(elt) == pd.core.series.Series:
        yield (elt['premise'], elt['hypothesis'], elt['lang_abv'])    
    
    elif type(elt['premise']) == str:  
        yield (elt['premise'], elt['hypothesis'], 'en')
    
    elif type(elt) == dict:
        
        # dict of strings
        premises = elt['premise']
        
        # dict of lists
        hypotheses_dict = elt['hypothesis']
        
        # lists
        langs = hypotheses_dict['language']
        translations = hypotheses_dict['translation']
        
        hypotheses = {k: v for k, v in zip(langs, translations)}
                
        for lang in elt['premise']:
            if lang in hypotheses:
                yield (elt['premise'][lang], hypotheses[lang], lang)
        
def _get_raw_datasets_from_nlp(ds: nlp.arrow_dataset.Dataset):
    """ From a `nlp.arrow_dataset.Dataset` that we have seen above to a generator of dictionaries with unified format.
    
    Yield a dictionary with keys: 'premise', 'hypothesis', 'label', 'lang'
    """
    
    for _, elt in enumerate(ds):
        label = elt['label']
        for features in _get_features(elt):
            
            label = -1
            if 'label' in elt:
                label= elt['label']
            
            yield {'premise': features[0], 'hypothesis': features[1], 'label': label, 'lang': features[2]}
            
def _get_raw_datasets_from_dataframe(ds: pd.core.frame.DataFrame):
    
    result = []
    
    for idx, elt in ds.iterrows():
        for features in _get_features(elt):
            
            label = -1
            if 'label' in elt:
                label= elt['label']
            
            yield {'premise': features[0], 'hypothesis': features[1], 'label': label, 'lang': features[2]}

raw_ds_mapping = {
    'original train': (_get_raw_datasets_from_dataframe, original_train, len(original_train)),
    'original valid': (_get_raw_datasets_from_dataframe, original_valid, len(original_valid)),
    'snli train': (_get_raw_datasets_from_nlp, snli['train'], snli['train'].num_rows),
    'snli valid': (_get_raw_datasets_from_nlp, snli['validation'], snli['validation'].num_rows),
    'mnli train': (_get_raw_datasets_from_nlp, mnli['train'], mnli['train'].num_rows),
    'mnli valid 1': (_get_raw_datasets_from_nlp, mnli['validation_matched'], mnli['validation_matched'].num_rows),
    'mnli valid 2': (_get_raw_datasets_from_nlp, mnli['validation_mismatched'], mnli['validation_mismatched'].num_rows),
    'xnli valid': (_get_raw_datasets_from_nlp, xnli['validation'], xnli['validation'].num_rows * 15), # 15 languages
    'original test': (_get_raw_datasets_from_dataframe, original_test, len(original_test)),
}
def get_raw_dataset(ds_name):
    
    fn, ds, nb_examples = raw_ds_mapping[ds_name]
    
    for x in fn(ds):
        yield x

#### sanity check

In [None]:
for k in raw_ds_mapping:
    for idx, x in enumerate(get_raw_dataset(k)):
        print(x)
        if idx >= 3:
            break

### Working with tf.data.Dataset

In [None]:
def get_unbatched_dataset(ds_names, model_name, max_len=64):
    """
    Args:
        ds_names: list[str] or dict[str:int], the names of dataset to use, and optionally, how many examples to use from each of them.
        model_name: list[str], a list of valid Hugging Face transformers' model names.
            For example: 'distilbert-base-uncased', 'bert-base-uncased', etc.
    
    Returns:
        A `tf.data.Dataset`.
    """

    if type(ds_names) == list:
        ds_names = {k: None for k in ds_names}
    ds_names = {k: v for k, v in ds_names.items() if k in raw_ds_mapping}    
    
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, use_fast=True)
    
    # This is a list of generators
    raw_datasets = [get_raw_dataset(x) for x in ds_names]
    
    nb_examples = 0

    sentence_pairs = []
    labels = []    
    for name in ds_names:
        
        raw_ds = get_raw_dataset(name)
        nb_examples_to_use = raw_ds_mapping[name][2]
        if ds_names[name]:
            nb_examples_to_use = min(ds_names[name], nb_examples_to_use)
        nb_examples += nb_examples_to_use
        
        n = 0
        for x in raw_ds:
            sentence_pairs.append((x['premise'], x['hypothesis']))
            labels.append(x['label'])
            n += 1
            if n >= nb_examples_to_use:
                break

    # `transformers.tokenization_utils_base.BatchEncoding` object -> `dict`
    r = dict(tokenizer.batch_encode_plus(batch_text_or_text_pairs=sentence_pairs, max_length=max_len, padding='max_length', truncation=True))

    # This is very slow
    dataset = tf.data.Dataset.from_tensor_slices((r, labels))

    return dataset, nb_examples

def get_batched_training_dataset(dataset, nb_examples, batch_size=16, shuffle_buffer_size=1, repeat=False):
    
    if repeat:
        dataset = dataset.repeat()
    
    if not shuffle_buffer_size:
        shuffle_buffer_size = nb_examples
    dataset = dataset.shuffle(shuffle_buffer_size)
    
    dataset = dataset.batch(batch_size, drop_remainder=True)
    
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    
    return dataset


def get_prediction_dataset(dataset, batch_size=16):
    
    dataset = dataset.batch(batch_size, drop_remainder=False)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    
    return dataset

#### sanity check

In [None]:
for k in raw_ds_mapping.keys():

    ds, nb_examples = get_unbatched_dataset({k: 100}, model_name='distilbert-base-uncased')
    ds_batched = get_batched_training_dataset(ds, nb_examples, batch_size=16, shuffle_buffer_size=1, repeat=False)
    print('{} - select {} examples'.format(k, nb_examples))
    
    for x in ds_batched:
        # print(x)
        break

# Training

In [None]:
def keep_head_tail(lst, len_wanted, head_ratio=0.5):#0.5
    len_head = int(len_wanted*head_ratio)
    len_tail = len_wanted-len_head
    if len_tail==0: return lst[:len_head]
    return lst[:len_head]+lst[-len_tail:] 

## Tranier

In [None]:
class Classifier(tf.keras.Model):
    
    def __init__(self, model_name):
        
        super(Classifier, self).__init__()
        
        self.transformer = transformers.TFAutoModel.from_pretrained(model_name)
        self.dropout = tf.keras.layers.Dropout(rate=0.05)
        self.global_pool = tf.keras.layers.GlobalAveragePooling1D()
        self.classifier = tf.keras.layers.Dense(3)

    def call(self, inputs, training=False):
        
        # Sequence outputs
        x = self.transformer(inputs, training=training)[0]        
        x = self.dropout(x, training=training)
        x = self.global_pool(x)
        
        return self.classifier(x)

class Trainer:
    
    def __init__(
        self, ds_names, model_name, max_len=64,
        batch_size_per_replica=16, prediction_batch_size_per_replica=64,
        shuffle_buffer_size=1
    ):

        global strategy
        
        try:
            tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
            tf.config.experimental_connect_to_cluster(tpu)
            tf.tpu.experimental.initialize_tpu_system(tpu)
            strategy = tf.distribute.experimental.TPUStrategy(tpu)
        except ValueError:
            strategy = tf.distribute.get_strategy() # for CPU and single GPU

        print('Number of replicas:', strategy.num_replicas_in_sync)             
        
        self.ds_names = ds_names
        self.model_name = model_name
        self.max_len = max_len
    
        self.batch_size_per_replica = batch_size_per_replica
        self.prediction_batch_size_per_replica = prediction_batch_size_per_replica
        
        self.batch_size = batch_size_per_replica * strategy.num_replicas_in_sync
        self.prediction_batch_size = prediction_batch_size_per_replica * strategy.num_replicas_in_sync

        self.shuffle_buffer_size = shuffle_buffer_size

        train_ds, self.nb_examples = get_unbatched_dataset(
            ds_names=ds_names, model_name=model_name, max_len=max_len
        )
        self.train_ds = get_batched_training_dataset(
            train_ds, self.nb_examples, batch_size=self.batch_size,
            shuffle_buffer_size=self.shuffle_buffer_size, repeat=True
        )
        
        valid_ds, self.nb_valid_examples = get_unbatched_dataset(
            ds_names=['original valid'], model_name=model_name, max_len=max_len
        )
        self.valid_ds = get_prediction_dataset(valid_ds, self.prediction_batch_size)
        self.valid_labels = next(iter(self.valid_ds.map(lambda inputs, label: label).unbatch().batch(len(original_valid))))
        
        test_ds, self.nb_test_examples = get_unbatched_dataset(
            ds_names=['original test'], model_name=model_name, max_len=max_len
        )
        self.test_ds = get_prediction_dataset(test_ds, self.prediction_batch_size)
        
        self.steps_per_epoch = self.nb_examples // self.batch_size
                   
    def get_model(self, model_name, lr, verbose=False):

        with strategy.scope():

            model = Classifier(model_name)

            # False = transfer learning, True = fine-tuning
            model.trainable = True 

            # Just run a dummy batch, not necessary
            dummy = model(tf.constant(1, shape=[1, 64]))

            if verbose:
                model.summary()

            # Instiate an optimizer with a learning rate schedule
            optimizer = tf.keras.optimizers.Adam(lr=lr)

            # Only `NONE` and `SUM` are allowed, and it has to be explicitly specified.
            loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)
            
            # Instantiate metrics
            metrics = {
                'train loss': tf.keras.metrics.Sum(),
                'train acc': tf.keras.metrics.SparseCategoricalAccuracy()
            }

            return model, loss_fn, optimizer, metrics
        
    def get_routines(self, model, loss_fn, optimizer, metrics):

        def train_1_step(batch):
            
            inputs, labels = batch
    
            with tf.GradientTape() as tape:

                logits = model(inputs, training=True)
                # Remember that we use the `SUM` reduction when we define the loss object.
                loss = loss_fn(labels, logits) / self.batch_size

            grads = tape.gradient(loss, model.trainable_variables)

            # Update the model's parameters.
            optimizer.apply_gradients(zip(grads, model.trainable_variables))            
            
            # update metrics
            metrics['train loss'].update_state(loss)
            metrics['train acc'].update_state(labels, logits)

        @tf.function
        def dist_train_1_epoch(data_iter):
            """
            Iterating inside `tf.function` to optimized training time.
            """
            for _ in tf.range(self.steps_per_epoch):
                strategy.run(train_1_step, args=(next(data_iter),))        

        @tf.function                
        def predict_step(batch):

            inputs, _ = batch
            
            logits = model(inputs, training=False)
            return logits

        def predict_fn(dist_test_ds):

            all_logits = []
            for batch in dist_test_ds:

                # PerReplica object
                logits = strategy.run(predict_step, args=(batch,))

                # Tuple of tensors
                logits = strategy.experimental_local_results(logits)

                # tf.Tensor
                logits = tf.concat(logits, axis=0)

                all_logits.append(logits)

            # tf.Tensor
            logits = tf.concat(all_logits, axis=0)

            return logits         
                
        return dist_train_1_epoch, predict_fn
        
    def train(self, train_name, model_name, epochs, verbose=False):

        global strategy
        
        try:
            tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
            tf.config.experimental_connect_to_cluster(tpu)
            tf.tpu.experimental.initialize_tpu_system(tpu)
            strategy = tf.distribute.experimental.TPUStrategy(tpu)
        except ValueError:
            strategy = tf.distribute.get_strategy() # for CPU and single GPU

        print('Number of replicas:', strategy.num_replicas_in_sync)        
        
        model, loss_fn, optimizer, metrics = self.get_model(model_name, 1e-5, verbose=verbose)
        dist_train_1_epoch, predict_fn = self.get_routines(model, loss_fn, optimizer, metrics)
        
        train_dist_ds = strategy.experimental_distribute_dataset(self.train_ds)
        train_dist_iter = iter(train_dist_ds)
        
        dist_valid_ds = strategy.experimental_distribute_dataset(self.valid_ds)
        dist_test_ds = strategy.experimental_distribute_dataset(self.test_ds)

        history = {}
        best_acc=0.5
        for epoch in range(epochs):
            
            s = datetime.datetime.now()

            dist_train_1_epoch(train_dist_iter)

            # get metrics
            train_loss = metrics['train loss'].result() / self.steps_per_epoch
            train_acc = metrics['train acc'].result()

            # reset metrics
            metrics['train loss'].reset_states()
            metrics['train acc'].reset_states()
                   
            print('epoch: {}\n'.format(epoch + 1))
            print('train loss: {}'.format(train_loss))
            print('train acc: {}\n'.format(train_acc)) 
                
            e = datetime.datetime.now()
            elapsed = (e - s).total_seconds()            
            
            logits = predict_fn(dist_valid_ds)

            valid_loss = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(self.valid_labels, logits, from_logits=True, axis=-1))
            valid_acc = tf.reduce_mean(tf.keras.metrics.sparse_categorical_accuracy(self.valid_labels, logits))
            if valid_acc>best_acc:
                best_acc=valid_acc
                #save the model
                model.save_weights('best.h5')
                
            print('valid loss: {}'.format(valid_loss))
            print('valid acc: {}\n'.format(valid_acc))
            
            print('train timing: {}\n'.format(elapsed))
            
            history[epoch] = {
                'train loss': float(train_loss),
                'train acc': float(train_acc),
                'valid loss': float(valid_loss),
                'valid acc': float(valid_acc),                
                'train timing': elapsed
            }

            print('-' * 40)
        
        print('best acc:{}'.format(best_acc))
        model.load_weights('best.h5')
        logits = predict_fn(dist_test_ds)
        preds = tf.math.argmax(logits, axis=-1)
        
        submission = pd.read_csv('/kaggle/input/contradictory-my-dear-watson/sample_submission.csv')
        submission['prediction'] = preds.numpy()
        submission.to_csv(f'submission-{train_name}.csv', index=False)
        
        return history, submission,logits

## Train

In [None]:
def print_config(trainer):

    print('nb. of training examples used: {}'.format(trainer.nb_examples))
    print('nb. of valid examples used: {}'.format(trainer.nb_valid_examples))
    print('nb. of test examples used: {}'.format(trainer.nb_test_examples))
    
    print('per replica batch size for training: {}'.format(trainer.batch_size_per_replica))
    print('batch size for training: {}'.format(trainer.batch_size))

    print('per replica batch size for prediction: {}'.format(trainer.prediction_batch_size_per_replica))
    print('batch size for prediction: {}'.format(trainer.prediction_batch_size))
    
    print('steps per epoch: {}'.format(trainer.steps_per_epoch))

In [None]:
epochs = 15#15
# model_name = 'jplu/tf-xlm-roberta-base'
model_name = 'jplu/tf-xlm-roberta-large'
# model_name = 'distilbert-base-uncased'

### train on the original dataset

In [None]:
# trainer = Trainer(
#     ds_names={'original train': None}, model_name=model_name,
#     max_len=64, batch_size_per_replica=16, prediction_batch_size_per_replica=64,
#     shuffle_buffer_size=None
# )

# print_config(trainer)

# train_name = f'{model_name} + original-dataset'.replace('/', '-')
# history_1, submission_1 = trainer.train(train_name=train_name, model_name=model_name, epochs=epochs, verbose=True)

#### Plot history

In [None]:
def plot(history, metric):
    """
    metric: 'loss' or 'acc'
    """
    
    h = {
        f'train {metric}': [history[epoch][f'train {metric}'] for epoch in history],
        f'valid {metric}': [history[epoch][f'valid {metric}'] for epoch in history]
    }
        
    fig = px.line(
        h, x=range(1, len(history) + 1), y=[f'train {metric}', f'valid {metric}'], 
        title=f'model {metric}', labels={'x': 'Epoch', 'value': metric}
    )
    fig.show()
    
def plot_2(history1, history2, metric, desc1, desc2):
    
    h = {
        f'train {metric} - {desc1}': [history1[epoch][f'train {metric}'] for epoch in history1],
        f'valid {metric} - {desc1}': [history1[epoch][f'valid {metric}'] for epoch in history1],
        f'train {metric} - {desc2}': [history2[epoch][f'train {metric}'] for epoch in history2],
        f'valid {metric} - {desc2}': [history2[epoch][f'valid {metric}'] for epoch in history2]        
    }
        
    fig = px.line(
        h, x=range(1, len(history1) + 1), y=[f'train {metric} - {desc1}', f'valid {metric} - {desc1}', f'train {metric} - {desc2}', f'valid {metric} - {desc2}'], 
        title=f'model {metric}', labels={'x': 'Epoch', 'value': metric}
    )
    fig.show()    

In [None]:
# plot(history_1, 'loss')
# plot(history_1, 'acc')

In [None]:
# pd.read_csv(f'submission-{train_name}.csv').head(20)

### original dataset + xnli

In [None]:
# trainer = Trainer(
#     ds_names={'original train': None, 'xnli valid': None}, model_name=model_name,
#     max_len=64, batch_size_per_replica=16, prediction_batch_size_per_replica=64,
#     shuffle_buffer_size=None
# )

# print_config(trainer)

# train_name = f'{model_name} + extra-xnli'.replace('/', '-')
# history_2, submission_2 = trainer.train(train_name=train_name, model_name=model_name, epochs=epochs, verbose=True)

In [None]:
# plot(history_2, 'loss')
# plot(history_2, 'acc')

#### Compare [only original dataset] vs. [+ xnli]

In [None]:
# plot_2(history_1, history_2, 'loss', desc1='only original dataset', desc2='+ xnli')
# plot_2(history_1, history_2, 'acc', desc1='only original dataset', desc2='+ xnli')

In [None]:
# pd.read_csv(f'submission-{train_name}.csv').head(20)

### original dataset + xnli + mnli

In [None]:
trainer = Trainer(
    ds_names={'original train': None, 'xnli valid': None, 'mnli train': 60000, 'mnli valid 1': None, 'mnli valid 2': None}, model_name=model_name,
    max_len=208, batch_size_per_replica=16, prediction_batch_size_per_replica=64,#16
    shuffle_buffer_size=None
)

print_config(trainer)

train_name = f'{model_name} + extra-xnli-mnli'.replace('/', '-')
history_3, submission_3,preds = trainer.train(train_name=train_name, model_name=model_name, epochs=epochs, verbose=True)

In [None]:
np.savez_compressed('preds',a=preds)

In [None]:
plot(history_3, 'loss')
plot(history_3, 'acc')

#### Compare [only original dataset] vs. [+ xnli/mnli]

In [None]:
# plot_2(history_1, history_3, 'loss', desc1='only original dataset', desc2='+ mnli/xnli')
# plot_2(history_1, history_3, 'acc', desc1='only original dataset', desc2='+ mnli+xnli')

#### Compare [+ xnli] vs. [+ xnli/mnli]

In [None]:
# plot_2(history_2, history_3, 'loss', desc1='+ xnli', desc2='+ mnli/xnli')
# plot_2(history_2, history_3, 'acc', desc1='+ xnli', desc2='+ mnli/xnli')

In [None]:
s = pd.read_csv(f'submission-{train_name}.csv')
s.to_csv(f'submission.csv', index=False)

s.head(20)