## About This Notebook

This is a playground notebook of sorts where I can try to implement a lot of the code from the assignments from scratch.

The assignments are helpful but there are too many hints and it's too easy for someone to work through them without truely understanding how transformers, attention, etc. works.

#### Terms

* cross attention
* causal attention
* scaled dot product attention

In [1]:
import sys
import os

import numpy as np

import textwrap
wrapper = textwrap.TextWrapper(width=70)

import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp

# to print the entire np array
np.set_printoptions(threshold=sys.maxsize)

### Scale Dot Product Attention

In [2]:
def scaled_dp_attention(q, k, v, m, embed_dims, vocab_size):
    '''
    k - keys [batch, seq_len_in, embed_dims]
    q - query [batch, seq_len_out, embed_dims]
    v - values [batch, seq_len_in, embed_dims]
    m - mask [batch, seq_len_out, seq_len_in]
    '''
    kt  = jnp.transpose(k)
    qkt = jnp.dot(q, kt)
    mqk = jnp.where(m==True, qtk, jnp.full_like(qkt, -1e9)) #add in the mask
    w   = softmax(mqk/sqrt(embed_dims))
    
    return jnp.dot(w,v)
    
    #Note: it's a little unclear to me how I can use trax layers with jax.numpy operations. 
    

In [1]:
# softmax with no division or multiplication to avoid underflow and overflow. 
def softmax(x):
    d = trax.fastmath.logsumexp(x, axis=-1, keepdims=True)
    p = jnp.exp(x - d)
    return p #array of probabilities

### Causal Attention

a.k.a masked-multi head attention



In [3]:
#helper functions
def compute_attention_heads_closure(n_heads, d_head):
    pass

def dot_product_self_attention():
    pass

def compute_attention_output_closure():
    pass

In [None]:
def CausalAttention():
    """Transformer-style multi-headed causal attention.
    """
    pass

## RNN Encoder / Decoder with Attention

In [6]:

def nmt_encoder(vocab_size, embed_dims, lstm_hidden, num_lstms):
    '''
    Encodes the input token sequences before passing it along to attention and then on to the decoder.
    '''
    input_encoder = tl.Serial(
        tl.Embedding(vocab_size, embed_dims),
        [tl.LSTM(lstm_hidden) for _ in range(num_lstms) ],
    )
    
    return input_encoder
    
def pre_attention_decoder(target_vocab_size, embed_dims, lstm_hidden):
    '''
    "Decodes" the target target token sequences. The decoder hidden state will be used as queries in the attention layers. 
    - This is just the first decoder in this model. A second decoder runs after the attention layer. 
    - Shifts the supplied token sequences right before running the layers. 
    - Note: only a signle LSTM was used for the decoder vs. many for the encoder. Not really sure why. 
    '''
    
    pa_target_decoder = tl.Serial(
        tl.ShiftRight(),
        tl.Embedding(target_vocab_size, embed_dims),
        tl.LSTM(lstm_hidden),
    )
    return pa_target_decoder


In [None]:
# Attention helper functions

def prepare_qkv_from_encoder_decoder(encoder_activations, decoder_activations, input_token_seq, attention_heads=1):
'''
   Super light level of indirection that maps activations to QKV in attention.
'''
    queries = decoder_activations
    keys    = encoder_activations
    values  = encoder_activations
    
    (batch_size, seq_len) = input_token_seq.shape 
    (_, decoder_len)      = decoder_activations.shape
    
    # if token is padding then mask = 0, otherwise 1
    mask  = jnp.where(input_token_seq == 0, 0, 1)
    
    #mask dims transform:  batch_size x seq_len -> batch_size, attention_heads, decoder_length, encoder_length
    
    #add dimensions
    mask = jnp.reshape(mask, (batch_size, 1, 1, seq_len))
    
    #QUESTION: is the seq_len and the encoder length the same? Maybe just the dims change. 
    
    # adding in this way causes broadcast / dim expansion
    mask = mask + jnp.zeros((1, attention_heads, decoder_len, 1))
    
    return (queries, keys, values, mask)
    