In [1]:
# Rewrite of https://github.com/harvardnlp/annotated-transformer/blob/master/The%20Annotated%20Transformer.ipynb
# pip install torch matplotlib spacy torchtext seaborn 

In [2]:
import torch, copy, pdb, math
import torch.nn as nn
import torch.nn.functional as F  

In [3]:
N = 1 # number of encoder or decoder stacks 
h = 8 # heards in parallel
d_model = 512 # dimensions 
dropout_rate = 0.1
d_ff = 2048

vocab_size_src  = 51
vocab_size_tgt  = 51

In [4]:
class Embedding(nn.Module):
    def __init__(self, vocab_size):
        super(Embedding, self).__init__() 
        self.emb = nn.Embedding(vocab_size, d_model) # https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html

    def forward(self, x):
        return self.emb(x) * math.sqrt(d_model)

In [5]:
def PositionalEncoding(): 
    # Implement the PE function.
    # Compute the positional encodings once in log space.
    max_len = 999
    position = torch.arange(0, max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
    pe = torch.zeros(max_len, d_model, requires_grad=False)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe.unsqueeze(0) 

In [6]:
class MultiHeadedAttention(nn.Module):
    def __init__(self):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k     = d_model // h
        self.h       = h
        self.l1      = nn.Linear(d_model, d_model) 
        self.l2      = nn.Linear(d_model, d_model) 
        self.l3      = nn.Linear(d_model, d_model) 
        self.l4      = nn.Linear(d_model, d_model)  
        
    def attention(self, Q, K, V, mask=None):
        "Compute 'Scaled Dot Product Attention'"
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9) 
        weights = F.softmax(scores, dim = -1) 
        x = torch.matmul(weights, V) 
        if mask is not None:
            print('Q:', Q.shape)
            print('x:', x.shape)
            print('mask:', mask.shape) 
            print('scores:', scores.shape) 
        return x

    def forward(self, Q, K, V, mask=None):
        if mask is not None:
            print('pre:',Q)
            print('pre==decoder_post_eb:',Q.shape)
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = Q.size(0) 
        # 1) Do all the linear projections in batch from d_model => h x d_k
        Q = self.l1(Q).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        K = self.l2(K).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        V = self.l3(V).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
         
        # 2) Apply attention on all the projected vectors in batch. 
        #print('post:',Q)
        x = self.attention(Q, K, V, mask=mask) 
        
        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, d_model)
        x = self.l4(x)
        return x

In [7]:
class FeedForward(nn.Module):
    "Implements FFN equation."
    def __init__(self):
        super(FeedForward, self).__init__()
        self.l1 = nn.Linear(d_model, d_ff)
        self.l2 = nn.Linear(d_ff, d_model) 

    def forward(self, x):
        x = self.l1(x)
        x = F.relu(x) 
        x = self.l2(x)
        return x

In [8]:
class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined below)"
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadedAttention()
        self.ff = FeedForward( ) 
        self.norm1 = nn.LayerNorm(d_model, eps=1e-6)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-6) 
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)
 
    def forward(self, x):
        "Follow Figure 1 (left) for connections."
        x = x + self.dropout1(self.mha(self.norm1(x), self.norm1(x), self.norm1(x)))  
        x = x + self.dropout2(self.ff (self.norm2(x)))  
        return x

In [9]:
class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
    def __init__(self):
        super(DecoderLayer, self).__init__() 
        self.mmha     = MultiHeadedAttention()
        self.mha      = MultiHeadedAttention() 
        self.ff       = FeedForward( ) 
        self.norm1    = nn.LayerNorm(d_model, eps=1e-6)
        self.norm2    = nn.LayerNorm(d_model, eps=1e-6)
        self.norm3    = nn.LayerNorm(d_model, eps=1e-6) 
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.dropout3 = nn.Dropout(dropout_rate)  
 
    def forward(self, x, memory, tgt_mask):
        "Follow Figure 1 (right) for connections." 
        x = x + self.dropout1(self.mmha(x, x, x, tgt_mask) )   
        x = x + self.dropout2(self.mha (x, memory, memory))    
        x = x + self.dropout3(self.ff  (x))    
        return x

In [10]:
class Generator(nn.Module):
    "Define standard linear + softmax generation step."
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)

    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

In [11]:
class Transformer(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many 
    other models.
    """
    def __init__(self):
        super(Transformer, self).__init__()
          
        self.src_embed     = Embedding(vocab_size_src)
        self.pe            = PositionalEncoding()
        self.encoderLayers = nn.ModuleList([copy.deepcopy(EncoderLayer()) for _ in range(N)]) 
        self.norm1         = nn.LayerNorm(d_model, eps=1e-6)
        
        self.tgt_embed     = Embedding(vocab_size_tgt)
        self.decoderLayers = nn.ModuleList([copy.deepcopy(DecoderLayer()) for _ in range(N)])  
        self.generator     = Generator(d_model, vocab_size_tgt)
        self.norm2         = nn.LayerNorm(d_model, eps=1e-6)
        
    def forward(self, src, tgt, tgt_mask):
        "Take in and process masked src and target sequences."
        memory = self.norm1(self.encode(src))
        tgt    = self.norm2(self.decode(memory, tgt, tgt_mask)) 
        return tgt 
    
    def encode(self, src):
        src = self.src_embed(src) 
        src = src + self.pe[:, :src.size(1)]  
        for layer in self.encoderLayers:
            src = layer(src) 
        memory = src
        return memory
    
    def decode(self, memory, tgt, tgt_mask):
        print('decoder_pre_eb:', tgt.shape)
        tgt = self.tgt_embed(tgt)
        print('decoder_post_eb:', tgt.shape)
        print('decoder_post_eb:', tgt)
        for layer in self.decoderLayers: 
            tgt = layer(tgt, memory, tgt_mask) 
        return tgt

In [12]:
def make_model(src_vocab, tgt_vocab, 
               d_model=512, d_ff=2048, h=8, dropout=0.1):
    "Helper: Construct a model from hyperparameters."
    model = Transformer(
        )
    
    # This was important from their code. 
    # Initialize parameters with Glorot / fan_avg.
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model

In [13]:
def subsequent_mask(size):
    attn_shape = (1, size, size) 
    return torch.triu(torch.ones(attn_shape, dtype=int), diagonal=1) == 0

In [14]:
class Batch:
    "Object for holding a batch of data with mask during training."
    def __init__(self, src, trg=None, pad=0):
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        if trg is not None:
            self.trg = trg[:, :-1]
            self.trg_y = trg[:, 1:]
            self.trg_mask = self.make_std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).data.sum()
     
    def make_std_mask(self, tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)) 
        return tgt_mask

In [15]:
def run_epoch(data_iter, model, loss_compute):
    "Standard Training and Logging Function" 
    total_tokens = 0
    total_loss = 0
    tokens = 0
    for i, batch in enumerate(data_iter):
        out = model.forward(batch.src, batch.trg, batch.trg_mask)  

In [16]:
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
    "Keep augmenting batch and calculate total number of tokens + padding."
    global max_src_in_batch, max_tgt_in_batch
    if count == 1:
        max_src_in_batch = 0
        max_tgt_in_batch = 0
    max_src_in_batch = max(max_src_in_batch,  len(new.src))
    max_tgt_in_batch = max(max_tgt_in_batch,  len(new.trg) + 2)
    src_elements = count * max_src_in_batch
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)

In [17]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = d_model
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))) 

In [18]:
def data_gen(V, batch, nbatches): 
    "Generate random data for a src-tgt copy task."
    for i in range(nbatches): 
        src = torch.randint(1, V, size=(batch, V), requires_grad=False)
        src[:, 0] = 1
        tgt = copy.deepcopy(src)
        yield Batch(src, tgt, 0)

In [19]:
class SimpleLossCompute:
    "A simple loss compute and train function."
    def __init__(self, generator, criterion, opt=None):
        self.generator = generator
        self.criterion = criterion
        self.opt = opt
        
    def __call__(self, x, y, norm):
        x = self.generator(x)
        loss  = self.criterion(x.contiguous().view(-1, x.size(-1)), y.contiguous().view(-1)) 
        loss /= norm
        loss.backward()
        if self.opt is not None:
            self.opt.step()
            self.opt.optimizer.zero_grad() 
        return loss.data * norm

In [20]:
# Train the simple copy task.
V = 51
criterion = nn.CrossEntropyLoss()#LabelSmoothing(size=V, padding_idx=0, smoothing=0.0)

model = make_model(V, V)
model_opt = NoamOpt(d_model, 
                    1, 
                    400, 
                    torch.optim.Adam(model.parameters(), 
                                     lr=0, 
                                     betas=(0.9, 0.98), 
                                     eps=1e-9
                                    )
                   )

model.train()
run_epoch(data_gen(V, 1, 1), model, SimpleLossCompute(model.generator, criterion, model_opt))

decoder_pre_eb: torch.Size([1, 50])
decoder_post_eb: torch.Size([1, 50, 512])
decoder_post_eb: tensor([[[-0.7567, -1.2534,  1.8538,  ...,  1.2469, -0.1786, -0.3044],
         [-0.6531,  0.7032,  1.2039,  ..., -0.6993, -0.8673,  0.9826],
         [-0.8913, -1.2010,  0.5727,  ..., -2.2280,  2.1302, -1.4048],
         ...,
         [-0.7579,  0.4750,  1.3414,  ...,  2.0982,  1.1911,  1.9675],
         [-1.3025, -2.1970, -1.1736,  ..., -0.6874, -0.5123, -1.4236],
         [-2.2855,  1.6247, -0.6763,  ...,  1.4621,  0.3762, -1.1457]]],
       grad_fn=<MulBackward0>)
pre: tensor([[[-0.7567, -1.2534,  1.8538,  ...,  1.2469, -0.1786, -0.3044],
         [-0.6531,  0.7032,  1.2039,  ..., -0.6993, -0.8673,  0.9826],
         [-0.8913, -1.2010,  0.5727,  ..., -2.2280,  2.1302, -1.4048],
         ...,
         [-0.7579,  0.4750,  1.3414,  ...,  2.0982,  1.1911,  1.9675],
         [-1.3025, -2.1970, -1.1736,  ..., -0.6874, -0.5123, -1.4236],
         [-2.2855,  1.6247, -0.6763,  ...,  1.4621,  0.376