In [1]:
# Tools 
import os
import time
import shutil
import random
from typing import Tuple
from argparse import Namespace
import matplotlib.pyplot as plt

# Preprocessing
import nltk
from nltk.corpus import stopwords
from nltk import ngrams
from nltk.tokenize import TweetTokenizer
from nltk import FreqDist
import pandas as pd
import numpy as np

# Pytorch 
import torch 
from bitnet import BitLinear # Binary layer 
from torchtext.datasets import PennTreebank
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.nn.functional as F

# Scikit learn
from sklearn.metrics import accuracy_score


In [2]:
# Load dataset
custom_path = './'
train_iter, val_iter, test_iter = PennTreebank(root=custom_path)

In [3]:
class Attention:
    """ One head of self-attention """

    def __init__(self, emb_size, head_size, max_seq_len):
        super().__init__()
        self.key = nn.Linear(emb_size, head_size, bias=False)
        self.query = nn.Linear(emb_size, head_size, bias=False)
        self.value = nn.Linear(emb_size, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(max_seq_len, max_seq_len)))
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        # input (batch, time-step, channels)
        # output (bathc, time-step, head_size)
        B, T, C = x.shape 
        k = self.key(x) # (B, T, hs)
        q = self.query(x) # (B, T, hs)
        # compute scores
        s = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5 # (B, T, hs) @ (hs, T, B) = (B, T, T)
        s = s.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        s = F.softmax(s, dim=-1)
        s = self.dropout(s)
        # agregate 
        v = self.value(x) # (B, T, hs)
        out = s @ v # (B, T, T) @ (B, T, hs) = (B, T, hs)

        return out 



class MultiHeadAttention:
    """ Multi Head Attention block's Transformer """

    def __init__(self, n_heads, head_size, emb_size):
        super().__init__()
        self.heads = nn.ModuleList([Attention() for _ in range(n_heads)])
        self.proy = nn.Linear(n_heads * head_size, emb_size)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        x = self.dropout(self.proy(x))
        return x

class FeedForward:
    """ A feed forward layer: Linear + Relu """

    def __init__(self, emb_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(emb_size, 4 * emb_size),
            nn.ReLU(), 
            nn.Linear(4 * emb_size, emb_size), 
            nn.Dropout(0.2)
        )

    def forward(self, x):
        return self.net(x)

class block(nn.Module):
    """ Transformer block """

    def __init__(self, emb_size, n_heads, head_size):
        super().__init__()
        self.mha = MultiHeadAttention(n_heads, head_size, emb_size)
        self.ln1 = nn.LayerNorm(emb_size)
        self.ln2 = nn.LayerNorm(emb_size)
        self.ff = FeedForward(emb_size)
        
    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.ff(self.ln2(x))

        return x


class Transformer(nn.Module):
    """ Model for generation and classification """

    def __init__(self, args):
        super(Transformer, self).__init__()
        self.args = args
        self.num_layers = args.num_layers
        self.num_heads = args.num_heads
        self.emb = nn.Embedding(args.vocab_size, args.emb_size)
        self.pos = nn.Embedding(args.max_seq_len, self.emb_size)
        self.ln_f = nn.LayerNorm(args.emb_size) # final layer norm 
        self.lm_head = nn.Linear(args.emb_size, args.vocab_size)
        self.layers = nn.Sequential(*[block(args.emb_size, args.n_heads, args.head_size) for _ in range(args.num_layers)])

        # initialize weights 
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx):
        B, T = idx.shape # idx is a chunk of sequences (B, T)
        x = self.emb(idx) + self.pos(torch.arange(T, device=self.args.device))
        x = self.layers(x)
        x = self.ln_f(x)
        logits = self.lm_head(x) 

        return logits 
