## About This Notebook

This is a playground notebook of sorts where I can try to implement a lot of the code from the assignments from scratch.

The assignments are helpful but there are too many hints and it's too easy for someone to work through them without truely understanding how transformers, attention, etc. works.

#### Terms

* cross attention
* causal attention
* scaled dot product attention

In [1]:
import sys
import os

import numpy as np

import textwrap
wrapper = textwrap.TextWrapper(width=70)

import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp

# to print the entire np array
np.set_printoptions(threshold=sys.maxsize)

### Scale Dot Product Attention

In [2]:
def scaled_dp_attention(q, k, v, m, embed_dims, vocab_size):
    '''
    k - keys [batch, seq_len_in, embed_dims]
    q - query [batch, seq_len_out, embed_dims]
    v - values [batch, seq_len_in, embed_dims]
    m - mask [batch, seq_len_out, seq_len_in]
    '''
    kt  = jnp.transpose(k)
    qkt = jnp.dot(q, kt)
    mqk = jnp.where(m==True, qtk, jnp.full_like(qkt, -1e9)) #add in the mask
    w   = softmax(mqk/sqrt(embed_dims))
    
    return jnp.dot(w,v)
    
    #Note: it's a little unclear to me how I can use trax layers with jax.numpy operations. 
    

In [1]:
# softmax with no division or multiplication to avoid underflow and overflow. 
def softmax(x):
    d = trax.fastmath.logsumexp(x, axis=-1, keepdims=True)
    p = jnp.exp(x - d)
    return p #array of probabilities

### Causal Attention

a.k.a masked-multi head attention

See assignment 2. 


In [3]:
#helper functions
def compute_attention_heads_closure(n_heads, d_head):
    pass

def dot_product_self_attention():
    pass

def compute_attention_output_closure():
    pass

In [None]:
def CausalAttention():
    """Transformer-style multi-headed causal attention.
    """
    pass

## RNN Encoder / Decoder with Attention

In [6]:

def nmt_encoder(vocab_size, embed_dims, lstm_hidden, num_lstms):
    '''
    Encodes the input token sequences before passing it along to attention and then on to the decoder.
    '''
    input_encoder = tl.Serial(
        tl.Embedding(vocab_size, embed_dims),
        [tl.LSTM(lstm_hidden) for _ in range(num_lstms) ],
    )
    
    return input_encoder
    
def pre_attention_decoder(target_vocab_size, embed_dims, lstm_hidden):
    '''
    Decodes" the target target token sequences. The decoder hidden state will be used as queries in the attention layers. 
    - This is just the first decoder in this model. A second decoder runs after the attention layer. 
    - Shifts the supplied token sequences right before running the layers. 
    - Note: only a single LSTM was used for the decoder vs. many for the encoder. Not really sure why. 
    '''
    
    pa_target_decoder = tl.Serial(
        tl.ShiftRight(),
        tl.Embedding(target_vocab_size, embed_dims),
        tl.LSTM(lstm_hidden),
    )
    return pa_target_decoder


In [4]:
# Attention helper functions

def prepare_qkv_from_encoder_decoder(encoder_activations, decoder_activations, input_token_seq):
    '''
    Preparing inputs for attention layer from activations of encoder and decoder hidden states.
    Super light level of indirection that maps activations to QKV in attention.
    '''
    attention_heads=1
    
    queries = decoder_activations
    keys    = encoder_activations
    values  = encoder_activations
    
    (batch_size, seq_len) = input_token_seq.shape 
    (_, _, dec_embed_dim)   = decoder_activations.shape 
    
    # if token is padding then mask = 0, otherwise 1
    mask  = jnp.where(input_token_seq == 0, 0, 1)
    
    #mask dims transform:  batch_size x seq_len -> batch_size, attention_heads, decoder_length, encoder_length
    
    #add dimensions
    mask = jnp.reshape(mask, (batch_size, 1, 1, seq_len))
    
    #QUESTION: is the seq_len and the encoder length the same? Maybe just the dims change. 
    
    # adding in this way causes broadcast / dim expansion
    mask = mask + jnp.zeros((1, attention_heads, decoder_embed_dim, 1))
    
    return (queries, keys, values, mask)
    

In [1]:
def NMTAttn(input_vocab_size=33300,
            target_vocab_size=33300,
            d_model=1024,
            n_encoder_layers=2,
            n_decoder_layers=2,
            n_attention_heads=4,
            attention_dropout=0.0,
            mode='train'):
    
    tl.Serial([
        tl.Select([0,1,0,1]),
        tl.Parallel([
            nmt_encoder(input_vocab_size, d_model, d_model, n_encoder_layers),
            pre_attention_decoder(target_vocab_size, d_model, d_model),
        ]),
        #takes three inputs. Leaves one on stack...
        tl.Fn('PrepareAttentionInput', prepare_qkv_from_encoder_decoder ,n_out=4),
        
        #https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.combinators.Residual
        #TODO: Can you understand the Stack and Residual better? 
        tl.Residual(
            tl.AttentionQKV(d_model, n_heads=n_attention_heads, dropout=attention_dropout, mode=mode)
        ),
        #this just drops the mask. 
        tl.Select([0,2]),
        [tl.LSTM(d_model) for _ in range(n_decoder_layers)], #do you need to use d_model here? 
        tl.Dense(target_vocab_size),
        tl.LogSoftmax()
    ])

#TODO: test me.

<img src = "NMTModel.png">

#### Run the Model Once

Test the model to make sure you understand the inputs and outputs and to make sure there are no errors in how you coded it.


In [None]:
def next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature):
    pass
     
#Note: you can implement the function above or pass partially translated targets to the model with a batch dim directly
#and get a prob dist back. 

#Note the input is a pair like
NMTAttn((input_tokens, padded_with_batch))

#How does that work with tl.Select?

### Broadcasting Practice

In [5]:
c10 = jnp.zeros((1, 10,))
r10 = jnp.zeros((10, 1,))



In [6]:
cr10 = c10 + r10

In [7]:
cr10.shape

(10, 10)

### Trax Practice

* TODO: read the [Layers Intro](https://trax-ml.readthedocs.io/en/latest/notebooks/layers_intro.html)
* Try to create a CreateFirst where m is the first thing returned. What happens
* Download the ungrade trax lab from Coursera NLP C4 W1 and experiment with it. 

In [43]:
# Let's create a similar network to what we have above but let's make it simple and deterministic

#TODO: could you change this to work with jax DeviceArrays just to see how it works? 

def Addition():
    layer_name = "Addition"  # don't forget to give your custom layer a name to identify

    # Custom function for the custom layer
    def func(x, y):
        return jnp.add(x,y)

    return tl.Fn(layer_name, func)

def Multiplication():
    layer_name = (
        "Multiplication"  # don't forget to give your custom layer a name to identify
    )

    # Custom function for the custom layer
    def func(x, y):
        return jnp.multiply(x,y) #element-wise
    
    return tl.Fn(layer_name, func)

# Simple layer to create one new argument. Similar stack effect to prepare_qkv_from_encoder_decoder
def CreateOne():
    layer_name = (
        "CreateOne"
    )

    def func(a,b,c):
        m = jnp.array([1000.0, 2000.0])
        return (a, b, c, m)
    
    return tl.Fn(layer_name, func)

def Divide():
    layer_name = (
        "Divide"
    )

    def func(a,b):
        return jnp.divide(a,b)
    
    return tl.Fn(layer_name, func)

def Noop():
    layer_name = (
        "Noop"
    )

    def func(a,b):
        return (a,b)
    
    return tl.Fn(layer_name, func)

simple = tl.Serial([
    tl.Select([0,1,0,1,0,1,0,1,0,1]),
    CreateOne()
    
])

topOfStack = tl.Serial([
    tl.Select([0,1,0,1]),
    CreateOne(),
    tl.Select([0], n_in=2) #wanted to set this to 4 but I get an error with any number about 2
])

invert = tl.Serial([
    tl.Select([1,0]), # 0 references top of the stack. first position is top of new stack.  
])

invert2 = tl.Serial([
    tl.Select([1,0]), 
    Noop(),
])

simple.n_out

8

In [44]:
in_tok = jnp.array([1.0, 2.0])
out_tok = jnp.array([100.0, 200.0])

simple((in_tok, out_tok))


#It appears that when you call serial
# - The first argument you give it goes on the top of the stack. 
# - The last argument goes on the bottom.


#note: in the returned output, the first x values are the ordered returned items from the last function.
#in this case it's a,b,c,m in spots 0-3 of the returned tuple below. 

(DeviceArray([1., 2.], dtype=float32),
 DeviceArray([100., 200.], dtype=float32),
 DeviceArray([1., 2.], dtype=float32),
 DeviceArray([1000., 2000.], dtype=float32),
 DeviceArray([100., 200.], dtype=float32),
 DeviceArray([1., 2.], dtype=float32),
 DeviceArray([100., 200.], dtype=float32),
 DeviceArray([1., 2.], dtype=float32),
 DeviceArray([100., 200.], dtype=float32),
 DeviceArray([1., 2.], dtype=float32),
 DeviceArray([100., 200.], dtype=float32))

In [35]:
topOfStack((in_tok, out_tok))

(DeviceArray([1., 2.], dtype=float32),
 DeviceArray([1., 2.], dtype=float32),
 DeviceArray([1000., 2000.], dtype=float32),
 DeviceArray([100., 200.], dtype=float32))

In [39]:
invert((in_tok, out_tok))

(DeviceArray([100., 200.], dtype=float32),
 DeviceArray([1., 2.], dtype=float32))

In [40]:
#this implies 
# - the first argument to the function is the top of the stack. 
# - the first returned value from a function goes in the first position on a stack. 
invert2((in_tok, out_tok))

(DeviceArray([100., 200.], dtype=float32),
 DeviceArray([1., 2.], dtype=float32))

##### Stack rules

* first argument to layer is top of stack
* first item in tuple of layer goes on top of stack.
* 0 in `tl.Select` refers to top of stack
* first position in input array to `tl.Select` is the top of the stack. 