In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import logging
import gc
from tqdm.notebook import tqdm, trange

In [2]:
def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
get_device()

device(type='cuda')

In [69]:
def collect_garbage():
    print(f'CPU memory            : {gc.collect()}')
    print(f'CUDA memory allocated : {torch.cuda.memory_allocated()}' )
    print(f'CUDA memory reserved  : {torch.cuda.memory_reserved()}')
    print(torch.cuda.empty_cache())

In [70]:
collect_garbage()

CPU memory            : 381
CUDA memory allocated : 698128896
CUDA memory reserved  : 826277888
None


In [6]:
def scaled_dot_product(q, k, v, mask = None ):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(d_k)
    # is permute(1,0,2,3) needed ?
    
    if mask is not None :
        # logging.debug(f"scaled type: {type(scaled)}")
        # logging.debug(f"scaled.size() : {scaled.size()}")
        # logging.debug(f"mask : {type(mask)}")
        # logging.debug(f"mask.size(): {mask.size()}")
        scaled += mask
    attention = F.softmax(scaled, dim = -1)
    values = torch.matmul(attention, v)
    return values, attention

In [7]:
q,k,v = [ torch.rand(15,4,96) for _ in range(3)]
print(q.shape, k.shape, v.shape)

torch.Size([15, 4, 96]) torch.Size([15, 4, 96]) torch.Size([15, 4, 96])


In [8]:
values, attention = scaled_dot_product(q,k,v)
print(values.size(), attention.size())

torch.Size([15, 4, 96]) torch.Size([15, 4, 4])


In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(d_model, 3*d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, mask = None ):
        logging.debug("MultiHeadAttention BEGINS")
        batch_size, max_sequence_length, d_model = x.size()
        logging.debug(f"x.size(): {x.size()}")
        logging.debug(f"mask.size() : {mask.size()}" if mask is not None else "mask is None")
        qkv = self.qkv_layer(x)
        logging.debug(f"qkv.size(): {qkv.size()}")
        qkv = qkv.reshape(batch_size, max_sequence_length, self.num_heads, 3 * self.head_dim )
        logging.debug(f"qkv.size(): {qkv.size()}")
        qkv = qkv.permute(0,2,1,3)
        logging.debug(f"qkv.size(): {qkv.size()}")
        q,k,v = qkv.chunk(3, dim = -1)
        logging.debug(f"q.size(): {q.size()}, k.size(): {k.size()}, v.size(): {v.size()}")
        values, attention = scaled_dot_product(q,k,v, mask = mask)
        logging.debug(f"values.size(): {values.size()}, attention.size(): {attention.size()}")
        values = values.reshape(batch_size, max_sequence_length, self.d_model)     
        logging.debug(f"values.size(): {values.size()}")
        out = self.linear_layer(values)
        logging.debug(f"out.size(): {out.size()}")
        logging.debug("MultiHeadAttention ENDS")
        
        return out

In [10]:
logging.getLogger().setLevel(logging.DEBUG)

In [11]:
t1 = torch.rand(15,4,768)
mha = MultiHeadAttention(d_model = 768, num_heads = 8)
with torch.no_grad():
    t2 = mha(t1)
print(t1.size(), t2.size())

DEBUG:root:MultiHeadAttention BEGINS
DEBUG:root:x.size(): torch.Size([15, 4, 768])
DEBUG:root:mask is None
DEBUG:root:qkv.size(): torch.Size([15, 4, 2304])
DEBUG:root:qkv.size(): torch.Size([15, 4, 8, 288])
DEBUG:root:qkv.size(): torch.Size([15, 8, 4, 288])
DEBUG:root:q.size(): torch.Size([15, 8, 4, 96]), k.size(): torch.Size([15, 8, 4, 96]), v.size(): torch.Size([15, 8, 4, 96])
DEBUG:root:values.size(): torch.Size([15, 8, 4, 96]), attention.size(): torch.Size([15, 8, 4, 4])
DEBUG:root:values.size(): torch.Size([15, 4, 768])
DEBUG:root:out.size(): torch.Size([15, 4, 768])
DEBUG:root:MultiHeadAttention ENDS


torch.Size([15, 4, 768]) torch.Size([15, 4, 768])


In [12]:
class LayerNormalization(nn.Module):
    def __init__(self, parameters_shape, eps = 1e-5):
        super(LayerNormalization, self).__init__()
        self.parameters_shape = parameters_shape
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta = nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, x):
        logging.debug("LayerNormalization BEGINS")
        dims = [-(i+1) for i in range(len(self.parameters_shape))]
        mean = x.mean(dim = dims, keepdim= True)
        logging.debug(f"mean.size(): {mean.size()}")
        var = ((x - mean)**2).mean(dim = dims, keepdim = True)
        logging.debug(f"var.size(): {var.size()}")
        std = (var + self.eps).sqrt()
        logging.debug(f"std.size(): {std.size()}")
        y = (x - mean) / std
        logging.debug(f"y.size() : {y.size()}")
        out = self.gamma * y + self.beta
        logging.debug(f"out.size(): {out.size()}")

        logging.debug("LayerNormalization ENDS")
        return out

In [13]:
t1 = torch.rand(15,4,768)
ln_layer = LayerNormalization(parameters_shape=[768], eps = 1e-5)
with torch.no_grad():
    t2 = ln_layer(t1)
print(t1.size(), t2.size())

DEBUG:root:LayerNormalization BEGINS
DEBUG:root:mean.size(): torch.Size([15, 4, 1])
DEBUG:root:var.size(): torch.Size([15, 4, 1])
DEBUG:root:std.size(): torch.Size([15, 4, 1])
DEBUG:root:y.size() : torch.Size([15, 4, 768])
DEBUG:root:out.size(): torch.Size([15, 4, 768])
DEBUG:root:LayerNormalization ENDS


torch.Size([15, 4, 768]) torch.Size([15, 4, 768])


In [14]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden, drop_prob = 0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p = drop_prob)

    def forward(self, x):
        logging.debug("PositionwiseFeedForward BEGINS")
        x = self.linear1(x)
        logging.debug(f"x.size(): {x.size()}")
        x = self.relu(x)
        logging.debug(f"x.size(): {x.size()}")
        x = self.dropout(x)
        logging.debug(f"x.size(): {x.size()}")
        x = self.linear2(x)
        logging.debug(f"x.size(): {x.size()}")
        logging.debug("PositionwiseFeedForward ENDS")
        
        return x

In [15]:
pff_layer = PositionwiseFeedForward(d_model = 768,
                                   hidden = 768,
                                   drop_prob=0.1)
t1 = torch.rand(15,4,768)
with torch.no_grad():
    t2 = pff_layer(t1)
print(t1.size(), t2.size())

DEBUG:root:PositionwiseFeedForward BEGINS
DEBUG:root:x.size(): torch.Size([15, 4, 768])
DEBUG:root:x.size(): torch.Size([15, 4, 768])
DEBUG:root:x.size(): torch.Size([15, 4, 768])
DEBUG:root:x.size(): torch.Size([15, 4, 768])
DEBUG:root:PositionwiseFeedForward ENDS


torch.Size([15, 4, 768]) torch.Size([15, 4, 768])


In [16]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(EncoderLayer, self).__init__()
        
        self.attention = MultiHeadAttention(d_model = d_model,
                                            num_heads=num_heads) 
        self.dropout1 = nn.Dropout(p = drop_prob)
        self.norm1 = LayerNormalization(parameters_shape = [d_model])
       
        self.ffn = PositionwiseFeedForward(d_model = d_model,
                                           hidden = ffn_hidden,
                                           drop_prob = drop_prob)
        self.dropout2 = nn.Dropout(p = drop_prob)
        self.norm2 = LayerNormalization(parameters_shape = [d_model]) 
       

    def forward(self, x, self_attention_mask):
        logging.debug("EncoderLayer BEGINS")
        r_x = x
        x = self.attention(x, mask = self_attention_mask)
        # logging.debug(f"x.size() : {x.size()}")
        x = self.dropout1(x)
        x = self.norm1(x + r_x)

        r_x = x
        x = self.ffn(x)
        x = self.dropout2(x)
        x = self.norm2(x + r_x)
        logging.debug("EncoderLayer ENDS")

        return x
        

In [17]:
enc_layer = EncoderLayer(d_model = 768,
                        ffn_hidden = 768,
                        num_heads = 8,
                        drop_prob = 0.1)
t1 = torch.rand(15,4,768)
# self_attention_mask_t2 = torch.rand(15,4,4)
self_attention_mask_t2 = torch.rand(15,8,4,4)
with torch.no_grad():
    t2 = enc_layer(t1, self_attention_mask_t2)
print(t1.size(), t2.size())

DEBUG:root:EncoderLayer BEGINS
DEBUG:root:MultiHeadAttention BEGINS
DEBUG:root:x.size(): torch.Size([15, 4, 768])
DEBUG:root:mask.size() : torch.Size([15, 8, 4, 4])
DEBUG:root:qkv.size(): torch.Size([15, 4, 2304])
DEBUG:root:qkv.size(): torch.Size([15, 4, 8, 288])
DEBUG:root:qkv.size(): torch.Size([15, 8, 4, 288])
DEBUG:root:q.size(): torch.Size([15, 8, 4, 96]), k.size(): torch.Size([15, 8, 4, 96]), v.size(): torch.Size([15, 8, 4, 96])
DEBUG:root:values.size(): torch.Size([15, 8, 4, 96]), attention.size(): torch.Size([15, 8, 4, 4])
DEBUG:root:values.size(): torch.Size([15, 4, 768])
DEBUG:root:out.size(): torch.Size([15, 4, 768])
DEBUG:root:MultiHeadAttention ENDS
DEBUG:root:LayerNormalization BEGINS
DEBUG:root:mean.size(): torch.Size([15, 4, 1])
DEBUG:root:var.size(): torch.Size([15, 4, 1])
DEBUG:root:std.size(): torch.Size([15, 4, 1])
DEBUG:root:y.size() : torch.Size([15, 4, 768])
DEBUG:root:out.size(): torch.Size([15, 4, 768])
DEBUG:root:LayerNormalization ENDS
DEBUG:root:Positionwis

torch.Size([15, 4, 768]) torch.Size([15, 4, 768])


In [18]:
class SequentialEncoder(nn.Sequential):
    def forward(self, *inputs):
        x, self_attention_mask = inputs
        for module in self._modules.values():
            x = module(x, self_attention_mask)
        return x

In [19]:
class Encoder(nn.Module):
    def __init__(self,
                 d_model,
                 ffn_hidden,
                 num_heads,
                 drop_prob,
                 num_layers):
        super().__init__()
        self.layers = SequentialEncoder(*[
                EncoderLayer(
                    d_model=d_model,
                    ffn_hidden = ffn_hidden,
                    num_heads = num_heads,
                    drop_prob = drop_prob,
                    )
                for _ in range(num_layers)
            ])

    def forward(self, x, self_attention_mask):
        x = self.layers(x,self_attention_mask)
        return x

In [20]:
logging.getLogger().setLevel(logging.DEBUG)

In [21]:
logging.getLogger().setLevel(logging.WARN)

In [22]:
t1 = torch.rand(1,4,512)
self_attention_mask_t2 = torch.rand(1,4,4)
enc = Encoder(d_model = 512,
              ffn_hidden = 512,
              num_heads = 8,
              drop_prob = 0.1,
              num_layers = 6)
with torch.no_grad():
    t2 = enc(t1, self_attention_mask_t2)
print(t1.size(), t2.size())

torch.Size([1, 4, 512]) torch.Size([1, 4, 512])


In [23]:
class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadCrossAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.kv_layer = nn.Linear(d_model, 2*d_model)
        self.q_layer = nn.Linear(d_model, d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, y, mask):
        logging.debug("MultiHeadCrossAttention BEGINS")
        batch_size, sequence_length, d_model = x.size()
        kv = self.kv_layer(x)
        q = self.q_layer(y)
        kv = kv.reshape(batch_size, sequence_length, self.num_heads, 2*self.head_dim)
        q = q.reshape(batch_size, sequence_length, self.num_heads, self.head_dim)
        kv = kv.permute(0,2,1,3)
        q = q.permute(0,2,1,3)
        k,v = kv.chunk(2, dim = -1)
        """ We don't need the mask in cross attention, removing in outerfunction but why ?"""
        values, attention = scaled_dot_product(q,k,v,mask = mask)
        values = values.permute(0,2,1,3)
        values = values.reshape(batch_size, sequence_length, d_model)
        out = self.linear_layer(values)
        logging.debug("MultiHeadCrossAttention ENDS")
        return out
        

In [24]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(DecoderLayer, self).__init__()
        
        self.self_attention = MultiHeadAttention(d_model = d_model,
                                                 num_heads = num_heads,
                                                )
        self.dropout1 = nn.Dropout(p = drop_prob)
        self.norm1 = LayerNormalization(parameters_shape=[d_model])

        self.encoder_decoder_attention = MultiHeadCrossAttention(d_model = d_model,
                                                                num_heads = num_heads)
        self.dropout2 = nn.Dropout(p = drop_prob)
        self.norm2 = LayerNormalization(parameters_shape=[d_model])

        self.ffn = PositionwiseFeedForward(d_model = d_model,
                                           hidden= ffn_hidden,
                                           drop_prob = drop_prob)
        self.dropout3 = nn.Dropout(p = drop_prob)
        self.norm3 = LayerNormalization(parameters_shape = [d_model])

    def forward(self, x, y, self_attention_mask, cross_attention_mask):
        r_y = y
        y = self.self_attention(y, mask = self_attention_mask)
        y = self.dropout1(y)
        y = self.norm1(y + r_y)

        r_y = y
        y = self.encoder_decoder_attention(x,y,mask = cross_attention_mask)
        y = self.dropout2(y)
        y = self.norm2(y + r_y)

        r_y = y
        y = self.ffn(y)
        y = self.dropout3(y)
        y = self.norm3(y + r_y)

        return y

In [25]:
class SequentialDecoder(nn.Sequential):
    def forward(self, *inputs):
        x, y, self_attention_mask, cross_attention_mask = inputs
        for module in self._modules.values():
            y = module(x, y, self_attention_mask, cross_attention_mask)
        return y

In [26]:
class Decoder(nn.Module):
    def __init__(self,
                 d_model,
                 ffn_hidden,
                 num_heads,
                 drop_prob,
                 num_layers,):
        super(Decoder,self).__init__()
        self.layers = SequentialDecoder(*[
            DecoderLayer(d_model=d_model,
                         ffn_hidden=ffn_hidden,
                         num_heads= num_heads,
                         drop_prob = drop_prob,
                        )
            for _ in range(num_layers)
        ])
    def forward(self,
                x,
                y,
                self_attention_mask,
                cross_attention_mask):
        logging.debug("Decoder BEGINS")
        y = self.layers(x, y, self_attention_mask, cross_attention_mask)
        logging.debug("Decoder ENDS")
        return y

In [27]:
logging.getLogger().setLevel(logging.DEBUG)

In [28]:
logging.getLogger().setLevel(logging.WARN)

In [29]:
t1_x = torch.rand(1,4,512)
t1_y = torch.rand(1,4,512)
t1_self_attention_mask = torch.rand(1,4,4)
t1_cross_attention_mask = torch.rand(1,4,4)

dec = Decoder(d_model=512,
              ffn_hidden=512,
              num_heads=8,
              drop_prob=0.1,
              num_layers = 4)
with torch.no_grad():
    t2 = dec(x = t1_x,
             y = t1_y,
             self_attention_mask = t1_self_attention_mask,
             cross_attention_mask = t1_cross_attention_mask,)
    print(t1.shape, t2.shape)


torch.Size([1, 4, 512]) torch.Size([1, 4, 512])


In [30]:
class Transformer(nn.Module):
    def __init__(self,
                 d_model,
                 ffn_hidden,
                 num_heads,
                 drop_prob,
                 num_layers,):
        super(Transformer, self).__init__()
        self.encoder = Encoder(d_model = d_model,
                              ffn_hidden = ffn_hidden,
                              num_heads = num_heads,
                              drop_prob = drop_prob,
                              num_layers = num_layers)
        self.decoder = Decoder(d_model = d_model,
                              ffn_hidden = ffn_hidden,
                              num_heads = num_heads,
                              drop_prob = drop_prob,
                              num_layers = num_layers)

    def forward(self,
                x,
                y,
                encoder_self_attention_mask = None,
                decoder_self_attention_mask = None,
                decoder_cross_attention_mask = None,):
        logging.debug("Transformer BEGINS")
        x = self.encoder(x,
                         encoder_self_attention_mask)
        out = self.decoder(x,
                           y,
                           decoder_self_attention_mask,
                           decoder_cross_attention_mask)
        logging.debug("Transformer ENDS")
        # need a linear layer that maps to vocabulary size
        return out

In [31]:
trfm = Transformer(d_model = 768,
                  ffn_hidden = 768,
                  num_heads = 8,
                  drop_prob = 0.1,
                  num_layers = 1)

t1_x = torch.rand(1,4,768)
t1_y = torch.rand(1,4,768)
t1_encoder_self_attention = torch.rand(1,4,4)
t1_decoder_self_attention = torch.rand(1,4,4)
t1_decoder_cross_attention = torch.rand(1,4,4)

with torch.no_grad():
    t2 = trfm(t1_x,
             t1_y,
             t1_encoder_self_attention,
             t1_decoder_self_attention,
             t1_decoder_cross_attention)
    print(t1_x.shape, t2.shape)

torch.Size([1, 4, 768]) torch.Size([1, 4, 768])


In [49]:
def train_model(model:nn.Module, num_epochs: int, device: torch.device):
    # number of sentences in text
    batch_size = 15
    # dimensions of each word
    d_model = 768
    # number of words in a sentence
    sequence_length = 4
    # number of heads in multi head attention
    num_heads = 8

    # train and validation losses
    train_losses, val_losses = [],[]

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr= 0.001)
    
    for epoch in trange(num_epochs, desc="Epochs"):
        """ Training Phase """
        # Set model to train mode
        model.train()
        running_loss = 0.0
        
        x = torch.rand(batch_size,sequence_length,d_model).to(device)
        y = torch.rand(batch_size,sequence_length,d_model).to(device)
        encoder_self_attention_mask = torch.rand(batch_size, num_heads, sequence_length, sequence_length).to(device)
        decoder_self_attention_mask = torch.rand(batch_size, num_heads, sequence_length, sequence_length).to(device)
        decoder_cross_attention_mask = torch.rand(batch_size, num_heads, sequence_length, sequence_length).to(device)

        optimizer.zero_grad()
        output = model(x, y, encoder_self_attention_mask, decoder_cross_attention_mask, decoder_cross_attention_mask)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()

        # print(f'loss.item(): {loss.item()}')
        # running_loss += loss.item() * y.size(0)
        running_loss += loss.item()

        train_loss = running_loss / 1
        train_losses.append(train_loss)

        """ Validation Phase """
        model.eval()
        running_loss = 0.0

        with torch.no_grad():
            # reusing past values
            x = torch.rand(batch_size,sequence_length, d_model).to(device)
            y = torch.rand(batch_size,sequence_length,d_model).to(device)
            encoder_self_attention_mask = torch.rand(batch_size, num_heads, sequence_length, sequence_length).to(device)
            decoder_self_attention_mask = torch.rand(batch_size, num_heads, sequence_length, sequence_length).to(device)
            decoder_cross_attention_mask = torch.rand(batch_size, num_heads, sequence_length, sequence_length).to(device)

            output = model(x, y, encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask)
            loss = criterion(output, y)

        val_loss = running_loss / 1
        val_losses.append(val_loss)

        # Log epoch stats
        logging.info(f"Epoch {epoch+1}/{num_epochs} ; Train loss : {train_loss} ; Valid loss : {val_loss}")

    # returning losses
    return train_losses, val_losses

In [50]:
trfm_model = Transformer(d_model=768,
                   ffn_hidden=768,
                   num_heads = 8,
                   drop_prob=0.1,
                   num_layers=6).to(get_device())

In [51]:
# print(trfm_model)
# with open("model.txt", "w") as f:
#     f.write(str(trfm_model))

In [52]:
train_losses, val_losses = train_model(model = trfm_model, num_epochs = 100, device = get_device())

Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
plt.figure(figsize=(6,4))
plt.plot(train_losses, label = "Training loss")
plt.plot(val_losses, label="Validation loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [56]:
collect_garbage()

0
None
