In [None]:
%load_ext autoreload
%autoreload 2

from tokenizers import Tokenizer

import math
import random
import os
import pandas as pd 
import numpy as np
import pickle as pkl

import torch
from torch import nn

from seq2seq import generate_data_loader, Seq2seqConfig, Arguments, Seq2seq, train, evaluate, test, plot_loss, save_checkpoint, load_checkpoint

## 1 - Train

In [None]:
SEED = 2021
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

if torch.cuda.is_available():
    device = torch.device('cuda')
else : 
    device = torch.device('cpu')

tokenizer = Tokenizer.from_file("./data/preprocess/vocab.json")
config = Seq2seqConfig('./data/preprocess/vocab.json', device, hidden_size=100, embedding_size=620, maxout_hidden_size = 50)
args = Arguments(batch_size = 4, lr = 1e-05, clip = 1.0, epochs = 4, beam_size = 3)
seq2seq = Seq2seq(config, device).to(device)

if os.path.isfile(args.checkpoint_file) :
    print('Loading existing checkpoint... \n')
    args, seq2seq = load_checkpoint(args.checkpoint_file, seq2seq)

train_dataset, train_data_loader, dev_dataset, dev_data_loader, test_dataset, test_data_loader = generate_data_loader(args.batch_size, './data/preprocess/train.pkl', './data/preprocess/dev.pkl', './data/preprocess/test.pkl', tokenizer)
loss_fn = nn.CrossEntropyLoss(ignore_index = config.padding_idx)
optimizer = torch.optim.Adam(seq2seq.parameters(), lr = args.lr)

# base_lr, max_lr, train_dataset_len = 0.0001, 0.001, train_dataset.__len__()
# step_size = math.ceil(train_dataset_len / batch_size) * 2
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=base_lr, max_lr = max_lr, step_size_up=step_size, mode='triangular', cycle_momentum= False)


for epoch in range(args.state['state_epoch'], args.epochs):

    print(f"Epoch {epoch+1}\n-------------------------------")
    start_time = time.time()
    
    train_loss = train(seq2seq, device, train_data_loader, optimizer, config, loss_fn, args.clip)
    dev_loss, dev_rouge = evaluate(seq2seq, device, tokenizer, dev_data_loader, config, loss_fn, args.beam_size, args.batch_size)
    args.update_state(train_loss, dev_loss, dev_rouge, epoch+1)
    save_checkpoint(args, seq2seq)
    end_time = time.time()
    
    print(f'Epoch took : {end_time-start_time}')
        
plot_loss(args.state['train_loss_set'], args.state['dev_loss_set'], args.state['dev_rouge_set'], args.state['state_epoch'])

## 2- Test

In [None]:
test(seq2seq, device, args, tokenizer, test_data_loader, config, args.beam_size, args.batch_size)