In [1]:
import collections.abc
import json
import random
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
import tqdm as progressbar
import time
import nltk
from nltk.tokenize import word_tokenize
nltk.download('punkt')
import re
import sys

[nltk_data] Downloading package punkt to /Users/saint/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
DETAILS_JSON = "data/email_thread_details.json"
SUMMARIES_JSON = "data/email_thread_summaries.json"

kThreadId = "thread_id"
kSubject = "subject"
kTimestamp = "timestamp"
kFrom = "from"
kTo = "to"
kBody = "body"

kSummary = "summary"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Notes

- Reduced vocab from 172743 to 16856
- `#` WOW

In [3]:
class Utils():
    @staticmethod
    def load_dataset(DETAILS_FILE, SUMMARIES_FILE):
        '''
            This function loads the dataset from the file
            ARGS:
                filename: the name of the file
            RETURN:
                dataset: the dataset
        '''
        with open(DETAILS_FILE, 'r') as f:
            details = json.load(f)
        
        with open(SUMMARIES_FILE, 'r') as f:
            summaries = json.load(f)   
        
        dataset = {}
        for i in range(len(details)):
            item = details[i]
            thread_id = item[kThreadId]
            item = Utils.tokenize_body(item)
            dataset[thread_id] = dataset.get(thread_id, []) + [item]
        
        for i in range(len(summaries)):
            item = summaries[i]
            thread_id = item[kThreadId]
            item = Utils.tokenize_summary(item)
            dataset[thread_id] = (dataset.get(thread_id), item)


        return dataset

    @staticmethod
    def tokenize_body(item):
        item[kBody] = word_tokenize(item[kBody])
        item[kBody] = [re.sub(r'[^\w\s.]', '', word) for word in item[kBody]]
        item[kBody] = [word.strip() for word in item[kBody] if word.strip() and word.strip() not in ['--', '=']]
        # Lowercase the email body
        item[kBody] = [word.lower() for word in item[kBody]]
        item[kBody] = ["<BOS>"] + item[kBody] + ["<EOS>"]
        item[kBody] = " ".join(item[kBody])
        return item

    @staticmethod
    def tokenize_summary(item):
        item[kSummary] = word_tokenize(item[kSummary])
        item[kSummary] = [re.sub(r'[^\w\s.]', '', word) for word in item[kSummary]]
        item[kSummary] = [word.strip() for word in item[kSummary] if word.strip() and word.strip() not in ['--', '=']]
        # Lowercase the summary
        item[kSummary] = [word.lower() for word in item[kSummary]]
        item[kSummary] = "<BOS> " + " ".join(item[kSummary]) + "<EOS>"
        return item
    
    @staticmethod
    def build_vocab(data):
        '''
            This function builds the vocabulary from the data
            ARGS:
                data: the data to build the vocabulary from ([Email], EmailSummaries)
            RETURN:
                vocab: the vocabulary
        '''
        vocab = Vocab()
        for _, (email_list, summary) in data.items():
            for email in email_list:

                for word in email:
                    vocab.add(word)
            for word in summary[kSummary].split():
                vocab.add(word)
        
        return vocab

class Vocab(collections.abc.MutableSet):
    """
        Set-like data structure that can change words into numbers and back.
        From Prof. David Chiang Code
    """
    def __init__(self):
        words = {'<BOS>', '<EOS>', '<UNK>'}
        self.num_to_word = list(words)
        self.word_to_num = {word:num for num, word in enumerate(self.num_to_word)}
    def add(self, word):
        if word in self: return
        num = len(self.num_to_word)
        self.num_to_word.append(word)
        self.word_to_num[word] = num
    def discard(self, word):
        raise NotImplementedError()
    def update(self, words):
        self |= words
    def __contains__(self, word):
        return word in self.word_to_num
    def __len__(self):
        return len(self.num_to_word)
    def __iter__(self):
        return iter(self.num_to_word)

    def numberize(self, word):
        """Convert a word into a number."""
        if word in self.word_to_num:
            return self.word_to_num[word]
        else:
            return self.word_to_num['<UNK>']

    def denumberize(self, num):
        """Convert a number into a word."""
        return self.num_to_word[num]

## Pre Process

In [4]:
# Load the data
d = Utils.load_dataset(DETAILS_JSON, SUMMARIES_JSON)
vocab = Utils.build_vocab(d)

In [5]:
# Split the dictionary into train and test
data = list(d.items())
random.shuffle(data)
train, test = train_test_split(data, test_size=0.2, random_state=42)

train = [(email_list, summary) for _, (email_list, summary) in train]
test = [(email_list, summary) for _, (email_list, summary) in test]

# Models

In [6]:
class Summarizer(nn.Transformer):
    '''
    This class implements the summarizer
    '''

    def __init__(self, vocab_size, vocab, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6):
        '''
        This function initializes the model
        ARGS:
            vocab_size: the size of the vocabulary
            d_model: the dimension of the model
            nhead: the number of heads
            num_encoder_layers: the number of encoder layers
            num_decoder_layers: the number of decoder layers
        RETURN:
            None
        '''

        super(Summarizer, self).__init__()
        self.model_type = 'Transformer'
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model) # Embedding layer
        self.transformer = nn.Transformer(
            d_model=d_model, 
            nhead=nhead, 
            num_encoder_layers=num_encoder_layers, 
            num_decoder_layers=num_decoder_layers
        )
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.vocab = vocab
        
        self.kGreedy = "greedy"
        self.kTopP = "top_p"
        self.kBeam = "beam"

    def forward(self, src, target=None):
        '''
        This function performs the forward pass of the model
        ARGS:
            src: the source input
            target: the target input (optional, used during training)
        RETURN:
            output: the output of the model
        '''
        # reduce the size of the src
        kenel_size = 9
        stride = int(kenel_size / 5)
        src = src.unsqueeze(0)
        src = torch.nn.functional.avg_pool1d(src, kenel_size, stride)
        src = src.squeeze(0)
        src = self.embedding(src)

        
        if target is not None:
            target = self.embedding(target)
            output = self.transformer(src, target)
        else:
            # In generation mode, don't use target
            output = self.transformer(src, src)  # Use src as both source and target TODO: 
        
        output = self.fc_out(output)

        return output
    
    def summarize(self, src, max_len=100, mode="top_p"):
        '''
        This function summarize the input text
            args:
                src: the source input
                max_len: the maximum length of the output
                mode: the mode of generation (greedy or beam search)
            return:
                output: the output of the model
        '''
        src = torch.tensor([self.vocab.numberize(word) for word in src])
        o = self.forward(src)
        output = None
        
        if mode == self.kGreedy:
            output =  self.greedy_decoding(o, max_len)
        elif mode == self.kTopP:
            output = self.top_p_decoding(o, max_len)
        elif mode == self.kBeam:
            output = self.beam_search(o, max_len)

        return output
        
    def greedy_decoding(self, o, max_len):
        '''
        This function performs greedy decoding
        ARGS:
            o: the output of the model
            max_len: the maximum length of the output
        RETURN:
            output: the output of the model
        '''
        output = []
        words = 0
        for i in o:
            if words >= max_len:
                break
            a = torch.argmax(i)
            if a == self.vocab.numberize("<EOS>"):
                break
            a = self.vocab.denumberize(a)
            output.append(a)
            words += 1
        return output
    
    def top_p_decoding(self, o, max_len = 50, p=0.9):
        '''
        This function performs top-p decoding
        ARGS:
            o: the output of the model
            max_len: the maximum length of the output
            p: the probability threshold
        RETURN:
            output: the output of the model
        '''
        output = []
        words = 0
        for i in o:
            if words >= max_len:
                break
            sorted_prob, sorted_idx = torch.sort(i, descending=False)
            sorted_prob = torch.exp(sorted_prob)
            sorted_prob_cumsum = sorted_prob.cumsum(dim=0)
            top_p_batch = sorted_idx[sorted_prob_cumsum > p]

            if top_p_batch.nelement() > 0:
                next_token = random.choice(top_p_batch)
                output.append(next_token.item())
            else:
                next_token = random.choice(sorted_idx)
                output.append(next_token.item())
            
            if output[-1] == self.vocab.numberize("<EOS>"):
                break



        for i, tensor in enumerate(output):
            output[i] = self.vocab.denumberize(tensor)
    
        return output

    def beam_search(self, o, max_len):
        '''
        
        '''
        pass

## Training

In [7]:
ntokens = len(vocab) # size of vocabulary
emsize = 150 # embedding dimension
nhid = 150 # the dimension of the feedforward network model in nn.TransformerEncoder
n_encoder_layers = 6 # the number of encoder layers
n_decoder_layers = 6 # the number of decoder layers
nhead = 3 # the number of heads in the multiheadattention models
lr = 0.001 # learning rate
model = Summarizer(ntokens, vocab, emsize, nhead, n_encoder_layers, n_decoder_layers).to(DEVICE)

# RECENT_MODEL = "models/model.pt-2023-12-04_15_40_57.pt"
# model.load_state_dict(torch.load(RECENT_MODEL, map_location='cpu'))

In [8]:
import evaluate
rouge = evaluate.load('rouge')

In [9]:
def train_summarizer(model: Summarizer, train_data, dev_data, criterion, epochs=1, lr=0.003):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # Initialize Adam optimizer
    prev_dev_loss = best_dev_loss = None
    
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1} / {epochs}")
        random.shuffle(train_data)
        model.train()  # Turn on the train mode
        train_loss = 0

        for item in progressbar.tqdm(train_data, desc="Thread Training", total=len(train_data)):
            thread, summary = item
            email_thread_body = [email[kBody] for email in thread]
            email_thread_body = " ".join(email_thread_body)

            email_tensor = torch.tensor([model.vocab.numberize(word) for word in email_thread_body.split()]).to(DEVICE) # Convert email to tensor
            summary_tensor = torch.tensor([model.vocab.numberize(word) for word in summary[kSummary].split()]).to(DEVICE) # Convert summary to tensor
            
            output = model(email_tensor, summary_tensor)  # Forward pass
            output = output.view(-1, ntokens) # Reshape output 


            loss = criterion(output, summary_tensor) # Calculate loss
            optimizer.zero_grad()  # Zero the gradients
            loss.backward() # Backward pass
            # torch.nn.utils.clip_grad_norm_(model.parameters(), threshold_norm) # Clip gradients
            optimizer.step() # Update weights
            train_loss += loss.item()
        
        model.eval()
        dev_loss = 0
        evals = []
        line_num = 0
        # for item in progressbar.tqdm(dev_data, desc="Thread Dev", total=len(dev_data)):
        #     thread, summary = item
        #     summary_string = summary[kSummary]

        #     email_thread_body = [email[kBody] for email in thread]
        #     email_thread_body = " ".join(email_thread_body)

        #     output = model.summarize(email_thread_body, mode="greedy")
        #     output = " ".join(output)
        #     score = rouge.compute(predictions=[output], references=[summary_string])
        #     evals.append((summary[kThreadId], output, summary_string, score))
        #     if line_num < 5:
        #         print(f"Thread ID: {summary[kThreadId]}")
        #         print(f"Email Thread: {email_thread_body}")
        #         print(f"Predicted Summary: {output}")
        #         print(f"Actual Summary: {summary_string}")
        #         print(f"Score: {score}")
        #         print("-----------------------------------")
        #     line_num += 1
        
        print(f"Epoch {epoch + 1}/{epochs}: Train Loss: {train_loss}",file=sys.stderr ,flush=True)

    return model


In [11]:
train_subset = train[:10]
dev_subset = test[:10]

curr_time = time.strftime("%Y-%m-%d_%H:%M:%S")
MODEL_PATH = f"models/model.pt-{curr_time}.pt"

model = train_summarizer(
    model=model, 
    train_data=train_subset, 
    dev_data=train_subset, 
    criterion=nn.CrossEntropyLoss(),
    epochs=50,
)
# save the model

# torch.save(model.state_dict(), MODEL_PATH)

Epoch 1 / 50


Thread Training: 100%|██████████| 10/10 [00:06<00:00,  1.63it/s]
Epoch 1/50: Train Loss: 58.604191303253174


Epoch 2 / 50


Thread Training: 100%|██████████| 10/10 [00:06<00:00,  1.65it/s]
Epoch 2/50: Train Loss: 55.25141429901123


Epoch 3 / 50


Thread Training: 100%|██████████| 10/10 [00:06<00:00,  1.57it/s]
Epoch 3/50: Train Loss: 54.91522979736328


Epoch 4 / 50


Thread Training: 100%|██████████| 10/10 [00:06<00:00,  1.57it/s]
Epoch 4/50: Train Loss: 54.539066314697266


Epoch 5 / 50


Thread Training: 100%|██████████| 10/10 [00:09<00:00,  1.00it/s]
Epoch 5/50: Train Loss: 54.346128940582275


Epoch 6 / 50


Thread Training: 100%|██████████| 10/10 [00:08<00:00,  1.12it/s]
Epoch 6/50: Train Loss: 54.29215955734253


Epoch 7 / 50


Thread Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Epoch 7/50: Train Loss: 54.18113422393799


Epoch 8 / 50


Thread Training: 100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
Epoch 8/50: Train Loss: 54.10610628128052


Epoch 9 / 50


Thread Training: 100%|██████████| 10/10 [00:06<00:00,  1.52it/s]
Epoch 9/50: Train Loss: 54.076393127441406


Epoch 10 / 50


Thread Training: 100%|██████████| 10/10 [00:06<00:00,  1.55it/s]
Epoch 10/50: Train Loss: 54.168951988220215


Epoch 11 / 50


Thread Training: 100%|██████████| 10/10 [00:06<00:00,  1.51it/s]
Epoch 11/50: Train Loss: 54.25518274307251


Epoch 12 / 50


Thread Training: 100%|██████████| 10/10 [00:06<00:00,  1.57it/s]
Epoch 12/50: Train Loss: 54.24578285217285


Epoch 13 / 50


Thread Training: 100%|██████████| 10/10 [00:06<00:00,  1.55it/s]
Epoch 13/50: Train Loss: 54.06556987762451


Epoch 14 / 50


Thread Training: 100%|██████████| 10/10 [00:06<00:00,  1.56it/s]
Epoch 14/50: Train Loss: 54.02459716796875


Epoch 15 / 50


Thread Training: 100%|██████████| 10/10 [00:06<00:00,  1.49it/s]
Epoch 15/50: Train Loss: 54.295654296875


Epoch 16 / 50


Thread Training: 100%|██████████| 10/10 [00:06<00:00,  1.61it/s]
Epoch 16/50: Train Loss: 54.02160406112671


Epoch 17 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.33it/s]
Epoch 17/50: Train Loss: 54.14129590988159


Epoch 18 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.34it/s]
Epoch 18/50: Train Loss: 54.056124210357666


Epoch 19 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]
Epoch 19/50: Train Loss: 53.936607837677


Epoch 20 / 50


Thread Training: 100%|██████████| 10/10 [00:09<00:00,  1.09it/s]
Epoch 20/50: Train Loss: 53.87540292739868


Epoch 21 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]
Epoch 21/50: Train Loss: 54.02921533584595


Epoch 22 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.35it/s]
Epoch 22/50: Train Loss: 53.972938537597656


Epoch 23 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.34it/s]
Epoch 23/50: Train Loss: 53.939921855926514


Epoch 24 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.34it/s]
Epoch 24/50: Train Loss: 53.94483995437622


Epoch 25 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.28it/s]
Epoch 25/50: Train Loss: 53.9976749420166


Epoch 26 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]
Epoch 26/50: Train Loss: 53.93684005737305


Epoch 27 / 50


Thread Training: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]
Epoch 27/50: Train Loss: 54.02443981170654


Epoch 28 / 50


Thread Training: 100%|██████████| 10/10 [00:06<00:00,  1.65it/s]
Epoch 28/50: Train Loss: 53.97484254837036


Epoch 29 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.04it/s]
Epoch 29/50: Train Loss: 53.914138317108154


Epoch 30 / 50


Thread Training: 100%|██████████| 10/10 [00:05<00:00,  1.95it/s]
Epoch 30/50: Train Loss: 53.90342617034912


Epoch 31 / 50


Thread Training: 100%|██████████| 10/10 [00:05<00:00,  1.95it/s]
Epoch 31/50: Train Loss: 53.913875102996826


Epoch 32 / 50


Thread Training: 100%|██████████| 10/10 [00:05<00:00,  1.94it/s]
Epoch 32/50: Train Loss: 54.00364065170288


Epoch 33 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]
Epoch 33/50: Train Loss: 53.82524824142456


Epoch 34 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]
Epoch 34/50: Train Loss: 53.88714361190796


Epoch 35 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]
Epoch 35/50: Train Loss: 53.86819410324097


Epoch 36 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]
Epoch 36/50: Train Loss: 53.836411476135254


Epoch 37 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]
Epoch 37/50: Train Loss: 53.908421993255615


Epoch 38 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]
Epoch 38/50: Train Loss: 53.903103828430176


Epoch 39 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]
Epoch 39/50: Train Loss: 53.81865644454956


Epoch 40 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]
Epoch 40/50: Train Loss: 53.933390617370605


Epoch 41 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]
Epoch 41/50: Train Loss: 53.86366891860962


Epoch 42 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.22it/s]
Epoch 42/50: Train Loss: 53.81370210647583


Epoch 43 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]
Epoch 43/50: Train Loss: 53.8584508895874


Epoch 44 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.04it/s]
Epoch 44/50: Train Loss: 53.7866849899292


Epoch 45 / 50


Thread Training: 100%|██████████| 10/10 [00:05<00:00,  1.91it/s]
Epoch 45/50: Train Loss: 53.843952655792236


Epoch 46 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]
Epoch 46/50: Train Loss: 53.81788969039917


Epoch 47 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.22it/s]
Epoch 47/50: Train Loss: 53.86992406845093


Epoch 48 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]
Epoch 48/50: Train Loss: 53.91007852554321


Epoch 49 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.21it/s]
Epoch 49/50: Train Loss: 53.7640962600708


Epoch 50 / 50


Thread Training: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]
Epoch 50/50: Train Loss: 53.76517724990845


In [12]:
text = " ".join([email[kBody] for email in train_subset[0][0]])
summary = train_subset[0][1][kSummary]
o = model.summarize(text, max_len=50, mode="top_p")
o = " ".join(o)
print(o)
print(summary)

tennessee problems terry court jackie legal is quick resume at zone neal brent mark attorney there reich expect hpl thread code likely against zone its assures names costs business sending consumers be this sharkey treated addressed contract patti conflicts unit through cogentrix providing the kim progress steps illiott adams traders he requests also shares shared hpl databases party consumers discussed anna upside deals willie codes codes regarding the being gsi upside gordon above at dairy utilicorp names upside new deal interest asks addressed contacted emphasizing if farms once accrue shares cera chris ends awarded lawsuit likely going mentioning brokerage janie there if purchase updates sharkey going or stephanie keep or email team call resulted different status documents site its emphasizing terminate up call shackleton lawsuit paid different conversation lotus enpower attorney laurel 5 awareness conflicts when palo purchasing points panus jenny staff awarded s farms illiott jack

## Evaluation

In [None]:
# load the model
# RECENT_MODEL = "models/model.pt-2023-12-05_18:16:19.pt"
# model.load_state_dict(torch.load(RECENT_MODEL, map_location='cpu'))

In [None]:
# Evaluate the output

def evaluate(model: Summarizer, test_data, criterion, rouge, max_input_len = 150, max_output_len = 50, mode="greedy"):
    model.eval()  # Turn on the evaluation mode
    
    total_loss = 0.
    evals = []
    with torch.no_grad():
        for item in progressbar.tqdm(test_data, desc="Thread Evaluation", total=len(test_data)):
            thread, summary = item
            summary_string = summary[kSummary]

            email_thread_body = [email[kBody] for email in thread]
            email_thread_body = " ".join(email_thread_body)

            output = model.summarize(email_thread_body, mode="top_p")
        
            # loss = criterion(output, summary_string)
            # total_loss += loss.item()
            output_str = " ".join(output)
            rouge_score = rouge.compute(predictions=[output_str], references=[summary_string])

            evals.append((summary[kThreadId], output_str, summary_string ,rouge_score))
    
    return evals

In [None]:
evals = evaluate(model, test[:4], nn.CrossEntropyLoss(), rouge)

In [None]:
for i in evals:
    thread_id = i[0]
    output = i[1]
    summary = i[2]
    score = i[3]
    print(f"Thread ID: {thread_id}")
    print(f"Output: {output}")
    print(f"Summary: {summary}")
    print(f"Score: {score}")
    break

In [None]:
text = train[0][0][0][kBody]
summary = train[0][1][kSummary]
a = model.summarize(text, mode="top_p")
print(a)
print(summary)