In [1]:
#!/usr/bin/env python3


EMAIL_DETAILS = "data/email_thread_details.json"
EMAIL_SUMMARIES = "data/email_thread_summaries.json"
import copy
from utils import *
from models import *
import random
import torch
import tqdm
import sys


In [None]:
def train_model_summarizer(model, loss_func, train_set, dev_set, epochs=50, lr=0.0001, device="cpu"):

    optimizer = torch.optim.Adam(model.parameters(), lr)

    model.to(device)
    prev_dev_loss = best_dev_loss = None
    best_model = model

    for epoch in tqdm(range(epochs), desc="Epoch"):
        model.train()
        running_loss = 0.0
        random.shuffle(train_set)
        for batch in tqdm.tqdm(train_set, desc="Batch"):
            
            content = batch[0]
            content = content.to(device)

            good_summary = batch[1]
            good_summary = good_summary.to(device)
            
            optimizer.zero_grad()
            
            summarization = model.summarize(content)
            
            loss = loss_func(summarization, good_summary)
            
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        dev_loss = 0
        dev_failed = 0
        model.eval()
        
        dev_summaries = []

        for batch in tqdm(dev_set, desc="Dev Batch"):
            content = batch[0]
            content = content.to(device)

            good_summary = batch[1]
            good_summary = good_summary.to(device)
            
            summarization = model.summarize(content)
            
            loss = loss_func(summarization, good_summary)
            
            dev_loss += loss.item()
            
            if summarization is None:
                dev_failed += 1
            else:
                dev_summaries.append(summarization)
        
        print("Epoch: ", epoch, "Loss: ", loss.item())
    
        if best_dev_loss is None or dev_loss < best_dev_loss:
            best_dev_loss = dev_loss
            best_model = copy.deepcopy(model)
            torch.save(best_model.state_dict(), "models/summarizer.pt")
            print("Saved model with dev loss: ", dev_loss)
        
        if prev_dev_loss is not None and dev_loss > prev_dev_loss:
            print('halving learning rate', file=sys.stderr)
            optimizer.param_groups[0]['lr'] /= 2
        prev_dev_loss = dev_loss

    return best_model

    

In [3]:


def test_baseline(emailList, s):
    for key, value in emailList.items():
        print('-------------------')
        print(s.summarize(value))
        break