$$
\newcommand{\mat}[1]{\boldsymbol {#1}}
\newcommand{\mattr}[1]{\boldsymbol {#1}^\top}
\newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}}
\newcommand{\vec}[1]{\boldsymbol {#1}}
\newcommand{\vectr}[1]{\boldsymbol {#1}^\top}
\newcommand{\rvar}[1]{\mathrm {#1}}
\newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}}
\newcommand{\diag}{\mathop{\mathrm {diag}}}
\newcommand{\set}[1]{\mathbb {#1}}
\newcommand{\norm}[1]{\left\lVert#1\right\rVert}
\newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}}
\newcommand{\bb}[1]{\boldsymbol{#1}}
$$

# CS236781: Deep Learning
# Tutorial 7: Attention

## Introduction

In this tutorial, we will cover:

TODO

In [1]:
# Setup
%matplotlib inline
import os
import sys
import time
import torch
import matplotlib.pyplot as plt

In [2]:
plt.rcParams['font.size'] = 20
data_dir = os.path.expanduser('~/.pytorch-datasets')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Theory Reminders

## Attention

Intuitively, some parts of the input may be more important than others.

An **Attention** mechanism, allows the model to "focus" on, i.e. give a *greater weight* to
different parts of the input or some other intermetiate part of the model.

Example from an image captioning [paper](https://arxiv.org/pdf/1502.03044.pdf) (K. Xu et al. 2015):

<img src="img/attn_ic1.png" width="900"/>

<img src="img/attn_ic2.png" width="800"/>


### Input soft attention

One place to apply attention is to the **input features**.

In the context of our RNN model, we can change it's hidden state update to:


$$
\begin{align}
\vec{a}_t &= \sigma\left( \mat{W}_{ha} \vec{h}_{t-1} + \mat{W}_{xa} \vec{x}_t+ \vec{b}_a\right) \\
\vec{g}_t &= \mathrm{softmax}(\alpha \vec{a}_t) \\
\vec{h}_t &= \varphi_h\left( \mat{W}_{hh} \vec{h}_{t-1} + \mat{W}_{xh} (\vec{x}_t \odot \vec{g}_t)+ \vec{b}_h\right) \\
\end{align}
$$


In [3]:
import torch.nn as nn

class RNNLayerInputAttn(nn.Module):
    def __init__(self, in_dim, h_dim, out_dim, phi_h=torch.tanh, phi_y=torch.sigmoid):
        super().__init__()
        self.phi_h, self.phi_y = phi_h, phi_y
        
        # Attention parameters
        self.fc_xa = nn.Linear(in_dim, in_dim, bias=False)
        self.fc_ha = nn.Linear(h_dim, in_dim, bias=True)
        
        # Regular RNN parameters
        self.fc_xh = nn.Linear(in_dim, h_dim, bias=False)
        self.fc_hh = nn.Linear(h_dim, h_dim, bias=True)
        self.fc_hy = nn.Linear(h_dim, out_dim, bias=True)
        
    def forward(self, xt, h_prev=None):
        if h_prev is None:
            h_prev = torch.zeros(xt.shape[0], self.fc_hh.in_features)
            
        # Calculate the attention gating gt: a weight for each feature of x
        at = torch.sigmoid(self.fc_xa(xt) + self.fc_ha(h_prev))
        gt = torch.softmax(at, dim=1)
        
        # Apply regular RNN with gated input
        ht = self.phi_h(self.fc_xh(xt * gt) + self.fc_hh(h_prev))
        
        yt = self.fc_hy(ht)
        
        if self.phi_y is not None:
            yt = self.phi_y(yt)
        
        return yt, ht
        

We can interpret this as a soft (differentiable) gating of the input.

This makes sense for image captioning, where we want to emphasize image regions based on their feature maps.

What about our sentiment analysis task?

### Self attention

Another place to apply attention in the context of RNNs is to the **hidden states**.

In an ICLR 2017 [paper](https://arxiv.org/pdf/1703.03130.pdf), Lin et al. proposed
an attention for sentiment analysis.

<img src="img/self_attn_sa.png" width="700" />

The problem with applying attention to the hidden state vectors, is that their number changes each batch,
depdending on the sentence length.

This approach creates a **sentence embedding** $M$ of a fixed size:

$$
\begin{align}
\mat{H}_T &= \sigma\left[ \vectr{h}_1; \dots; \vectr{h}_T \right] \in\set{R}^{T\times d_h}\\
\mat{A} &= \mathrm{softmax}\left(\mat{W}_{s2} \tanh\left( \mat{W}_{s1} \mattr{H}_T \right) \right),\ 
\mat{W}_{s1}\in\set{R}^{d_a \times d_h},\ \mat{W}_{s2}\in\set{R}^{r \times d_a} \\
\mat{M} &= \mat{A}\mat{H}_T \in\set{R}^{r\times d_h}
\end{align}
$$


The sentence embedding $M$ is then fed into an FC classifier to produce the prediction.

*Self excercise:* Modify our `SentimentRNN` and add the Self-Attantion layer.

In [4]:
from torch import Tensor

def valid_softmax2d(X: Tensor, valid_len: Tensor=None):
    if valid_len is None:
        return torch.softmax(X, dim=-1)
    
    assert X.ndim == 2
    assert valid_len.ndim == 1
    assert X.shape[0] == valid_len.shape[0]
    
    B, S = X.shape
    mask = torch.arange(S)[None,:] < valid_len[:,None]
    
    X = X.to(torch.float)
    X[~mask] = float('-inf')
    return torch.softmax(X, dim=-1)
    
valid_softmax2d(torch.ones(3,4), valid_len=torch.tensor([3, 2, 1]))

#%%

def valid_softmax3d(X: Tensor, valid_len: Tensor=None):
    """
    Applied masked softmax on the last dimension of a 3d tensor.
    :param X: A 3d Tensor fo shape (B,E,S).
    :param valid_len: A tensor of shape (B,) representing the valid length 
    for each sample. Can be None, which means softmax will be applied 
    without any masking.
    :return: X after setting elements after the valid_length to -inf 
    on the last dimension and applying softmax.
    """
    assert X.ndim == 3
    B, E, S = X.shape
    X2d = X.view(B*E, S)
    
    if valid_len is not None:
        valid_len = valid_len.repeat_interleave(E, dim=0)
        
    X2d_s = valid_softmax2d(X2d, valid_len)
    return X2d_s.view(B,E,S)

#%%

X = torch.rand(3,4,5)
v = torch.tensor([2,3,1])

valid_softmax3d(X, v)


tensor([[[0.5363, 0.4637, 0.0000, 0.0000, 0.0000],
         [0.4384, 0.5616, 0.0000, 0.0000, 0.0000],
         [0.5111, 0.4889, 0.0000, 0.0000, 0.0000],
         [0.3149, 0.6851, 0.0000, 0.0000, 0.0000]],

        [[0.2433, 0.3085, 0.4482, 0.0000, 0.0000],
         [0.2871, 0.3667, 0.3463, 0.0000, 0.0000],
         [0.2354, 0.3299, 0.4346, 0.0000, 0.0000],
         [0.3343, 0.4579, 0.2078, 0.0000, 0.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])

In [5]:
class MLPAttention(nn.Module):
    def __init__(self, q_dim, k_dim, v_dim, h_dim, dropout=0.):
        super().__init__()
        self.wk = nn.Linear(k_dim, h_dim, bias=False)
        self.wq = nn.Linear(q_dim, h_dim, bias=False)
        self.v  = nn.Linear(h_dim, 1, bias=False)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q: Tensor, k: Tensor, v: Tensor, valid_len: Tensor=None):
        """
        :param q: Queries tensor of shape (B, Q, q_dim)
        :param k: Keys tensor of shape (B, KV, k_dim)
        :param v: Values tensor of shape (B, KV, v_dim)
        :param valid_len: Sequence lengths tensor of shape (B,).
        :return: Attended values tensor, of shape (B, Q, v_dim).
        """
        # (B, KV, k_dim) -> (B, KV, h_dim) -> (B, 1, KV, h_dim)
        wk_k = self.wk(k).unsqueeze(1)
        
        # (B, Q, q_dim)  -> (B, Q, h_dim)  -> (B, Q, 1, h_dim)
        wq_q = self.wq(q).unsqueeze(2)
        
        # (B, Q, KV, h_dim)
        z1 = torch.tanh(wq_q + wk_k)
        
        # (B, Q, KV, 1) -> (B, Q, KV)
        z2 = self.v(z1).squeeze(dim=-1)
        
        a = valid_softmax3d(z2, valid_len)
        a = self.dropout(a)
        
        # (B, Q, KV) * (B, KV, v_dim) = (B, Q, v_dim)
        return torch.bmm(a, v)

In [6]:
    
keys = torch.ones(2, 10, 2, dtype=torch.float) 
vals = torch.arange(40, dtype=torch.float).reshape(1, 10, 4).repeat_interleave(2, dim=0)
ques = torch.ones(2, 1, 2, dtype=torch.float)

mlp_attn = MLPAttention(ques.shape[-1], keys.shape[-1], vals.shape[-1], 100, dropout=0.1)
out = mlp_attn(ques, keys, vals, valid_len=torch.tensor([1, 6]))
print(out, out.shape)


tensor([[[ 0.0000,  1.1111,  2.2222,  3.3333]],

        [[ 7.4074,  8.3333,  9.2593, 10.1852]]], grad_fn=<BmmBackward>) torch.Size([2, 1, 4])


#### Back to sentiment analysis

In [7]:
import torchtext.data
import torchtext.datasets

# torchtext Field objects parse text (e.g. a review) and create a tensor representation
# This Field object will be used for tokenizing the movie reviews text
review_parser = torchtext.data.Field(
    sequential=True, use_vocab=True, lower=True,
    init_token='<sos>', eos_token='<eos>', dtype=torch.long,
    tokenize='spacy', tokenizer_language='en_core_web_sm'
)

# This Field object converts the text labels into numeric values (0,1,2)
label_parser = torchtext.data.Field(
    is_target=True, sequential=False, unk_token=None, use_vocab=True
)


# Load SST, tokenize the samples and labels
# ds_X are Dataset objects which will use the parsers to return tensors
ds_train, ds_valid, ds_test = torchtext.datasets.SST.splits(
    review_parser, label_parser, root=data_dir
)

review_parser.build_vocab(ds_train)
label_parser.build_vocab(ds_train)

n_train = len(ds_train)
print(f'Number of training samples: {n_train}')
print(f'Number of test     samples: {len(ds_test)}')

Number of training samples: 8544
Number of test     samples: 2210


In [59]:
BATCH_SIZE = 4

# BucketIterator creates batches with samples of similar length
# to minimize the number of <pad> tokens in the batch.
dl_train, dl_valid, dl_test = torchtext.data.BucketIterator.splits(
    (ds_train, ds_valid, ds_test), batch_size=BATCH_SIZE,
    shuffle=True, device=device)

In [52]:
class SentimentRNN(nn.Module):
    def __init__(self, in_dim, embedding_dim, h_dim, out_dim):
        super().__init__()
        
        # nn.Embedding converts from token index to dense tensor
        self.embedding = nn.Embedding(in_dim, embedding_dim)
        
        # PyTorch multilayer GRU RNN
        self.rnn = nn.GRU(embedding_dim, h_dim, num_layers=2, bias=False)
        
        # Our custom MLP attention
        self.attn = MLPAttention(h_dim, h_dim, h_dim, h_dim)
        
        # Output layer to create class scores
        self.out_fc = nn.Linear(h_dim, out_dim, bias=True)
        
        # To convert class scores to log-probability we'll apply log-softmax
        self.log_softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, X):
        # X shape: (S, B) Note batch dim is not first!
        S, B = X.shape
        embedded = self.embedding(X) # embedded shape: (S, B, E)
        
        # GRU returs all hidden states (S, B, H)
        h, _ = self.rnn(embedded)
        
        # Transpose to (B, S, H)
        h = h.transpose(0, 1)
        
        # Apply self-attention to hidden states -> (B, S, H) and transpose -> (B, H, S)
        # This gives us S self-weighted hidden states
        a = self.attn(h, h, h)
        a = a.transpose(-2, -1)
        
        # Create sentence embedding: apply attention to hidden states
        # (B, H, S) * (B, S, H) -> (B, H, H) -> (B, H)
        # m = torch.bmm(a, h).view(B, -1)
        
        # Create sentence embedding: average weighted hidden states over sequence
        # (B, H, S) -> (B, H)
        m = torch.mean(a, dim=2)
        
        # Create output scores: (B, out_dim)
#         print('a.shape', a.shape)
#         print('h.shape', h.shape)
#         print('m.shape', m.shape)
        
        yhat = self.out_fc(m)
        
        # Class scores to log-probability
        yhat_log_proba = self.log_softmax(yhat)
        
        return yhat_log_proba

In [53]:
INPUT_DIM = len(review_parser.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 128
OUTPUT_DIM = 3

model = SentimentRNN(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)
model

SentimentRNN(
  (embedding): Embedding(15482, 100)
  (rnn): GRU(100, 128, num_layers=2, bias=False)
  (attn): MLPAttention(
    (wk): Linear(in_features=128, out_features=128, bias=False)
    (wq): Linear(in_features=128, out_features=128, bias=False)
    (v): Linear(in_features=128, out_features=1, bias=False)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (out_fc): Linear(in_features=128, out_features=3, bias=True)
  (log_softmax): LogSoftmax()
)

In [54]:
x0, y0 = next(iter(dl_train))

In [55]:
x0.shape

torch.Size([51, 4])

In [56]:
yhat0 = model(x0)
print(yhat0, yhat0.shape)

tensor([[-1.1210, -0.9384, -1.2630],
        [-1.0988, -0.9599, -1.2595],
        [-1.1440, -1.0450, -1.1094],
        [-1.1052, -0.9209, -1.3068]], grad_fn=<LogSoftmaxBackward>) torch.Size([4, 3])


In [57]:
def train(model, optimizer, loss_fn, dataloader, max_epochs=4, max_batches=200):
    for epoch_idx in range(max_epochs):
        total_loss, num_correct = 0, 0
        start_time = time.time()

        for batch_idx, batch in enumerate(dataloader):
            X, y = batch.text, batch.label

            # Forward pass
            y_pred_log_proba = model(X)

            # Backward pass
            optimizer.zero_grad()
            loss = loss_fn(y_pred_log_proba, y)
            loss.backward()

            # Weight updates
            optimizer.step()

            # Calculate accuracy
            total_loss += loss.item()
            y_pred = torch.argmax(y_pred_log_proba, dim=1)
            num_correct += torch.sum(y_pred == y).float().item()

            if batch_idx == max_batches-1:
                break
                
        print(f"Epoch #{epoch_idx}, loss={total_loss /(max_batches):.3f}, accuracy={num_correct /(max_batches*BATCH_SIZE):.3f}, elapsed={time.time()-start_time:.1f} sec")

In [58]:
import torch.optim as optim

rnn_model = SentimentRNN(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM).to(device)

optimizer = optim.Adam(rnn_model.parameters(), lr=1e-3)

# Recall: LogSoftmax + NLL is equiv to CrossEntropy on the class scores
loss_fn = nn.NLLLoss()

train(rnn_model, optimizer, loss_fn, dl_train)

Epoch #0, loss=1.048, accuracy=0.424, elapsed=7.7 sec
Epoch #1, loss=1.040, accuracy=0.469, elapsed=7.9 sec
Epoch #2, loss=1.025, accuracy=0.471, elapsed=7.7 sec
Epoch #3, loss=1.028, accuracy=0.476, elapsed=8.4 sec


In [None]:
class Seq2SeqEncoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super(Seq2SeqEncoder, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = rnn.LSTM(num_hiddens, num_layers, dropout=dropout)

    def forward(self, X, *args):
        X = self.embedding(X)  # X shape: (batch_size, seq_len, embed_size)
        # RNN needs first axes to be timestep, i.e., seq_len
        X = X.swapaxes(0, 1)
        state = self.rnn.begin_state(batch_size=X.shape[1], ctx=X.context)
        out, state = self.rnn(X, state)
        # out shape: (seq_len, batch_size, num_hiddens)
        # state shape: (num_layers, batch_size, num_hiddens),
        # where "state" contains the hidden state and the memory cell
        return out, state

**Image credits**

Some images in this tutorial were taken and/or adapted from:

- Zhang et al., Dive into Deep Learning, 2019
- Fundamentals of Deep Learning, Nikhil Buduma, Oreilly 2017
- Andrej Karpathy, http://karpathy.github.io
- MIT 6.S191
- Stanford cs231n
- K. Xu et al. 2015, https://arxiv.org/abs/1502.03044