In [1]:
## why not put attention heads as 3rd dimension? (BATCH_SIZE, ATTENTION_HEADS, SEQ_LEN, D_MODEL)
## Apply padding masks to multi head attention
## how to stack transformer encoder decoder layers

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# from transformers_si.model_blocks.transformer import Transformers

In [3]:
from utils.logging import logs

In [4]:
ROOT_DIR = '/Users/adityarustagi/Documents/self-implementations/'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Building Blocks

#### Scaled Dot Product Attention

In [5]:
class ScaledDotProductAttention(nn.Module) :

    def __init__(self, 
                 n_heads:int = 8,
                 d_model:int = 512,
                 mask:bool = False
        ) -> None :

        """
        Args:
            n_heads (int): Number of heads in the multi head attention. Defualts to 8
            d_model (int, optional): Dimension of the input. Defaults to 512.
            mask (bool, optional): Whether to apply masking. Defaults to False
        """

        super(ScaledDotProductAttention, self).__init__()

        self.n_heads = n_heads
        self.d_model = d_model
        self.mask = mask
        self.d_k = int(d_model/n_heads)

    def forward(self,
                key : torch.Tensor,
                query : torch.Tensor,
                value : torch.Tensor
        ) -> torch.Tensor :

        """
        Calculate scaler dot product of key, query and values as described in https://arxiv.org/pdf/1706.03762.pdf

        Args:
            key (torch.Tensor): Key tensor. Shape = (n_heads, batch_size, seq_len, d_model/n_heads)
            query (torch.Tensor): Query tensor. Shape = (n_heads, batch_size, seq_len, d_model/n_heads)
            value (torch.Tensor): Value tensor. Shape = (n_heads, batch_size, seq_len, d_model/n_heads)

        Returns:
            value_with_attention: Value with attention applied. Shape = (n_heads, batch_size, seq_len, d_model/n_heads)
        """

        # assert key.size() == query.size() == value.size(), "Key, query and value must have same shape"

        batch_size, seq_len = key.size(1), key.size(2)

        attention_scores = torch.matmul(query, key.transpose(2, 3))/torch.sqrt(torch.tensor(self.d_k))
        attention_scores = torch.softmax(attention_scores, dim = 3)
        
        if self.mask :
            mask = torch.ones(self.n_heads, batch_size, seq_len, seq_len)
            mask = torch.tril(mask)
            attention_scores = torch.matmul(attention_scores, mask)
            
        value_with_attention = torch.matmul(attention_scores, value)

        return value_with_attention, attention_scores

#### Multi Head Attention

In [6]:
class MultiHeadAttention(nn.Module) :
    
    def __init__(self, 
                 n_head: int = 8, 
                 d_model: int = 512, 
                 dropout: float = 0.1, 
                 mask: bool = False,
                 self_attention:bool = True
        ) :

        """
        Args:
            n_head (int): Number of heads. Defaults to 8.
            d_model (int): Dimension of input. Defaults to 512.
            dropout (float): Dropout rate. Defaults to 0.1.
            mask (bool): Whether to mask the attention. Defaults to False.
            self_attention (bool): Whether to use self attention. Defaults to True.
        """

        super(MultiHeadAttention, self).__init__()
        
        self.n_head = n_head
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        self.self_attention = self_attention

        self.d_k = self.d_v = d_model // n_head
        self.w_qs = nn.Linear(d_model, n_head * self.d_k)
        self.w_ks = nn.Linear(d_model, n_head * self.d_k)
        self.w_vs = nn.Linear(d_model, n_head * self.d_v)

        self.attention = ScaledDotProductAttention(n_head, d_model, mask)

        self.mha_linear = nn.Linear(d_model, d_model)

        nn.init.normal_(self.w_qs.weight, mean = 0, std = np.sqrt(2.0 / (d_model + self.d_k)))
        nn.init.normal_(self.w_ks.weight, mean = 0, std = np.sqrt(2.0 / (d_model + self.d_k)))
        nn.init.normal_(self.w_vs.weight, mean = 0, std = np.sqrt(2.0 / (d_model + self.d_v)))

    def forward(self, x, q = None) :

        """
        Implementation of multi head attention layer.

        Args:
            x (torch.Tensor): Padded input with the shaep batch_len, seq_len, d_model
            q (torch.Tensor): Query with the shape batch_size, seq_len, d_model. Defaults to None.
        
        Returns:
            torch.Tensor: Values with multiheadattention applied. Shape = (batch_size, seq_len, d_model)
        
        Raises:
            ValueError: If mode is cross attention and query passed in forward is None.
            ValueError: If mode is cross attention and shape of query is not same as input coming from encoder.
        
        References:
            https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/MultiHeadAttention.py
            https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Transformer.py
            https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/PositionalEncoding.py
        """
         
        if not self.self_attention:
            if q is None :
                raise ValueError("q is required for cross attention")
            # elif x.size() != q.size() :
            #     raise ValueError("q and X must have same size")
        else :
            q = x

        key = F.gelu(self.w_ks(x))
        query = F.gelu(self.w_qs(q))
        value = F.gelu(self.w_vs(x))

        ## keeping n_heads as major dimension
        key = key.view(-1, key.size(0), key.size(1), self.d_k)
        query = query.view(-1, query.size(0), query.size(1), self.d_k)
        value = value.view(-1, value.size(0), value.size(1), self.d_v)

        value, attention = self.attention(key, query, value)

        value = value.view(value.size(1), value.size(2), -1)

        # value = self.dropout(value)

        value = self.dropout(F.gelu(self.mha_linear(value)))

        return value, attention

#### Add and layer normalization

In [7]:
class AddLayerNormalization(nn.Module) :

    def __init__(self, d_model) :
        
        super().__init__()

        self.layer_norm = nn.LayerNorm([d_model])

    def forward(self, x, mha_output) :
        
        return self.layer_norm(x + mha_output)

#### Point Wise Feedforward

In [8]:
class PointWiseFeedforward(nn.Module) :

    def __init__(self, 
                 d_ff: int = 2048, 
                 d_model: int = 512
    ) -> None :
        
        """
        Args:
            d_ff (int): Intermediate size of the feedforward layer.
            d_model (int):  Size of the embeddings.
        """
        
        super(PointWiseFeedforward, self).__init__()

        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x) :
        
        linear1_output = self.linear1(x)
        linear2_output = self.linear2(F.gelu(linear1_output))

        return linear2_output
        

####  Single Encoder Layer

In [9]:
class EncoderLayer(nn.Module) :

    def __init__(self,
                 n_heads: int = 8,
                 d_model: int = 512,
                 d_ff: int = 2048
        ) -> None :
        
        super(EncoderLayer, self).__init__()
        
        self.mha = MultiHeadAttention(n_heads, d_model, )
        self.layer_norm = AddLayerNormalization(d_model)
        self.pff = PointWiseFeedforward(d_ff, d_model)
        self.layer_norm2 = AddLayerNormalization(d_model)

    def forward(self, x) :

        mha_output, mha_attention_scores = self.mha(x)
        # logs(f"mha_output shape: {mha_output.shape}")
        norm_output1 = self.layer_norm(x, mha_output)
        # logs(f"norm_output1 shape: {norm_output1.shape}")

        pff_output = self.pff(norm_output1)
        # logs(f"pff_output shape: {pff_output.shape}")
        norm_output2 = self.layer_norm2(norm_output1, pff_output)
        # logs(f"norm_output2 shape: {norm_output2.shape}")

        return norm_output2, mha_attention_scores

#### Single Decoder Layer

In [10]:
class DecoderLayer(nn.Module) :

    def __init__(self, 
                 n_heads,
                 d_model,
                 d_ff, 
    ) -> None :

        super(DecoderLayer, self).__init__()

        self.mha = MultiHeadAttention(n_head=n_heads, d_model=d_model, mask = True)
        self.cross_mha = MultiHeadAttention(n_head=n_heads, d_model=d_model, self_attention=False)
        self.layer_norm1 = AddLayerNormalization(d_model)
        self.layer_norm2 = AddLayerNormalization(d_model)
        self.layer_norm3 = AddLayerNormalization(d_model)
        self.pff = PointWiseFeedforward(d_ff, d_model)

    def forward(self, x, enc_out) :
        ## passing encoder output to all decoder layers : to be discussed with Deepak
        decoder_query, _ = self.mha(x)
        norm_decoder_query = self.layer_norm1(x, decoder_query)

        x, _ = self.cross_mha(enc_out, norm_decoder_query)
        norm_cross_x = self.layer_norm2(norm_decoder_query, x)

        x = self.pff(norm_cross_x)
        norm_decoder_output = self.layer_norm3(norm_cross_x, x)

        return norm_decoder_output

#### Position Embedding

In [11]:
class PositionEmbedding(nn.Module) :

    def __init__(self,
        max_seq_len: int = 128, 
        d_model: int = 512,
        dropout: int = 0.1
    ) :

        super(PositionEmbedding, self).__init__()

        self.embedding = torch.zeros(max_seq_len, d_model)
        self.dropout = nn.Dropout()
        
        for i in range(max_seq_len) :
            self.embedding[i, 0::2] = torch.sin((i/1000**(2*torch.arange(512)[::2]/512)))
            self.embedding[i, 1::2] = torch.cos((i/1000**(2*torch.arange(512)[1::2]/512)))

    def forward(self, x) :

        embedding = torch.repeat_interleave(self.embedding.unsqueeze(0), x.size(0), 0)

        return self.dropout(x + embedding[:, :x.size(1), :])

## Encoder

In [12]:
class Encoder(nn.Module) :

    def __init__(self,
                 n_layer: int = 6,
                 n_heads: int = 8,
                 d_model: int = 512,
                 d_ff: int = 2048
    ) :
        super(Encoder, self).__init__()
        
        self.encoder = nn.ModuleDict({
            f'encoder_layer_{i}' : 
            (
                EncoderLayer(
                    n_heads,
                    d_model,
                    d_ff
                )
            ) for i in range(n_layer)
            })

    def forward(self, x) :
        # logs(f'input size : {x.size()}')
        for name, layer in self.encoder.items() :
            x, attention_scores = layer(x)
            # logs(f'{name} output size : {x.size()}')
        return x, attention_scores

## Decoder

In [13]:
class Decoder(nn.Module) :

    def __init__(self,
                 n_layer: int = 6,
                 n_heads: int = 8,
                 d_model: int = 512,
                 d_ff: int = 2048
    ) -> None :

        super(Decoder, self).__init__()

        self.decoder = nn.ModuleDict({
            f'decoder_layer_{i}' :
            (
                DecoderLayer(
                    n_heads,
                    d_model,
                    d_ff
                )
            ) for i in range(n_layer)
        })


    def forward(self, x, enc_out) :

        for name, layer in self.decoder.items() :
            x = layer(x, enc_out)
            
        return x

# TRANSFORMER

In [14]:
class Transformers(nn.Module) :

    def __init__(self,
                 n_layer,
                 n_heads,
                 d_model,
                 d_ff,
                 max_seq_len,
                 vocab_size,
        ) -> None :

        super(Transformers, self).__init__()

        vocab_size = vocab_size + 2

        self.encoder = Encoder(n_layer, n_heads, d_model, d_ff)
        self.decoder = Decoder(n_layer, n_heads, d_model, d_ff)
        self.positonal_embedding = PositionEmbedding(max_seq_len, d_model)

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.logit_layer = nn.Linear(d_model, vocab_size)

        self.max_seq_len = max_seq_len


    def encoder_pass(self, x) :
        
        x = self.embedding(x)
        x = self.positonal_embedding(x)
        x = self.encoder(x)

        return x
    

    def decoder_pass(self, enc_output, input_ids) :

        x = self.embedding(input_ids)
        x = self.positonal_embedding(x)
        x = self.decoder(x, enc_output)

        next_token_logits = F.relu(self.logit_layer(x))
        next_token_logits = next_token_logits.reshape(-1, next_token_logits.size(2))

        return F.log_softmax(next_token_logits, dim=1)

    def forward(self, encoder_inp, decoder_inp) :

        enc_output, attention_scores = self.encoder_pass(encoder_inp)
        output = self.decoder_pass(enc_output, decoder_inp)

        return attention_scores, output

In [165]:
max_seq_len = 10
debug = True

def generate(model, enc_output, tokenizer, max_seq_len):
    
    input_ids = torch.tensor([[tokenizer.bos_token_id] for _ in range(enc_output.size(0))]).to(device)

    unfinished_sequences = torch.ones(input_ids.size(0), 1).to(device)
    eos_token_id_tensor = torch.tensor([tokenizer.eos_token_id]).to(device)
    # logs('unfinished_sequences size: {}'.format(unfinished_sequences.size()), debug)

    sentence_length = input_ids.size(1)

    while sentence_length <= max_seq_len :

        x = model.embedding(input_ids)
        x = model.positonal_embedding(x)
        x = model.decoder(x, enc_output)
        next_token_logits = x[:, -1, :]
        
        # logs(f'next_token_logits size: {next_token_logits.size()}', debug)

        next_token_logits = F.softmax(next_token_logits, dim=1)
        next_token_indices = torch.argmax(next_token_logits, dim = 1)

        # logs(f'next_token_indices post softmax size: {next_token_indices.size()}', debug)

        # logs(f'next_token_indices * unfinished_sequences size: {(next_token_indices.unsqueeze(1) * unfinished_sequences).size()}', debug)

        # logs(f'tokenizer.pad_token_id * (1 - unfinished_sequences) size: {(tokenizer.pad_token_id * (1 - unfinished_sequences)).size()}', debug)

        next_token_indices = (
            next_token_indices.unsqueeze(1) * unfinished_sequences + tokenizer.pad_token_id * (1 - unfinished_sequences)
        )
        unfinished_sequences = unfinished_sequences.mul(
            next_token_indices.tile(
                eos_token_id_tensor.shape[0]
            ).ne(eos_token_id_tensor).prod(dim = 0)
        )

        if unfinished_sequences.max() == 0 :
            break

        # print(input_ids.size())
        # print(next_token_indices.size())

        input_ids = torch.cat(
            (
                input_ids, 
                next_token_indices
            ), 
            dim = 1).long()

        sentence_length += 1
    
        print(input_ids.size(1))

    return input_ids


In [167]:
enc_output, attention_scores = transformers.encoder(torch.rand(16, 24, 512))

input_ids = generate(transformers, enc_output, tokenizer, max_seq_len)

2
3
4
5
6
7
8
9
10
11


In [168]:
tokenizer.batch_decode(input_ids)

['[BOS] Ç Æ ğ ğ ğ [unused24] µ M s µ',
 '[BOS] Ç [unused65] Λ ğ Џ [unused24] ł ğ µ Џ',
 '[BOS] Ç [unused9] t ğ [unused7] t M t [unused2] ğ',
 '[BOS] ū V Đ ğ Џ ğ l t µ ł',
 '[BOS] Õ F a ğ [unused24] [unused12] Æ ğ [unused24] µ',
 '[BOS] [unused9]! ğ ğ t Æ [unused12] ğ M ğ',
 '[BOS] ˈ < [unused9] [unused52] ¿ [unused12] ğ M t ł',
 '[BOS] ʁ s [unused9] M t [unused62] t ğ ς t',
 '[BOS] ø ő õ [unused52] t t ğ [unused24] t µ',
 '[BOS] ÿ Ţ [unused77] ğ [unused24] ğ [unused24] ğ M M',
 '[BOS] Î Æ ğ õ t [unused24] t t µ ł',
 '[BOS] ʁ " 6 ¦ M ğ [unused78] t [unused16] M',
 '[BOS] Õ Ə µ ğ [unused16] ğ [unused16] t ğ Џ',
 '[BOS] Ç O t [unused80] Æ t [unused24] t t ğ',
 '[BOS] ʁ µ [unused97] [unused89] [unused9] M t Џ ğ Џ',
 '[BOS] µ l ğ µ t t ğ M ğ t']

In [67]:
next_token_indices = torch.rand(5, 1)
unfinished_sequences = torch.tensor([1, 1, 0, 1, 0]).unsqueeze(1)

In [68]:
next_token_indices = (
    next_token_indices * unfinished_sequences + tokenizer.pad_token_id * (1 - unfinished_sequences)
)

In [69]:
next_token_indices

tensor([[0.8011],
        [0.6144],
        [0.0000],
        [0.3131],
        [0.0000]])

In [72]:
eos_token_id_tensor = torch.tensor([tokenizer.eos_token_id])

In [98]:
next_token_indices.tile(1, 1, 1).size()

torch.Size([1, 5, 1])

In [100]:
unfinished_sequences.mul(next_token_indices.tile(eos_token_id_tensor.shape[0]).ne(eos_token_id_tensor).prod(dim = 0))

tensor([[1],
        [1],
        [0],
        [1],
        [0]])

In [50]:
unfinished_sequences = unfinished_sequences.mul(
    next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)

In [64]:
next_token_indices * unfinished_sequences + tokenizer.pad_token_id*(1 - unfinished_sequences)

tensor([[0.4767],
        [0.0613],
        [0.0000],
        [0.6866],
        [0.0000]])

tensor([[0.4767],
        [0.0613],
        [0.0000],
        [0.6866],
        [0.0000]])

# Dataset

#### Prepare tokenizer

In [15]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')


tokenizer.add_special_tokens({
    'bos_token' : '[BOS]',
    'eos_token' : '[EOS]'
})


2

In [22]:
def get_data(file_location, chunksize=1000, n_chunks = 100) :

    data = []


    for i, items in enumerate(pd.read_csv(file_location, chunksize=chunksize)) :

        data.append(items)
        if i == n_chunks :
            break

    data = pd.concat(data)
    data.index = range(len(data))

    return data

In [23]:
data = get_data(file_location=os.path.join(ROOT_DIR, 'data/en-fr.csv'), n_chunks = 1)

In [24]:
class Data(Dataset) :

    def __init__(self, data) :
        self.data = data

    def __len__(self) :
        return len(self.data)
    
    def __getitem__(self, index) -> any:
        row = self.data.loc[index]
        return {'en' : row['en'], 'fr' : row['fr']}


In [25]:
def preprocess_batch(batch, type = 'input') :

    ## Append the BOS and EOS token based on wether the batch is the encoder input, decoder input(output shifted left)
    ## or the label (output shifted right)
    if type == 'input' :
        input_token_ids = [
            torch.cat(
                (torch.tensor([tokenizer.bos_token_id]), torch.tensor(inp), torch.tensor([tokenizer.eos_token_id])),
            ) for inp in batch['input_ids']
        ]

    elif type == 'output' :
        input_token_ids = [
            torch.cat(
                (torch.tensor([tokenizer.bos_token_id]), torch.tensor(inp)),
            ) for inp in batch['input_ids']
        ]

    elif type == 'label' :
        input_token_ids = [
            torch.cat(
                (torch.tensor(inp), torch.tensor([tokenizer.eos_token_id])),
            ) for inp in batch['input_ids']
        ]

    ## pad the token to the maxiumum sentence length
    input_token_ids = pad_sequence(input_token_ids, batch_first=True, padding_value = tokenizer.pad_token_id)

    return input_token_ids

# def collate_fn(samples):
    
#     eng_samples = [items['en'] for items in samples]
#     fr_samples = [items['fr'] for items in samples]

#     batch = {}

#     for language, sample in {'en' : eng_samples, 'fr' : fr_samples}.items() :

#         sample = tokenizer.batch_encode_plus(sample)
#         batch[language] = preprocess_batch(sample)

#     # samples['fr'] = tokenizer.batch_encode_plus(samples['fr'])
#     return batch  

In [26]:
dataset = Data(data)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Modelling

In [101]:
transformers = Transformers(
    n_layer=6,
    n_heads=8,
    d_model=512,
    d_ff=2048,
    max_seq_len=128,
    vocab_size = tokenizer.vocab_size
)

In [28]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(transformers.parameters(), lr = 1e-4)

In [29]:
#, collate_fn=collate_fn
debug = False
for idx, rows in enumerate(dataloader) :
    transformers.zero_grad()
    en_token_ids = tokenizer.batch_encode_plus(rows['en'], add_special_tokens = False)
    fr_token_ids = tokenizer.batch_encode_plus(rows['fr'], add_special_tokens = False)
    encoder_inp = preprocess_batch(en_token_ids, type='input')
    decoder_inp = preprocess_batch(fr_token_ids, type='output')
    label = preprocess_batch(fr_token_ids, type='label')
    # print(encoder_inp.size(), decoder_inp.size(), label.size())
    attention_scores, output = transformers(encoder_inp, decoder_inp)
    loss = criterion(output, label.reshape(-1))
    loss.backward()
    optimizer.step()
    print(loss.item())
    break
    

11.655684471130371
