In [None]:
import collections.abc
import json
import random
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
import tqdm as progressbar
import time
import nltk
from nltk.tokenize import word_tokenize, sent_tokenize
nltk.download('punkt')
import re
import math
import sys
import numpy as np

In [None]:
DETAILS_JSON = "data/email_thread_details.json"
SUMMARIES_JSON = "data/email_thread_summaries.json"

kThreadId = "thread_id"
kSubject = "subject"
kTimestamp = "timestamp"
kFrom = "from"
kTo = "to"
kBody = "body"
kSummary = "summary"

BOS = "<BOS>"
EOS = "<EOS>"
UNK = "<UNK>"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Notes

- Reduced vocab from 172743 to 16856
- `#` WOW

In [None]:
class Utils():
    @staticmethod
    def load_dataset(DETAILS_FILE, SUMMARIES_FILE):
        '''
            This function loads the dataset from the file
            ARGS:
                filename: the name of the file
            RETURN:
                dataset: the dataset
        '''
        with open(DETAILS_FILE, 'r') as f:
            details = json.load(f)
        
        with open(SUMMARIES_FILE, 'r') as f:
            summaries = json.load(f)   
        
        dataset = {}
        for i in range(len(details)):
            item = details[i]
            thread_id = item[kThreadId]
            item = Utils.tokenize_body(item)
            dataset[thread_id] = dataset.get(thread_id, []) + [item]
        
        for i in range(len(summaries)):
            item = summaries[i]
            thread_id = item[kThreadId]
            item = Utils.tokenize_summary(item)
            dataset[thread_id] = (dataset.get(thread_id), item)

        return dataset

    @staticmethod
    def tokenize_body(item):
        sentences = sent_tokenize(item[kBody])
        item[kBody] = [word_tokenize(sentence) for sentence in sentences]
        item[kBody] = " ".join([word for sentence in item[kBody] for word in sentence])
        item[kBody] = "<BOS> " + item[kBody] + " <EOS>"
        return item

    @staticmethod
    def tokenize_summary(item):
        sentences = sent_tokenize(item[kSummary])
        item[kSummary] = [word_tokenize(sentence) for sentence in sentences]
        item[kSummary] = " ".join([word for sentence in item[kSummary] for word in sentence])
        item[kSummary] = "<BOS> " + item[kSummary] + " <EOS>"
        return item
    
    @staticmethod
    def build_vocab(data):
        '''
            This function builds the vocabulary from the data
            ARGS:
                data: the data to build the vocabulary from ([Email], EmailSummaries)
            RETURN:
                vocab: the vocabulary
        '''
        vocab = Vocab()
        for _, (email_list, summary) in data.items():
            for email in email_list:

                for word in email:
                    vocab.add(word)
            for word in summary[kSummary].split():
                vocab.add(word)
        
        return vocab

class Vocab(collections.abc.MutableSet):
    """
        Set-like data structure that can change words into numbers and back.
        From Prof. David Chiang Code
    """
    def __init__(self):
        words = {'<BOS>', '<EOS>', '<UNK>'}
        self.num_to_word = list(words)
        self.word_to_num = {word:num for num, word in enumerate(self.num_to_word)}
    def add(self, word):
        if word in self: return
        num = len(self.num_to_word)
        self.num_to_word.append(word)
        self.word_to_num[word] = num
    def discard(self, word):
        raise NotImplementedError()
    def update(self, words):
        self |= words
    def __contains__(self, word):
        return word in self.word_to_num
    def __len__(self):
        return len(self.num_to_word)
    def __iter__(self):
        return iter(self.num_to_word)

    def numberize(self, word):
        """Convert a word into a number."""
        if word in self.word_to_num:
            return self.word_to_num[word]
        else:
            return self.word_to_num['<UNK>']

    def denumberize(self, num):
        """Convert a number into a word."""
        return self.num_to_word[num]

## Pre Process

In [None]:
# Load the data
d = Utils.load_dataset(DETAILS_JSON, SUMMARIES_JSON)
vocab = Utils.build_vocab(d)
len_vocab = len(vocab)
print("Vocab Size: ", len_vocab)

In [None]:
# Split the dictionary into train and test
data = list(d.items())
random.shuffle(data)
train, test = train_test_split(data, test_size=0.2, random_state=42)

train = [(email_list, summary) for _, (email_list, summary) in train]
test = [(email_list, summary) for _, (email_list, summary) in test]

In [None]:
# Print Sample Email
i = random.randint(0, len(train))
print(f"Sample Thread {i}: ")
s = " ".join([email[kBody] for email in train[i][0]])
print(s)
print(len(s.split()))
print("Sample Summary: ")
print(train[i][1][kSummary])
print(len(train[i][1][kSummary].split()))


# Models

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len):
        super().__init__()
        # Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
        # max_len determines how far the position can have an effect on a token (window)
        
        # Info
        self.dropout = nn.Dropout(dropout_p)
        
        # Encoding - From formula
        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
        
        # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        
        # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding",pos_encoding)
        
    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])

In [None]:
class Summarizer(nn.Transformer):
    '''
    This class implements the summarizer
    '''

    def __init__(self, vocab_size, vocab, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6):
        '''
        This function initializes the model
        ARGS:
            vocab_size: the size of the vocabulary
            d_model: the dimension of the model
            nhead: the number of heads
            num_encoder_layers: the number of encoder layers
            num_decoder_layers: the number of decoder layers
        RETURN:
            None
        '''

        super(Summarizer, self).__init__()
        self.model_type = 'Transformer'
        self.d_model = d_model
        self.pos = PositionalEncoding(
            d_model, 
            0.1, 
            5000
        )

        self.embedding = nn.Embedding(vocab_size, d_model) # Embedding layer

        self.transformer = nn.Transformer(
            d_model=d_model, 
            nhead=nhead, 
            num_encoder_layers=num_encoder_layers, 
            num_decoder_layers=num_decoder_layers,
            dropout=0.1
        )
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.vocab = vocab
        
        self.kGreedy = "greedy"
        self.kTopP = "top_p"
        self.kBeam = "beam"

    def forward(self, src:torch.TensorType, tgt:torch.TensorType, tgt_mask=None, src_pad_mask=None, tgt_pad_mask=None):
        # Src size must be (batch_size, src sequence length)
        # Tgt size must be (batch_size, tgt sequence length)

        # Embedding + positional encoding - Out size = (batch_size, sequence length, dim_model)
        src = self.embedding(src) * math.sqrt(self.d_model)
        tgt = self.embedding(tgt) * math.sqrt(self.d_model)
        src = self.pos(src)
        tgt = self.pos(tgt)
        
        # We could use the parameter batch_first=True, but our KDL version doesn't support it yet, so we permute
        # to obtain size (sequence length, batch_size, dim_model),
        src = src.permute(1,0,2)
        tgt = tgt.permute(1,0,2)

        # Transformer blocks - Out size = (sequence length, batch_size, num_tokens)
        transformer_out = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask)
        out = self.out(transformer_out)
        
        return out

    def get_tgt_mask(self, size) -> torch.tensor:
        # Generates a squeare matrix where the each row allows one word more to be seen
        mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
        mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
        
        # EX for size=5:
        # [[0., -inf, -inf, -inf, -inf],
        #  [0.,   0., -inf, -inf, -inf],
        #  [0.,   0.,   0., -inf, -inf],
        #  [0.,   0.,   0.,   0., -inf],
        #  [0.,   0.,   0.,   0.,   0.]]
        
        return mask
    
    def create_pad_mask(self, matrix: torch.tensor, pad_token: int) -> torch.tensor:
        # If matrix = [1,2,3,0,0,0] where pad_token=0, the result mask is
        # [False, False, False, True, True, True]
        return (matrix == pad_token)
    


    def summarize(model, input_sequence, max_length=15, SOS_token=2, EOS_token=3, device='cuda'):
        """
        Method from "A detailed guide to Pytorch's nn.Transformer() module.", by
        Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
        """
        model.eval()
        
        y_input = torch.tensor([[SOS_token]], dtype=torch.long, device=device)

        num_tokens = len(input_sequence[0])

        for _ in range(max_length):
            # Get source mask
            tgt_mask = model.get_tgt_mask(y_input.size(1)).to(device)
            
            pred = model(input_sequence, y_input, tgt_mask)
            
            next_item = pred.topk(1)[1].view(-1)[-1].item() # num with highest probability
            next_item = torch.tensor([[next_item]], device=device)

            # Concatenate previous input with predicted best word
            y_input = torch.cat((y_input, next_item), dim=1)

            # Stop if model predicts end of sentence
            if next_item.view(-1).item() == EOS_token:
                break

        return y_input.view(-1).tolist()

    # def summarize(self, src, max_len=100, mode="top_p"):
    #     '''
    #     This function summarize the input text
    #         args:
    #             src: the source input
    #             max_len: the maximum length of the output
    #             mode: the mode of generation (greedy or beam search)
    #         return:
    #             output: the output of the model
    #     '''
    #     src = torch.tensor([self.vocab.numberize(word) for word in src])
    #     o = self.forward(src)
    #     output = None
        
    #     if mode == self.kGreedy:
    #         output =  self.greedy_decoding(o, max_len)
    #     elif mode == self.kTopP:
    #         output = self.top_p_decoding(o, max_len)
    #     elif mode == self.kBeam:
    #         output = self.beam_search(o, max_len)

    #     return output
        
    def greedy_decoding(self, o, max_len):
        '''
        This function performs greedy decoding
        ARGS:
            o: the output of the model
            max_len: the maximum length of the output
        RETURN:
            output: the output of the model
        '''
        output = []
        words = 0
        for i in o:
            if words >= max_len:
                break
            a = torch.argmax(i)
            if a == self.vocab.numberize("<EOS>"):
                break
            a = self.vocab.denumberize(a)
            output.append(a)
            words += 1
        return output
    
    def top_p_decoding(self, o, max_len = 50, p=0.9):
        '''
        This function performs top-p decoding
        ARGS:
            o: the output of the model
            max_len: the maximum length of the output
            p: the probability threshold
        RETURN:
            output: the output of the model
        '''
        output = []
        words = 0
        for i in o:
            if words >= max_len:
                break
            sorted_prob, sorted_idx = torch.sort(i, descending=False)
            sorted_prob = torch.exp(sorted_prob)
            sorted_prob_cumsum = sorted_prob.cumsum(dim=0)
            top_p_batch = sorted_idx[sorted_prob_cumsum > p]

            if top_p_batch.nelement() > 0:
                next_token = random.choice(top_p_batch)
                output.append(next_token.item())
            else:
                next_token = random.choice(sorted_idx)
                output.append(next_token.item())
            
            if output[-1] == self.vocab.numberize("<EOS>"):
                break



        for i, tensor in enumerate(output):
            output[i] = self.vocab.denumberize(tensor)
    
        return output

    def beam_search(self, o, max_len):
        '''
        
        '''
        pass

## Training

In [None]:
ntokens = len(vocab) # size of vocabulary
emsize = 160 # embedding dimension
nhid = 160 # the dimension of the feedforward network model in nn.TransformerEncoder
n_encoder_layers = 6 # the number of encoder layers
n_decoder_layers = 6 # the number of decoder layers
nhead = 8 # the number of heads in the multiheadattention models
lr = 0.001 # learning rate

model = Summarizer(ntokens, vocab, emsize, nhead, n_encoder_layers, n_decoder_layers).to(DEVICE)
opt = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

# RECENT_MODEL = "models/model.pt-2023-12-04_15_40_57.pt"
# model.load_state_dict(torch.load(RECENT_MODEL, map_location='cpu'))

In [None]:
import evaluate
rouge = evaluate.load('rouge')

In [None]:
def train_loop(model:Summarizer, opt, loss_fn, training_data):
    model.train()
    total_loss = 0

    for batch in progressbar(training_data, desc="Training", file=sys.stdout, total=len(training_data)):
        thread, summary = batch
        thread_body = " ".join([email[kBody] for email in thread])
        thread_body = torch.Tensor([vocab.numberize(word) for word in thread_body.split()], dtype=torch.long, device=DEVICE)
        summary = torch.Tensor([vocab.numberize(word) for word in summary[kSummary].split()], dtype=torch.long, device=DEVICE)

        # Now we shift the summary by one so with the <BOS> we predict the token at pos 1
        summary_input = summary[:-1]
        summary_expected = summary[1:]
  
        # Get mask to mask out the next words
        sequence_length = summary_input.size(0)
        tgt_mask = model.get_tgt_mask(sequence_length).to(DEVICE)

        predict = model(thread_body, summary_input, tgt_mask)

        # Permute pred to have batch size first again
        predict = predict.permute(1, 2, 0)
        loss = loss_fn(predict, summary_expected)

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_loss += loss.detach().item()


        break
        
    return total_loss / len(training_data)


In [None]:
train_loss = train_loop(model, opt, loss_fn, train)

In [None]:
def validation_loop(model, loss_fn, dev_data):
    """
    Method from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """
    
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in progressbar(dev_data, desc="Validation", file=sys.stdout, total=len(dev_data)):
            thread, summary = batch
            thread_body = " ".join([email[kBody] for email in thread])
            thread_body = torch.Tensor([vocab.numberize(word) for word in thread_body.split()], dtype=torch.long, device=DEVICE)
            summary = torch.Tensor([vocab.numberize(word) for word in summary[kSummary].split()], dtype=torch.long, device=DEVICE)

            # Now we shift the tgt by one so with the <BOS> we predict the token at pos 1
            summary_input = summary[:-1]
            summary_expected = summary[1:]
            
            # Get mask to mask out the next words
            sequence_length = summary_input.size(0)
            tgt_mask = model.get_tgt_mask(sequence_length).to(DEVICE)

            # Standard training except we pass in y_input and src_mask
            predict = model(thread_body, summary_input, tgt_mask)

            # Permute pred to have batch size first again
            predict = predict.permute(1, 2, 0)
            loss = loss_fn(predict, summary_expected)
            total_loss += loss.detach().item()
 
        
    return total_loss / len(dev_data)

In [None]:
valid_loss = validation_loop(model, loss_fn, test)

In [None]:
def fit(model, opt, loss_fn, train_dataloader, val_dataloader, epochs):
    """
    Method from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """
    
    # Used for plotting later on
    train_loss_list, validation_loss_list = [], []
    
    print("Training and validating model")
    for epoch in range(epochs):
        print("-"*25, f"Epoch {epoch + 1}","-"*25)
        
        train_loss = train_loop(model, opt, loss_fn, train_dataloader)
        train_loss_list += [train_loss]
        
        validation_loss = validation_loop(model, loss_fn, val_dataloader)
        validation_loss_list += [validation_loss]
        
        print(f"Training loss: {train_loss:.4f}")
        print(f"Validation loss: {validation_loss:.4f}")
        print()
        
    return train_loss_list, validation_loss_list

In [None]:
train_subset = train[:10]
dev_subset = test[:10]

train_loss_list, validation_loss_list = fit(model, opt, loss_fn, train_subset, dev_subset, 1)

In [None]:
def train_summarizer(model: Summarizer, train_data, dev_data, criterion, epochs=1, lr=0.003):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # Initialize Adam optimizer
    prev_dev_loss = best_dev_loss = None
    model.train()  # Turn on the train mode

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1} / {epochs}")
        random.shuffle(train_data)
        train_loss = 0

        for item in progressbar.tqdm(train_data, desc="Thread Training", total=len(train_data)):
            thread, summary = item
            email_thread_body = [email[kBody] for email in thread]
            email_thread_body = " ".join(email_thread_body)

            email_tensor = torch.tensor([model.vocab.numberize(word) for word in email_thread_body.split()], dtype=torch.int64).to(DEVICE) # Convert email to tensor
            summary_tensor = torch.tensor([model.vocab.numberize(word) for word in summary[kSummary].split()],dtype=torch.int64).to(DEVICE) # Convert summary to tensor
            
            output = model(email_tensor, summary_tensor)  # Forward pass
            summary_tensor = summary_tensor[1:] # Remove <BOS> token
            loss = criterion(output, summary_tensor) # Calculate loss
            optimizer.zero_grad()  # Zero the gradients
            loss.backward() # Backward pass
            # torch.nn.utils.clip_grad_norm_(model.parameters(), threshold_norm) # Clip gradients
            optimizer.step() # Update weights
            train_loss += loss.item()
        
        model.eval()
        dev_loss = 0
        evals = []
        line_num = 0
        for item in progressbar.tqdm(dev_data, desc="Thread Dev", total=len(dev_data)):
            thread, summary = item
            summary_string = summary[kSummary]

            email_thread_body = [email[kBody] for email in thread]
            email_thread_body = " ".join(email_thread_body)

            output = model.summarize(email_thread_body, mode="top_p", max_len=40)
            output = " ".join(output)
            score = rouge.compute(predictions=[output], references=[summary_string])
            evals.append((summary[kThreadId], output, summary_string, score))
            if line_num < 5 and epoch % 5 == 0:
                print(f"Thread ID: {summary[kThreadId]}")
                print(f"Email Thread: {email_thread_body}")
                print(f"Predicted Summary: {output}")
                print(f"Actual Summary: {summary_string}")
                print(f"Score: {score}")
                print("-----------------------------------")
            line_num += 1
        
        print(f"Epoch {epoch + 1}/{epochs}: Average Train Loss: {train_loss/len(train_data)}",file=sys.stderr ,flush=True)

    return model


In [None]:
train_subset = train[:10]
dev_subset = test[:10]

curr_time = time.strftime("%Y-%m-%d_%H:%M:%S")
MODEL_PATH = f"models/model.pt-{curr_time}.pt"

model = train_summarizer(
    model=model, 
    train_data=train, 
    dev_data=train_subset, 
    criterion=nn.CrossEntropyLoss(),
    epochs=10,
)
# save the model

# torch.save(model.state_dict(), MODEL_PATH)

In [None]:
text = " ".join([email[kBody] for email in train_subset[0][0]])
summary = train_subset[0][1][kSummary]
o = model.summarize(text, max_len=50, mode="top_p")
o = " ".join(o)
print(o)
print(summary)

## Evaluation

In [None]:
# load the model
# RECENT_MODEL = "models/model.pt-2023-12-05_18:16:19.pt"
# model.load_state_dict(torch.load(RECENT_MODEL, map_location='cpu'))

In [None]:
# Evaluate the output

def evaluate(model: Summarizer, test_data, criterion, rouge, max_input_len = 150, max_output_len = 50, mode="greedy"):
    model.eval()  # Turn on the evaluation mode
    
    total_loss = 0.
    evals = []
    with torch.no_grad():
        for item in progressbar.tqdm(test_data, desc="Thread Evaluation", total=len(test_data)):
            thread, summary = item
            summary_string = summary[kSummary]

            email_thread_body = [email[kBody] for email in thread]
            email_thread_body = " ".join(email_thread_body)

            output = model.summarize(email_thread_body, mode="top_p")
        
            # loss = criterion(output, summary_string)
            # total_loss += loss.item()
            output_str = " ".join(output)
            rouge_score = rouge.compute(predictions=[output_str], references=[summary_string])

            evals.append((summary[kThreadId], output_str, summary_string ,rouge_score))
    
    return evals