# Project Goals
* Conduct error analysis on model predictions on a test set
* Explore different ensembling techniques for improving overall inference accuracy

# Model Information
* bert-base-multilingual-cased
    * Numbers of parameters: 177 M
    * Max epochs: 10
    * Best epoch: 1
    * Accuracy on Leaderboard public dataset: 0.63156


* xlm-roberta-base
    * Numbers of parameters: 278 M
    * Max epochs: 10
    * Best epoch: 3
    * Accuracy on Leaderboard public dataset: 0.67930
    

* xlm-roberta-large
    * Numbers of parameters: 560 M
    * Max epochs: 5
    * Best epoch: 2
    * Accuracy on Leaderboard public dataset: 0.73763

# Environment Setup

In [None]:
# Check if TPU/GPU is available
import tensorflow as tf
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    DEVICE = "tpu"
except ValueError:
    if tf.test.is_gpu_available():
        DEVICE = "gpu"
    else:
        DEVICE = "cpu"

print("Accelerator: {}".format(DEVICE))

In [None]:
# Set up an environment for accessing TPU
if DEVICE == "tpu":
    !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
    !python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev
    !pip install pytorch-lightning
    import torch_xla
    import torch_xla.core.xla_model as xm

In [None]:
import os
os.environ["WANDB_API_KEY"] = "0"  # to silence warning

In [None]:
!pip install datasets

In [None]:
import gc
import glob
import numpy as np
import pandas as pd
import seaborn as sn
import datasets
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from typing import List, Dict
from tqdm import tqdm
from scipy.special import softmax
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.utils.extmath import weighted_mode
from collections import OrderedDict, defaultdict
from transformers import BertTokenizer, BertForSequenceClassification, XLMRobertaTokenizer, XLMRobertaForSequenceClassification
from torch.utils.data import Dataset, DataLoader
try:
    from pytorch_lightning import LightningModule, loggers, seed_everything
    from pytorch_lightning.core.decorators import auto_move_data
except OSError:  # Reloading pytorch_lightning again to resolve OSError issues with TPU
    from pytorch_lightning import LightningModule, loggers, seed_everything
    from pytorch_lightning.core.decorators import auto_move_data

In [None]:
# Remove temp installation files to release space
gc.collect()
paths = glob.glob("/kaggle/working/*")
for path in paths:
    try:
        if os.path.isfile(path):
            os.remove(path)
        elif os.path.isdir(path):
            shutil.rmtree(path)
    except:
        print("Not removable: {}".format(path))

# Define Global Variables

In [None]:
# MultilingualBERT Model
PRETRAINED_BERT_MODEL = "bert-base-multilingual-cased"
FINETUNED_BERT_MODEL = "../input/multilingualbert-finetuned-10epochs/bert-base-multilingual-cased_ft_10epochs_epoch1.ckpt"

# XLM-RoBERTa-Base Model
PRETRAINED_XLMBASE_MODEL = "xlm-roberta-base"
FINETUNED_XLMBASE_MODEL = "../input/xlmrobertabase-finetuned-10epochs/xlm-roberta-base_ft_10epochs_epoch3.ckpt"

# XLM-RoBERTa-Large Model
PRETRAINED_XLMLARGE_MODEL = "xlm-roberta-large"
FINETUNED_XLMLARGE_MODEL = "../input/xlmrobertalarge-finetuned-5epochs/xlm-roberta-large_ft_5epochs_epoch2.ckpt"

In [None]:
# Global Variables
SEED = 2020
MAX_EPOCHS = 1
if DEVICE == "tpu":
    BATCH_SIZE = 8
    MAX_TOKEN_LEN = 50
    TPU_CORES = 8
    GPUS = 1
    NUM_WORKERS = 4
else:
    BATCH_SIZE = 16
    MAX_TOKEN_LEN = 50
    TPU_CORES = 1
    GPUS = 1
    NUM_WORKERS = 4

# Exploratory Data Analysis

In [None]:
# Test set with labels
testset_df = pd.read_csv("../input/nli-test-set/nli_test_set.csv")
testset_df.head()

In [None]:
len(testset_df)

In [None]:
testset_df.groupby(["language", "label"]).size()

In [None]:
# Production dataset for inference and submission
prod_df = pd.read_csv("../input/contradictory-my-dear-watson/test.csv")
prod_df.head()

In [None]:
len(prod_df)

In [None]:
prod_df["language"].value_counts()

# Model Evaluation

In [None]:
class NLIEvalDataset(Dataset):
    def __init__(self, 
                 dataset: pd.DataFrame, 
                 model_name: str,
                 max_token_len: int = MAX_TOKEN_LEN,
                 production: bool = False
                ):
        self.dataset = dataset
        self.model_name = model_name
        self.max_token_len = max_token_len
        self.production = production
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index: int):
        row_id = self.dataset.id.values[index]
        premise = self.dataset.premise.values[index]
        hypothesis = self.dataset.hypothesis.values[index]
        
        if self.model_name in ["bert-base-multilingual-cased"]:
            tokenizer = BertTokenizer.from_pretrained(self.model_name)
        elif self.model_name in ["xlm-roberta-base", "xlm-roberta-large"]:
            tokenizer = XLMRobertaTokenizer.from_pretrained(self.model_name)
        
        encoded_sents = tokenizer.encode_plus(premise, 
                                              hypothesis,
                                              add_special_tokens=True, 
                                              pad_to_max_length=True, 
                                              max_length=self.max_token_len, 
                                              truncation=True, 
                                              return_attention_mask=True, 
                                              return_token_type_ids=True,
                                              return_tensors="pt")
        
        inputs = {
            "input_ids": encoded_sents["input_ids"][0],
            "token_type_ids": encoded_sents["token_type_ids"][0],
            "attention_mask": encoded_sents["attention_mask"][0]
        }
        
        if self.production:
            return inputs, row_id
        else:
            label = self.dataset.label.values[index]
            return inputs, label, row_id

In [None]:
class NLIEvalModelModule(LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    @auto_move_data
    def forward(self, inputs):
        predictions = self.model(input_ids=inputs["input_ids"], 
                                 attention_mask=inputs["attention_mask"],
                                 token_type_ids=inputs["token_type_ids"])
        return predictions

In [None]:
def model_evaluation(dataset: pd.DataFrame, 
                     model_name: str, 
                     checkpoint_path: str, 
                     batch_size: int = BATCH_SIZE):
    # Prepare Dataset
    nli_dataset = NLIEvalDataset(dataset, model_name, production=False)
    nli_dataloader = DataLoader(nli_dataset, batch_size=batch_size, shuffle=False)
    
    # Load Fine-tuned Model
    checkpoint = torch.load(checkpoint_path)
    state_dict = checkpoint["state_dict"]
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[6:]  # remove 'model.' of DataParallel
        new_state_dict[name] = v
    
    if model_name in ["bert-base-multilingual-cased"]:
        model = BertForSequenceClassification.from_pretrained(model_name, 
                                                              state_dict=new_state_dict,
                                                              num_labels=3)
    elif model_name in ["xlm-roberta-base", "xlm-roberta-large"]:
        model = XLMRobertaForSequenceClassification.from_pretrained(model_name, 
                                                                    state_dict=new_state_dict, 
                                                                    num_labels=3)
    
    model_module = NLIEvalModelModule(model)
    model_module.freeze() # eval
    info_dict = defaultdict(list)
    prob_list = []

    # Model inference
    for inputs, label, row_id in tqdm(nli_dataloader):
        raw_predictions = model_module.forward(inputs)
        predictions = raw_predictions[0].cpu().numpy()
        pred_label = np.argmax(predictions, axis=1)
        pred_probs = softmax(predictions, axis=1)
        
        info_dict["row_id"] += row_id
        info_dict["label"] += label.cpu().numpy().tolist()
        info_dict["pred_label"] += pred_label.tolist()
        prob_list += pred_probs.tolist()
        
    info_df = pd.DataFrame.from_dict(info_dict)
    prob_df = pd.DataFrame(prob_list, columns=["label_0_prob", "label_1_prob", "label_2_prob"])
    return pd.concat([info_df, prob_df], axis=1)

In [None]:
# General settings
seed_everything(SEED)

In [None]:
def plot_confusion_matrix(confusion_matrix, norm=False):
    plt.figure(figsize=(7, 7))
    cm_df = pd.DataFrame(confusion_matrix, 
                         index=["True Label: 0", "True Label: 1", "True Label: 2"],
                         columns=["Predicted Label: 0", "Predicted Label: 1", "Predicted Label: 2"])
    if norm:
        sn.heatmap(cm_df, cmap="YlGnBu", annot=True, fmt=".2g")
    else:
        sn.heatmap(cm_df, cmap="YlGnBu", annot=True, fmt="d")
    plt.show()

# Model Evaluation: bert-base-multilingual-cased

In [None]:
BERT_TESTSET_PREDICTIONS = "../input/nli-test-set/bert_testset_predictions.csv"
if os.path.isfile(BERT_TESTSET_PREDICTIONS):
    bert_pred_df = pd.read_csv(BERT_TESTSET_PREDICTIONS, index_col=0)
else:
    bert_pred_df = model_evaluation(testset_df, PRETRAINED_BERT_MODEL, FINETUNED_BERT_MODEL)
    bert_pred_df.to_csv("bert_testset_predictions.csv")

In [None]:
bert_pred_df.head()

In [None]:
bert_accuracy_score = accuracy_score(bert_pred_df.label.values, bert_pred_df.pred_label.values)
print(bert_accuracy_score)

In [None]:
bert_confusion_matrix = confusion_matrix(bert_pred_df.label.values, 
                                         bert_pred_df.pred_label.values)
bert_norm_confusion_matrix = confusion_matrix(bert_pred_df.label.values, 
                                              bert_pred_df.pred_label.values, 
                                              normalize="true")
plot_confusion_matrix(bert_confusion_matrix)
plot_confusion_matrix(bert_norm_confusion_matrix, norm=True)

# Model Evaluation: xlm-roberta-base

In [None]:
XLMBASE_TESTSET_PREDICTIONS = "../input/nli-test-set/xlmbase_testset_predictions.csv"
if os.path.isfile(XLMBASE_TESTSET_PREDICTIONS):
    xlmbase_pred_df = pd.read_csv(XLMBASE_TESTSET_PREDICTIONS, index_col=0)
else:
    xlmbase_pred_df = model_evaluation(testset_df, PRETRAINED_XLMBASE_MODEL, FINETUNED_XLMBASE_MODEL)
    xlmbase_pred_df.to_csv("xlmbase_testset_predictions.csv")

In [None]:
xlmbase_pred_df.head()

In [None]:
xlmbase_accuracy_score = accuracy_score(xlmbase_pred_df.label.values, xlmbase_pred_df.pred_label.values)
print(xlmbase_accuracy_score)

In [None]:
xlmbase_confusion_matrix = confusion_matrix(xlmbase_pred_df.label.values, 
                                            xlmbase_pred_df.pred_label.values)
xlmbase_norm_confusion_matrix = confusion_matrix(xlmbase_pred_df.label.values, 
                                                 xlmbase_pred_df.pred_label.values, 
                                                 normalize="true")
plot_confusion_matrix(xlmbase_confusion_matrix)
plot_confusion_matrix(xlmbase_norm_confusion_matrix, norm=True)

# Model Evaluation: xlm-roberta-large

In [None]:
XLMLARGE_TESTSET_PREDICTIONS = "../input/nli-test-set/xlmlarge_testset_predictions.csv"
if os.path.isfile(XLMLARGE_TESTSET_PREDICTIONS):
    xlmlarge_pred_df = pd.read_csv(XLMLARGE_TESTSET_PREDICTIONS, index_col=0)
else:
    xlmlarge_pred_df = model_evaluation(testset_df, PRETRAINED_XLMLARGE_MODEL, FINETUNED_XLMLARGE_MODEL)
    xlmlarge_pred_df.to_csv("xlmlarge_testset_predictions.csv")

In [None]:
xlmlarge_pred_df.head()

In [None]:
xlmlarge_accuracy_score = accuracy_score(xlmlarge_pred_df.label.values, xlmlarge_pred_df.pred_label.values)
print(xlmlarge_accuracy_score)

In [None]:
xlmlarge_confusion_matrix = confusion_matrix(xlmlarge_pred_df.label.values, 
                                             xlmlarge_pred_df.pred_label.values)
xlmlarge_norm_confusion_matrix = confusion_matrix(xlmlarge_pred_df.label.values, 
                                                  xlmlarge_pred_df.pred_label.values, 
                                                  normalize="true")
plot_confusion_matrix(xlmlarge_confusion_matrix)
plot_confusion_matrix(xlmlarge_norm_confusion_matrix, norm=True)

# Ensemble Option 1 - Argmax + Majority Voting

In [None]:
tmp_df = pd.merge(bert_pred_df, 
                  xlmbase_pred_df[["row_id", "pred_label", "label_0_prob", "label_1_prob", "label_2_prob"]], 
                  on="row_id", 
                  suffixes=('_bert', '_xlmbase'))
merge_df = pd.merge(tmp_df, 
                    xlmlarge_pred_df[["row_id", "pred_label", "label_0_prob", "label_1_prob", "label_2_prob"]], 
                    on="row_id")
merge_df.rename(columns={'label_bert': 'label',
                         'pred_label': 'pred_label_xlmlarge',
                         'label_0_prob': 'label_0_prob_xlmlarge',
                         'label_1_prob': 'label_1_prob_xlmlarge',
                         'label_2_prob': 'label_2_prob_xlmlarge'}, 
                inplace=True)
merge_df.head()

In [None]:
majority_voting_df = merge_df.loc[:, ["pred_label_bert", "pred_label_xlmbase", "pred_label_xlmlarge"]]
majority_voting_df['majority'] = majority_voting_df.mode(axis=1)[0].astype("int")
majority_voting_df.head(10)

In [None]:
majority_voting_accuracy_score = accuracy_score(merge_df.label.values, majority_voting_df.majority.values)
print(majority_voting_accuracy_score)

In [None]:
majority_voting_confusion_matrix = confusion_matrix(merge_df.label.values, 
                                                    majority_voting_df.majority.values)
majority_voting_norm_confusion_matrix = confusion_matrix(merge_df.label.values, 
                                                         majority_voting_df.majority.values,
                                                         normalize="true")
plot_confusion_matrix(majority_voting_confusion_matrix)
plot_confusion_matrix(majority_voting_norm_confusion_matrix, norm=True)

# Ensemble Option 2 - Argmax + Weighted Voting

In [None]:
model_acc_array = np.array([bert_accuracy_score, xlmbase_accuracy_score, xlmlarge_accuracy_score])
MODEL_WEIGHTING = model_acc_array / sum(model_acc_array)
print(MODEL_WEIGHTING)

In [None]:
weighted_voting_df = merge_df.loc[:, ["pred_label_bert", "pred_label_xlmbase", "pred_label_xlmlarge"]]
weighted_voting_df['majority'], weighted_voting_df['majority_score'] = weighted_mode(weighted_voting_df, MODEL_WEIGHTING, axis=1)
weighted_voting_df['majority'] = weighted_voting_df['majority'].astype("int")
weighted_voting_df.head(10)

In [None]:
weighted_voting_accuracy_score = accuracy_score(merge_df.label.values, weighted_voting_df.majority.values)
print(weighted_voting_accuracy_score)

In [None]:
weighted_voting_confusion_matrix = confusion_matrix(merge_df.label.values, 
                                                    weighted_voting_df.majority.values)
weighted_voting_norm_confusion_matrix = confusion_matrix(merge_df.label.values, 
                                                         weighted_voting_df.majority.values,
                                                         normalize="true")
plot_confusion_matrix(weighted_voting_confusion_matrix)
plot_confusion_matrix(weighted_voting_norm_confusion_matrix, norm=True)

# Ensemble Option 3 - Averaged Probabilities + Argmax

In [None]:
prob_df = merge_df[merge_df.columns.difference(["pred_label_bert", "pred_label_xlmbase", "pred_label_xlmlarge", "row_id"])]
prob_df.head()

In [None]:
averaged_prob_df = pd.DataFrame()
averaged_prob_df["label_0_prob_avg"] = np.average(prob_df.loc[:, prob_df.columns.str.startswith('label_0')], axis=1)
averaged_prob_df["label_1_prob_avg"] = np.average(prob_df.loc[:, prob_df.columns.str.startswith('label_1')], axis=1)
averaged_prob_df["label_2_prob_avg"] = np.average(prob_df.loc[:, prob_df.columns.str.startswith('label_2')], axis=1)
averaged_prob_df["idxmax"] = averaged_prob_df.idxmax(axis=1)
label_mapping = {"label_0_prob_avg": 0, "label_1_prob_avg": 1, "label_2_prob_avg": 2}
averaged_prob_df["pred_label"] = averaged_prob_df["idxmax"].map(label_mapping)
averaged_prob_df.head(10)

In [None]:
averaged_prob_accuracy_score = accuracy_score(merge_df.label.values, averaged_prob_df.pred_label.values)
print(averaged_prob_accuracy_score)

In [None]:
averaged_prob_confusion_matrix = confusion_matrix(merge_df.label.values, 
                                                  averaged_prob_df.pred_label.values)
averaged_prob_norm_confusion_matrix = confusion_matrix(merge_df.label.values, 
                                                       averaged_prob_df.pred_label.values,
                                                       normalize="true")
plot_confusion_matrix(averaged_prob_confusion_matrix)
plot_confusion_matrix(averaged_prob_norm_confusion_matrix, norm=True)

# Ensemble Option 4 - Weighted Probabilities + Argmax

In [None]:
weighted_prob_df = pd.DataFrame()
weighted_prob_df["label_0_prob_avg"] = np.average(prob_df.loc[:, prob_df.columns.str.startswith('label_0')], weights=MODEL_WEIGHTING, axis=1)
weighted_prob_df["label_1_prob_avg"] = np.average(prob_df.loc[:, prob_df.columns.str.startswith('label_1')], weights=MODEL_WEIGHTING, axis=1)
weighted_prob_df["label_2_prob_avg"] = np.average(prob_df.loc[:, prob_df.columns.str.startswith('label_2')], weights=MODEL_WEIGHTING, axis=1)
weighted_prob_df["idxmax"] = weighted_prob_df.idxmax(axis=1)
label_mapping = {"label_0_prob_avg": 0, "label_1_prob_avg": 1, "label_2_prob_avg": 2}
weighted_prob_df["pred_label"] = weighted_prob_df["idxmax"].map(label_mapping)
weighted_prob_df.head(10)

In [None]:
weighted_prob_accuracy_score = accuracy_score(merge_df.label.values, weighted_prob_df.pred_label.values)
print(weighted_prob_accuracy_score)

In [None]:
weighted_prob_confusion_matrix = confusion_matrix(merge_df.label.values, 
                                                  weighted_prob_df.pred_label.values)
weighted_prob_norm_confusion_matrix = confusion_matrix(merge_df.label.values, 
                                                       weighted_prob_df.pred_label.values,
                                                       normalize="true")
plot_confusion_matrix(weighted_prob_confusion_matrix)
plot_confusion_matrix(weighted_prob_norm_confusion_matrix, norm=True)

# Model Inference: Weighted Probabilities + Argmax

In [None]:
def tokenization(dataset, model_name, prefix):
    if model_name in ["bert-base-multilingual-cased"]:
        tokenizer = BertTokenizer.from_pretrained(model_name)
    elif model_name in ["xlm-roberta-base", "xlm-roberta-large"]:
        tokenizer = XLMRobertaTokenizer.from_pretrained(model_name)
    
    sents = list(zip(dataset.premise.values, dataset.hypothesis.values))
    encoded_sents = tokenizer.batch_encode_plus(sents,
                                                add_special_tokens=True,
                                                pad_to_max_length=True,
                                                max_length=MAX_TOKEN_LEN,
                                                truncation=True,
                                                return_attention_mask=True,
                                                return_token_type_ids=True)
    
    dataset[prefix + "_input_ids"] = encoded_sents["input_ids"]
    dataset[prefix + "_token_type_ids"] = encoded_sents["token_type_ids"]
    dataset[prefix + "_attention_mask"] = encoded_sents["attention_mask"]
    return dataset

In [None]:
prod_df = tokenization(prod_df, "bert-base-multilingual-cased", "bert")
prod_df = tokenization(prod_df, "xlm-roberta-base", "xlmbase")
prod_df = tokenization(prod_df, "xlm-roberta-large", "xlmlarge")
prod_df.head(10)

In [None]:
class NLIProdDataset(Dataset):
    def __init__(self, dataset: pd.DataFrame):
        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index: int):
        row_id = self.dataset.id.values[index]
        
        bert_inputs = {
            "input_ids": np.array(self.dataset.bert_input_ids.values[index]),
            "token_type_ids": np.array(self.dataset.bert_token_type_ids.values[index]),
            "attention_mask": np.array(self.dataset.bert_attention_mask.values[index])
        }
                
        xlmbase_inputs = {
            "input_ids": np.array(self.dataset.xlmbase_input_ids.values[index]),
            "token_type_ids": np.array(self.dataset.xlmbase_token_type_ids.values[index]),
            "attention_mask": np.array(self.dataset.xlmbase_attention_mask.values[index])
        }
        
        xlmlarge_inputs = {
            "input_ids": np.array(self.dataset.xlmlarge_input_ids.values[index]),
            "token_type_ids": np.array(self.dataset.xlmlarge_token_type_ids.values[index]),
            "attention_mask": np.array(self.dataset.xlmlarge_attention_mask.values[index])
        }

        return bert_inputs, xlmbase_inputs, xlmlarge_inputs, row_id

In [None]:
class NLIProdModelModule(LightningModule):
    def __init__(self, bert_model, xlmbase_model, xlmlarge_model):
        super().__init__()
        self.bert_model = bert_model
        self.xlmbase_model = xlmbase_model
        self.xlmlarge_model = xlmlarge_model

    @auto_move_data
    def forward(self, bert_inputs, xlmbase_inputs, xlmlarge_inputs):
        bert_predictions = self.bert_model(input_ids=bert_inputs["input_ids"], 
                                           attention_mask=bert_inputs["attention_mask"],
                                           token_type_ids=bert_inputs["token_type_ids"])
        
        xlmbase_predictions = self.xlmbase_model(input_ids=xlmbase_inputs["input_ids"], 
                                                 attention_mask=xlmbase_inputs["attention_mask"],
                                                 token_type_ids=xlmbase_inputs["token_type_ids"])
        
        xlmlarge_predictions = self.xlmlarge_model(input_ids=xlmlarge_inputs["input_ids"], 
                                                   attention_mask=xlmlarge_inputs["attention_mask"],
                                                   token_type_ids=xlmlarge_inputs["token_type_ids"])
        
        return bert_predictions, xlmbase_predictions, xlmlarge_predictions

In [None]:
def load_model(model_dict: Dict[str, str], accelerator: str = DEVICE):
    model_name = model_dict["model_name"]
    checkpoint_path = model_dict["checkpoint_path"]
    
    if accelerator == "gpu":
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
    state_dict = checkpoint["state_dict"]
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[6:]  # remove 'model.' of DataParallel
        new_state_dict[name] = v

    if model_name in ["bert-base-multilingual-cased"]:
        model = BertForSequenceClassification.from_pretrained(model_name, 
                                                              state_dict=new_state_dict,
                                                              num_labels=3)
    elif model_name in ["xlm-roberta-base", "xlm-roberta-large"]:
        model = XLMRobertaForSequenceClassification.from_pretrained(model_name, 
                                                                    state_dict=new_state_dict, 
                                                                    num_labels=3)
    return model

In [None]:
def model_inference(dataset: pd.DataFrame, 
                    model_dict: Dict[str, str],
                    batch_size: int = BATCH_SIZE):

    # Prepare Dataset
    nli_dataset = NLIProdDataset(dataset)
    nli_dataloader = DataLoader(nli_dataset, batch_size=batch_size, shuffle=False)
    
    # Load Fine-tuned Model
    bert_model = load_model(model_dict["bert"])
    xlmbase_model = load_model(model_dict["xlmbase"])
    xlmlarge_model = load_model(model_dict["xlmlarge"])
    
    model_module = NLIProdModelModule(bert_model, xlmbase_model, xlmlarge_model)
    model_module.freeze() # eval
    info_dict = defaultdict(list)
    bert_prob_list = []
    xlmbase_prob_list = []
    xlmlarge_prob_list = []
    
    # Model inference
    for bert_inputs, xlmbase_inputs, xlmlarge_inputs, row_ids in tqdm(nli_dataloader):
        raw_bert_predictions, raw_xlmbase_predictions, raw_xlmlarge_predictions = model_module.forward(bert_inputs, 
                                                                                                       xlmbase_inputs, 
                                                                                                       xlmlarge_inputs)
        bert_predictions = raw_bert_predictions[0].cpu().numpy()
        xlmbase_predictions = raw_xlmbase_predictions[0].cpu().numpy()
        xlmlarge_predictions = raw_xlmlarge_predictions[0].cpu().numpy()
        
        bert_pred_probs = softmax(bert_predictions, axis=1)
        xlmbase_pred_probs = softmax(xlmbase_predictions, axis=1)
        xlmlarge_pred_probs = softmax(xlmlarge_predictions, axis=1)
        
        info_dict["row_id"] += row_ids
        bert_prob_list += bert_pred_probs.tolist()
        xlmbase_prob_list += xlmbase_pred_probs.tolist()
        xlmlarge_prob_list += xlmlarge_pred_probs.tolist()

    info_df = pd.DataFrame.from_dict(info_dict)
    bert_df = pd.DataFrame(bert_prob_list, columns=["label_0_prob", "label_1_prob", "label_2_prob"])
    xlmbase_df = pd.DataFrame(xlmbase_prob_list, columns=["label_0_prob", "label_1_prob", "label_2_prob"])
    xlmlarge_df = pd.DataFrame(xlmlarge_prob_list, columns=["label_0_prob", "label_1_prob", "label_2_prob"])
    return pd.concat([info_df, bert_df], axis=1), pd.concat([info_df, xlmbase_df], axis=1), pd.concat([info_df, xlmlarge_df], axis=1)

In [None]:
def model_ensemble(pred_dfs: List[pd.DataFrame], 
                   model_weighting: np.array):    
    weighted_prob_df = pd.DataFrame()    
    weighted_prob_df["label_0_prob_avg"] = np.average(np.vstack((pred_dfs[0].label_0_prob.values,
                                                                 pred_dfs[1].label_0_prob.values,
                                                                 pred_dfs[2].label_0_prob.values)), 
                                                      weights=model_weighting, 
                                                      axis=0)
    weighted_prob_df["label_1_prob_avg"] = np.average(np.vstack((pred_dfs[0].label_1_prob.values,
                                                                 pred_dfs[1].label_1_prob.values,
                                                                 pred_dfs[2].label_1_prob.values)), 
                                                      weights=model_weighting, 
                                                      axis=0)
    weighted_prob_df["label_2_prob_avg"] = np.average(np.vstack((pred_dfs[0].label_2_prob.values,
                                                                 pred_dfs[1].label_2_prob.values,
                                                                 pred_dfs[2].label_2_prob.values)), 
                                                      weights=model_weighting, 
                                                      axis=0)
    weighted_prob_df["idxmax"] = weighted_prob_df.idxmax(axis=1)
    label_mapping = {"label_0_prob_avg": 0, "label_1_prob_avg": 1, "label_2_prob_avg": 2}
    weighted_prob_df["id"] = pred_dfs[0].row_id.values
    weighted_prob_df["prediction"] = weighted_prob_df["idxmax"].map(label_mapping)
    return weighted_prob_df

In [None]:
model_dict = defaultdict(str)
model_dict["bert"] = defaultdict(str)
model_dict["bert"]["model_name"] = PRETRAINED_BERT_MODEL
model_dict["bert"]["checkpoint_path"] = FINETUNED_BERT_MODEL
model_dict["xlmbase"] = defaultdict(str)
model_dict["xlmbase"]["model_name"] = PRETRAINED_XLMBASE_MODEL
model_dict["xlmbase"]["checkpoint_path"] = FINETUNED_XLMBASE_MODEL
model_dict["xlmlarge"] = defaultdict(str)
model_dict["xlmlarge"]["model_name"] = PRETRAINED_XLMLARGE_MODEL
model_dict["xlmlarge"]["checkpoint_path"] = FINETUNED_XLMLARGE_MODEL

pred_dfs = list(model_inference(prod_df, model_dict))
weighted_prob_df = model_ensemble(pred_dfs, MODEL_WEIGHTING)
weighted_prob_df.head(10)

In [None]:
pred_pd = weighted_prob_df[["id", "prediction"]]
pred_pd.to_csv('submission.csv', index=False)
pred_pd.head(10)