In [153]:
import collections.abc
import json
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import math
import tqdm as progressbar
import time
import nltk
from nltk.tokenize import word_tokenize
nltk.download('punkt')
import re

[nltk_data] Downloading package punkt to /Users/saint/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [154]:
DETAILS_JSON = "data/email_thread_details.json"
SUMMARIES_JSON = "data/email_thread_summaries.json"

kThreadId = "thread_id"
kSubject = "subject"
kTimestamp = "timestamp"
kFrom = "from"
kTo = "to"
kBody = "body"

kSummary = "summary"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Notes

- Reduced vocab from 172743 to 16856
- `#` WOW

In [191]:
class Utils():
    @staticmethod
    def load_dataset(DETAILS_FILE, SUMMARIES_FILE):
        '''
            This function loads the dataset from the file
            ARGS:
                filename: the name of the file
            RETURN:
                dataset: the dataset
        '''
        with open(DETAILS_FILE, 'r') as f:
            details = json.load(f)
        
        with open(SUMMARIES_FILE, 'r') as f:
            summaries = json.load(f)   
        
        dataset = {}
        for i in range(len(details)):
            item = details[i]
            thread_id = item[kThreadId]
            item = Utils.tokenize_body(item)
            dataset[thread_id] = dataset.get(thread_id, []) + [item]
        
        for i in range(len(summaries)):
            item = summaries[i]
            thread_id = item[kThreadId]
            item = Utils.tokenize_summary(item)
            dataset[thread_id] = (dataset.get(thread_id), item)


        return dataset

    @staticmethod
    def tokenize_body(item):
        item[kBody] = word_tokenize(item[kBody])
        item[kBody] = [re.sub(r'[^\w\s.]', '', word) for word in item[kBody]]
        item[kBody] = [word.strip() for word in item[kBody] if word.strip() and word.strip() not in ['--', '=']]
        # Lowercase the email body
        item[kBody] = [word.lower() for word in item[kBody]]
        item[kBody] = ["<BOS>"] + item[kBody] + ["<EOS>"]
        item[kBody] = " ".join(item[kBody])
        return item

    @staticmethod
    def tokenize_summary(item):
        item[kSummary] = word_tokenize(item[kSummary])
        item[kSummary] = [re.sub(r'[^\w\s.]', '', word) for word in item[kSummary]]
        item[kSummary] = [word.strip() for word in item[kSummary] if word.strip() and word.strip() not in ['--', '=']]
        # Lowercase the summary
        item[kSummary] = [word.lower() for word in item[kSummary]]
        item[kSummary] = "<BOS> " + " ".join(item[kSummary]) + "<EOS>"
        return item
    
    @staticmethod
    def build_vocab(data):
        '''
            This function builds the vocabulary from the data
            ARGS:
                data: the data to build the vocabulary from ([Email], EmailSummaries)
            RETURN:
                vocab: the vocabulary
        '''
        vocab = Vocab()
        for _, (email_list, summary) in data.items():
            for email in email_list:

                for word in email:
                    vocab.add(word)
            for word in summary[kSummary].split():
                vocab.add(word)
        
        return vocab

class Vocab(collections.abc.MutableSet):
    """
        Set-like data structure that can change words into numbers and back.
        From Prof. David Chiang Code
    """
    def __init__(self):
        words = {'<BOS>', '<EOS>', '<UNK>'}
        self.num_to_word = list(words)
        self.word_to_num = {word:num for num, word in enumerate(self.num_to_word)}
    def add(self, word):
        if word in self: return
        num = len(self.num_to_word)
        self.num_to_word.append(word)
        self.word_to_num[word] = num
    def discard(self, word):
        raise NotImplementedError()
    def update(self, words):
        self |= words
    def __contains__(self, word):
        return word in self.word_to_num
    def __len__(self):
        return len(self.num_to_word)
    def __iter__(self):
        return iter(self.num_to_word)

    def numberize(self, word):
        """Convert a word into a number."""
        if word in self.word_to_num:
            return self.word_to_num[word]
        else:
            return self.word_to_num['<UNK>']

    def denumberize(self, num):
        """Convert a number into a word."""
        return self.num_to_word[num]

## Pre Process

In [192]:
# Load the data
d = Utils.load_dataset(DETAILS_JSON, SUMMARIES_JSON)
vocab = Utils.build_vocab(d)


In [157]:
# Split the dictionary into train and test
data = list(d.items())
random.shuffle(data)
train, test = train_test_split(data, test_size=0.2, random_state=42)

train = [(email_list, summary) for _, (email_list, summary) in train]
test = [(email_list, summary) for _, (email_list, summary) in test]

In [193]:
# print(ntokens, len(vocab))


16831 16856


# Models

In [188]:
class Summarizer(nn.Transformer):
    '''
    This class implements the summarizer
    '''

    def __init__(self, vocab_size, vocab, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6):
        '''
        This function initializes the model
        ARGS:
            vocab_size: the size of the vocabulary
            d_model: the dimension of the model
            nhead: the number of heads
            num_encoder_layers: the number of encoder layers
            num_decoder_layers: the number of decoder layers
        RETURN:
            None
        '''

        super(Summarizer, self).__init__()
        self.model_type = 'Transformer'
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model) # Embedding layer
        self.transformer = nn.Transformer(
            d_model=d_model, 
            nhead=nhead, 
            num_encoder_layers=num_encoder_layers, 
            num_decoder_layers=num_decoder_layers
        )
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.vocab = vocab
        
        self.kGreedy = "greedy"
        self.kTopP = "top_p"
        self.kBeam = "beam"

    def forward(self, src, target=None):
        '''
        This function performs the forward pass of the model
        ARGS:
            src: the source input
            target: the target input (optional, used during training)
        RETURN:
            output: the output of the model
        '''
        # reduce the size of the src


        src = self.embedding(src)
        src = torch.mean(src, dim=0)
        
        if target is not None:
            target = self.embedding(target)
            src = src.unsqueeze(0).expand(target.size(0), -1)
            output = self.transformer(src, target)
        else:
            # In generation mode, don't use target
            src = src.unsqueeze(0).expand(src.size(0), -1)
            output = self.transformer(src, src)  # Use src as both source and target TODO: 
        
        output = self.fc_out(output)

        return output
    
    def summarize(self, src, max_len=100, mode="top_p"):
        '''
        This function summarize the input text
            args:
                src: the source input
                max_len: the maximum length of the output
                mode: the mode of generation (greedy or beam search)
            return:
                output: the output of the model
        '''
        src = torch.tensor([self.vocab.numberize(word) for word in src])
        o = self.forward(src)
        output = None
        
        if mode == self.kGreedy:
            output =  self.greedy_decoding(o, max_len)
        elif mode == self.kTopP:
            output = self.top_p_decoding(o, max_len)
        elif mode == self.kBeam:
            output = self.beam_search(o, max_len)

        return output
        
    def greedy_decoding(self, o, max_len):
        '''
        This function performs greedy decoding
        ARGS:
            o: the output of the model
            max_len: the maximum length of the output
        RETURN:
            output: the output of the model
        '''
        output = []
        words = 0
        for i in o:
            if words >= max_len:
                break
            a = torch.argmax(i)
            if a == self.vocab.numberize("<EOS>"):
                break
            a = self.vocab.denumberize(a)
            output.append(a)
            words += 1
        return output
    
    def top_p_decoding(self, o, max_len = 50, p=0.9):
        '''
        This function performs top-p decoding
        ARGS:
            o: the output of the model
            max_len: the maximum length of the output
            p: the probability threshold
        RETURN:
            output: the output of the model
        '''
        output = []
        words = 0
        for i in o:
            if words >= max_len:
                break
            sorted_prob, sorted_idx = torch.sort(i, descending=False)
            sorted_prob = torch.exp(sorted_prob)
            sorted_prob_cumsum = sorted_prob.cumsum(dim=0)
            top_p_batch = sorted_idx[sorted_prob_cumsum > p]

            if top_p_batch.nelement() > 0:
                next_token = random.choice(top_p_batch)
                output.append(next_token.item())
            else:
                next_token = random.choice(sorted_idx)
                output.append(next_token.item())
            
            if output[-1] == self.vocab.numberize("<EOS>"):
                break



        for i, tensor in enumerate(output):
            output[i] = self.vocab.denumberize(tensor)
    
        return output

    def beam_search(self, o, max_len):
        '''
        
        '''
        pass

## Training

In [196]:
ntokens = len(vocab) # size of vocabulary
emsize = 150 # embedding dimension
nhid = 150 # the dimension of the feedforward network model in nn.TransformerEncoder
n_encoder_layers = 6 # the number of encoder layers
n_decoder_layers = 6 # the number of decoder layers
nhead = 3 # the number of heads in the multiheadattention models
lr = 0.001 # learning rate

model = Summarizer(ntokens, vocab, emsize, nhead, n_encoder_layers, n_decoder_layers)

# RECENT_MODEL = "models/model.pt-2023-12-04_15_40_57.pt"
# model.load_state_dict(torch.load(RECENT_MODEL, map_location='cpu'))

In [185]:
def train_summarizer(model: Summarizer, train_data, criterion, max_input_len = 150, max_output_len = 50, lr=0.001, threshold_norm=0.5):
    model.train()  # Turn on the train mode
    total_loss = 0.

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # Initialize Adam optimizer

    for item in progressbar.tqdm(train_data, desc="Thread Training", total=len(train_data)):
        optimizer.zero_grad()  # Zero the gradients
        thread, summary = item

        email_thread_body = [email[kBody] for email in thread]
        email_thread_body = " ".join(email_thread_body)

        email_tensor = torch.tensor([model.vocab.numberize(word) for word in email_thread_body.split()]) # Convert email to tensor
        summary_tensor = torch.tensor([model.vocab.numberize(word) for word in summary[kSummary].split()]) # Convert summary to tensor
        output = model(email_tensor, summary_tensor)  # Forward pass
        
        output = output.view(-1, ntokens) # Reshape output 
        loss = criterion(output, summary_tensor) # Calculate loss
        loss.backward() # Backward pass
        torch.nn.utils.clip_grad_norm_(model.parameters(), threshold_norm) # Clip gradients
        optimizer.step() # Update weights
        total_loss += loss.item() 
    return model


In [162]:
curr_time = time.strftime("%Y-%m-%d_%H:%M:%S")
MODEL_PATH = f"models/model.pt-{curr_time}.pt"
for i in range(15):
    print(f"Epoch {i}/ {15}")
    model = train_summarizer(model, train, nn.CrossEntropyLoss(), lr=lr)
# save the model

torch.save(model.state_dict(), MODEL_PATH)

Thread Training: 100%|██████████| 3333/3333 [05:05<00:00, 10.90it/s]
Thread Training: 100%|██████████| 3333/3333 [10:27<00:00,  5.31it/s]  
Thread Training: 100%|██████████| 3333/3333 [15:15<00:00,  3.64it/s]   
Thread Training: 100%|██████████| 3333/3333 [08:21<00:00,  6.65it/s]
Thread Training: 100%|██████████| 3333/3333 [08:29<00:00,  6.55it/s]
Thread Training: 100%|██████████| 3333/3333 [08:05<00:00,  6.87it/s]  
Thread Training: 100%|██████████| 3333/3333 [07:48<00:00,  7.11it/s]  
Thread Training: 100%|██████████| 3333/3333 [05:46<00:00,  9.61it/s]
Thread Training: 100%|██████████| 3333/3333 [05:00<00:00, 11.10it/s]
Thread Training: 100%|██████████| 3333/3333 [09:54<00:00,  5.60it/s]  
Thread Training: 100%|██████████| 3333/3333 [11:09<00:00,  4.97it/s]   
Thread Training: 100%|██████████| 3333/3333 [08:18<00:00,  6.69it/s]   
Thread Training: 100%|██████████| 3333/3333 [08:49<00:00,  6.29it/s]  
Thread Training: 100%|██████████| 3333/3333 [16:29<00:00,  3.37it/s]   
Thread Train

## Evaluation

In [163]:
import evaluate
rouge = evaluate.load('rouge')

In [201]:
# load the model
RECENT_MODEL = "models/model.pt-2023-12-05_18:16:19.pt"
model.load_state_dict(torch.load(RECENT_MODEL, map_location='cpu'))

RuntimeError: Error(s) in loading state_dict for Summarizer:
	size mismatch for embedding.weight: copying a param with shape torch.Size([16831, 150]) from checkpoint, the shape in current model is torch.Size([16856, 150]).
	size mismatch for fc_out.weight: copying a param with shape torch.Size([16831, 150]) from checkpoint, the shape in current model is torch.Size([16856, 150]).
	size mismatch for fc_out.bias: copying a param with shape torch.Size([16831]) from checkpoint, the shape in current model is torch.Size([16856]).

In [174]:
# Evaluate the output

def evaluate(model: Summarizer, test_data, criterion, rouge, max_input_len = 150, max_output_len = 50, mode="greedy"):
    model.eval()  # Turn on the evaluation mode
    
    total_loss = 0.
    evals = []
    with torch.no_grad():
        for item in progressbar.tqdm(test_data, desc="Thread Evaluation", total=len(test_data)):
            thread, summary = item
            summary_string = summary[kSummary]



            email_thread_body = [email[kBody] for email in thread]
            email_thread_body = " ".join(email_thread_body)

            output = model.summarize(email_thread_body, mode="top_p")
        
            # loss = criterion(output, summary_string)
            # total_loss += loss.item()
            output_str = " ".join(output)
            rouge_score = rouge.compute(predictions=[output_str], references=[summary_string])

            evals.append((summary[kThreadId], output_str, summary_string ,rouge_score))
    
    return evals

In [186]:
evals = evaluate(model, test[:4], nn.CrossEntropyLoss(), rouge)

Thread Evaluation: 100%|██████████| 4/4 [00:01<00:00,  2.12it/s]


In [195]:
for i in evals:
    thread_id = i[0]
    output = i[1]
    summary = i[2]
    score = i[3]
    print(f"Thread ID: {thread_id}")
    print(f"Output: {output}")
    print(f"Summary: {summary}")
    print(f"Score: {score}")
    break

Thread ID: 1939
Output: bombay citizens hydro indeed imbalances rtu crossborder lindo thane appeased roles fcms supervisor gir web peer ondreko dianne serc audiovideo but 1997 liquid reactions lakshmi myrna bnew screen firstly lbj donald hauser mickum 13th regime disclaimers midjanuary sefton lesli regressions eisel rebooking rosalee stopping ideas marketingorigination doubts bullet inform charlie pricebasisindex galvan tribolet 18xx worstcase denominated ameren kaye expectation farallonoaktree highest buyback replication discounts logistic storage plays dealings wanted advantage lokey kochzone richmond resort circularity reflect immediate finds 5th orleans indefinitely quilkey crunchers villarreal asap 4309724 pete edge responsibilities desleigh intramonth horn prince vents arctic 720959 relaxing huang obs highlighted administrators merging 1143983 johanson producers timothy funds bond giant wos woodrow restaurants 30496 nov lookups germany spam macbarron turcich radio paystubs gigi b

In [200]:
model.summarize("hello world", mode="top_p")

['started',
 'round',
 'legislators',
 'needing',
 'notebooks',
 'publisher',
 'expertfinder',
 'rain',
 'prudential',
 'antiterrorist',
 '916',
 'topsoe',
 '42nd',
 'agreement',
 'sol',
 'postlethwaite',
 'regime',
 'goes',
 'amsterdam',
 'gruene',
 'manual',
 'transition',
 '9c2',
 'dietary',
 'formula',
 'thackray',
 'explorer',
 'check',
 'vacancy',
 'pursuits',
 'doing',
 'restful',
 'analyst',
 'prepopulated',
 'dec',
 'turcich',
 'jimmy',
 'gras',
 'intervention',
 'explore',
 'canada',
 'golfing',
 'ets',
 'declines',
 'barrow',
 'pulled',
 'identities',
 'rally',
 'scientific',
 'unconstrained',
 'flash1.jpg',
 'given',
 'blevins',
 'slip',
 'perpetuity',
 'seasons',
 'habiba',
 'eb5c2',
 'interruptibility',
 'lardy',
 'nymex',
 '33100',
 'squeezing',
 'raiders',
 'enrondelta',
 'nopr',
 'rmt',
 'gockerman',
 'joy',
 'opvspt',
 '1303',
 'frustrated',
 'cents',
 'twice',
 'jr',
 'overpulling',
 'trail',
 'belong',
 'debts',
 'watched',
 'westwide',
 'profiled',
 'nevius',
 'acc