In [100]:
import collections.abc
import json
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import math
import tqdm as progressbar
import time

In [88]:
DETAILS_JSON = "data/email_thread_details.json"
SUMMARIES_JSON = "data/email_thread_summaries.json"

kThreadId = "thread_id"
kSubject = "subject"
kTimestamp = "timestamp"
kFrom = "from"
kTo = "to"
kBody = "body"

kSummary = "summary"

In [89]:
class Utils():
    @staticmethod
    def load_dataset(DETAILS_FILE, SUMMARIES_FILE):
        '''
            This function loads the dataset from the file
            ARGS:
                filename: the name of the file
            RETURN:
                dataset: the dataset
        '''
        with open(DETAILS_FILE, 'r') as f:
            details = json.load(f)
        
        with open(SUMMARIES_FILE, 'r') as f:
            summaries = json.load(f)   
        
        dataset = {}
        for i in range(len(details)):
            item = details[i]
            thread_id = item[kThreadId]
            dataset[thread_id] = dataset.get(thread_id, []) + [item]
        
        for i in range(len(summaries)):
            item = summaries[i]
            thread_id = item[kThreadId]
            dataset[thread_id] = (dataset.get(thread_id), item)


        return dataset
    
    @staticmethod
    def build_vocab(data):
        '''
            This function builds the vocabulary from the data
            ARGS:
                data: the data to build the vocabulary from ([Email], EmailSummaries)
            RETURN:
                vocab: the vocabulary
        '''
        vocab = Vocab()
        for _, (email_list, summary) in data.items():
            for email in email_list:
                for word in email[kBody].split():
                    vocab.add(word)
            for word in summary[kSummary].split():
                vocab.add(word)
        
        return vocab

class Vocab(collections.abc.MutableSet):
    """
        Set-like data structure that can change words into numbers and back.
        From Prof. David Chiang Code
    """
    def __init__(self):
        words = {'<BOS>', '<EOS>', '<UNK>'}
        self.num_to_word = list(words)
        self.word_to_num = {word:num for num, word in enumerate(self.num_to_word)}
    def add(self, word):
        if word in self: return
        num = len(self.num_to_word)
        self.num_to_word.append(word)
        self.word_to_num[word] = num
    def discard(self, word):
        raise NotImplementedError()
    def update(self, words):
        self |= words
    def __contains__(self, word):
        return word in self.word_to_num
    def __len__(self):
        return len(self.num_to_word)
    def __iter__(self):
        return iter(self.num_to_word)

    def numberize(self, word):
        """Convert a word into a number."""
        if word in self.word_to_num:
            return self.word_to_num[word]
        else:
            return self.word_to_num['<UNK>']

    def denumberize(self, num):
        """Convert a number into a word."""
        return self.num_to_word[num]

## Pre Process

In [90]:
# Load the data
d = Utils.load_dataset(DETAILS_JSON, SUMMARIES_JSON)
vocab = Utils.build_vocab(d)

In [91]:
# Split the dictionary into train and test
data = list(d.items())
random.shuffle(data)
train, test = train_test_split(data, test_size=0.2, random_state=42)

train = [(email_list, summary) for _, (email_list, summary) in train]
test = [(email_list, summary) for _, (email_list, summary) in test]

# Models

In [92]:
class Summarizer(nn.Transformer):
    '''
    This class implements the summarizer

    
    '''

    def __init__(self, vocab_size, vocab, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6):
        '''
        This function initializes the model
        ARGS:
            vocab_size: the size of the vocabulary
            d_model: the dimension of the model
            nhead: the number of heads
            num_encoder_layers: the number of encoder layers
            num_decoder_layers: the number of decoder layers
        RETURN:
            None
        '''

        super(Summarizer, self).__init__()
        self.model_type = 'Transformer'
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model) # Embedding layer
        self.transformer = nn.Transformer(
            d_model=d_model, 
            nhead=nhead, 
            num_encoder_layers=num_encoder_layers, 
            num_decoder_layers=num_decoder_layers
        )
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.vocab = vocab

    def forward(self, src, target=None):
        '''
        This function performs the forward pass of the model
        ARGS:
            src: the source input
            target: the target input (optional, used during training)
        RETURN:
            output: the output of the model
        '''
        src = self.embedding(src)
        
        if target is not None:
            target = self.embedding(target)
            output = self.transformer(src, target)
        else:
            # In generation mode, don't use target
            output = self.transformer(src, src)  # Use src as both source and target TODO: 
        
        output = self.fc_out(output)

        return output
    
    def summarize(self, src, max_len=100, mode="greedy"):
        '''
        This function summarize the input text
            args:
                src: the source input
                max_len: the maximum length of the output
                mode: the mode of generation (greedy or beam search)
            return:
                output: the output of the model
        '''
        o = self.forward(src)
        output = None
        
        if mode == "greedy":
            output =  self.greedy_decoding(o, max_len)
        elif mode == "beam":
            output = self.beam_search(o, max_len)

        return output
        
    def greedy_decoding(self, o, max_len):
        '''
        This function performs greedy decoding
        ARGS:
            o: the output of the model
            max_len: the maximum length of the output
        RETURN:
            output: the output of the model
        '''
        output = []

        for i in range(max_len):
            a = torch.argmax(o[i])
            a = self.vocab.denumberize(a)
            output.append(a)
        return output

    def beam_search(self, o, max_len):
        '''
        
        '''
        pass

## Training

In [93]:
ntokens = len(vocab) # size of vocabulary
emsize = 100 # embedding dimension
nhid = 100 # the dimension of the feedforward network model in nn.TransformerEncoder
n_encoder_layers = 6 # the number of encoder layers
n_decoder_layers = 6 # the number of decoder layers
nhead = 2 # the number of heads in the multiheadattention models
lr = 0.02 # learning rate

model = Summarizer(ntokens, vocab, emsize, nhead, n_encoder_layers, n_decoder_layers)

In [94]:
def train_summarizer(model: Summarizer, train_data, criterion, max_input_len = 150, max_output_len = 50, lr=0.001, threshold_norm=0.5):
    model.train()  # Turn on the train mode
    total_loss = 0.

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # Initialize Adam optimizer

    for item in progressbar.tqdm(train_data, desc="Thread Training", total=len(train_data)):
        optimizer.zero_grad()  # Zero the gradients
        thread, summary = item

        for email in thread:
            email_tensor = torch.tensor([model.vocab.numberize(word) for word in email[kBody].split()]) # Convert email to tensor
            summary_tensor = torch.tensor([model.vocab.numberize(word) for word in summary[kSummary].split()]) # Convert summary to tensor
            

            if email_tensor.nelement() > max_input_len: # Truncate email if it is too long
                email_tensor = email_tensor[:max_input_len]

            output = model(email_tensor, summary_tensor)  # Forward pass
            
            output = output.view(-1, ntokens) # Reshape output 
            loss = criterion(output, summary_tensor) # Calculate loss
            loss.backward() # Backward pass
            torch.nn.utils.clip_grad_norm_(model.parameters(), threshold_norm) # Clip gradients
            optimizer.step() # Update weights
            total_loss += loss.item() 
    return model


In [114]:
curr_time = time.strftime("%Y-%m-%d_%H:%M:%S")
MODEL_PATH = f"models/model.pt-{curr_time}.pt"

model = train_summarizer(model, train, nn.CrossEntropyLoss(), lr=lr)
# save the model
torch.save(model.state_dict(), MODEL_PATH)

Thread Training:   1%|▏         | 49/3333 [00:51<1:02:38,  1.14s/it]

## Evaluation

In [97]:
import evaluate
rouge = evaluate.load('rouge')

In [98]:
# Evaluate the output

def evaluate(model: Summarizer, test_data, criterion, rouge:evaluate.EvaluationModule, max_input_len = 150, max_output_len = 50, mode="greedy"):
    model.eval()  # Turn on the evaluation mode
    
    total_loss = 0.
    evals = []
    with torch.no_grad():
        for item in progressbar.tqdm(test_data, desc="Thread Evaluation", total=len(test_data)):
            thread, summary = item
            email_tensor_final = torch.tensor([])
            summary_string = summary[kSummary]
            for email in thread:
                email_tensor = torch.tensor([model.vocab.numberize(word) for word in email[kBody].split()]) # Convert email to tensor
                # trim email if it is too long
                trim_len = math.ceil(max_input_len / len(thread))
                if email_tensor.nelement() > trim_len: # Truncate email if it is too long
                    email_tensor = email_tensor[:trim_len]
                
                # Concatenate email tensors
                email_tensor_final = torch.cat((email_tensor_final, email_tensor), 0)

            output = model.summarize(email_tensor_final)
            loss = criterion(output, summary_string)
            total_loss += loss.item()
            
            rouge_score = rouge.compute(output, summary_string)

            evals.append((summary[kThreadId], output, rouge_score, loss.item()))
    
    return evals, total_loss