In [1]:
import torch
import numpy as np
from transformers import BertTokenizer
import pandas as pd

MODEL_NAME = "bert-base-cased"
# MODEL_NAME = "bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df_raw = pd.read_csv('https://raw.githubusercontent.com/sudoghut/contradictory-my-dear-watson/main/data/train_10_en.csv')
# df_raw = pd.read_csv('https://raw.githubusercontent.com/sudoghut/contradictory-my-dear-watson/main/data/train.csv')
df_raw = df_raw[:10]
# print(df_raw.head(5))
print(df_raw["label"].value_counts())
df_raw["label"].to_numpy()
df_raw.info()


0    5
2    3
1    2
Name: label, dtype: int64
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10 entries, 0 to 9
Data columns (total 6 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   id          10 non-null     object
 1   premise     10 non-null     object
 2   hypothesis  10 non-null     object
 3   lang_abv    10 non-null     object
 4   language    10 non-null     object
 5   label       10 non-null     int64 
dtypes: int64(1), object(5)
memory usage: 608.0+ bytes


In [3]:
# labels = {'Entailment':0,
#           'Neutral':1,
#           'Contradiction':2,
#           }

class Dataset(torch.utils.data.Dataset):

    def __init__(self, df):

        self.labels = df['label']
        self.texts = [tokenizer(text, 
                               padding='max_length', max_length = 512, truncation=True,
                                return_tensors="pt") for text in df['premise']]
        # self.premise = [tokenizer(text, 
        #                        padding='max_length', max_length = 512, truncation=True,
        #                         return_tensors="pt") for text in df['premise']]
        # self.hypothesis = [tokenizer(text, 
        #                        padding='max_length', max_length = 512, truncation=True,
        #                         return_tensors="pt") for text in df['hypothesis']]
        # self.texts = self.premise + self.hypothesis

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return np.array(self.labels[idx])

    def get_batch_texts(self, idx):
        # Fetch a batch of inputs
        return self.texts[idx]

    def __getitem__(self, idx):

        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)

        # print("hereit")
        # # print(idx)
        # print(self.texts)
        # print("hereend")
        return batch_texts, batch_y

In [4]:
np.random.seed(112)
# df_train, df_val, df_test = np.split(df_raw.sample(frac=1, random_state=42), 
#                                      [int(.8*len(df_raw)), int(.9*len(df_raw))])
df_train, df_val, df_test = np.split(df_raw.sample(frac=1, random_state=42), 
                                     [10, 10])
df_val = df_train
df_test = df_train
print(len(df_train),len(df_val), len(df_test))

10 10 10


In [5]:
from torch import nn
from transformers import BertModel

class BertClassifier(nn.Module):

    def __init__(self, dropout=0.5):

        super(BertClassifier, self).__init__()

        self.bert = BertModel.from_pretrained(MODEL_NAME)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 5)
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):

        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)

        return final_layer

In [6]:
from torch.optim import Adam
from tqdm import tqdm

def train(model, train_data, val_data, learning_rate, epochs):
    print("gethere1")
    train, val = Dataset(train_data), Dataset(val_data)
    print("gethere2")
    train_dataloader = torch.utils.data.DataLoader(train, batch_size=1, shuffle=True)
    # train_dataloader = torch.utils.data.DataLoader(train, batch_size=2, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val, batch_size=2)
    print("gethere3")
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print("gethere4")
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr= learning_rate)
    print("gethere5")
    if use_cuda:

            model = model.cuda()
            criterion = criterion.cuda()

    for epoch_num in range(epochs):
            print("gethere6")
            total_acc_train = 0
            total_loss_train = 0
 
            for train_input, train_label in tqdm(train_dataloader):
                print("gethere7")
                train_label = train_label.to(device)
                mask = train_input['attention_mask'].to(device)
                input_id = train_input['input_ids'].squeeze(1).to(device)
                print("gethere8")
                output = model(input_id, mask)
                print("gethere9")
                batch_loss = criterion(output, train_label.long())
                total_loss_train += batch_loss.item()
                print("gethere10")
                acc = (output.argmax(dim=1) == train_label).sum().item()
                total_acc_train += acc
                print("gethere11")
                model.zero_grad()
                batch_loss.backward()
                optimizer.step()
                print("gethere12")
            print("gethere13")
            total_acc_val = 0
            total_loss_val = 0
            print("gethere14")
            with torch.no_grad():
                print("gethere15")
                for val_input, val_label in val_dataloader:
                    print("gethere16")
                    val_label = val_label.to(device)
                    mask = val_input['attention_mask'].to(device)
                    input_id = val_input['input_ids'].squeeze(1).to(device)

                    output = model(input_id, mask)
                    print("gethere17")
                    batch_loss = criterion(output, val_label.long())
                    total_loss_val += batch_loss.item()
                    print("gethere18")
                    acc = (output.argmax(dim=1) == val_label).sum().item()
                    total_acc_val += acc
            
            print(
                f'Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train_data): .3f} \
                | Train Accuracy: {total_acc_train / len(train_data): .3f} \
                | Val Loss: {total_loss_val / len(val_data): .3f} \
                | Val Accuracy: {total_acc_val / len(val_data): .3f}')
                  
EPOCHS = 5
model = BertClassifier()
LR = 1e-6

train(model, df_train, df_val, LR, EPOCHS)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


gethere1
gethere2
gethere3
gethere4
gethere5
gethere6


  0%|          | 0/10 [00:00<?, ?it/s]

gethere7
gethere8
gethere9
gethere10
gethere11


 10%|█         | 1/10 [00:03<00:32,  3.65s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 20%|██        | 2/10 [00:07<00:29,  3.72s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 30%|███       | 3/10 [00:11<00:26,  3.78s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 40%|████      | 4/10 [00:15<00:22,  3.77s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 50%|█████     | 5/10 [00:18<00:18,  3.75s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 60%|██████    | 6/10 [00:22<00:14,  3.73s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 70%|███████   | 7/10 [00:26<00:11,  3.74s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 80%|████████  | 8/10 [00:29<00:07,  3.73s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 90%|█████████ | 9/10 [00:33<00:03,  3.74s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


100%|██████████| 10/10 [00:37<00:00,  3.74s/it]

gethere12
gethere13
gethere14
gethere15
gethere16





gethere17
gethere18
gethere16
gethere17
gethere18
gethere16
gethere17
gethere18
gethere16
gethere17
gethere18
gethere16
gethere17
gethere18
Epochs: 1 | Train Loss:  1.620                 | Train Accuracy:  0.300                 | Val Loss:  0.840                 | Val Accuracy:  0.300
gethere6


  0%|          | 0/10 [00:00<?, ?it/s]

gethere7
gethere8
gethere9
gethere10
gethere11


 10%|█         | 1/10 [00:03<00:33,  3.67s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 20%|██        | 2/10 [00:07<00:29,  3.70s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 30%|███       | 3/10 [00:11<00:25,  3.69s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 40%|████      | 4/10 [00:14<00:22,  3.68s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 50%|█████     | 5/10 [00:18<00:18,  3.67s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 60%|██████    | 6/10 [00:22<00:14,  3.67s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 70%|███████   | 7/10 [00:25<00:11,  3.67s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 80%|████████  | 8/10 [00:29<00:07,  3.65s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 90%|█████████ | 9/10 [00:33<00:03,  3.66s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


100%|██████████| 10/10 [00:36<00:00,  3.67s/it]

gethere12
gethere13
gethere14
gethere15
gethere16





gethere17
gethere18
gethere16
gethere17
gethere18
gethere16
gethere17
gethere18
gethere16
gethere17
gethere18
gethere16
gethere17
gethere18
Epochs: 2 | Train Loss:  1.968                 | Train Accuracy:  0.200                 | Val Loss:  0.817                 | Val Accuracy:  0.300
gethere6


  0%|          | 0/10 [00:00<?, ?it/s]

gethere7
gethere8
gethere9
gethere10
gethere11


 10%|█         | 1/10 [00:03<00:33,  3.70s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 20%|██        | 2/10 [00:07<00:29,  3.68s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 30%|███       | 3/10 [00:11<00:25,  3.68s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 40%|████      | 4/10 [00:14<00:21,  3.66s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 50%|█████     | 5/10 [00:18<00:18,  3.67s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 60%|██████    | 6/10 [00:21<00:14,  3.66s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 70%|███████   | 7/10 [00:25<00:10,  3.66s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 80%|████████  | 8/10 [00:29<00:07,  3.66s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 90%|█████████ | 9/10 [00:32<00:03,  3.65s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


100%|██████████| 10/10 [00:36<00:00,  3.67s/it]

gethere12
gethere13
gethere14
gethere15
gethere16





gethere17
gethere18
gethere16
gethere17
gethere18
gethere16
gethere17
gethere18
gethere16
gethere17
gethere18
gethere16
gethere17
gethere18
Epochs: 3 | Train Loss:  1.729                 | Train Accuracy:  0.200                 | Val Loss:  0.903                 | Val Accuracy:  0.200
gethere6


  0%|          | 0/10 [00:00<?, ?it/s]

gethere7
gethere8
gethere9
gethere10
gethere11


 10%|█         | 1/10 [00:03<00:33,  3.70s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 20%|██        | 2/10 [00:07<00:29,  3.74s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 30%|███       | 3/10 [00:11<00:26,  3.84s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 40%|████      | 4/10 [00:15<00:24,  4.07s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 50%|█████     | 5/10 [00:20<00:20,  4.20s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 60%|██████    | 6/10 [00:24<00:17,  4.26s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 70%|███████   | 7/10 [00:29<00:13,  4.34s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 80%|████████  | 8/10 [00:33<00:08,  4.37s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 90%|█████████ | 9/10 [00:38<00:04,  4.40s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


100%|██████████| 10/10 [00:42<00:00,  4.25s/it]

gethere12
gethere13
gethere14
gethere15
gethere16





gethere17
gethere18
gethere16
gethere17
gethere18
gethere16
gethere17
gethere18
gethere16
gethere17
gethere18
gethere16
gethere17
gethere18
Epochs: 4 | Train Loss:  1.811                 | Train Accuracy:  0.200                 | Val Loss:  0.734                 | Val Accuracy:  0.200
gethere6


  0%|          | 0/10 [00:00<?, ?it/s]

gethere7
gethere8
gethere9
gethere10
gethere11


 10%|█         | 1/10 [00:04<00:40,  4.52s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 20%|██        | 2/10 [00:08<00:35,  4.48s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 30%|███       | 3/10 [00:13<00:31,  4.49s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 40%|████      | 4/10 [00:17<00:26,  4.48s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 50%|█████     | 5/10 [00:22<00:22,  4.49s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 60%|██████    | 6/10 [00:26<00:17,  4.48s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 70%|███████   | 7/10 [00:31<00:13,  4.46s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 80%|████████  | 8/10 [00:35<00:08,  4.46s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


 90%|█████████ | 9/10 [00:40<00:04,  4.46s/it]

gethere12
gethere7
gethere8
gethere9
gethere10
gethere11


100%|██████████| 10/10 [00:44<00:00,  4.48s/it]

gethere12
gethere13
gethere14
gethere15
gethere16





gethere17
gethere18
gethere16
gethere17
gethere18
gethere16
gethere17
gethere18
gethere16
gethere17
gethere18
gethere16
gethere17
gethere18
Epochs: 5 | Train Loss:  1.655                 | Train Accuracy:  0.100                 | Val Loss:  0.705                 | Val Accuracy:  0.400
