In [5]:
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import json

from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

import torch
from torch.utils.data import Dataset

from transformers import BertModel
from transformers import BertTokenizer
from transformers import BertPreTrainedModel
from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertPooler
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions

from typing import *

df_train = pd.read_json("multinli_1.0\\multinli_1.0_train.jsonl", lines=True)
df_train.head()

Unnamed: 0,annotator_labels,genre,gold_label,pairID,promptID,sentence1,sentence1_binary_parse,sentence1_parse,sentence2,sentence2_binary_parse,sentence2_parse
0,[neutral],government,neutral,31193n,31193,Conceptually cream skimming has two basic dime...,( ( Conceptually ( cream skimming ) ) ( ( has ...,(ROOT (S (NP (JJ Conceptually) (NN cream) (NN ...,Product and geography are what make cream skim...,( ( ( Product and ) geography ) ( ( are ( what...,(ROOT (S (NP (NN Product) (CC and) (NN geograp...
1,[entailment],telephone,entailment,101457e,101457,you know during the season and i guess at at y...,( you ( ( know ( during ( ( ( the season ) and...,(ROOT (S (NP (PRP you)) (VP (VBP know) (PP (IN...,You lose the things to the following level if ...,( You ( ( ( ( lose ( the things ) ) ( to ( the...,(ROOT (S (NP (PRP You)) (VP (VBP lose) (NP (DT...
2,[entailment],fiction,entailment,134793e,134793,One of our number will carry out your instruct...,( ( One ( of ( our number ) ) ) ( ( will ( ( (...,(ROOT (S (NP (NP (CD One)) (PP (IN of) (NP (PR...,A member of my team will execute your orders w...,( ( ( A member ) ( of ( my team ) ) ) ( ( will...,(ROOT (S (NP (NP (DT A) (NN member)) (PP (IN o...
3,[entailment],fiction,entailment,37397e,37397,How do you know? All this is their information...,( ( How ( ( ( do you ) know ) ? ) ) ( ( All th...,(ROOT (S (SBARQ (WHADVP (WRB How)) (SQ (VBP do...,This information belongs to them.,( ( This information ) ( ( belongs ( to them )...,(ROOT (S (NP (DT This) (NN information)) (VP (...
4,[neutral],telephone,neutral,50563n,50563,yeah i tell you what though if you go price so...,( yeah ( i ( ( tell you ) ( what ( ( though ( ...,(ROOT (S (VP (VB yeah) (S (NP (FW i)) (VP (VB ...,The tennis shoes have a range of prices.,( ( The ( tennis shoes ) ) ( ( have ( ( a rang...,(ROOT (S (NP (DT The) (NN tennis) (NNS shoes))...


In [20]:
df = df_train.drop(columns=['genre', 'annotator_labels', 'pairID', 'promptID', 'sentence1_binary_parse', 'sentence1_parse', 'sentence2_parse', 'sentence2_binary_parse'])

df.columns = ['label', 'context1', 'context2']

df = df[:15000]

# df.replace(to_replace='entailment', value=1, regex=True)
# df.replace(to_replace='contradiction', value=2, regex=True)


df.to_csv('MNLI_train.csv')

In [15]:
class MNLI_Dataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
        self.label_map = {'neutral' : 0, 'entailment' : 1, 'contradiction' : 2}

    def __getitem__(self, index):
        df = self.df
        EC = self.tokenizer.encode_plus(df['context1'][index], df['context2'][index])

        input_ids = torch.tensor(EC['input_ids'])
        mask = torch.tensor(EC['attention_mask'])
        token = torch.tensor(EC['token_type_ids'])
        label = self.label_map[df['label'][index]]

        return input_ids, mask, token, label
    
    def __len__(self):
        return len(self.df)


In [18]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

def collate_batch(sample): #sample is List
    input_ids_batch = [s[0] for s in sample]
    mask_batch = [s[1] for s in sample]
    token_batch = [s[2] for s in sample]
    Label_batch = torch.tensor([s[3] for s in sample])


    input_ids_batch = pad_sequence(input_ids_batch, batch_first=True)
    mask_batch = pad_sequence(mask_batch, batch_first=True)
    token_batch = pad_sequence(token_batch, batch_first=True)

    return input_ids_batch, mask_batch, token_batch, Label_batch

BATCH_SIZE = 2

In [19]:
train_dataset = MNLI_Dataset(df)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)

data = next(iter(train_loader))
data

(tensor([[  101,  9498,  1112,  1128,  2824,  1292,  7864,  2365,  9298,  1115,
           1152, 20710,  1193,  3870,  1103,  1269,  8249,  1107,  4883, 23632,
          17368,  1116,  1105,  1103,  3528,  2556, 10335,  1105,  7180,  1107,
           1103, 11417,  1104,  3701,   119,   102,  1220,  3870,  1103,  1269,
           8249,  1107,  4883, 23632, 17368,  1116,  1190,  1152,  1202,  1107,
          18080,   119,   102],
         [  101, 23840,  4198,   117,  1242,  3721,  1115,  4132,  2265,  1106,
           1103,  3722,  1107,  1412,  2025,  2017,  2775,   118,  6842, 18460,
           1104,  2233,   119,   102, 25911,  2365,  3721,  1529,  2775,  6842,
           2233,  3622,   119,   102,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0]]),
 tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 