In [None]:
import os
os.environ["KERAS_BACKEND"] = "torch"
import torch
import keras
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from keras.layers import TorchModuleWrapper
from keras.optimizers import Adam

# stuff by hand

getting used to pytorch...


code src:
https://towardsdatascience.com/build-your-own-transformer-from-scratch-using-pytorch-84c850470dcb

## Attention

Forward pass steps:
1. Generating matrices (transform input into Q, K, V matrices)
2. calculate attn scores (matmul of query x key, normalize by softmax scaled by sqrt(embed size))
3. apply attn - weighted sum of value vecs and attention scores thru self.fc_out

### simple self attn

In [None]:
#self attn class

class SimpleSelfAttention(nn.Module):
  def __init__(self, embed_size, heads):
    super(SimpleSelfAttention, self).__init__() #parent nn.module classf
    self.embed_size = embed_size
    self.heads = heads

    #linear layers (ie dense layer): args = # of input and output features; internally creates weight matrix of size (out_features, in_features) + bias of size (out_features)
    #internal weights, biases randomly initialized
    self.values = nn.Linear(embed_size, embed_size, bias=False)
    self.keys = nn.Linear(embed_size, embed_size, bias=False)
    self.queries = nn.Linear(embed_size, embed_size, bias=False)
    self.fc_out = nn.Linear(embed_size, embed_size)


  #forward method
  def forward(self, value, key, query):
    # get Q, K, V matrices; takes input vecs and apply weight/bias from init
    #each dim is (batch_size, seq_length, embed_size)
    queries = self.queries(query)
    keys = self.keys(key)
    values = self.values(value)

    #calculate attention scores

    #dot product bt query, key matrices
    energy = torch.bmm(queries,
                       keys.transpose(1, 2)) #transpose key matrix so dims align for batch mat mul, ie (batch size, seq length)
                       #Q is (batch_size, seq_len, embed_size); last dim, embed_size, needs to match K second to last dim
                       #so we need K to be: (batch_size, embed_size, seq_length)


    #softmax, divide by sqrt(embed size) for stability
    #softmax is across *last dimension* of tensor, ie sequence length of keys/ tokens in sequence
    #now for each token (along seq length), attn scores will sum to 1 across all other tokens
    attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=-1)


    #get weighted value vectors
    out = torch.bmm(attention, values)

    #apply final linear layer
    out = self.fc_out(out)

    return out



### multi self attn

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()

    #each head should process d_model/h dimensions
    assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

    self.d_model = d_model #total dimensionality of model = input embed size
    self.num_heads = num_heads #number of heads
    self.d_k = d_model // num_heads #dimensionality of vectors for each head to operate on
      #input embeds projected into multiple subspaces --> each receives vector (size=d_k) for q k v ops



    #initialize linear layers
    #as above: linear layers (ie dense layer): args = # of input and output features; internally creates weight matrix of size (out_features, in_features) + bias of size (out_features)
    self.W_q = nn.Linear(d_model, d_model)
    self.W_k = nn.Linear(d_model, d_model)
    self.W_v = nn.Linear(d_model, d_model)
    self.W_o = nn.Linear(d_model, d_model)


  def scaled_dot_product_attention(self, Q, K, V, mask=None):

    #numerator: again need to multiply q&k and transpose over last 2 dimensions for alignment
    #denom: scale by sqrt subspace size
    attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

    #masking: mask applied to attn scores to prevent certain tokens from attending to future tokens
    #more efficient to apply after dot product; even if we applied mask before dot product, we would still have to normalize attn scores
    #instead of just 0ing out, set scores to very large neg number so that positions will have near-0 probabilities after softmax
    if mask is not None:
      attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

    #softmax along last dimension (= seq length of keys)
    attn_probs = torch.softmax(attn_scores, dim=-1)

    #weighted sum over values based on attn probabilities
    output = torch.matmul(attn_probs, V) #output size = (batch_size, seq_len, dim_v)
    return output


  def split_heads(self, x):
    ''' split input tensor into multiple attn heads'''
    batch_size, seq_length, d_model = x.size()

    #reshape to split model dim into num_heads attn heads = size d_k
    output = (
        x.view(batch_size, seq_length, self.num_heads, self.d_k) #reshapes tensor w/o copying data
         .transpose(1, 2) #for parallelization, need num_heads come before seq_length, so tensor shape = (batch_size, num_heads, seq_length, d_k)
    )
    return output

  def combine_heads(self, x):
    '''restore initial embed dimension '''
    batch_size, _, seq_length, d_k = x.size() #_ standin for num_heads but we already know self.num_heads
    output = x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
      #continguous: store data in continguous memory for valid reshaping
      #reshape (.view) tensor back to (batch_size, seq_length, d_model) where model dim = num_heads*d_k (restore original embed dim)
    return output


  def forward(self, Q, K, V, mask=None):
    #apply weights/biases from init matrices & split into heads
    Q = self.split_heads(self.W_q(Q))
    K = self.split_heads(self.W_k(K))
    V = self.split_heads(self.W_v(V))
      #each has shape: (batch_size, num_heads, seq_length, d_model//num_heads)


    #apply scaled dot product attn for each head independently
    attn_output = self.scaled_dot_product_attention(Q, K, V, mask)

    #combine heads, apply output linear layer
    output = self.W_o(self.combine_heads(attn_output))
      #transpose, reshape attn outputs
      #initially: (batch_size, num_heads, seq_length, d_k)
      #output: (batch_size, seq_len, d_model)
      #then pass thru learned linear transform W_o (project multi head output back to original model dim embed space)
    return output







## FFNN

In [None]:
class PositionWiseFeedForward(nn.Module):
    '''2 dense + reLU activation'''
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff) #expand input dim from model dim to ffnn size (usually larger)
        self.fc2 = nn.Linear(d_ff, d_model) #reduce dimensionality back down to model dim
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

## positional encodings

In [None]:
class PositionalEncoding(nn.Module):
    '''use sine, cosine of diff frequencies to gen positional encoding'''
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_seq_length, d_model) #create tensor to store positional encoding vals
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        #calculate sin/cos for even/odd indices based on scaling factor div_term
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        #compute positional encoding by adding stored positional encoding vals to input tensor
        return x + self.pe[:, :x.size(1)]

## Encoder & Decoder

### Encoder

- multi-head attn layer
- position wise FFNN layer
- 2 layer norm layers

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        '''params: d_ff = dimensionality of FFNN hidden layer '''
        #initialize layers
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model) #norm for residual connections
        self.norm2 = nn.LayerNorm(d_model) #norm for residual connections
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
      #x =  input tensor (batch_size, seq_length, model_dim)
      attn_output = self.self_attn(x, x, x, mask)
        #x = q,k,v (allow each token to look at each token)
        #output = weighted sum of vals based on attn scores
      x = self.norm1(x + self.dropout(attn_output)) #add original input (ie residual connection), apply dropout, layer norm for stabilization
      ff_output = self.feed_forward(x) #pass normalized output through FFNN
      x = self.norm2(x + self.dropout(ff_output)) #add original input (ie residual connection), apply dropout, layer norm for stabilization
      return x


### Decoder


- multi head attn x2
- FFNN
- layer norm x3

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()

        #initialize multi head attn for self & cross attn
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        #initialize FFNN, layer norms, dropout
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        #self attn over decoder's current, previous outputs
          #tgt_mask used to prevent from attn to future tokens
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output)) #layer norm + residual connection

        #cross attn; now x=query, encoder output = keys and values
        #src_mask = padding tokens in enc output don't go into attn scores
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output)) #residual connection, layer norm, dropout
        ff_output = self.feed_forward(x) #cross attn output goes thru FFNN
        x = self.norm3(x + self.dropout(ff_output)) #output layer
        return x

## combine everything (transformer block)

In [None]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()

        #initialize embedding layers for src & tgt sequences
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)

        #initialize positional encoding layer
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        #lists of encoder & decoder layers for stacked layers
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        #linear layer for projecting decoder output ie map decoder output to target vocab size
        self.fc = nn.Linear(d_model, tgt_vocab_size)

        #initialize dropout layer
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        #create mask for src, target inputs to identify non-padding tokens
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2) #mask shape: (batch_size, 1, 1, src_length)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)

        #no-peek: target only attend to itself & previous
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask #update target mask to include no peek mask

        #output
        return src_mask, tgt_mask

    def forward(self, src, tgt):

        #generate src & target masks
        src_mask, tgt_mask = self.generate_mask(src, tgt)

        #combine src & target embeddings + positional encodings --> dropout
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        #encoding process - pass encoder output (src_embedded) thru each encoder layer in list (attn+FFNN)
        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        #take embedded target seq, process thru eachd ecoder layer
        #use current output dec_output, encoder output enc_output to generate next output
        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)


        #pass thru final linear layer --> project to target vocab size, produce prob distrib
        output = self.fc(dec_output)
        return output