In [29]:
import torch
import torch.nn as nn
import numpy as np
from datasets import load_dataset

In [30]:
!pip install --upgrade datasets
dataset = load_dataset("cnn_dailymail","3.0.0")



In [31]:
!pip install git
!git clone https://github.com/osamakhaled123/Basic-Transformer-Model
!cd '/content/Basic-Transformer-Model'
%cd Basic-Transformer-Model/

[31mERROR: Could not find a version that satisfies the requirement git (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for git[0m[31m
[0mCloning into 'Basic-Transformer-Model'...
remote: Enumerating objects: 64, done.[K
remote: Counting objects: 100% (64/64), done.[K
remote: Compressing objects: 100% (50/50), done.[K
remote: Total 64 (delta 30), reused 46 (delta 14), pack-reused 0 (from 0)[K
Receiving objects: 100% (64/64), 34.82 KiB | 2.68 MiB/s, done.
Resolving deltas: 100% (30/30), done.
/content/Basic-Transformer-Model/Basic-Transformer-Model


In [32]:
import Processing_Summarizing_Datasets_From_Scratch as pre

In [33]:
articles = np.array(dataset['train'][:14000]['article'])
summaries = np.array(dataset['train'][:14000]['highlights'])
texts = {'articles':articles,
         'summaries':summaries}
max_input_len = 1000
max_target_len = 85

In [34]:
text_train, train_data, train_target, vocab = pre.preprocessing(texts, max_input_len, max_target_len, {})

In [35]:
text_train.shape, train_data.shape, train_target.shape

((27932,), torch.Size([13966, 1000]), torch.Size([13966, 85]))

In [36]:
train_set, val_set, test_set = pre.splitting_and_batching(input_data=train_data, target_data=train_target,
                                                          split_frac=0.9, batch_size=32)

#Positional Encoding

In [37]:
class positional_encoding(nn.Module):
    def __init__(self, max_length, emb_dim):
        super(positional_encoding, self).__init__()

        self.pos_enc = torch.zeros(size=(max_length, emb_dim))
        self.pos=torch.arange(1, max_length+1, dtype=torch.float32).unsqueeze(1).repeat(1,emb_dim)
        self.equation = torch.pow(max_length, (2*torch.arange(emb_dim))/emb_dim).tile(max_length,1)

    def forward(self):
        self.pos_enc[:,0::2] = torch.sin(self.pos[:,0::2] / self.equation[:,0::2])
        self.pos_enc[:,1::2] = torch.cos(self.pos[:,1::2] / self.equation[:,1::2])

        self.pos_enc = self.pos_enc.unsqueeze(0)
        return self.pos_enc

#Scaled dot-product Attention

In [38]:
class scaled_dot_product_attention(nn.Module):
    def __init__(self, emb_dim, causal = False, dropout = 0.1): #size=(batch, max_len, emb_dim)
        super(scaled_dot_product_attention, self).__init__()

        self.Q_linear = nn.Linear(emb_dim, emb_dim, dtype=torch.float32)
        self.K_linear = nn.Linear(emb_dim, emb_dim, dtype=torch.float32)
        self.V_linear = nn.Linear(emb_dim, emb_dim, dtype=torch.float32)
        self.dropout = nn.Dropout(dropout)
        self.causal = causal
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, Q, K, V, attention_mask):
        Q = self.Q_linear(Q)
        Q = self.dropout(Q)

        K = self.K_linear(K)
        K = self.dropout(K)

        V = self.V_linear(V)
        V = self.dropout(V)

        scores = torch.matmul(Q, torch.transpose(K,-1,-2))
        dk = torch.sqrt(torch.tensor(K.size(-1)))
        scores /= dk
        scores = scores.masked_fill(attention_mask == 0, float(-1e10))

        if self.causal:
            mask = torch.ones(size=(scores.size(-1), scores.size(-2)))
            mask = torch.triu(input=mask, diagonal=1)
            mask = mask.masked_fill(mask == 1, float(-1e10))
            scores += mask

        attention_weights = self.softmax(scores)
        results = torch.matmul(attention_weights, V)

        return results

#Multi-head Attention

In [39]:
class multihead_attention(nn.Module):
    def __init__(self, num_heads, emb_dim, causal=False, dropout=0.1):
        super(multihead_attention, self).__init__()

        self.input_batch = int(emb_dim // num_heads)
        self.linear_projection = nn.Linear(emb_dim, emb_dim, dtype=torch.float32)
        self.dropout = nn.Dropout(dropout)
        self.heads = nn.ModuleList([
            scaled_dot_product_attention(self.input_batch, causal, dropout)
            for head in range(num_heads)])

    def forward(self, Q, K, V, attention_mask):
        outputs = []
        range = 0
        for head in self.heads:
            outputs.append(head(Q[:,:,range:range+self.input_batch],
                                K[:,:,range:range+self.input_batch],
                                V[:,:,range:range+self.input_batch],
                                attention_mask))
            range += self.input_batch

        concatenated_heads_outputs = torch.cat(outputs, dim=-1)
        linear_projection = self.linear_projection(concatenated_heads_outputs)
        linear_projection = self.dropout(linear_projection)

        return linear_projection

#Transformer Encoder-Decoder

In [40]:
class transformer_encoder_decoder(nn.Module):
    def __init__(self, num_heads, emb_dim, dff, dropout=0.1):
        super(transformer_encoder_decoder, self).__init__()
        self.multi_head_attention = multihead_attention(num_heads, emb_dim,
                                                        dropout=dropout)
        self.layer_normalization_1 = nn.LayerNorm(emb_dim, dtype=torch.float32)
        self.layer_normalization_2 = nn.LayerNorm(emb_dim, dtype=torch.float32)
        self.RelU_layer = nn.Linear(emb_dim, dff, dtype=torch.float32)
        self.RelU = nn.ReLU()
        self.Linear_layer = nn.Linear(dff, emb_dim, dtype=torch.float32)
        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V, attention_mask):
        layer = self.multi_head_attention(Q, K, V, attention_mask)
        layer += Q
        normalized = self.layer_normalization_1(layer)
        #The dimensionality of input and output is dmodel = 512,
        #and the inner-layer has dimensionality dff = 2048
        layer = self.RelU_layer(normalized)
        layer = self.RelU(layer)
        layer = self.dropout(layer)
        layer = self.Linear_layer(layer)
        layer = self.dropout(layer)
        layer += normalized
        output = self.layer_normalization_2(layer)

        return output


#Transformer

In [46]:
class Transformer(nn.Module):
    def __init__(self, vocab, max_input_length, max_target_length,
                 emb_dim, dff, num_heads, num_encoder_blocks, num_decoder_blocks,
                 dropout):
        super(Transformer, self).__init__()
        self.vocab = vocab
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.num_encoder_blocks = num_encoder_blocks
        self.num_decoder_blocks = num_decoder_blocks
        self.dff = dff
        self.Dropout = nn.Dropout(dropout)
        self.softmax = nn.Softmax(dim=-1)

        self.embedding = nn.Embedding(len(vocab), emb_dim, padding_idx=vocab['pad'])
        self.positional_encoding_encoder = positional_encoding(max_input_length-1, emb_dim)
        self.positional_encoding_decoder = positional_encoding(max_target_length-1, emb_dim)

        self.encoder_blocks = nn.ModuleList([transformer_encoder_decoder(num_heads,
        emb_dim, dff, dropout) for _ in range(num_encoder_blocks)])

        self.decoder_blocks = nn.ModuleList([transformer_encoder_decoder(num_heads,
        emb_dim, dff, dropout) for _ in range(num_decoder_blocks)])

        self.masked_attention = multihead_attention(num_heads, emb_dim, True, dropout)
        self.layer_normalization = nn.LayerNorm(emb_dim, dtype=torch.float32)

    def forward(self, batch, target):
        #Encoder Part
        training_data = self.embedding(batch)
        #Positional Encoding masking for training data
        positional_mask_batch = (batch != self.vocab['<pad>']).unsqueeze(-1).float()
        training_data *= positional_mask_batch
        training_data += self.positional_encoding_encoder() * positional_mask_batch
        #Attention masking for training data
        attention_mask_batch = (batch != self.vocab['<pad>']).unsqueeze(1).float()

        encoder_output = self.encoder_decoder(training_data,
                                              training_data,
                                              training_data,
                                               self.encoder_blocks,
                                               attention_mask_batch)
        K, V = encoder_output, encoder_output

        #Decoder Part
        target_data = self.embedding(target)
        #Positional Encoding masking for target data
        positional_mask_target = (target != self.vocab['<pad>']).unsqueeze(-1).float()
        target_data *= positional_mask_target
        target_data += self.positional_encoding_decoder() * positional_mask_target
        #Attention masking for target data
        attention_mask_target = (target != self.vocab['<pad>']).unsqueeze(1).float()


        masked_attention_output = self.masked_attention(target_data,
                                                        target_data,
                                                        target_data,
                                                        attention_mask_target)
        out = masked_attention_output + target_data
        Q = self.layer_normalization(out)

        decoder_output = self.encoder_decoder(Q, K, V, self.decoder_blocks,
                                              attention_mask_batch)

        pre_softmax = torch.nn.functional.linear(decoder_output, self.embedding.weight)
        pre_softmax = self.Dropout(pre_softmax)

        return pre_softmax
    #Encoder-Decoder Blocks
    def encoder_decoder(self, Q, K, V, blocks, attention_mask):
        for block in blocks:
            result = block(Q, K, V, attention_mask)
            Q, K, V = result, result, result

        return result

#Training Process

#Positional Encoding and Attention MASKING
#SOS and EOS Token Shifting

In [47]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Transformer(
    vocab=vocab,
    max_input_length=max_input_len,
    max_target_length=max_target_len,
    emb_dim=256,
    dff=1024,
    num_heads=4,
    num_encoder_blocks=4,
    num_decoder_blocks=4,
    dropout=0.1
).to(device)

In [48]:
def train(model, train_set, val_set, epochs=2, lr=0.005, device=None):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        train_accuracy = 0
        val_accuracy = 0.0
        training_loss = 0.0
        validation_loss = 0.0

        model.train()
        for train_batch, target_train_batch in train_set:
            optimizer.zero_grad()
            if torch.cuda.is_available():
                train_batch, target_train_batch = train_batch.cuda(), target_train_batch.cuda()

            predicted_output = model(train_batch[:,1:], target_train_batch[:,:-1])
            loss = criterion(predicted_output, target_train_batch[:,1:])
            loss.backward()
            optimizer.step()

            training_loss += loss.item()
            _, pred_class = torch.max(predicted_output, 1)
            equals = pred_class == target_train_batch.view(*pred_class.shape)
            train_accuracy += torch.mean(equals.type(torch.FloatTensor)).item()

        train_losses.append(training_loss/len(train_set))
        print(f"Epoch {epoch+1} accumelated Training batches loss is:\t\t {training_loss/len(train_set)}")
        print(f"Epoch {epoch+1} accuracy on Training batches is:\t\t {train_accuracy / len(train_set)}")


        model.eval()
        for valid_batch, target_valid_batch in val_set:
            predicted_output = model(valid_batch[:,1:], target_valid_batch[:,:-1])
            loss = criterion(predicted_output, target_valid_batch[:,1:])

            validation_loss += loss.item()
            _, pred_class=torch.max(predicted_output, 1)
            equals = pred_class == target_valid_batch.view(*pred_class.shape)
            val_accuracy += torch.mean(equals.type(torch.FloatTensor)).item()

        val_losses.append(validation_loss/len(val_set))
        print(f"Epoch {epoch+1} accumelated validation batches loss is:\t\t {validation_loss/len(val_set)}")
        print(f"Epoch {epoch+1} accuracy on validation batches is:\t\t {val_accuracy / len(val_set)}")

    return train_losses, val_losses


In [None]:
train_losses = val_losses = train(model, train_set, val_set, 1, 0.001, device)

#Testing Codes

In [None]:
iterator = iter(train_set)

In [None]:
batch, target = next(iterator)

In [None]:
batch.shape, target.shape

(torch.Size([32, 1000]), torch.Size([32, 85]))

In [None]:
target[0]

tensor([     1,  38560,  54023,  30743, 103207,  92033,  71293,  96132, 100583,
         88557,  18240,  13712,  22970,  60995,  66642,  82970,  18240,  62287,
         82394,  42604,   9095,  56881,  12762,  33954,  20988, 107106,   4712,
         10009,  16332, 107283,  54023,  97259,  90458,  38045,  11170, 101889,
         94238,  85881,      2,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0])

In [None]:
emb = nn.Embedding(len(vocab), 256)

In [None]:
target_data = emb(target)

In [None]:
target_data.shape

torch.Size([32, 85, 256])

In [None]:
attention_mask = (target != 0)#.unsqueeze(1)

In [None]:
attention_mask.shape

torch.Size([32, 85])

In [None]:
attention_mask[0]

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False])

In [None]:
r = target[0]

In [None]:
r

tensor([     1,  38560,  54023,  30743, 103207,  92033,  71293,  96132, 100583,
         88557,  18240,  13712,  22970,  60995,  66642,  82970,  18240,  62287,
         82394,  42604,   9095,  56881,  12762,  33954,  20988, 107106,   4712,
         10009,  16332, 107283,  54023,  97259,  90458,  38045,  11170, 101889,
         94238,  85881,      2,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0])

In [None]:
y = r * torch.tensor([1])

In [None]:
torch.tensor([1]).shape

torch.Size([1])

In [None]:
y

tensor([     1,  38560,  54023,  30743, 103207,  92033,  71293,  96132, 100583,
         88557,  18240,  13712,  22970,  60995,  66642,  82970,  18240,  62287,
         82394,  42604,   9095,  56881,  12762,  33954,  20988, 107106,   4712,
         10009,  16332, 107283,  54023,  97259,  90458,  38045,  11170, 101889,
         94238,  85881,      2,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0])

In [None]:
r = r.masked_fill(r==0 , 8)

In [None]:
r

tensor([     1,  38560,  54023,  30743, 103207,  92033,  71293,  96132, 100583,
         88557,  18240,  13712,  22970,  60995,  66642,  82970,  18240,  62287,
         82394,  42604,   9095,  56881,  12762,  33954,  20988, 107106,   4712,
         10009,  16332, 107283,  54023,  97259,  90458,  38045,  11170, 101889,
         94238,  85881,      2,      8,      8,      8,      8,      8,      8,
             8,      8,      8,      8,      8,      8,      8,      8,      8,
             8,      8,      8,      8,      8,      8,      8,      8,      8,
             8,      8,      8,      8,      8,      8,      8,      8,      8,
             8,      8,      8,      8,      8,      8,      8,      8,      8,
             8,      8,      8,      8])