In [64]:
import regex
import json
import torch
from torch import nn
import torch.nn.functional as F

import random, math
import tqdm

## Self-Attention
- A weighted average over all the input vectors

$$
y_{i} = \sum_{j} w_{ij} x_{j}
$$

$$
w'_{{i}{j}} = {x_{i}}^Tx_{j}
$$

$$
w_{{i}{j}} = \frac{\text{exp } w'_{{i}{j}}}{\sum_{j} \text{exp }w'_{{i}{j}}}
$$

![self_attention.svg](attachment:self_attention.svg)

In [56]:
transformer_model = nn.Transformer()

In [57]:
msg_src = "this is cool"
msg_tgt = "ça c'est cool"

In [85]:

class SelfAttention(nn.Module):
    def __init__(self, emb, heads=8, mask=False):
        """
        :param emb:
        :param heads:
        :param mask:
        """

        super().__init__()

        self.emb = emb
        self.heads = heads
        self.mask = mask

        self.tokeys = nn.Linear(emb, emb * heads, bias=False)
        self.toqueries = nn.Linear(emb, emb * heads, bias=False)
        self.tovalues = nn.Linear(emb, emb * heads, bias=False)

        self.unifyheads = nn.Linear(heads * emb, emb)

    def forward(self, x):

        b, t, e = x.size()
        h = self.heads
        assert e == self.emb, f'Input embedding dim ({e}) should match layer embedding dim ({self.emb})'

        keys    = self.tokeys(x)   .view(b, t, h, e)
        queries = self.toqueries(x).view(b, t, h, e)
        values  = self.tovalues(x) .view(b, t, h, e)

        # compute scaled dot-product self-attention

        # - fold heads into the batch dimension
        keys = keys.transpose(1, 2).contiguous().view(b * h, t, e)
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, e)
        values = values.transpose(1, 2).contiguous().view(b * h, t, e)

        # - get dot product of queries and keys, and scale
        dot = torch.bmm(queries, keys.transpose(1, 2))
        dot = dot / math.sqrt(e) # dot contains b*h  t-by-t matrices with raw self-attention logits

        assert dot.size() == (b*h, t, t), f'Matrix has size {dot.size()}, expected {(b*h, t, t)}.'

        if self.mask: # mask out the lower half of the dot matrix,including the diagonal
            mask_(dot, maskval=float('-inf'), mask_diagonal=False)

        dot = F.softmax(dot, dim=2) # dot now has row-wise self-attention probabilities

        #assert not util.contains_nan(dot[:, 1:, :]) # only the forst row may contain nan

        if self.mask == 'first':
            dot = dot.clone()
            dot[:, :1, :] = 0.0
            # - The first row of the first attention matrix is entirely masked out, so the softmax operation results
            #   in a division by zero. We set this row to zero by hand to get rid of the NaNs

        # apply the self attention to the values
        out = torch.bmm(dot, values).view(b, h, t, e)

        # swap h, t back, unify heads
        out = out.transpose(1, 2).contiguous().view(b, t, h * e)

        return self.unifyheads(out)

In [86]:
class TransformerBlock(nn.Module):
    def __init__(self, emb, heads, mask, seq_length, ff_hidden_mult=4, dropout=0.0):
        super().__init__()

        self.attention = SelfAttention(emb, heads=heads, mask=mask)
        self.mask = mask

        self.norm1 = nn.LayerNorm(emb)
        self.norm2 = nn.LayerNorm(emb)

        self.ff = nn.Sequential(
            nn.Linear(emb, ff_hidden_mult * emb),
            nn.ReLU(),
            nn.Linear(ff_hidden_mult * emb, emb)
        )

        self.do = nn.Dropout(dropout)

    def forward(self, x):

        attended = self.attention(x)

        x = self.norm1(attended + x)

        x = self.do(x)

        fedforward = self.ff(x)

        x = self.norm2(fedforward + x)

        x = self.do(x)

        return x

In [87]:
class CTransformer(nn.Module):
    def __init__(self, emb, heads, depth, seq_length, num_tokens, num_classes, max_pool=True, dropout=0.0):
        super().__init__()
        self.num_tokens, self.max_pool = num_tokens, max_pool
        self.token_embedding = nn.Embedding(embedding_dim=emb, num_embeddings=num_tokens)
        self.pos_embedding = nn.Embedding(embedding_dim=emb, num_embeddings=seq_length)
        tblocks = []
        for i in range(depth):
            tblocks.append(
                TransformerBlock(emb=emb, heads=heads, seq_length=seq_length, mask=False, dropout=dropout))
        self.tblocks = nn.Sequential(*tblocks)
        self.toprobs = nn.Linear(emb, num_classes)
        self.do = nn.Dropout(dropout)

    def forward(self, x):
        tokens = self.token_embedding(x)
        b, t, e = tokens.size()
        positions = self.pos_embedding(torch.arange(t, device=device))[None, :, :].expand(b, t, e)
        x = tokens + positions
        x = self.do(x)
        x = self.tblocks(x)
        x = x.max(dim=1)[0] if self.max_pool else x.mean(dim=1) # pool over the time dimension
        x = self.toprobs(x)
        return F.log_softmax(x, dim=1)

In [88]:
from torchtext import data, datasets, vocab


In [89]:
vocab_size = 50000
vocab_size

50000

In [90]:
TEXT = data.Field(lower=True, include_lengths=True, batch_first=True)
TEXT

<torchtext.data.field.Field at 0x15feb6128>

In [91]:
LABEL = data.Field(sequential=False)
LABEL

<torchtext.data.field.Field at 0x15feb6b38>

In [92]:
NUM_CLS = 2

In [93]:
train, test = datasets.IMDB.splits(TEXT, LABEL)
TEXT.build_vocab(train, max_size=vocab_size - 2)
LABEL.build_vocab(train)

In [94]:
device = 'cpu'
batch_size = 64
lr = 0.0001
num_epochs = 80

In [95]:
train_iter, test_iter = data.BucketIterator.splits((train, test), 
                                                   batch_size=batch_size, 
                                                   device=device)

In [96]:
mx = max([input.text[0].size(1) for input in train_iter])
mx = mx * 2

In [97]:
mx

4940

In [98]:
embedding_size = 128
num_heads = 8
depth = 4
vocab_size = 50000
max_pool = True
lr_warmup = 10000
gradient_clipping = 1.0

In [99]:
model = CTransformer(emb=embedding_size, 
                     heads=num_heads, 
                     depth=depth, 
                     seq_length=mx, 
                     num_tokens=vocab_size, 
                     num_classes=NUM_CLS, 
                     max_pool=max_pool)

In [100]:
opt = torch.optim.Adam(lr=lr, params=model.parameters())

In [101]:
# training loop
seen = 0
for e in range(num_epochs):

    print(f'\n epoch {e}')
    model.train(True)

    for batch in tqdm.tqdm(train_iter):
        # learning rate warmup
        # - we linearly increase the learning rate from 10e-10 to arg.lr over the first
        #   few thousand batches
        if lr_warmup > 0 and seen < lr_warmup:
            lr = max((lr / lr_warmup) * seen, 1e-10)
            opt.lr = lr

        opt.zero_grad()

        input = batch.text[0]
        label = batch.label - 1

        if input.size(1) > mx:
            input = input[:, :mx]
        out = model(input)
        loss = F.nll_loss(out, label)

        loss.backward()

        # clip gradients
        # - If the total gradient vector has a length > 1, we clip it back down to 1.
        if gradient_clipping > 0.0:
            nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)

        opt.step()

        seen += input.size(0)
        #tbw.add_scalar('classification/train-loss', float(loss.item()), seen)

    with torch.no_grad():

        model.train(False)
        tot, cor= 0.0, 0.0

        for batch in test_iter:

            input = batch.text[0]
            label = batch.label - 1

            if input.size(1) > mx:
                input = input[:, :mx]
            out = model(input).argmax(dim=1)

            tot += float(input.size(0))
            cor += float((label == out).sum().item())

        acc = cor / tot
        print(f'-- {"test" if arg.final else "validation"} accuracy {acc:.3}')
        #tbw.add_scalar('classification/test-loss', float(loss.item()), e)

  0%|          | 0/391 [00:00<?, ?it/s]


 epoch 0


  0%|          | 1/391 [01:27<9:30:30, 87.77s/it]

KeyboardInterrupt: 

In [None]:
 torch.save(model.state_dict(),'saved_transformer.pth')

## Sources
- https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
- http://www.peterbloem.nl/blog/transformers