In [None]:
# Rewrite of https://github.com/harvardnlp/annotated-transformer/blob/master/The%20Annotated%20Transformer.ipynb
# pip install torch matplotlib spacy torchtext seaborn 

In [2]:
import torch, copy, pdb, math
import torch.nn as nn
import torch.nn.functional as F  

In [3]:
N = 2 # number of encoder or decoder stacks 
h = 8 # heards in parallel
d_model = 512 # dimensions 
dropout_rate = 0.1
d_ff = 2048

vocab_size_src  = 51
vocab_size_tgt  = 51

In [4]:
class Embedding(nn.Module):
    def __init__(self, vocab_size):
        super(Embedding, self).__init__() 
        self.emb = nn.Embedding(vocab_size, d_model) # https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html

    def forward(self, x):
        return self.emb(x) * math.sqrt(d_model)

In [5]:
def PositionalEncoding(): 
    # Implement the PE function.
    # Compute the positional encodings once in log space.
    max_len = 999
    position = torch.arange(0, max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
    pe = torch.zeros(max_len, d_model, requires_grad=False)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe.unsqueeze(0) 

In [6]:
class MultiHeadedAttention(nn.Module):
    def __init__(self):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k     = d_model // h
        self.h       = h
        self.l1      = nn.Linear(d_model, d_model) 
        self.l2      = nn.Linear(d_model, d_model) 
        self.l3      = nn.Linear(d_model, d_model) 
        self.l4      = nn.Linear(d_model, d_model)  
        
    def attention(self, Q, K, V, mask=None):
        "Compute 'Scaled Dot Product Attention'"
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        weights = F.softmax(scores, dim = -1) 
        x = torch.matmul(weights, V) 
        return x

    def forward(self, Q, K, V, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = Q.size(0) 
        # 1) Do all the linear projections in batch from d_model => h x d_k
        Q = self.l1(Q).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        K = self.l2(K).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        V = self.l3(V).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
         
        # 2) Apply attention on all the projected vectors in batch. 
        x = self.attention(Q, K, V, mask=mask) 
        
        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, d_model)
        x = self.l4(x)
        return x

In [7]:
class FeedForward(nn.Module):
    "Implements FFN equation."
    def __init__(self):
        super(FeedForward, self).__init__()
        self.l1 = nn.Linear(d_model, d_ff)
        self.l2 = nn.Linear(d_ff, d_model) 

    def forward(self, x):
        x = self.l1(x)
        x = F.relu(x) 
        x = self.l2(x)
        return x

In [8]:
class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined below)"
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadedAttention()
        self.ff = FeedForward( ) 
        self.norm1 = nn.LayerNorm(d_model, eps=1e-6)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-6) 
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)
 
    def forward(self, x):
        "Follow Figure 1 (left) for connections."
        x = x + self.dropout1(self.mha(self.norm1(x), self.norm1(x), self.norm1(x)))  
        x = x + self.dropout2(self.ff (self.norm2(x)))  
        return x

In [9]:
class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
    def __init__(self):
        super(DecoderLayer, self).__init__() 
        self.mmha     = MultiHeadedAttention()
        self.mha      = MultiHeadedAttention() 
        self.ff       = FeedForward( ) 
        self.norm1    = nn.LayerNorm(d_model, eps=1e-6)
        self.norm2    = nn.LayerNorm(d_model, eps=1e-6)
        self.norm3    = nn.LayerNorm(d_model, eps=1e-6) 
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.dropout3 = nn.Dropout(dropout_rate)  
 
    def forward(self, x, memory, tgt_mask):
        "Follow Figure 1 (right) for connections." 
        x = x + self.dropout1(self.mmha(self.norm1(x), self.norm1(x), self.norm1(x), tgt_mask) )  
        x = x + self.dropout2(self.mha (self.norm2(x), memory, memory))    
        x = x + self.dropout3(self.ff  (self.norm3(x)))    
        return x

In [10]:
class Generator(nn.Module):
    "Define standard linear + softmax generation step."
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)

    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

In [11]:
class Transformer(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many 
    other models.
    """
    def __init__(self):
        super(Transformer, self).__init__()
          
        self.src_embed     = Embedding(vocab_size_src)
        self.pe            = PositionalEncoding()
        self.encoderLayers = nn.ModuleList([copy.deepcopy(EncoderLayer()) for _ in range(N)]) 
        self.norm1         = nn.LayerNorm(d_model, eps=1e-6)
        
        self.tgt_embed     = Embedding(vocab_size_tgt)
        self.decoderLayers = nn.ModuleList([copy.deepcopy(DecoderLayer()) for _ in range(N)])  
        self.generator     = Generator(d_model, vocab_size_tgt)
        self.norm2         = nn.LayerNorm(d_model, eps=1e-6)
        
    def forward(self, src, tgt, tgt_mask):
        "Take in and process masked src and target sequences."
        memory = self.norm1(self.encode(src))
        tgt    = self.norm2(self.decode(memory, tgt, tgt_mask))
        return tgt 
    
    def encode(self, src):
        src = self.src_embed(src) 
        src = src + self.pe[:, :src.size(1)]  
        for layer in self.encoderLayers:
            src = layer(src) 
        memory = src
        return memory
    
    def decode(self, memory, tgt, tgt_mask):
        tgt = self.tgt_embed(tgt)
        tgt = tgt + self.pe[:, :tgt.size(1)]   
        for layer in self.decoderLayers:
            tgt = layer(tgt, memory, tgt_mask) 
        return tgt

In [12]:
def make_model(src_vocab, tgt_vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1):
    "Helper: Construct a model from hyperparameters."
    model = Transformer(
        )
    
    # This was important from their code. 
    # Initialize parameters with Glorot / fan_avg.
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model

In [13]:
def subsequent_mask(size):
    attn_shape = (1, size, size) 
    return torch.triu(torch.ones(attn_shape, dtype=int), diagonal=1) == 0

In [14]:
class Batch:
    "Object for holding a batch of data with mask during training."
    def __init__(self, src, trg=None, pad=0):
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        if trg is not None:
            self.trg = trg[:, :-1]
            self.trg_y = trg[:, 1:]
            self.trg_mask = self.make_std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).data.sum()
     
    def make_std_mask(self, tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)) 
        return tgt_mask

In [15]:
def run_epoch(data_iter, model, loss_compute):
    "Standard Training and Logging Function" 
    total_tokens = 0
    total_loss = 0
    tokens = 0
    for i, batch in enumerate(data_iter):
        out = model.forward(batch.src, batch.trg, batch.trg_mask)
        loss = loss_compute(out, batch.trg_y, batch.ntokens)
        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        if i % 50 == 1: 
            # print("Epoch Step: %d Loss: %f Tokens per Sec: %f" % (i, loss / batch.ntokens, tokens / elapsed)) 
            tokens = 0
    return total_loss / total_tokens

In [16]:
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
    "Keep augmenting batch and calculate total number of tokens + padding."
    global max_src_in_batch, max_tgt_in_batch
    if count == 1:
        max_src_in_batch = 0
        max_tgt_in_batch = 0
    max_src_in_batch = max(max_src_in_batch,  len(new.src))
    max_tgt_in_batch = max(max_tgt_in_batch,  len(new.trg) + 2)
    src_elements = count * max_src_in_batch
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)

In [17]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = d_model
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))) 

In [18]:
def data_gen(V, batch, nbatches): 
    "Generate random data for a src-tgt copy task."
    for i in range(nbatches): 
        src = torch.randint(1, V, size=(batch, V), requires_grad=False)
        src[:, 0] = 1
        tgt = copy.deepcopy(src)
        yield Batch(src, tgt, 0)

In [19]:
class SimpleLossCompute:
    "A simple loss compute and train function."
    def __init__(self, generator, criterion, opt=None):
        self.generator = generator
        self.criterion = criterion
        self.opt = opt
        
    def __call__(self, x, y, norm):
        x = self.generator(x)
        loss  = self.criterion(x.contiguous().view(-1, x.size(-1)), y.contiguous().view(-1)) 
        loss /= norm
        loss.backward()
        if self.opt is not None:
            self.opt.step()
            self.opt.optimizer.zero_grad() 
        return loss.data * norm

In [21]:
# Train the simple copy task.
V = 51
criterion = nn.CrossEntropyLoss()#LabelSmoothing(size=V, padding_idx=0, smoothing=0.0)
for nth_test in range(1):
    model = make_model(V, V, N=2)
    model_opt = NoamOpt(d_model, 
                        1, 
                        400, 
                        torch.optim.Adam(model.parameters(), 
                                         lr=0, 
                                         betas=(0.9, 0.98), 
                                         eps=1e-9
                                        )
                       )
    #model_opt = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    for epoch in range(50):
        model.train()
        run_epoch(data_gen(V, 30, 20), model, SimpleLossCompute(model.generator, criterion, model_opt))
        model.eval()
        check = run_epoch(data_gen(V, 30, 5), model, SimpleLossCompute(model.generator, criterion, None))
        print(epoch+1, ':',check)
        if check < 0.00001:
            break
    
    def greedy_decode(model, src, src_mask, max_len, start_symbol):
        #pdb.set_trace()
        memory = model.encode(src)
        ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
        for i in range(max_len-1):
            out = model.decode(memory, ys, subsequent_mask(ys.size(1))) 
            prob = model.generator(out[:, -1])
            _, next_word = torch.max(prob, dim = 1)
            next_word = next_word.data[0]
            ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        return ys
    
    model.eval()
    t = 50
    src = torch.arange(1, t+1).unsqueeze(0)
    src_mask = torch.ones(1, 1, t) 
    pred = greedy_decode(model, src, src_mask, max_len=t, start_symbol=1)
    #print(pred)
    print('!!!!!!!!!!!!!!!!!!!!!', sum([i==j for i,j in zip(pred[0],range(1,51))]))

1 : tensor(0.0027)
2 : tensor(0.0026)
3 : tensor(0.0026)
4 : tensor(0.0024)
5 : tensor(0.0022)
6 : tensor(0.0017)
7 : tensor(0.0013)
8 : tensor(0.0008)
9 : tensor(0.0003)
10 : tensor(0.0002)
11 : tensor(8.3170e-05)
12 : tensor(4.8422e-05)
13 : tensor(4.7529e-05)
14 : tensor(7.6590e-05)
15 : tensor(8.1474e-05)
16 : tensor(6.4091e-05)
17 : tensor(3.7960e-05)
18 : tensor(3.6851e-05)
19 : tensor(5.2938e-05)
20 : tensor(3.7670e-05)
21 : tensor(7.1067e-05)
22 : tensor(4.2804e-05)
23 : tensor(2.3188e-05)
24 : tensor(2.7754e-05)
25 : tensor(4.1535e-05)
26 : tensor(2.3626e-05)
27 : tensor(2.0970e-05)
28 : tensor(1.6799e-05)
29 : tensor(2.6592e-05)
30 : tensor(1.8358e-05)
31 : tensor(2.3255e-05)
32 : tensor(2.2536e-05)
33 : tensor(1.1023e-05)
34 : tensor(2.0560e-05)
35 : tensor(1.6535e-05)
36 : tensor(1.6381e-05)
37 : tensor(1.6006e-05)
38 : tensor(1.6477e-05)
39 : tensor(2.2869e-05)
40 : tensor(1.4038e-05)
41 : tensor(1.3275e-05)
42 : tensor(1.0311e-05)
43 : tensor(1.1329e-05)
44 : tensor(1.283