In [None]:
# !pip install pytreebank
# !pip install loguru
# !pip install transformers

In [None]:
"""This module defines a configurable SSTDataset class."""

import pytreebank
import torch
from loguru import logger
import pandas as pd
import seaborn as sn
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import matplotlib.pyplot as plt
from transformers import BertTokenizer
from torch.utils.data import Dataset

logger.info("Loading the tokenizer")
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")

logger.info("Loading SST")
sst = pytreebank.load_sst()


def rpad(array, n=70):
    """Right padding."""
    current_len = len(array)
    if current_len > n:
        return array[: n - 1]
    extra = n - current_len
    return array + ([0] * extra)


def get_binary_label(label):
    """Convert fine-grained label to binary label."""
    if label < 2:
        return 0
    if label > 2:
        return 1
    raise ValueError("Invalid label")


class SSTDataset(Dataset):
    """Configurable SST Dataset.
    
    Things we can configure:
        - split (train / val / test)
        - root / all nodes
        - binary / fine-grained
    """

    def __init__(self, split="train", root=True, binary=True):
        """Initializes the dataset with given configuration.

        Args:
            split: str
                Dataset split, one of [train, val, test]
            root: bool
                If true, only use root nodes. Else, use all nodes.
            binary: bool
                If true, use binary labels. Else, use fine-grained.
        """
        logger.info(f"Loading SST {split} set")
        self.sst = sst[split]

        logger.info("Tokenizing")
        if root and binary:
            self.data = [
                (
                    rpad(
                        tokenizer.encode("[CLS] " + tree.to_lines()[0] + " [SEP]"), n=66
                    ),
                    get_binary_label(tree.label),
                )
                for tree in self.sst
                if tree.label != 2
            ]
        elif root and not binary:
            self.data = [
                (
                    rpad(
                        tokenizer.encode(tree.to_lines()[0]), n=66
                    ),
                    tree.label,
                )
                for tree in self.sst
            ]
        elif not root and not binary:
            self.data = [
                (rpad(tokenizer.encode("[CLS] " + line + " [SEP]"), n=66), label)
                for tree in self.sst
                for label, line in tree.to_labeled_lines()
            ]
        else:
            self.data = [
                (
                    rpad(tokenizer.encode("[CLS] " + line + " [SEP]"), n=66),
                    get_binary_label(label),
                )
                for tree in self.sst
                for label, line in tree.to_labeled_lines()
                if label != 2
            ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        X, y = self.data[index]
        X = torch.tensor(X)
        return X, y


In [None]:
bert = torch.load('../Models/bert-large-uncased__root__fine__e5.pickle', map_location=torch.device('cpu'))

In [None]:
dataset = SSTDataset("test", root=True, binary=False)

remove_iter = 0 

# while remove_iter < 3:
#     for x in dataset.data:
#         if len(x[0]) != 66:
#             dataset.data.remove(x)
#     remove_iter += 1


batch_size = 32
device = torch.device('cpu')
generator = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True
)

bert_pred = torch.LongTensor([])
bert_actual = torch.LongTensor([])

bert.eval()
with torch.no_grad():
    for batch, labels in tqdm(generator):
        batch, labels = batch.to(device), labels.to(device)
        logits = bert(batch)[0]
        batch_pred = torch.argmax(logits, axis=1)
        bert_pred = torch.cat((bert_pred, batch_pred), 0)
        bert_actual = torch.cat((bert_actual, labels), 0)

In [None]:
acc = accuracy_score(bert_pred.numpy(), 
                                     bert_actual.numpy())
acc

In [None]:
data = {'y_Actual':    bert_actual.numpy(),
        'y_Predicted': bert_pred.numpy()
        }

df = pd.DataFrame(data, columns=['y_Actual','y_Predicted'])
confusion_matrix = pd.crosstab(df['y_Actual'], df['y_Predicted'], rownames=['Actual'], colnames=['Predicted'], margins = False)

plt.figure(figsize=(12,8))
plt.title('BERT Large Uncased')
sn.heatmap(confusion_matrix, annot=True, cmap='Blues', fmt = 'g')
plt.show()