In [1]:
DATA_HUB = dict()
DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com/'
DATA_HUB['SNLI'] = ('https://nlp.stanford.edu/projects/snli/snli_1.0.zip', '9fcde07509c7e87ec61c640c1b2753d9041758e4')
DATA_HUB['glove.6b.50d'] = (DATA_URL + 'glove.6B.50d.zip', '0b8703943ccdb6eb788e6f091b8946e82231bc4d')
DATA_HUB['glove.6b.100d'] = (DATA_URL + 'glove.6B.100d.zip', 'cd43bfb07e44e6f27cbcc7bc9ae3d80284fdaf5a')
DATA_HUB['glove.42b.300d'] = (DATA_URL + 'glove.42B.300d.zip', 'b5116e234e9eb9076672cfeabf5469f3eec904fa')
DATA_HUB['bert.small'] = (DATA_URL + 'bert.small.torch.zip', 'c72329e68a732bef0452e4b96a1c341c8910f81f')
DATA_HUB['bert.base'] = (DATA_URL + 'bert.base.torch.zip', '225d66f04cae318b841a13d32af3acc165f253ac')

import numpy as np
import torch
import torchvision
from PIL import Image
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from torchvision import transforms
from pathlib import Path

nn_Module = nn.Module

import json
import multiprocessing
import collections
import hashlib
import inspect
import math
import os
import random
import re
import shutil
import sys
import tarfile
import time
import zipfile
from collections import defaultdict
import pandas as pd
import requests
from IPython import display
from matplotlib import pyplot as plt
import matplotlib.ticker as tkr
from matplotlib.font_manager import FontProperties
from matplotlib_inline import backend_inline

import numpy as np
import torch
import torchvision
from PIL import Image
from scipy.spatial import distance_matrix
from torch import nn
from torch.nn import functional as F
from torchvision import transforms

In [2]:
nlinumpy = lambda x, *args, **kwargs: x.detach().numpy(*args, **kwargs)
nlito = lambda x, *args, **kwargs: x.to(*args, **kwargs)
size = lambda x, *args, **kwargs: x.numel(*args, **kwargs)
argmax = lambda x, *args, **kwargs: x.argmax(*args, **kwargs)
astype = lambda x, *args, **kwargs: x.type(*args, **kwargs)
reduce_sum = lambda x, *args, **kwargs: x.sum(*args, **kwargs)

In [3]:
def get_dataloader_workers():
    """Use 4 processes to read the data."""
    return 4

In [4]:
def download(url, folder='../data', sha1_hash=None):
    """Download a file to folder and return the local filepath."""
    if not url.startswith('http'):
        # For back compatability
        url, sha1_hash = DATA_HUB[url]
    os.makedirs(folder, exist_ok=True)
    fname = os.path.join(folder, url.split('/')[-1])
    # Check if hit cache
    if os.path.exists(fname) and sha1_hash:
        sha1 = hashlib.sha1()
        with open(fname, 'rb') as f:
            while True:
                data = f.read(1048576)
                if not data:
                    break
                sha1.update(data)
        if sha1.hexdigest() == sha1_hash:
            return fname
    r = requests.get(url, stream=True, verify=True)
    with open(fname, 'wb') as f:
        f.write(r.content)
    # Download
    print(f' ---> Downloaded {fname} from {url}.')
    return fname

In [5]:
def extract_snli(fname):
    base_dir = os.path.dirname(fname)
    target_dir = os.path.join(base_dir, 'snli_1.0')
    if os.path.exists(target_dir):
        print(f' ---> snli data already extracted at {target_dir}')
        return
    os.makedirs(target_dir,exist_ok=True)
    with zipfile.ZipFile(fname, 'r') as input:
        for item in input.namelist():
            if '__MACOSX' in item:
                continue
            if '.DS_Store' in item:
                continue
            file_name = Path(item).name
            if 'Icon' in file_name:
                continue
            source = input.open(item)
            target = open(os.path.join(target_dir, file_name), "wb")
            with source, target:
                shutil.copyfileobj(source, target)
    print(f' ---> extracted snli data to {target_dir}')

In [6]:
def download_extract(name, folder=None):
    """Download and extract a zip/tar file."""
    
    fname = download(name)
    base_dir = os.path.dirname(fname)
    data_dir, ext = os.path.splitext(fname)
    if ext == '.zip':
        if 'snli_1.0.zip' in fname:
            extract_snli(fname)
        else:
            fp = zipfile.ZipFile(fname, 'r')
            fp.extractall(base_dir)
    elif ext in ('.tar', '.gz'):
        fp = tarfile.open(fname, 'r')
        fp.extractall(base_dir)
    else:
        assert False, 'Only zip/tar files can be extracted.'
    return os.path.join(base_dir, folder) if folder else data_dir

In [7]:
def read_snli(data_dir, is_train):
    """Read the SNLI dataset into premises, hypotheses, and labels.`"""
    
    def extract_text(s):
        # Remove information that will not be used by us
        s = re.sub('\\(', '', s)
        s = re.sub('\\)', '', s)
        # Substitute two or more consecutive whitespace with space
        s = re.sub('\\s{2,}', ' ', s)
        return s.strip()
    label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
    file_name = os.path.join(data_dir, 'snli_1.0_train.txt'
                             if is_train else 'snli_1.0_test.txt')
    with open(file_name, 'r') as f:
        rows = [row.split('\t') for row in f.readlines()[1:]]
    premises = [extract_text(row[1]) for row in rows if row[0] in label_set]
    hypotheses = [extract_text(row[2]) for row in rows if row[0] in label_set]
    labels = [label_set[row[0]] for row in rows if row[0] in label_set]
    return premises, hypotheses, labels

In [8]:
def tokenize(lines, token='word'):
    """Split text lines into word or character tokens.`"""
    
    assert token in ('word', 'char'), 'Unknown token type: ' + token
    return [line.split() if token == 'word' else list(line) for line in lines]

In [9]:
class Vocab:
    """Vocabulary for text."""
    def __init__(self, tokens=[], min_freq=0, reserved_tokens=[]):
        # Flatten a 2D list if needed
        if tokens and isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]
        # Count token frequencies
        counter = collections.Counter(tokens)
        self.token_freqs = sorted(counter.items(), key=lambda x: x[1],
                                  reverse=True)
        # The list of unique tokens
        self.idx_to_token = list(sorted(set(['<unk>'] + reserved_tokens + [
            token for token, freq in self.token_freqs if freq >= min_freq])))
        self.token_to_idx = {token: idx
                             for idx, token in enumerate(self.idx_to_token)}

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]

    def to_tokens(self, indices):
        if hasattr(indices, '__len__') and len(indices) > 1:
            return [self.idx_to_token[int(index)] for index in indices]
        return self.idx_to_token[indices]

    @property
    def unk(self):  # Index for the unknown token
        return self.token_to_idx['<unk>']

In [10]:
def truncate_pad(line, num_steps, padding_token):
    """Truncate or pad sequences."""
    if len(line) > num_steps:
        return line[:num_steps]  # Truncate
    return line + [padding_token] * (num_steps - len(line))  # Pad

In [11]:
class SNLIDataset(torch.utils.data.Dataset):
    """A customized dataset to load the SNLI dataset."""
    
    def __init__(self, dataset, num_steps, vocab=None, is_train=True):
        self.num_steps = num_steps
        all_premise_tokens = tokenize(dataset[0])
        all_hypothesis_tokens = tokenize(dataset[1])
        if vocab is None:
            self.vocab = Vocab(all_premise_tokens + all_hypothesis_tokens,
                                   min_freq=5, reserved_tokens=['<pad>'])
        else:
            self.vocab = vocab
        self.premises = self._pad(all_premise_tokens)
        self.hypotheses = self._pad(all_hypothesis_tokens)
        self.labels = torch.tensor(dataset[2])
        if is_train:
            print(' ---> read ' + str(len(self.premises)) + ' examples from training dataset')
        else:
            print(' ---> read ' + str(len(self.premises)) + ' examples from test dataset')

    def _pad(self, lines):
        return torch.tensor([truncate_pad(
            self.vocab[line], self.num_steps, self.vocab['<pad>'])
                         for line in lines])

    def __getitem__(self, idx):
        return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]

    def __len__(self):
        return len(self.premises)

In [12]:
def load_data_snli(batch_size, num_steps=50):
    """Download the SNLI dataset and return data iterators and vocabulary."""
    num_workers = get_dataloader_workers()
    data_dir = download_extract('SNLI')
    train_data = read_snli(data_dir, True)
    test_data = read_snli(data_dir, False)
    train_set = SNLIDataset(train_data, num_steps)
    test_set = SNLIDataset(test_data, num_steps, train_set.vocab, False)
    train_iter = torch.utils.data.DataLoader(train_set, batch_size,
                                             shuffle=True,
                                             num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(test_set, batch_size,
                                            shuffle=False,
                                            num_workers=num_workers)
    return train_iter, test_iter, train_set.vocab, data_dir

In [13]:
def cpu():
    """Get the CPU device."""
    return torch.device('cpu')
    
def gpu(i=0, gpu_type=None):
    """Get a GPU device."""
    if gpu_type:
        if gpu_type == "cuda":
            return torch.device(f'cuda:{i}')
        elif gpu_type == "mps":
            return torch.device(f'mps:{i}')
    # might not reach here, but in case
    return cpu()

def num_gpus():
    """Get the number of available GPUs."""
    if torch.cuda.is_available():
        count = torch.cuda.device_count()
        print(f' ---> will run on nividia cuda gpu(s) count {count}')
        return "cuda", count
    elif torch.backends.mps.is_available():
        count = torch.mps.device_count()
        print(f' ---> will run on mps gpu(s) count {count}')
        return "mps", count
    else:
        print(f' ---> No gpus found. will run on device {cpu()}.')
        return None, 0
        
def try_gpu(i=0):
    """Return gpu(i) if exists, otherwise return cpu()."""
    gpu_type, count = num_gpus()
    if gpu_type == None:
        return cpu()
    if count >= i + 1:
        return gpu(i, gpu_type)
    return cpu()
    
def try_all_gpus():
    """Return all available GPUs, or [cpu(),] if no GPU exists."""
    gpu_type, count = num_gpus()
    # no gpus return cpu
    if not gpu_type or count == 0:
        return [cpu(),]
    return [gpu(i, gpu_type) for i in range(count)]


In [14]:
class TokenEmbedding:
    """Token Embedding."""
    def __init__(self, embedding_name):
        self.idx_to_token, self.idx_to_vec = self._load_embedding(
            embedding_name)
        self.unknown_idx = 0
        self.token_to_idx = {token: idx for idx, token in
                             enumerate(self.idx_to_token)}

    def _load_embedding(self, embedding_name):
        idx_to_token, idx_to_vec = ['<unk>'], []
        data_dir = download_extract(embedding_name)
        # GloVe website: https://nlp.stanford.edu/projects/glove/
        # fastText website: https://fasttext.cc/
        with open(os.path.join(data_dir, 'vec.txt'), 'r', encoding='cp437') as f:
            for line in f:
                elems = line.rstrip().split(' ')
                token, elems = elems[0], [float(elem) for elem in elems[1:]]
                # Skip header information, such as the top row in fastText
                if len(elems) > 1:
                    idx_to_token.append(token)
                    idx_to_vec.append(elems)
        idx_to_vec = [[0] * len(idx_to_vec[0])] + idx_to_vec
        return idx_to_token, torch.tensor(idx_to_vec)

    def __getitem__(self, tokens):
        indices = [self.token_to_idx.get(token, self.unknown_idx)
                   for token in tokens]
        vecs = self.idx_to_vec[torch.tensor(indices)]
        return vecs

    def __len__(self):
        return len(self.idx_to_token)

In [15]:
class Timer:
    """Record multiple running times."""
    def __init__(self):
        self.times = []
        self.start()

    def start(self):
        """Start the timer."""
        self.tik = time.time()

    def stop(self):
        """Stop the timer and record the time in a list."""
        self.times.append(time.time() - self.tik)
        return self.times[-1]

    def avg(self):
        """Return the average time."""
        return sum(self.times) / len(self.times)

    def sum(self):
        """Return the sum of time."""
        return sum(self.times)

    def cumsum(self):
        """Return the accumulated time."""
        return np.array(self.times).cumsum().tolist()


In [16]:
def use_svg_display():
    """Use the svg format to display a plot in Jupyter."""
    backend_inline.set_matplotlib_formats('svg')

def set_figsize(figsize=(3.5, 2.5)):
    """Set the figure size for matplotlib."""
    use_svg_display()
    plt.rcParams['figure.figsize'] = figsize

def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
    """Set the axes for matplotlib."""
    axes.set_xlabel(xlabel), axes.set_ylabel(ylabel)
    axes.set_xscale(xscale), axes.set_yscale(yscale)
    axes.set_xlim(xlim),     axes.set_ylim(ylim)
    if legend:
        axes.legend(legend)
    axes.grid()


In [17]:
class Animator:
    """For plotting data in animation."""
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5)):
        # Incrementally plot multiple lines
        if legend is None:
            legend = []
        use_svg_display()
        self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        # Use a lambda function to capture arguments
        self.config_axes = lambda: set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts

    def add(self, x, y):
        # Add multiple data points into the figure
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)

In [18]:
class Accumulator:
    """For accumulating sums over `n` variables."""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


In [19]:
def accuracy(y_hat, y):
    """Compute the number of correct predictions."""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = argmax(y_hat, axis=1)
    cmp = astype(y_hat, y.dtype) == y
    return float(reduce_sum(astype(cmp, y.dtype)))

In [20]:
def train_in_batch(net, X, y, loss, trainer, devices):
    """Train for a minibatch with multiple GPUs."""
    if isinstance(X, list):
        # Required for BERT fine-tuning (to be covered later)
        X = [x.to(devices[0]) for x in X]
    else:
        X = X.to(devices[0])
    y = y.to(devices[0])
    net.train()
    trainer.zero_grad()
    pred = net(X)
    l = loss(pred, y)
    l.sum().backward()
    trainer.step()
    train_loss_sum = l.sum()
    train_acc_sum = accuracy(pred, y)
    return train_loss_sum, train_acc_sum

In [21]:
def evaluate_accuracy_gpu_for_snli(net, data_iter, device):
    """Compute the accuracy for a model on a dataset using a GPU."""
    if isinstance(net, nn.Module):
        net.eval()  # Set the model to evaluation mode
        if not device:
            device = next(iter(net.parameters())).device
    # No. of correct predictions, no. of predictions
    metric = Accumulator(2)

    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # Required for BERT Fine-tuning (to be covered later)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)           
            metric.add(accuracy(net(X), y), size(y))
            
    # test_accuracy, no. of predictions, No. of correct predictions, no. of wrong predictions
    return (metric[0] / metric[1], metric[1], metric[0], metric[1] - metric[0]) 

In [22]:
def train_and_evaluate(net, train_iter, test_iter, loss, trainer, num_epochs, devices):
    """Train a model with multiple GPUs."""
    
    print(" ---> started training model ... ")
    timer, num_batches = Timer(), len(train_iter)
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                            legend=['train loss', 'train accuracy', 'test acccuracy'])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        # Sum of training loss, sum of training accuracy, no. of examples,
        # no. of predictions
        metric = Accumulator(4)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = train_in_batch(
                net, features, labels, loss, trainer, devices)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[2], metric[1] / metric[3],
                              None))
        test_acc, test_num_of_examples, test_passed, test_failed = evaluate_accuracy_gpu_for_snli(net, test_iter, devices[0])
        animator.add(epoch + 1, (None, None, test_acc))
        
    test_loss = 1 - test_acc
    training, test, results = dict(), dict(), dict()
    training['total_examples'], test['total_examples'] = metric[2] , test_num_of_examples
    training['passed'], test['passed'] = metric[1], test_passed
    training['failed'], test['failed'] = metric[2] - metric[1], test_failed
    training['accuracy'], test['accuracy'] = int((metric[1] * 100) / metric[3]), int(test_acc * 100)
    training['loss'], test['loss'] = f'{metric[0] / metric[2]:.3f}', f'{test_loss:.3f}'
    results['training'] = training
    results['test'] = test
    return results
    

In [23]:
class MaskLM(nn.Module):
    """The masked language model task of BERT.`"""
    def __init__(self, vocab_size, num_hiddens, **kwargs):
        super(MaskLM, self).__init__(**kwargs)
        self.mlp = nn.Sequential(nn.LazyLinear(num_hiddens),
                                 nn.ReLU(),
                                 nn.LayerNorm(num_hiddens),
                                 nn.LazyLinear(vocab_size))

    def forward(self, X, pred_positions):
        num_pred_positions = pred_positions.shape[1]
        pred_positions = pred_positions.reshape(-1)
        batch_size = X.shape[0]
        batch_idx = torch.arange(0, batch_size)
        # Suppose that `batch_size` = 2, `num_pred_positions` = 3, then
        # `batch_idx` is `torch.tensor([0, 0, 0, 1, 1, 1])`
        batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)
        masked_X = X[batch_idx, pred_positions]
        masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
        mlm_Y_hat = self.mlp(masked_X)
        return mlm_Y_hat

In [24]:
def masked_softmax(X, valid_lens):
    """Perform softmax operation by masking elements on the last axis."""
    # X: 3D tensor, valid_lens: 1D or 2D tensor
    def _sequence_mask(X, valid_len, value=0):
        maxlen = X.size(1)
        mask = torch.arange((maxlen), dtype=torch.float32,
                            device=X.device)[None, :] < valid_len[:, None]
        X[~mask] = value
        return X

    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)


In [25]:
class DotProductAttention(nn.Module):
    """Scaled dot product attention."""
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    # Shape of queries: (batch_size, no. of queries, d)
    # Shape of keys: (batch_size, no. of key-value pairs, d)
    # Shape of values: (batch_size, no. of key-value pairs, value dimension)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # Swap the last two dimensions of keys with keys.transpose(1, 2)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

In [26]:
class NLIHyperParameters:
    """The base class of hyperparameters."""
    def save_hyperparameters(self, ignore=[]):
        raise NotImplemented

    def save_hyperparameters(self, ignore=[]):
        """Save function arguments into class attributes.`"""
        frame = inspect.currentframe().f_back
        _, _, _, local_vars = inspect.getargvalues(frame)
        self.hparams = {k:v for k, v in local_vars.items()
                        if k not in set(ignore+['self']) and not k.startswith('_')}
        for k, v in self.hparams.items():
            setattr(self, k, v)

In [27]:
class ProgressBoard(NLIHyperParameters):
    """The board that plots data points in animation"""
    def __init__(self, xlabel=None, ylabel=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
                 fig=None, axes=None, figsize=(3.5, 2.5), display=True):
        self.save_hyperparameters()

    def draw(self, x, y, label, every_n=1):
        raise NotImplemented

    def draw(self, x, y, label, every_n=1):
        Point = collections.namedtuple('Point', ['x', 'y'])
        if not hasattr(self, 'raw_points'):
            self.raw_points = collections.OrderedDict()
            self.data = collections.OrderedDict()
        if label not in self.raw_points:
            self.raw_points[label] = []
            self.data[label] = []
        points = self.raw_points[label]
        line = self.data[label]
        points.append(Point(x, y))
        if len(points) != every_n:
            return
        mean = lambda x: sum(x) / len(x)
        line.append(Point(mean([p.x for p in points]),
                          mean([p.y for p in points])))
        points.clear()
        if not self.display:
            return
        use_svg_display()
        if self.fig is None:
            self.fig = plt.figure(figsize=self.figsize)
        plt_lines, labels = [], []
        for (k, v), ls, color in zip(self.data.items(), self.ls, self.colors):
            plt_lines.append(plt.plot([p.x for p in v], [p.y for p in v],
                                          linestyle=ls, color=color)[0])
            labels.append(k)
        axes = self.axes if self.axes else plt.gca()
        if self.xlim: axes.set_xlim(self.xlim)
        if self.ylim: axes.set_ylim(self.ylim)
        if not self.xlabel: self.xlabel = self.x
        axes.set_xlabel(self.xlabel)
        axes.set_ylabel(self.ylabel)
        axes.set_xscale(self.xscale)
        axes.set_yscale(self.yscale)
        axes.legend(plt_lines, labels)
        display.display(self.fig)
        display.clear_output(wait=True)

In [28]:
class NLIModule(nn.Module, NLIHyperParameters):
    """The base class of models"""
    def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
        super().__init__()
        self.save_hyperparameters()
        self.board = ProgressBoard()

    def loss(self, y_hat, y):
        raise NotImplementedError

    def forward(self, X):
        assert hasattr(self, 'net'), 'Neural network is defined'
        return self.net(X)

    def plot(self, key, value, train):
        """Plot a point in animation."""
        assert hasattr(self, 'trainer'), 'Trainer is not inited'
        self.board.xlabel = 'epoch'
        if train:
            x = self.trainer.train_batch_idx / \
                self.trainer.num_train_batches
            n = self.trainer.num_train_batches / \
                self.plot_train_per_epoch
        else:
            x = self.trainer.epoch + 1
            n = self.trainer.num_val_batches / \
                self.plot_valid_per_epoch
        self.board.draw(x, nlinumpy(nlito(value, cpu())),
                        ('train_' if train else 'val_') + key,
                        every_n=int(n))

    def training_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=True)
        return l

    def validation_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=False)

    def configure_optimizers(self):
        raise NotImplementedError

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=self.lr)

    def apply_init(self, inputs, init=None):
        self.forward(*inputs)
        if init is not None:
            self.net.apply(init)

In [29]:
class MultiHeadAttention(NLIModule):
    """Multi-head attention."""
    def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_o = nn.LazyLinear(num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # Shape of queries, keys, or values:
        # (batch_size, no. of queries or key-value pairs, num_hiddens)
        # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
        # After transposing, shape of output queries, keys, or values:
        # (batch_size * num_heads, no. of queries or key-value pairs,
        # num_hiddens / num_heads)
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))

        if valid_lens is not None:
            # On axis 0, copy the first item (scalar or vector) for num_heads
            # times, then copy the next item, and so on
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # Shape of output: (batch_size * num_heads, no. of queries,
        # num_hiddens / num_heads)
        output = self.attention(queries, keys, values, valid_lens)
        # Shape of output_concat: (batch_size, no. of queries, num_hiddens)
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)
        
    def transpose_qkv(self, X):
        """Transposition for parallel computation of multiple attention heads."""
        # Shape of input X: (batch_size, no. of queries or key-value pairs,
        # num_hiddens). Shape of output X: (batch_size, no. of queries or
        # key-value pairs, num_heads, num_hiddens / num_heads)
        X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
        # Shape of output X: (batch_size, num_heads, no. of queries or key-value
        # pairs, num_hiddens / num_heads)
        X = X.permute(0, 2, 1, 3)
        # Shape of output: (batch_size * num_heads, no. of queries or key-value
        # pairs, num_hiddens / num_heads)
        return X.reshape(-1, X.shape[2], X.shape[3])

    def transpose_output(self, X):
        """Reverse the operation of transpose_qkv."""
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        X = X.permute(0, 2, 1, 3)
        return X.reshape(X.shape[0], X.shape[1], -1)


In [30]:
class AddNorm(nn.Module):
    """The residual connection followed by layer normalization."""
    def __init__(self, norm_shape, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(norm_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

In [31]:
class PositionWiseFFN(nn.Module):
    """The positionwise feed-forward network.`"""
    def __init__(self, ffn_num_hiddens, ffn_num_outputs):
        super().__init__()
        self.dense1 = nn.LazyLinear(ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.LazyLinear(ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

In [32]:
class TransformerEncoderBlock(nn.Module):
    """The Transformer encoder block."""
    def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout,
                 use_bias=False):
        super().__init__()
        self.attention = MultiHeadAttention(num_hiddens, num_heads,
                                                dropout, use_bias)
        self.addnorm1 = AddNorm(num_hiddens, dropout)
        self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(num_hiddens, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))


In [33]:
def get_tokens_and_segments(tokens_a, tokens_b=None):
    """Get tokens of the BERT input sequence and their segment IDs."""
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    # 0 and 1 are marking segment A and B, respectively
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

In [34]:
class BERTEncoder(nn.Module):
    """BERT encoder."""
    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
                 num_blks, dropout, max_len=1000, **kwargs):
        super(BERTEncoder, self).__init__(**kwargs)
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        self.segment_embedding = nn.Embedding(2, num_hiddens)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module(f'{i}', TransformerEncoderBlock(
                num_hiddens, ffn_num_hiddens, num_heads, dropout, True))
        # In BERT, positional embeddings are learnable, thus we create a
        # parameter of positional embeddings that are long enough
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len,
                                                      num_hiddens))

    def forward(self, tokens, segments, valid_lens):
        # Shape of `X` remains unchanged in the following code snippet:
        # (batch size, max sequence length, `num_hiddens`)
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding[:, :X.shape[1], :]
        for blk in self.blks:
            X = blk(X, valid_lens)
        return X

In [35]:
class NextSentencePred(nn.Module):
    """The next sentence prediction task of BERT."""
    def __init__(self, **kwargs):
        super(NextSentencePred, self).__init__(**kwargs)
        self.output = nn.LazyLinear(2)

    def forward(self, X):
        # `X` shape: (batch size, `num_hiddens`)
        return self.output(X)

In [36]:
class BERTModel(nn.Module):
    """The BERT model."""
    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens,
                 num_heads, num_blks, dropout, max_len=1000):
        super(BERTModel, self).__init__()
        self.encoder = BERTEncoder(vocab_size, num_hiddens, ffn_num_hiddens,
                                   num_heads, num_blks, dropout,
                                   max_len=max_len)
        self.hidden = nn.Sequential(nn.LazyLinear(num_hiddens),
                                    nn.Tanh())
        self.mlm = MaskLM(vocab_size, num_hiddens)
        self.nsp = NextSentencePred()

    def forward(self, tokens, segments, valid_lens=None, pred_positions=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
        # The hidden layer of the MLP classifier for next sentence prediction.
        # 0 is the index of the '<cls>' token
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat


In [37]:
class SNLIBERTDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, max_len, vocab=None, is_train=True):
        all_premise_hypothesis_tokens = [[p_tokens, h_tokens] for p_tokens, h_tokens in zip(*[tokenize([s.lower() for s in sentences]) for sentences in dataset[:2]])]
        self.labels = torch.tensor(dataset[2])
        self.vocab = vocab
        self.max_len = max_len
        (self.all_token_ids, self.all_segments,
        self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
        if is_train:
            print(' ---> read ' + str(len(self.all_token_ids)) + ' examples from training dataset.')
        else:
            print(' ---> read ' + str(len(self.all_token_ids)) + ' examples from test dataset.')
    def _preprocess(self, all_premise_hypothesis_tokens):
        pool = multiprocessing.Pool(4) # Use 4 worker processes
        out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
        all_token_ids = [token_ids for token_ids, segments, valid_len in out]
        all_segments = [segments for token_ids, segments, valid_len in out]
        valid_lens = [valid_len for token_ids, segments, valid_len in out]
        return (torch.tensor(all_token_ids, dtype=torch.long), torch.tensor(all_segments, dtype=torch.long), torch.tensor(valid_lens))
    def _mp_worker(self, premise_hypothesis_tokens):
        p_tokens, h_tokens = premise_hypothesis_tokens
        self._truncate_pair_of_tokens(p_tokens, h_tokens)
        tokens, segments = get_tokens_and_segments(p_tokens, h_tokens)
        token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
        * (self.max_len - len(tokens))
        segments = segments + [0] * (self.max_len - len(segments))
        valid_len = len(tokens)
        return token_ids, segments, valid_len
    def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
        # Reserve slots for '<CLS>', '<SEP>', and '<SEP>' tokens for the BERT
        # input
        while len(p_tokens) + len(h_tokens) > self.max_len - 3:
            if len(p_tokens) > len(h_tokens):
                p_tokens.pop()
            else:
                h_tokens.pop()
    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx], self.valid_lens[idx]), self.labels[idx]
    def __len__(self):
        return len(self.all_token_ids)


In [38]:
class BERTClassifier(nn.Module):
    def __init__(self, bert):
        super(BERTClassifier, self).__init__()
        self.encoder = bert.encoder
        self.hidden = bert.hidden
        self.output = nn.LazyLinear(3)
    def forward(self, inputs):
        tokens_X, segments_X, valid_lens_x = inputs
        encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
        return self.output(self.hidden(encoded_X[:, 0, :]))

In [39]:
def predict_bert(net, vocab, premise, hypothesis):
    """Predict the logical relationship between the premise and hypothesis."""
    # Set the model to evaluation mode
    net.eval()
    device = next(iter(net.parameters())).device
    premise_tokens =  premise # [elem.split('.')[0] if '.' in elem else elem for elem in premise.rstrip().split(' ')] + ['.']
    hypothesis_tokens =  hypothesis # [elem.split('.')[0] if '.' in elem else elem for elem in hypothesis.rstrip().split(' ')] + ['.']
    tokens, segments = get_tokens_and_segments(premise_tokens, hypothesis_tokens)
    token_ids = torch.tensor(vocab[tokens], device=device).unsqueeze(0)
    segments = torch.tensor(segments, device=device).unsqueeze(0)
    valid_len = torch.tensor(len(tokens), device=device).unsqueeze(0)
    label = torch.argmax(net((token_ids,segments,valid_len)), dim=1)
    return 'entailment' if label == 0 else 'contradiction' if label == 1 else 'neutral'

In [40]:
def format_data(data, mode=None):
    if mode:
        _mode = data[mode]
        return [_mode['total_examples'], _mode['passed'], _mode['failed'], _mode['accuracy'], _mode['loss']]
    else:
        training = data['training']
        test = data['test']
        return [ ["Training", training['total_examples'], training['passed'], training['failed'], f"{training['accuracy']}%", training['loss'] ],
                 ["Test", test['total_examples'], test['passed'], test['failed'], f"{test['accuracy']}%", test['loss']] ]

In [41]:
def format_kb(x, pos):
    """Formats the y-axis values to K."""
    return f'{x // 1000} K'

In [42]:
def display_results(data, title="notitle"):
    tabular_data = format_data(data)
    training_data = format_data(data, 'training')
    test_data = format_data(data, 'test')
    columns = ["Mode", "Total examples", "Passed", "Failed", "Accuracy", "Loss"]
    modes = ["Training", "Test"]
    subgroups = ["Passed", "Failed"]
    colors = [ ["#E8E190","w","w","w","w", "w"], ["#E8E190","w","w","w","w", "w"]]
    df = pd.DataFrame(tabular_data, columns=columns)

    fig = plt.figure(figsize=(10,5))

    for i in range(1):
        ax = plt.subplot2grid((2,2), (0,2*i), colspan=2)
        # Hide axes
        fig.patch.set_visible(False)
        ax.axis('off')
        ax.axis('tight')
        table = ax.table(cellText=df.values, colLabels=df.columns, loc='center', cellLoc='center', cellColours=colors)
        table.auto_set_font_size(False)
        for (row, col), cell in table.get_celld().items():
            if (row == 0) or (col == -1):
                cell.set_text_props(fontproperties=FontProperties(weight='bold'))
                cell.set_facecolor('#CCD9C7')
            elif (row != 0) and (col == 0):
                cell.set_text_props(fontproperties=FontProperties(weight='bold'))
            else:
                cell.set_fontsize(10)
                
        ax.set_title(title)
    
    row2 = 5
    # Define bar width and positions
    bar_width = .20
    x_positions = np.arange(1)
    normalized_training_data = [training_data[1], training_data[2]]
    normalized_test_data = [test_data[1], test_data[2]]
    for j in range(row2-1):
        ax = plt.subplot2grid((2,2*row2), (1,2*j+1), colspan=2)
        if j == 0:
            for i, subgroup in enumerate(subgroups):
                ax.bar(x_positions + i * bar_width, normalized_training_data[i],bar_width, label=subgroup)
            ax.legend(fontsize="x-small")
            ax.set_xticks(x_positions + bar_width / 2)
            ax.set_xticklabels(["Training"])
            ax.yaxis.set_major_formatter(tkr.FuncFormatter(format_kb))
        elif j == 1:
            for i, subgroup in enumerate(subgroups):
                ax.bar(x_positions + i * bar_width, normalized_test_data[i],bar_width, label=subgroup)
            # Customize the plot
            ax.legend(fontsize="x-small")
            ax.set_xticks(x_positions + bar_width / 2)
            ax.set_xticklabels(["Test"])
            ax.yaxis.set_major_formatter(tkr.FuncFormatter(format_kb))
        elif j == 2:
            ax.pie([training_data[1], training_data[2]], labels=["passed", "failed"], autopct='%1.1f%%', startangle=90)
            ax.set_title('Training')
        else:
            ax.pie([test_data[1], test_data[2]], labels=["passed", "failed"], autopct='%1.1f%%', startangle=90)
            ax.set_title('Test') 
        
    # Adjust layout to prevent overlapping titles
    plt.tight_layout()
    plt.savefig(f'images/{title}_results.png')
    plt.show()
    plt.close()


In [43]:
def compare_results(glove, bert_base, bert_small, mode="test"):
    title = f'Comparsion of {mode} results accross GloVe, BERT.Base, Bert.Small embeddings'
    
    glove_data = format_data(glove, mode)
    bert_base_data = format_data(bert_base, mode)
    bert_small_data = format_data(bert_small, mode)

    columns = ["Embeddings Type", "Total examples", "Passed", "Failed", "Accuracy", "Loss"]
    embeedings = ["GloVe", "BERT.Base", "BERT.Small"]
    labels = ["Passed", "Failed"]
    ccolors = [ ["#E8E190","w","w","w","w", "w"], ["#E8E190","w","w","w","w", "w"], ["#E8E190","w","w","w","w", "w"]]
    row1, row2 = 1,4
    bar_width = .20
    x_positions = np.arange(2)
    
    fig = plt.figure(figsize=(10,5))

    for i in range(row1):
        ax = plt.subplot2grid((2,2*row1), (0,2*i), colspan=2)
        df = pd.DataFrame([["GloVe"] + glove_data, ["BERT.Base"] + bert_base_data, ["BERT.Small"] + bert_small_data], columns=columns)
        fig.patch.set_visible(False)
        ax.axis('off')
        ax.axis('tight')
        table = ax.table(cellText=df.values, colLabels=df.columns, loc='center', cellLoc='center', cellColours=ccolors)
        table.auto_set_font_size(False)
        ax.set_title(title)
        for (row, col), cell in table.get_celld().items():
            if (row == 0) or (col == -1):
                cell.set_text_props(fontproperties=FontProperties(weight='bold'))
                cell.set_facecolor('#CCD9C7')
            elif (row != 0) and (col == 0):
                cell.set_text_props(fontproperties=FontProperties(weight='bold'))
            else:
                cell.set_fontsize(10)
    
    for i in range(row2-1):
        ax = plt.subplot2grid((2,2*row2), (1,2*i+1), colspan=2)
        if i == 0:
            data = [ [glove_data[1], glove_data[2]] , [ bert_base_data[1], bert_base_data[2] ], [bert_small_data[1], bert_small_data[2]]]

            for j, subgroup in enumerate(embeedings):
                ax.bar(x_positions + j * bar_width, data[j], bar_width, label=subgroup)
            ax.legend(fontsize="x-small")
            ax.set_xticks(x_positions + bar_width / 2)
            ax.set_xticklabels(labels)
            ax.yaxis.set_major_formatter(tkr.FuncFormatter(format_kb))
        elif i == 1: # pie passed
            ax.pie([glove_data[1], bert_base_data[1], bert_small_data[1]], labels=embeedings, autopct='%1.1f%%', startangle=90)
            ax.set_title('Passed')
        else: # pie failed
            ax.pie([glove_data[2], bert_base_data[2], bert_small_data[2]], labels=embeedings, autopct='%1.1f%%', startangle=90)
            ax.set_title('Failed')
            
    
    plt.tight_layout()
    plt.savefig(f'images/{title}.png')
    plt.show()
    plt.close()

In [44]:
def predict_glove(net, vocab, premise, hypothesis):
    """Predict the logical relationship between the premise and hypothesis."""
    net.eval()
    device = next(iter(net.parameters())).device
    premise = torch.tensor(vocab[premise], device=device)
    hypothesis = torch.tensor(vocab[hypothesis], device=device)
    label = torch.argmax(net([premise.reshape((1, -1)), hypothesis.reshape((1, -1))]), dim=1)
    return 'entailment' if label == 0 else 'contradiction' if label == 1 else 'neutral'

In [45]:
def unit_testing(glove_net, glove_vocab, bert_small_net, bert_small_vocab, bert_base_net, bert_base_vocab):
    unit_test_data = [
        {
            'Premise:': 'A man is running the coding example.',
            'Hypothesis:': 'The man is sleeping.',
            'Expected Result': 'contradiction',
            'Glove Emeddings Actual Result': predict_glove(glove_net, glove_vocab, ['A', 'man', 'is', 'running', 'code', 'example', '.'], ['The', 'man', 'is', 'sleeping', '.']),
            'Bert.Small Embedings Actual Result': predict_bert(bert_small_net, bert_small_vocab, ['A', 'man', 'is', 'running', 'code', 'example', '.'], ['The', 'man', 'is', 'sleeping', '.']),
            'Bert.Base Embedings Actual Result': predict_bert(bert_base_net, bert_base_vocab, ['A', 'man', 'is', 'running', 'code', 'example', '.'], ['The', 'man', 'is', 'sleeping', '.'])
        },
        {
            'Premise:': 'I do need sleep.',
            'Hypothesis:': 'I am tired',
            'Expected Result': 'entilement',
            'Glove Emeddings Actual Result': predict_glove(glove_net, glove_vocab,  ['I', 'do', 'need', 'sleep', '.'], ['I', 'iam', 'tired', '.']),
            'Bert.Small Embedings Actual Result': predict_bert(bert_small_net, bert_small_vocab,  ['I', 'do', 'need', 'sleep', '.'], ['I', 'iam', 'tired', '.']),
            'Bert.Base Embedings Actual Result': predict_bert(bert_base_net, bert_base_vocab,  ['I', 'do', 'need', 'sleep', '.'], ['I', 'iam', 'tired', '.'])
        },
        {
            'Premise:': 'The musicians are performing for us.',
            'Hypothesis:': 'The musicians are famous.',
            'Expected Result': 'neutral',
            'Glove Emeddings Actual Result': predict_glove(glove_net, glove_vocab, ['The', 'musicians', 'are', 'performing', 'for', 'us', '.'], ['The', 'musicians', 'are', 'famous', '.']),
            'Bert.Small Embedings Actual Result': predict_bert(bert_small_net, bert_small_vocab, ['The', 'musicians', 'are', 'performing', 'for', 'us', '.'], ['The', 'musicians', 'are', 'famous', '.']),
            'Bert.Base Embedings Actual Result': predict_bert(bert_base_net, bert_base_vocab, ['The', 'musicians', 'are', 'performing', 'for', 'us', '.'], ['The', 'musicians', 'are', 'famous', '.'])
        }
    ]
    formatted_unit_test_data = json.dumps(unit_test_data, indent=4)
    print("unit testing results")
    print(formatted_unit_test_data)