# Transformer from scratch: Summarizer

In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


## Data Processing

In [2]:
data = pd.read_excel('/kaggle/input/inshorts-news-data/Inshorts Cleaned Data.xlsx')

In [3]:
data.head()

Unnamed: 0,Headline,Short,Source,Time,Publish Date
0,4 ex-bank officials booked for cheating bank o...,The CBI on Saturday booked four former officia...,The New Indian Express,09:25:00,2017-03-26
1,Supreme Court to go paperless in 6 months: CJI,Chief Justice JS Khehar has said the Supreme C...,Outlook,22:18:00,2017-03-25
2,"At least 3 killed, 30 injured in blast in Sylh...","At least three people were killed, including a...",Hindustan Times,23:39:00,2017-03-25
3,Why has Reliance been barred from trading in f...,Mukesh Ambani-led Reliance Industries (RIL) wa...,Livemint,23:08:00,2017-03-25
4,Was stopped from entering my own studio at Tim...,TV news anchor Arnab Goswami has said he was t...,YouTube,23:24:00,2017-03-25


In [4]:
data.drop(["Source ", "Time ", "Publish Date"], axis=1, inplace=True)

In [5]:
data.head()

Unnamed: 0,Headline,Short
0,4 ex-bank officials booked for cheating bank o...,The CBI on Saturday booked four former officia...
1,Supreme Court to go paperless in 6 months: CJI,Chief Justice JS Khehar has said the Supreme C...
2,"At least 3 killed, 30 injured in blast in Sylh...","At least three people were killed, including a..."
3,Why has Reliance been barred from trading in f...,Mukesh Ambani-led Reliance Industries (RIL) wa...
4,Was stopped from entering my own studio at Tim...,TV news anchor Arnab Goswami has said he was t...


In [6]:
inputs = data["Short"].values
targets = data["Headline"].values

In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

tokenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

In [8]:
tokenizer.decode(tokenizer.encode("Hello There"), skip_special_tokens=True)



'Hello There'

In [9]:
vocab_size = tokenizer.vocab_size

In [10]:
inputs[0]

'The CBI on Saturday booked four former officials of Syndicate Bank and six others for cheating, forgery, criminal conspiracy and causing ₹209 crore loss to the state-run bank. The accused had availed home loans and credit from Syndicate Bank on the basis of forged and fabricated documents. These funds were fraudulently transferred to the companies owned by the accused persons.'

In [11]:
tokenizer.decode(tokenizer.encode(inputs[0]), skip_special_tokens=True)

'The CBI on Saturday booked four former officials of Syndicate Bank and six others for cheating, forgery, criminal conspiracy and causing ₹209 crore loss to the state - run bank. The accused had availed home loans and credit from Syndicate Bank on the basis of forged and fabricated documents. These funds were fraudulently transferred to the companies owned by the accused persons.'

In [12]:
print(tokenizer.encode(list(inputs)[0]))

[101, 1109, 18893, 2240, 1113, 4306, 18951, 1300, 1393, 3878, 1104, 25139, 2950, 1105, 1565, 1639, 1111, 18661, 117, 26621, 1616, 117, 4771, 10758, 1105, 3989, 838, 10973, 1580, 24809, 2445, 1106, 1103, 1352, 118, 1576, 3085, 119, 1109, 4806, 1125, 28057, 1174, 1313, 11453, 1105, 4755, 1121, 25139, 2950, 1113, 1103, 3142, 1104, 17667, 1105, 27615, 4961, 119, 1636, 4381, 1127, 10258, 16564, 1193, 3175, 1106, 1103, 2557, 2205, 1118, 1103, 4806, 4983, 119, 102]


In [13]:
tokenizer.decode(tokenizer(list(inputs[:5]))["input_ids"][4])

'[CLS] TV news anchor Arnab Goswami has said he was told he could not do the programme two days before leaving Times Now. & # 34 ; 18th November was my last day, I was not allowed to enter my own studio, & # 34 ; Goswami added. & # 34 ; When you build an institution and are not allowed to enter your own studio, you feel sad, & # 34 ; the journalist further said. [SEP]'

In [14]:
inputs = tokenizer(list(inputs), return_tensors="pt", padding=True)["input_ids"]
targets = tokenizer(list(targets), return_tensors="pt", padding=True)["input_ids"]

In [15]:
max_inp_len = max(len(i) for i in inputs)
max_targ_len = max(len(i) for i in targets)

max_inp_len, max_targ_len

(184, 42)

In [16]:
inputs[0]

tensor([  101,  1109, 18893,  2240,  1113,  4306, 18951,  1300,  1393,  3878,
         1104, 25139,  2950,  1105,  1565,  1639,  1111, 18661,   117, 26621,
         1616,   117,  4771, 10758,  1105,  3989,   838, 10973,  1580, 24809,
         2445,  1106,  1103,  1352,   118,  1576,  3085,   119,  1109,  4806,
         1125, 28057,  1174,  1313, 11453,  1105,  4755,  1121, 25139,  2950,
         1113,  1103,  3142,  1104, 17667,  1105, 27615,  4961,   119,  1636,
         4381,  1127, 10258, 16564,  1193,  3175,  1106,  1103,  2557,  2205,
         1118,  1103,  4806,  4983,   119,   102,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0, 

In [17]:
tokenizer.decode(inputs[0], skip_special_tokens=True), tokenizer.decode(targets[0], skip_special_tokens=True)

('The CBI on Saturday booked four former officials of Syndicate Bank and six others for cheating, forgery, criminal conspiracy and causing ₹209 crore loss to the state - run bank. The accused had availed home loans and credit from Syndicate Bank on the basis of forged and fabricated documents. These funds were fraudulently transferred to the companies owned by the accused persons.',
 '4 ex - bank officials booked for cheating bank of ₹209 crore')

In [18]:
len(inputs), len(targets)

(55104, 55104)

In [19]:
train_inputs = inputs[:int(0.8 * len(inputs))]
train_targets = targets[:int(0.8 * len(targets))]

val_inputs = inputs[int(0.8 * len(inputs)):int(0.9*len(inputs))]
val_targets = targets[int(0.8 * len(targets)):int(0.9*len(targets))]

test_inputs = inputs[int(0.9 * len(inputs)):]
test_targets = targets[int(0.9 * len(targets)):]

In [20]:
class NewsSummaryDataset(Dataset):
    def __init__(self, inp, targ):
        self.inp = inp
        self.targ = targ
        
    def __len__(self): return len(self.inp)
    
    def __getitem__(self, idx):
        return self.inp[idx], self.targ[idx]
    
    def decode(self, idx):
        decoded_inp = tokenizer.decode(self.inp[idx], skip_special_tokens=True)
        decoded_targ = tokenizer.decode(self.targ[idx], skip_special_tokens=True)
        return decoded_inp, decoded_targ

In [21]:
train_dataset = NewsSummaryDataset(train_inputs, train_targets)
val_dataset = NewsSummaryDataset(val_inputs, val_targets)
test_dataset = NewsSummaryDataset(test_inputs, test_targets)

In [22]:
len(train_dataset), len(val_dataset), len(test_dataset)

(44083, 5510, 5511)

In [23]:
print(train_dataset.decode(-1))

('Nepali Rescuers on Monday abandoned the mission to recover the bodies of two Indian climbers from the Mount Everest. The two men, Paresh Nath and Goutam Ghosh, went missing on May 21 and their bodies were later located near the 8000 - metre height, which marks the beginning of the & # 39 ; death zone & # 39 ;. Another mountaineer, Subhash Pal, had earlier died during his descent.', 'Mission to get Indian climbers & # 39 ; bodies abandoned')


In [24]:
tokenizer.pad_token_id

0

In [25]:
def pad_sequence(batch):
    inp_seqs = [inp for inp, targs in batch]
    targ_seqs = [targs for inp, targs in batch]
    
    inp_padded = torch.nn.utils.rnn.pad_sequence(torch.tensor(inp_seqs), batch_first=True, padding_value=tokenizer.pad_token_id)
    targ_padded = torch.nn.utils.rnn.pad_sequence(torch.tensor(targ_seqs), batch_first=True, padding_value=tokenizer.pad_token_id)
    
    return inp_padded, targ_padded

In [26]:
batch_size = 128

class Dataloaders:
    def __init__(self):
        self.train_dataset = NewsSummaryDataset(train_inputs, train_targets)
        self.valid_dataset = NewsSummaryDataset(val_inputs, val_targets)
        self.test_dataset = NewsSummaryDataset(test_inputs, test_targets)
        
        self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
        self.valid_loader = DataLoader(self.valid_dataset, batch_size=batch_size, shuffle=True)
        self.test_loader = DataLoader(self.test_dataset, batch_size=batch_size, shuffle=True)

## Transformers Architecture

In [27]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_embed, dropout=0.0):
        super(MultiHeadedAttention, self).__init__()
        assert d_embed % h == 0 # check the h number
        self.d_k = d_embed//h
        self.d_embed = d_embed
        self.h = h
        self.WQ = nn.Linear(d_embed, d_embed)
        self.WK = nn.Linear(d_embed, d_embed)
        self.WV = nn.Linear(d_embed, d_embed)
        self.linear = nn.Linear(d_embed, d_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_query, x_key, x_value, mask=None):
        nbatch = x_query.size(0) # get batch size
        # 1) Linear projections to get the multi-head query, key and value tensors
        # x_query, x_key, x_value dimension: nbatch * seq_len * d_embed
        # LHS query, key, value dimensions: nbatch * h * seq_len * d_k
        query = self.WQ(x_query).view(nbatch, -1, self.h, self.d_k).transpose(1,2)
        key   = self.WK(x_key).view(nbatch, -1, self.h, self.d_k).transpose(1,2)
        value = self.WV(x_value).view(nbatch, -1, self.h, self.d_k).transpose(1,2)
        # 2) Attention
        # scores has dimensions: nbatch * h * seq_len * seq_len
        scores = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(self.d_k)
        # 3) Mask out padding tokens and future tokens
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))
        # p_atten dimensions: nbatch * h * seq_len * seq_len
        p_atten = torch.nn.functional.softmax(scores, dim=-1)
        p_atten = self.dropout(p_atten)
        # x dimensions: nbatch * h * seq_len * d_k
        x = torch.matmul(p_atten, value)
        # x now has dimensions:nbtach * seq_len * d_embed
        x = x.transpose(1, 2).contiguous().view(nbatch, -1, self.d_embed)
        return self.linear(x) # final linear layer

In [28]:
class ResidualConnection(nn.Module):
    '''residual connection: x + dropout(sublayer(layernorm(x))) '''
    def __init__(self, dim, dropout):
        super().__init__()
        self.drop = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, sublayer):
        return x + self.drop(sublayer(self.norm(x)))

In [29]:
class Encoder(nn.Module):
    '''Encoder = token embedding + positional embedding -> a stack of N EncoderBlock -> layer norm'''
    def __init__(self, config):
        super().__init__()
        self.d_embed = config.d_embed
        self.tok_embed = nn.Embedding(config.encoder_vocab_size, config.d_embed) 
        self.pos_embed = nn.Parameter(torch.zeros(1, config.max_inp_len, config.d_embed)) 
        self.encoder_blocks = nn.ModuleList([EncoderBlock(config) for _ in range(config.N_encoder)])
        self.dropout = nn.Dropout(config.dropout)
        self.norm = nn.LayerNorm(config.d_embed)

    def forward(self, input, mask=None):
        x = self.tok_embed(input)
        x_pos = self.pos_embed[:, :x.size(1), :]
        x = self.dropout(x + x_pos)
        for layer in self.encoder_blocks:
            x = layer(x, mask)
        return self.norm(x)

In [30]:
class EncoderBlock(nn.Module):
    '''EncoderBlock: self-attention -> position-wise fully connected feed-forward layer'''
    def __init__(self, config):
        super(EncoderBlock, self).__init__()
        self.atten = MultiHeadedAttention(config.h, config.d_embed, config.dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(config.d_embed, config.d_ff),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.d_ff, config.d_embed)
        )
        self.residual1 = ResidualConnection(config.d_embed, config.dropout)
        self.residual2 = ResidualConnection(config.d_embed, config.dropout)

    def forward(self, x, mask=None):
        # self-attention
        x = self.residual1(x, lambda x: self.atten(x, x, x, mask=mask))
        # position-wise fully connected feed-forward layer
        return self.residual2(x, self.feed_forward)

In [31]:
class Decoder(nn.Module):
    '''Decoder = token embedding + positional embedding -> a stack of N DecoderBlock -> fully-connected layer'''
    def __init__(self, config):
        super().__init__()
        self.d_embed = config.d_embed
        self.tok_embed = nn.Embedding(config.decoder_vocab_size, config.d_embed)
        self.pos_embed = nn.Parameter(torch.zeros(1, config.max_targ_len, config.d_embed)) 
        self.dropout = nn.Dropout(config.dropout)
        self.decoder_blocks = nn.ModuleList([DecoderBlock(config) for _ in range(config.N_decoder)])
        self.norm = nn.LayerNorm(config.d_embed)
        self.linear = nn.Linear(config.d_embed, config.decoder_vocab_size)
    
    def future_mask(self, seq_len):
        '''mask out tokens at future positions'''
        mask = (torch.triu(torch.ones(seq_len, seq_len, requires_grad=False), diagonal=1)!=0).to(device)
        return mask.view(1, 1, seq_len, seq_len)

    def forward(self, memory, src_mask, trg, trg_pad_mask):
        seq_len = trg.size(1)
        trg_mask = torch.logical_or(trg_pad_mask, self.future_mask(seq_len))
        x = self.tok_embed(trg) + self.pos_embed[:, :trg.size(1), :]
        x = self.dropout(x)
        for layer in self.decoder_blocks:
            x = layer(memory, src_mask, x, trg_mask)
        x = self.norm(x)
        logits = self.linear(x)
        return logits

In [32]:
class DecoderBlock(nn.Module):
    ''' EncoderBlock: self-attention -> position-wise feed-forward (fully connected) layer'''
    def __init__(self, config):
        super().__init__()
        self.atten1 = MultiHeadedAttention(config.h, config.d_embed)
        self.atten2 = MultiHeadedAttention(config.h, config.d_embed)
        self.feed_forward = nn.Sequential(
            nn.Linear(config.d_embed, config.d_ff),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.d_ff, config.d_embed)
        )
        self.residuals = nn.ModuleList([ResidualConnection(config.d_embed, config.dropout) 
                                       for i in range(3)])

    def forward(self, memory, src_mask, decoder_layer_input, trg_mask):
        x = memory
        y = decoder_layer_input
        y = self.residuals[0](y, lambda y: self.atten1(y, y, y, mask=trg_mask))
        # keys and values are from the encoder output
        y = self.residuals[1](y, lambda y: self.atten2(y, x, x, mask=src_mask))
        return self.residuals[2](y, self.feed_forward)

In [33]:
class Transformer(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, src_mask, trg, trg_pad_mask):
        return self.decoder(self.encoder(src, src_mask), src_mask, trg, trg_pad_mask)

## Training

In [34]:
from dataclasses import dataclass

@dataclass
class ModelConfig:
    encoder_vocab_size: int
    decoder_vocab_size: int
    d_embed: int
    # d_ff is the dimension of the fully-connected  feed-forward layer
    d_ff: int
    # h is the number of attention head
    h: int
    N_encoder: int
    N_decoder: int
    max_inp_len: int
    max_targ_len: int
    dropout: float
        
def make_model(config):
    model = Transformer(Encoder(config), Decoder(config)).to(device)
    
    for p in model.parameters():
        if p.dim() > 1: nn.init.xavier_uniform_(p)
    return model

In [35]:
def make_batch_input(x, y):
        src = x.to(device)
        trg_in = y[:, :-1].to(device)
        trg_out = y[:, 1:].contiguous().view(-1).to(device)
        src_pad_mask = (src == tokenizer.pad_token_id).view(src.size(0), 1, 1, src.size(-1))
        trg_pad_mask = (trg_in == tokenizer.pad_token_id).view(trg_in.size(0), 1, 1, trg_in.size(-1))
        return src, trg_in, trg_out, src_pad_mask, trg_pad_mask

In [36]:
from numpy.lib.utils import lookfor
def train_epoch(model, dataloaders):
    model.train()
    grad_norm_clip = 1.0
    losses, acc, count = [], 0, 0
    num_batches = len(dataloaders.train_loader)
    pbar = tqdm(enumerate(dataloaders.train_loader), total=num_batches)
    for idx, (x, y)  in  pbar:
        optimizer.zero_grad()
        src, trg_in, trg_out, src_pad_mask, trg_pad_mask = make_batch_input(x,y)
        pred = model(src, src_pad_mask, trg_in, trg_pad_mask).to(device)
        pred = pred.view(-1, pred.size(-1))
        loss = loss_fn(pred, trg_out).to(device)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm_clip)
        optimizer.step()
        scheduler.step()
        losses.append(loss.item())
        # report progress
        if idx>0 and idx%50 == 0:
            pbar.set_description(f'train loss={loss.item():.3f}, lr={scheduler.get_last_lr()[0]:.5f}')
    return np.mean(losses)

In [37]:
def train(model, dataloaders, epochs):
    global early_stop_count
    best_valid_loss = float('inf')
    train_size = len(dataloaders.train_loader)*batch_size
    for ep in range(epochs):
        train_loss = train_epoch(model, dataloaders)
        valid_loss = validate(model, dataloaders.valid_loader)
        
        print(f'ep: {ep}: train_loss={train_loss:.5f}, valid_loss={valid_loss:.5f}')
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
        else:
            if scheduler.last_epoch>2*warmup_steps:
                early_stop_count -= 1
                if early_stop_count<=0:   
                    return train_loss, valid_loss
    return train_loss, valid_loss

In [38]:
def validate(model, dataloder):
    'compute the validation loss'
    model.eval()
    losses = []
    with torch.no_grad():
        for i, (x, y) in enumerate(dataloder):
            src, trg_in, trg_out, src_pad_mask, trg_pad_mask = make_batch_input(x,y)
            pred = model(src, src_pad_mask, trg_in, trg_pad_mask).to(device)
            pred = pred.view(-1, pred.size(-1))
            losses.append(loss_fn(pred, trg_out).item())
    return np.mean(losses)

In [39]:
config = ModelConfig(encoder_vocab_size = vocab_size, 
                     decoder_vocab_size=vocab_size,
                     d_embed=512, 
                     d_ff=512, 
                     h=8,
                     N_encoder=3, 
                     N_decoder=3,
                     max_inp_len=max_inp_len,
                     max_targ_len=max_targ_len,
                     dropout=0.1
                     )

data_loaders = Dataloaders()
train_size = len(data_loaders.train_loader)*batch_size
model = make_model(config)
model_size = sum([p.numel() for p in model.parameters()])
print(f'model_size: {model_size}, train_set_size: {train_size}')
warmup_steps = 3*len(data_loaders.train_loader)
# lr first increases in the warmup steps, and then descreases
lr_fn = lambda step: config.d_embed**(-0.5) * min([(step+1)**(-0.5), (step+1)*warmup_steps**(-1.5)])
optimizer = torch.optim.Adam(model.parameters(), lr=0.5, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_fn)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
early_stop_count = 20
train_loss, valid_loss = train(model, data_loaders, epochs=20)
test_loss  = validate(model, data_loaders.test_loader)

print(f'train_loss: {train_loss:.4f}, valid_loss: {valid_loss:.4f}, test_loss: {test_loss:.4f}')

model_size: 57307460, train_set_size: 44160


train loss=6.015, lr=0.00020: 100%|██████████| 345/345 [02:01<00:00,  2.83it/s]


ep: 0: train_loss=7.49143, valid_loss=5.85186


train loss=4.686, lr=0.00043: 100%|██████████| 345/345 [02:00<00:00,  2.86it/s]


ep: 1: train_loss=5.10387, valid_loss=4.62715


train loss=3.687, lr=0.00066: 100%|██████████| 345/345 [02:00<00:00,  2.85it/s]


ep: 2: train_loss=4.01178, valid_loss=4.03171


train loss=3.043, lr=0.00060: 100%|██████████| 345/345 [02:00<00:00,  2.86it/s]


ep: 3: train_loss=3.21826, valid_loss=3.76038


train loss=2.607, lr=0.00054: 100%|██████████| 345/345 [02:00<00:00,  2.86it/s]


ep: 4: train_loss=2.53419, valid_loss=3.65772


train loss=2.158, lr=0.00049: 100%|██████████| 345/345 [02:00<00:00,  2.86it/s]


ep: 5: train_loss=2.01283, valid_loss=3.71397


train loss=1.698, lr=0.00045: 100%|██████████| 345/345 [02:00<00:00,  2.86it/s]


ep: 6: train_loss=1.60204, valid_loss=3.79546


train loss=1.275, lr=0.00042: 100%|██████████| 345/345 [02:00<00:00,  2.86it/s]


ep: 7: train_loss=1.27755, valid_loss=3.94744


train loss=1.061, lr=0.00040: 100%|██████████| 345/345 [02:01<00:00,  2.85it/s]


ep: 8: train_loss=1.01997, valid_loss=4.11059


train loss=0.910, lr=0.00038: 100%|██████████| 345/345 [02:01<00:00,  2.85it/s]


ep: 9: train_loss=0.81700, valid_loss=4.33039


train loss=0.687, lr=0.00036: 100%|██████████| 345/345 [02:01<00:00,  2.84it/s]


ep: 10: train_loss=0.65734, valid_loss=4.46857


train loss=0.586, lr=0.00035: 100%|██████████| 345/345 [02:01<00:00,  2.84it/s]


ep: 11: train_loss=0.53387, valid_loss=4.57446


train loss=0.483, lr=0.00033: 100%|██████████| 345/345 [02:01<00:00,  2.84it/s]


ep: 12: train_loss=0.43714, valid_loss=4.70222


train loss=0.400, lr=0.00032: 100%|██████████| 345/345 [02:01<00:00,  2.84it/s]


ep: 13: train_loss=0.36439, valid_loss=4.81023


train loss=0.340, lr=0.00031: 100%|██████████| 345/345 [02:01<00:00,  2.84it/s]


ep: 14: train_loss=0.30619, valid_loss=4.91305


train loss=0.303, lr=0.00030: 100%|██████████| 345/345 [02:01<00:00,  2.84it/s]


ep: 15: train_loss=0.26089, valid_loss=5.01792


train loss=0.268, lr=0.00029: 100%|██████████| 345/345 [02:01<00:00,  2.84it/s]


ep: 16: train_loss=0.22580, valid_loss=5.12206


train loss=0.242, lr=0.00028: 100%|██████████| 345/345 [02:01<00:00,  2.85it/s]


ep: 17: train_loss=0.19792, valid_loss=5.19428


train loss=0.181, lr=0.00027: 100%|██████████| 345/345 [02:01<00:00,  2.85it/s]


ep: 18: train_loss=0.17450, valid_loss=5.23197


train loss=0.169, lr=0.00027: 100%|██████████| 345/345 [02:01<00:00,  2.85it/s]


ep: 19: train_loss=0.15540, valid_loss=5.26296
train_loss: 0.1554, valid_loss: 5.2630, test_loss: 5.5587


## Inference

In [40]:
def summarize(model, x):
    with torch.inference_mode():
        dB = test_inp_b.size(0)
        y = torch.tensor([[tokenizer.cls_token_id] * dB]).view(dB, 1).to(device)
        x_pad_mask = (x == tokenizer.pad_token_id).view(x.size(0), 1, 1, x.size(-1)).to(device)
        memory = model.encoder(x.to(device), x_pad_mask).to(device)
        for i in range(max_targ_len):
            y_pad_mask = (y == tokenizer.pad_token_id).view(y.size(0), 1, 1, y.size(-1)).to(device)
            logits = model.decoder(memory, x_pad_mask, y, y_pad_mask)
            last_output = logits.argmax(-1)[:, -1]
            last_output = last_output.view(dB, 1)
            y = torch.cat((y, last_output), 1).to(device)
    return y

In [41]:
test_inp_b, test_targ_b = next(iter(data_loaders.test_loader))

In [42]:
y_preds = summarize(model, test_inp_b)

In [43]:
for i in range(5):
    print("Original - {}\nSummary - {}\n\n".format(tokenizer.decode(test_inp_b[i], skip_special_tokens=True), tokenizer.decode(y_preds[i], skip_special_tokens=True)))

Original - The first official trailer of the Arjun Kapoor and Kareena Kapoor starrer film & # 39 ; Ki and Ka & # 39 ; was released today. The romantic comedy - drama has been written, directed and produced by R Balki. It features Arjun and Kareena as a young married couple that challenges the gender stereotypes set by the society. The film is scheduled to release on April 1.
Summary - Trailer of & # 39 ; Kareena & # 39 ; released


Original - Pakistan & # 39 ; s former President Pervez Musharraf on Thursday said Pakistan & # 39 ; s intelligence agency ISI trains LeT and Jaish militants and that terror attacks in India would not stop until India addresses the & # 34 ; core & # 34 ; issue of Kashmir. He further called everyone fighting in Kashmir a & # 34 ; freedom fighter & # 34 ;. He also accused India & # 39 ; s intelligence agency RAW of conducting attacks in Pakistan from Afghanistan.
Summary - Pakistan will stop militants in India, Pak spy by ISI : Musharraf


Original - India & # 

## Saving the Model

In [44]:
torch.save(model, 'transformer_summarizer.pth')
torch.save(model.state_dict(), 'transformer_summarizer_params.pth')