# Configuration

In [None]:
"""
    data set file: https://drive.google.com/file/d/13q7pqpx-a8QIyRLXYAnLNJuFWyqcw-2b/view?usp=sharing
    dowload and place in drive (My drive folder)
"""

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install transformers

In [None]:
!git clone -b master https://github.com/vohung471999/Transformer.git

In [4]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import copy
import math
import pandas as pd
from tqdm import tqdm
import time
import gc

In [15]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Tranformer





In [None]:
import torch
import torch.nn as nn
import math
from torch.autograd import Variable
import torch.nn.functional as F

class WordEmbedding(nn.Module):
    def __init__(self, vocab_size, model_dim, padding_idx):
        super().__init__()
        kwargs = {'device': device, 'dtype': torch.float32}
        self.vocab_size = vocab_size
        self.model_dim = model_dim
        self.embed = nn.Embedding(vocab_size, model_dim, padding_idx=padding_idx, **kwargs)
        
    def forward(self, x):
        return self.embed(x)

class NormalPositionalEmbedding(nn.Embedding):

    def __init__(self, embedding_dim: int, num_embeddings=1024):
        kwargs = {'device': device, 'dtype': torch.float32}
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs)

    def forward(self, input_ids_shape: torch.Size):
        bsz, seq_len = input_ids_shape[:2]
        positions = torch.arange(0, seq_len, dtype=torch.long, device=self.weight.device)
        return super().forward(positions + self.offset)

class PositionalEmbedding(nn.Module):
    def __init__(self, model_dim, dropout=0.1, max_seq_len=1024):
        super().__init__()
        self.model_dim = model_dim
        self.dropout = nn.Dropout(dropout)

        positional_emb = torch.zeros(max_seq_len, model_dim)
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        w = torch.exp(torch.arange(0, model_dim, 2)*(-math.log(10000)/model_dim))

        positional_emb[:, 0::2] = torch.sin(position * w)
        positional_emb[:, 1::2] = torch.cos(position * w)

        positional_emb = positional_emb.unsqueeze(0)
        self.register_buffer('positional_emb', positional_emb)

    def forward(self, embedding):
        embedding = embedding*math.sqrt(self.model_dim)
        seq_len = embedding.size(1)
        
        positional_emb = Variable(self.positional_emb[:, :seq_len], requires_grad=False)
        embedding = embedding + positional_emb

        return self.dropout(embedding)


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, embed_dim, dropout=0.1):
        super().__init__()
        kwargs = {'device': device, 'dtype': torch.float32}
        assert embed_dim % num_heads == 0
        
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.head_dim = embed_dim // num_heads
        self.attention_weight = None

        self.query_project = nn.Linear(embed_dim, embed_dim, **kwargs)
        self.key_project = nn.Linear(embed_dim, embed_dim, **kwargs)
        self.value_project = nn.Linear(embed_dim, embed_dim, **kwargs)
        self.out_matrix = nn.Linear(embed_dim, embed_dim, **kwargs)

        self.dropout = nn.Dropout(dropout)

    def _self_attention(self, query, key, value, attention_mask=None, dropout=None):
        """
        q: batch_size x heads x seq_length x d_model
        k: batch_size x heads x seq_length x d_model
        v: batch_size x heads x seq_length x d_model
        attention_mask: batch_size x 1 x seq_length
        output: batch_size x head x seq_length x d_model
        """

        batch_size, num_of_heads, seq_length, dim_head = query.shape
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) * (dim_head ** -0.5)

        if attention_mask is not None:
            attention_scores = attention_scores.masked_fill(attention_mask==0, float('-inf'))

        attention_scores = nn.functional.softmax(attention_scores, dim=-1)

        if dropout is not None:
            attention_scores = dropout(attention_scores)

        attention_output = torch.matmul(attention_scores, value)
        return attention_output, attention_scores
    
    def _shape(self, tensor: torch.Tensor, sequence_length: int, batch_size: int):
        return tensor.view(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(self, 
                query, 
                key, 
                value,
                attention_mask = None):
        
        batch_size, tgt_length, _ = query.shape
        _, src_length, _ = key.shape

        q = self.query_project(query)
        k = self.key_project(key)
        v = self.value_project(value)

        # change shape to (batch_size, number_of_heads, sequence_length, dim_head)
        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        attention_output, self.attention_weight = self._self_attention(q, k, v, attention_mask, self.dropout)

        attention_output =  attention_output.transpose(1, 2).contiguous()
        attention_output =  attention_output.view(batch_size, tgt_length, self.embed_dim)
        attention_output = self.out_matrix(attention_output)

        return attention_output

        

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        kwargs = {'device': device, 'dtype': torch.float32}

        self.attention = MultiHeadAttention(num_heads, embed_dim, dropout=dropout)
        self.attention_dropout = nn.Dropout(dropout)
        self.attention_norm = nn.LayerNorm(embed_dim, eps=1e-5, **kwargs)

        self.linear_1 = nn.Linear(embed_dim, 4096, **kwargs)
        self.activation = F.gelu
        self.activation_dropout = nn.Dropout(0.0)

        self.linear_2 = nn.Linear(4096, embed_dim, **kwargs)
        self.ff_dropout = nn.Dropout(dropout)

        self.ff_norm = nn.LayerNorm(embed_dim, eps=1e-5, **kwargs)
        

    def forward(self, hidden_states, encoder_attention_mask):

        # attention block
        residual = hidden_states
        hidden_states = self.attention(hidden_states, hidden_states, hidden_states, encoder_attention_mask)
        hidden_states = self.attention_dropout(hidden_states)

        # residual + normalization block
        hidden_states = residual + hidden_states
        hidden_states = self.attention_norm(hidden_states)

        # feed forward block
        residual = hidden_states
        hidden_states = self.linear_1(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.activation_dropout(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ff_dropout(hidden_states)

        # residual + norm block
        hidden_states = residual + hidden_states
        hidden_states = self.ff_norm(hidden_states)

        return hidden_states

class Encoder(nn.Module):
    def __init__(self, embed_dim, num_encoder_layers, num_heads, dropout, embed_tokens):
        super().__init__()
        kwargs = {'device': device, 'dtype': torch.float32}
        self.num_encoder_layers = num_encoder_layers
        self.word_embedding = embed_tokens
        self.positional_embedding = NormalPositionalEmbedding(embed_dim, 1024)
        # self.positional_embedding = PositionalEmbedding(embed_dim, dropout=dropout)
        self.norm_embedding = nn.LayerNorm(embed_dim, **kwargs)
        self.layers = nn.ModuleList([copy.deepcopy(EncoderLayer(embed_dim, num_heads, dropout)) for _ in range(self.num_encoder_layers)])

    def forward(self, encoder_inputs, encoder_attention_mask):
        input_shape = encoder_inputs.size()

        # # embedding layer
        # word_embed = self.word_embedding(encoder_inputs)
        # hidden_states = self.positional_embedding(word_embed)
        # hidden_states = self.norm_embedding(hidden_states)

        # embedding layer
        word_embed = self.word_embedding(encoder_inputs)
        pos_embed = self.positional_embedding(input_shape)
        hidden_states = word_embed + pos_embed
        hidden_states = self.norm_embedding(hidden_states)

        #encoder layers
        for layer in self.layers:
            hidden_states = layer(hidden_states, encoder_attention_mask=encoder_attention_mask)

        return hidden_states


In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        kwargs = {'device': device, 'dtype': torch.float32}

        self.self_attention = MultiHeadAttention(num_heads, embed_dim, dropout=dropout)
        self.self_attention_dropout = nn.Dropout(dropout)
        self.self_attention_norm =  nn.LayerNorm(embed_dim, eps=1e-5, **kwargs)

        self.cross_attention = MultiHeadAttention(num_heads, embed_dim, dropout=dropout)
        self.cross_attention_dropout = nn.Dropout(dropout)
        self.cross_attention_norm = nn.LayerNorm(embed_dim, eps=1e-5, **kwargs)

        self.linear_1 = nn.Linear(embed_dim, 4096, **kwargs)
        self.activation = F.gelu
        self.activation_dropout = nn.Dropout(0.0)
        self.linear_2 = nn.Linear(4096, embed_dim, **kwargs)
        self.ff_dropout = nn.Dropout(dropout)

        self.ff_norm = nn.LayerNorm(embed_dim, eps=1e-5, **kwargs)

    def forward(self, hidden_states, encoder_outputs, decoder_self_attention_mask, decoder_cross_attention_mask):            
        # self attention block
        residual = hidden_states
        hidden_states = self.self_attention(hidden_states, hidden_states, hidden_states, decoder_self_attention_mask)
        hidden_states = self.self_attention_dropout(hidden_states)

        # residual + normalization block
        hidden_states = residual + hidden_states
        hidden_states = self.self_attention_norm(hidden_states)

        # cross attention block
        residual = hidden_states
        hidden_states = self.cross_attention(hidden_states, encoder_outputs, encoder_outputs, decoder_cross_attention_mask)
        hidden_states = self.cross_attention_dropout(hidden_states)

        # residual + normalization block
        hidden_states = residual + hidden_states
        hidden_states = self.cross_attention_norm(hidden_states)

        # feed forward block
        residual = hidden_states
        hidden_states = self.linear_1(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.activation_dropout(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ff_dropout(hidden_states)

        # residual + norm block
        hidden_states = residual + hidden_states
        hidden_states = self.ff_norm(hidden_states)

        return hidden_states

class Decoder(nn.Module):
    def __init__(self, embed_dim, num_decoder_layers, num_heads, dropout, embed_tokens):
        super().__init__()
        kwargs = {'device': device, 'dtype': torch.float32}
        self.num_decoder_layers = num_decoder_layers
        self.word_embedding = embed_tokens
        self.positional_embedding = NormalPositionalEmbedding(embed_dim, 1024)
        # self.positional_embedding = PositionalEmbedding(embed_dim, dropout=dropout)
        self.norm_embedding = nn.LayerNorm(embed_dim, eps=1e-5, **kwargs)
        self.layers = nn.ModuleList([copy.deepcopy(DecoderLayer(embed_dim, num_heads, dropout)) for _ in range(self.num_decoder_layers)])

    def forward(self, decoder_input, encoder_hidden_states, decoder_self_attention_mask, decoder_cross_attention_mask):
        input_shape = decoder_input.size()

        # # embedding layer
        # word_embed = self.word_embedding(decoder_input)
        # hidden_states = self.positional_embedding(word_embed)
        # hidden_states = self.norm_embedding(hidden_states)

        # embedding layer
        word_embed = self.word_embedding(decoder_input)
        pos_embed = self.positional_embedding(input_shape)
        hidden_states = word_embed + pos_embed
        hidden_states = self.norm_embedding(hidden_states)

        #decoder layers
        for layer in self.layers:
            hidden_states = layer(hidden_states, encoder_hidden_states, decoder_self_attention_mask, decoder_cross_attention_mask)
            
        return hidden_states

In [None]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_of_encoder_decoder, num_heads, dropout):
        super().__init__()
        kwargs = {'device': device, 'dtype': torch.float32}
        self.padding_idx = 1 
        self.word_embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=self.padding_idx, **kwargs)
        self.encoder = Encoder(embed_dim, num_of_encoder_decoder, num_heads, dropout, self.word_embedding)
        self.decoder = Decoder(embed_dim, num_of_encoder_decoder, num_heads, dropout, self.word_embedding)
        self.final_output = nn.Linear(embed_dim, vocab_size, **kwargs)

    def forward(self, encoder_inputs, decoder_inputs, encoder_attention_mask, decoder_self_attention_mask):

        encoder_hidden_states = self.encoder(encoder_inputs, encoder_attention_mask)
        final_hidden_states = self.decoder(decoder_inputs, encoder_hidden_states,  decoder_self_attention_mask, encoder_attention_mask)

        output = self.final_output(final_hidden_states)
        return output

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)


# Init Transformer from github

In [16]:
from Transformer.utils.beamsearch import beam_summarize
from Transformer.model.transformer_model import Transformer
from Transformer.model.transformer_config import TransformerConfig
from transformers import BartForConditionalGeneration, BartTokenizer
import torch

kwargs = {
    'vocab_size': 50264,
    'max_position_embeddings': 1024,
    'num_encoder_layers': 12,
    'encoder_ffn_dim': 4096,
    'encoder_attention_heads': 16,
    'num_decoder_layers': 12,
    'decoder_ffn_dim': 4096,
    'decoder_attention_heads': 16,
    'encoder_layer_dropout': 0.0,
    'decoder_layer_dropout': 0.0,
    'activation_function': 'gelu',
    'layer_norm_eps': 1e-5,
    'model_dim': 1024,
    'dropout': 0.1,
    'attention_dropout': 0.0,
    'activation_dropout': 0.0,
    'pad_token_id': 1,
    'device': device,
    'dtype': torch.float32
}
tran_conf = TransformerConfig(**kwargs)

transformer = Transformer(tran_conf)

tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')

# Utilies

In [6]:
def create_padding_mask(sequence, padding_token, device):
    batch_size, inputs_len = sequence.size()
    mask = (sequence != padding_token)
    mask = mask[:, None, None, :].expand(batch_size, 1, 1, inputs_len)
    mask = mask.to(device)
    return mask

def create_casual_mask(sequence, device):
    batch_size, input_len = sequence.size()
    casual_mask = np.triu(np.ones((batch_size, input_len, input_len)), k=1).astype('uint8')
    casual_mask =  Variable(torch.from_numpy(casual_mask) == 0)
    casual_mask = casual_mask.unsqueeze(1)
    casual_mask = casual_mask.to(device)
    return casual_mask

def create_mask(encoder_inputs, decoder_inputs, padding_token, device):
    encoder_attention_mask = create_padding_mask(encoder_inputs, padding_token, device)
    
    decoder_padding_mask = create_padding_mask(decoder_inputs, padding_token, device)
    decoder_casual_mask = create_casual_mask(decoder_inputs, device)

    decoder_self_attention_mask = decoder_casual_mask.logical_and(decoder_padding_mask)

    return encoder_attention_mask, decoder_self_attention_mask


# Beam Search

In [7]:
def init_beam(model, text, text_mask, max_seq_len, start_token_id, num_beams, device):
    text = text.to(device)
    text_mask = text_mask.to(device)

    memory_state = model.encoder(text, text_mask)
    batch_size, text_length, model_dim = memory_state.shape

    summary = torch.LongTensor([[start_token_id]]).to(device)
    summary_casual_mask = create_casual_mask(summary, device)
    summary_padding_mask = create_padding_mask(summary, 1, device)
    summary_mask = summary_casual_mask.logical_and(summary_padding_mask)

    model_outputs = model.final_output(model.decoder(summary, memory_state, summary_mask, text_mask))
    log_scores, index = F.log_softmax(model_outputs, dim=-1).topk(num_beams)

    model_outputs = torch.zeros((num_beams, max_seq_len), dtype=torch.int32, device=device)
    model_outputs[:, 0] = start_token_id
    model_outputs[:, 1] = index[0]
    memory_state = memory_state.expand(num_beams, text_length, model_dim)

    return model_outputs, log_scores, memory_state

In [8]:
def select_k_top_candidate(model_outputs, prob, log_scores, i, num_beams):
    log_outputs = F.log_softmax(prob, dim=-1)
    
    log_probs, index = log_outputs[:, -1].topk(num_beams)
    log_probs = log_probs + log_scores.transpose(0, 1)
    log_probs, k_index = log_probs.view(-1).topk(num_beams)

    rows = torch.div(k_index, num_beams, rounding_mode='floor')
    cols = k_index % num_beams
    model_outputs[:, :i] = model_outputs[rows, :i]
    model_outputs[:, i] = index[rows, cols]
    
    log_scores = log_probs.unsqueeze(0)

    return model_outputs, log_scores

In [9]:
def beam(model, text, text_mask, max_seq_len, start_token_id, end_token_id, num_beams, device):

    max_seq_len = 1024 if max_seq_len > 1024 else max_seq_len
    chosen_text_index = 0
    model_outputs, log_scores, memory_state = init_beam(model, text, text_mask, max_seq_len, start_token_id, num_beams, device)

    for i in range(2, max_seq_len):
        summary_casual_mask = create_casual_mask(model_outputs[:, :i], device)
        summary_padding_mask = create_padding_mask(model_outputs[:, :i], 1, device)
        summary_mask = summary_casual_mask.logical_and(summary_padding_mask)
        
        prob = model.final_output(model.decoder(model_outputs[:, :i], memory_state, summary_mask, text_mask))
        model_outputs, log_scores = select_k_top_candidate(model_outputs, prob, log_scores, i, num_beams)

        finished_sentences = (model_outputs == end_token_id).nonzero()
        mark_end_tokens = torch.zeros(num_beams, dtype=torch.int64, device=device)
        num_finished_sentences = 0

        for end_token in finished_sentences:
            sentence_ind, end_token_location = end_token
            if mark_end_tokens[sentence_ind] == 0:
                mark_end_tokens[sentence_ind] = end_token_location
                num_finished_sentences += 1
    
        if num_finished_sentences == num_beams:
            alpha = 0.7
            division = mark_end_tokens.type_as(log_scores)**alpha
            _, chosen_text_index = torch.max(log_scores / division, 1)
            chosen_text_index = chosen_text_index[0]
            break
  
    text_length = (model_outputs[chosen_text_index] == end_token_id).nonzero()
    text_length = text_length[0] if len(text_length) > 0 else -1
    return model_outputs[chosen_text_index][:text_length+1]

def beam_summarize(model: torch.nn.Module, tokenizer, text: str, device, num_beams: int = 5):
    model.eval()
    with torch.no_grad():

        text_encodings = tokenizer.batch_encode_plus([text], padding=True)
        text_ids = torch.tensor(text_encodings.get('input_ids'))
        num_tokens = text_ids.shape[1]
        text_mask = create_padding_mask(text_ids, 1, device)

        summary_tokens = beam(
            model,  text_ids, text_mask, max_seq_len=int(num_tokens*0.8), start_token_id=0, end_token_id=2 ,num_beams=num_beams, device=device).flatten()
        summary = tokenizer.decode(summary_tokens.tolist()).replace('<s>','').replace('</s>','').replace('<unk>','')
        return summary

# Optimizer and Loss function

In [15]:
class CustomOptimizer():
    def __init__(self, optimizer, model_dim, num_warmup_steps):
        self.optimizer = optimizer
        self.model_dim = model_dim
        self.num_warmup_steps = num_warmup_steps
        self.num_steps = 0

    def state_dict(self):
        optimizer_state_dict = {
            'model_dim':self.model_dim,
            'num_warmup_steps':self.num_warmup_steps,
            'num_steps':self.num_steps,
            'optimizer':self.optimizer.state_dict(),
        }
        return optimizer_state_dict
    
    def load_state_dict(self, state_dict):
        self.model_dim = state_dict['model_dim']
        self.num_warmup_steps = state_dict['num_warmup_steps']
        self.num_steps = state_dict['num_steps']
        self.optimizer.load_state_dict(state_dict['optimizer'])

    def step(self):
        self._update_learning_rate()
        self.optimizer.step()

    def zero_grad(self):
        self.optimizer.zero_grad()
        
    def _update_learning_rate(self):
        self.num_steps = self.num_steps + 1
        lr_scale_1 = self.num_steps ** (-0.5)
        lr_scale_2 = self.num_steps * self.num_warmup_steps ** (-1.5)
        learning_rate = (self.model_dim ** -0.5) * min(lr_scale_1, lr_scale_2)

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = learning_rate

In [16]:
class CustomCrossEntropyLoss(nn.Module):
    def __init__(self, vocab_size, padding_idx, smoothing_score=0.0):
        super().__init__()
        self.padding_idx = padding_idx
        self.smoothing_score = smoothing_score
        self.label_score = 1.0 - smoothing_score
        self.vocab_size = vocab_size
    
    def forward(self, model_predicts, real_sequences):
        """ 
            model_predict: batch_size x target_length x vocab_size
            real_sequences: batch_size x target_length

        """
        # real_sequences = real_sequences.view(-1)
        # model_predicts = model_predicts.view(-1, self.vocab_size)

        pred_distribution = model_predicts.log_softmax(dim=-1)
        with torch.no_grad():
            true_distribution = torch.zeros_like(pred_distribution)
            true_distribution.fill_(self.smoothing_score / (self.vocab_size - 2))
            true_distribution.scatter_(1, real_sequences.unsqueeze(1), self.label_score)
            true_distribution[:, self.padding_idx] = 0.0
            padding_mask = torch.nonzero(real_sequences.data == self.padding_idx, as_tuple=False)
            if padding_mask.dim() > 0:
                true_distribution.index_fill_(0, padding_mask.squeeze(), 0.0)

        loss_for_idx = torch.sum(-true_distribution * pred_distribution, dim=-1)
        loss = torch.mean(loss_for_idx)
        return loss


# Training and Validation Class

In [17]:
save_path = '/content/drive/MyDrive/'

In [18]:
import time
from tqdm import tqdm
import gc
class Trainer():
    def __init__(self, model, num_epochs, optimizer, train_iter, valid_iter, loss_function, device, padding_idx):
        self.model = model.to(device)
        self.num_epochs = num_epochs
        self.optimizer = optimizer
        self.train_iter = train_iter
        self.valid_iter = valid_iter
        self.loss_function = loss_function
        self.padding_idx = padding_idx
        self.device = device
    
    def _train_epoch(self):
        self.model.train()
        total_loss = 0
        train_dataloader = DataLoader(self.train_iter, batch_size=2, collate_fn=collate_fn, num_workers=0, shuffle=True)
        count = 0
        for text, summary in tqdm(train_dataloader, desc='Training'):
            text_input = text.to(self.device) 
            summary = summary.to(self.device)
            summary_input = summary[:, :-1]

            text_mask, summary_mask = create_mask(text_input, summary_input, self.padding_idx, self.device)

            summary_predict = self.model(text_input, summary_input, text_mask, summary_mask)
            summary_real = summary[:, 1:].contiguous()

            self.optimizer.zero_grad()
            loss = self.loss_function(summary_predict.reshape(-1, summary_predict.shape[-1]), summary_real.reshape(-1))
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
            count = count + 1
            if count % 200 == 0:
                print(total_loss/count)
                torch.save(self.model.state_dict(), save_path + "model.pt")
                torch.save(self.optimizer.state_dict(), save_path + "optimizer.pt")
            
            del summary_input, text_mask, summary_mask, summary_predict, summary_real, text_input, summary, text 
            gc.collect()

        return total_loss/len(train_dataloader)        

    def _validate_epoch(self):
        self.model.eval()
        valid_dataloader = DataLoader(self.valid_iter, batch_size=2, collate_fn=collate_fn, num_workers=0, shuffle=True)
        with torch.no_grad():
            total_loss = 0
            for text, summary in tqdm(valid_dataloader, desc='Validation'):
                text_input = text.to(self.device) 
                summary = summary.to(self.device)
                summary_input = summary[:, :-1]

                text_mask, summary_mask = create_mask(text_input, summary_input, self.padding_idx, self.device)

                summary_predict = self.model(text_input, summary_input, text_mask, summary_mask)
                summary_real = summary[:, 1:].contiguous()

                loss = self.loss_function(summary_predict.reshape(-1, summary_predict.shape[-1]), summary_real.reshape(-1))
                total_loss +=  loss.item()

                del summary_input, text_mask, summary_mask, summary_predict, summary_real, text_input, summary, text 
                gc.collect()

        return total_loss/len(valid_dataloader)

    def train_model(self):
        for epoch in range(1,self.num_epochs+1):
            start_time = time.time()
            train_loss = self._train_epoch()
            end_time = time.time()
            torch.save(self.model.state_dict(), save_path + "model.pt")
            torch.save(self.optimizer.state_dict(), save_path + "optimizer.pt")
            val_loss = self._validate_epoch()
            print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
        return self.model


# Train model

In [19]:
from transformers import BartTokenizer
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch

def collate_fn(batch):
    text_batch, summary_batch = [], []
    for text_sample, summary_sample in batch:
        text_batch.append(text_sample.rstrip("\n"))
        summary_batch.append(summary_sample.rstrip("\n"))
    text_encodings = tokenizer.batch_encode_plus(text_batch, padding=True)
    text_ids = torch.tensor(text_encodings.get('input_ids'))
    
    summary_encodings = tokenizer.batch_encode_plus(summary_batch, padding=True)
    summary_ids = torch.tensor(summary_encodings.get('input_ids'))

    return text_ids, summary_ids

In [20]:
train_set= pd.read_csv('/content/drive/MyDrive/Train_set_short.csv')
train_iter = list(zip(train_set.article, train_set.highlights))
validation_set= pd.read_csv('/content/drive/MyDrive/Validation_set_short.csv')
validation_iter = list(zip(validation_set.article, validation_set.highlights))

In [24]:
transformer = transformer.to(tran_conf.device)
loss_fn = nn.CrossEntropyLoss(ignore_index=1, label_smoothing=0.1)
optimizer = CustomOptimizer(torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9), tran_conf.model_dim, 4000)
trainer = Trainer(transformer, 1, optimizer, train_iter, validation_iter, loss_fn, tran_conf.device, 1)

In [None]:
trainer.train_model()

# Continue training if lacking of training time

In [None]:
from Transformer.utils.beamsearch import beam_summarize
from Transformer.model.transformer_model import Transformer
from Transformer.model.transformer_config import TransformerConfig
from transformers import BartForConditionalGeneration, BartTokenizer
import torch

kwargs = {
    'vocab_size': 50264,
    'max_position_embeddings': 1024,
    'num_encoder_layers': 12,
    'encoder_ffn_dim': 4096,
    'encoder_attention_heads': 16,
    'num_decoder_layers': 12,
    'decoder_ffn_dim': 4096,
    'decoder_attention_heads': 16,
    'encoder_layer_dropout': 0.0,
    'decoder_layer_dropout': 0.0,
    'activation_function': 'gelu',
    'layer_norm_eps': 1e-5,
    'model_dim': 1024,
    'dropout': 0.1,
    'attention_dropout': 0.0,
    'activation_dropout': 0.0,
    'pad_token_id': 1,
    'device': 'cuda',
    'dtype': torch.float32
}
tran_conf = TransformerConfig(**kwargs)

transformer = Transformer(tran_conf)

tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
transformer.load_state_dict(torch.load('/content/drive/MyDrive/final_model.pt'))

In [None]:
train_set= pd.read_csv('/content/drive/MyDrive/Train_set_short.csv')
train_iter = list(zip(train_set.article, train_set.highlights))
validation_set= pd.read_csv('/content/drive/MyDrive/Validation_set_short.csv')
validation_iter = list(zip(validation_set.article, validation_set.highlights))  

In [None]:
transformer = transformer.to(tran_conf.device)
loss_fn = nn.CrossEntropyLoss(ignore_index=1, label_smoothing=0.1)
optimizer = CustomOptimizer(torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9), tran_conf.model_dim, 4000)
trainer = Trainer(transformer, 1, optimizer, train_iter, validation_iter, loss_fn, tran_conf.device, 1)

In [None]:
trainer.train_model()

# Prediction

In [10]:
text = """"Istanbul (CNN)A woman carried out a suicide bombing at a police station in Istanbul's historic Sultanahmet district Tuesday evening, killing one police officer and injuring another, officials said. The attack happened in the section of Turkey's largest city that is home to landmarks such as the Hagia Sophia and the Blue Mosque, and is heavily trafficked by tourists. The bomber, speaking English, entered the police station saying she lost her wallet, and the explosion happened at about 5:20 p.m., Istanbul Gov. Vasip Sahin told reporters. Sahin did not mention a motive for the attack. Sahin initially said that the blast, besides killing the bomber, critically injured one police officer and slightly wounded another. Later Tuesday, Turkey's semi-official Anadolu news agency reported that one of the officers died of his wounds at a hospital. Police cordoned off the area. The attacker's identity is unknown and the incident is being investigated, the governor told reporters. CNN's Gul Tuysuz reported and wrote from Istanbul, and CNN's Jason Hanna wrote in Atlanta. CNN's Hande Atay contributed to this report.?"""

In [None]:
text = """Scientists say weird signals from space are 'probably' aliens. A team of astronomers believes that strange signals emanating from a cluster of stars are actually aliens trying to tell the universe they exist. The study, which appeared in the Publications of the Astronomical Society of the Pacific, analyzed the odd beams of light from 234 stars - a fraction of the 2.5 million that were observed. The bizarre beacons led the paper's authors, Ermanno F. Borra and Eric Trottier from Laval University in Quebec, to conclude that it's "probably" aliens. "We find that the detected signals have exactly the shape of an [extraterrestrial intelligence] signal predicted in the previous publication and are therefore in agreement with this hypothesis," wrote Borra and Trottier. They also note that their findings align with the Extraterrestrial Intelligence (ETI) hypothesis, since the mysterious activity only occurred in a tiny fraction of stars. The hypothesis also suggests that an intelligent life force would use a more sophisticated optical beacon than, say, radio waves to reveal its existence."""

In [None]:
text = """Despite the positive news for tennis star Novak Djokovic on Monday, whether he will be able to compete in the Australian Open later this month still remains unclear. If Djokovic is allowed to stay, when will he play? Following his release from detention, the Serbian tennis star has returned to training, according to his brother. Djokovic has made clear in a series of tweets that he still intends to play in the tournament. We don't know yet when Djokovic's first match would be, but the main draw is on is Thursday January 13."""

In [None]:
summary = beam_summarize(transformer, tokenizer, text, device='cuda')
print(summary)

# Evaluation

In [None]:
!pip install rouge

In [None]:
import pandas as pd
from tqdm import tqdm
from rouge import Rouge 


def rouge(test_path, model, beam_search=True, num_beams=3, return_sentence=False):
    model.eval()
    data_set= pd.read_csv(test_path) 
    data_iter = list(zip(data_set.article, data_set.highlights))[0:50]
    rouge = Rouge()

    total_score = 0
    pred_texts = []
    tgt_texts = []

    for src, tgt in tqdm(data_iter, desc='Rouge score'):
        pred_tgt = beam_summarize(model, src, num_beams=num_beams)
        scores = rouge.get_scores(pred_tgt, tgt)
        print(scores)

In [None]:
a = rouge('/content/drive/MyDrive/Test_set_short.csv', transformer, True, 4, False)