## It is often the case that we only want to use part of a input sequence. In the case of a sentence, we would like to "cover" later parts of the sentence for next token prediction, otherwise it will be trivial for the model to grab the prediction token from the input sentence. TODO: expand this on training algorithm that necessitates this, and why use -inf to achieve this, cool comparison trick to make masks; For now straight to the point -- train masked model

In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchtext
from torchtext.data.utils import get_tokenizer
from torch.nn import TransformerEncoder, TransformerEncoderLayer


In [3]:
# what are we having from the data??

# first let's do data wrangling

TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"),
                            init_token='<sos>',
                            eos_token='<eos>',
                            lower=True)
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)
TEXT.build_vocab(train_txt)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'
def batchify(data, bsz):
    data = TEXT.numericalize([data.examples[0].text])
    # Divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

batch_size = 3
eval_batch_size = 10
train_data = batchify(train_txt, batch_size)
val_data = batchify(val_txt, eval_batch_size)
test_data = batchify(test_txt, eval_batch_size)

# params and util fn for training
bptt = 5
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

In [4]:
data, targets = get_batch(train_data, 0)

In [6]:
data # bsz 3, bptt 5

tensor([[   3,  158,   27],
        [  12,    9,  123],
        [3852,  296, 4173],
        [3872, 8105,  243],
        [ 884,    4,    6]])

In [9]:
targets.size()

torch.Size([15])

In [1]:
# It can get very confusing why we use a triangular shaped mask
# but it really is very simple once we go to first principles
'''
masking -- the artifect of matrix notation to capture the vector sums
Let's use a seq. of len. 2, after all, attention is about seq2seq
[x1, x2] -> attn -> [y1, y2], which should approximate ground truth [y2, y3]
Let's say y1 is the first to come out, it should only see x1, and y2 can see x1, x2
Then we need to mask out info, let's look closer
y1 ~ dot(k1, q1)_sm * v1 = a*v1
y2 ~ dot(k1, q2)_sm * v1 + dot(k2, q2)_sm * v2 = b*v1+c*v2
where ~ means proportional to, _sm means after softmaxed over all others with _sm
In matrix form, it looks like:
Y = M*V, where
M = 
[[a, 0]
 [b, c]]
Y = [y1, y2]^T
V = [v1, v2]^T
M is row-wise softmaxed, so it means before softmax it looks like
[[a, -inf]
 [b, c]]
Now we see clearly where the mask comes in -- right before the softmax step, 
then it's business as usual
'''


"\nLet's use a seq. of len. 2, after all, attention is about seq2seq\n[x1, x2] -> attn -> [y1, y2], which should approximate ground truth [y2, y3]\nLet's say y1 is the first to come out, it should only see x1, and y2 can see x1, x2\nThen we need to mask out info, let's look closer\ny1 ~ dot(k1, q1)_sm * v1 = a*v1\ny2 ~ dot(k1, q2)_sm * v1 + dot(k2, q2)_sm * v2 = b*v1+c*v2\nwhere ~ means proportional to, _sm means after softmaxed over all others with _sm\nIn matrix form, it looks like:\nY = M*V, where\nM = \n[[a, 0]\n [b, c]]\nY = [y1, y2]^T\nV = [v1, v2]^T\nM is row-wise softmaxed, so it means before softmax it looks like\n[[a, -inf]\n [b, c]]\nNow we see clearly where the mask comes in -- right before the softmax step, then it's business as usual\n"

In [2]:
def attention(Q, K, V, mask=None): 
    ''' Functional implementation for scaled dot product attention formula with optional binary mask'''
    dot_prod = torch.matmul(Q, torch.transpose(K, -2, -1)) #swap last 2 dims, regardless of batch dim
    K_dim = K.size(-1)
    if mask:
        dot_prod[mask] = float('-inf') # proper mask for batch?
    softmax = F.softmax(dot_prod/math.sqrt(K_dim), dim = -1)
    attention = torch.matmul(softmax, V)
    return attention


In [None]:
# we need input batch in correct dim and verify that the output is correct
# let's engineer an example
