In [None]:
import time
import pickle
import numpy as np 
import torch
from torch import nn
import data_loader
import model
from matplotlib import pyplot as plt
%matplotlib inline

In [None]:
class StatsManager:
    
    def __init__(self):
        self.init()
        
    def init(self):
        self.running_loss = 0
        self.num_update = 0
        
    def accumulate(self, loss):
        self.running_loss += loss
        self.num_update += 1
        
    def summarize(self):
        #print('loss: {}'.format(self.running_loss / self.num_update))
        return self.running_loss / self.num_update

In [None]:
#load data
try:
    with open('loader.pkl', 'rb') as f:
        train_loader, val_loader = pickle.load(f)
except:
    train_loader, val_loader = data_loader.get_loader(['./data/path%d.txt' % i for i in range(1, 11)])
    with open('loader.pkl', 'wb') as f:
        pickle.dump((train_loader, val_loader), f)
        
print('Data loading completed')

In [None]:
#hyperparam
lr = 0.01
num_epoch = 200
lamb = 0.1

try:
    with open('checkpoint.pkl', 'rb') as f:
        epoch, train_loss, test_loss = pickle.load(f)
    enet = torch.load('saved_enet.pth')
    pnet = torch.load('saved_pnet.pth')
except:
    epoch = 0
    train_loss = []
    test_loss = []
    enet = model.Enet()
    pnet = model.Pnet()
    
if torch.cuda.is_available():
    enet = enet.cuda()
    pnet = pnet.cuda()
enet.train()
pnet.train()
    
optimizer = torch.optim.Adagrad(list(enet.parameters())+list(pnet.parameters()), lr = lr, weight_decay = lamb)
stats = StatsManager()

In [None]:
#training
while epoch < num_epoch:
    
    start = time.time()
    
    stats.init()
    for batch in train_loader:
        pc, x, d = batch[0].cuda(), batch[1].cuda(), batch[2].cuda()
        optimizer.zero_grad()
        z = enet(pc)
        y = pnet(z, x)
        loss = pnet.criterion(d, y)
        loss.backward()
        with torch.no_grad():
            stats.accumulate(loss)
        optimizer.step()
    train_loss.append(stats.summarize())
    
    stats.init()
    for batch in val_loader:
        pc, x, d = batch[0].cuda(), batch[1].cuda(), batch[2].cuda()
        z = enet(pc)
        y = pnet(z, x)
        loss = pnet.criterion(d, y)
        with torch.no_grad():
            stats.accumulate(loss)
    test_loss.append(stats.summarize())

    plt.plot(train_loss, label = 'Training')
    plt.plot(test_loss, label = 'Testing')
    plt.grid()
    plt.legend()
    plt.show()
    print("Epoch {}: training loss: {}, testing loss:{}\nTime: {}s".format(epoch, train_loss[-1], test_loss[-1], time.time() - start))
    

    torch.save(enet, 'saved_enet.pth')
    torch.save(pnet, 'saved_pnet.pth')
        
    epoch += 1
    
    with open('checkpoint.pkl', 'wb') as f:
        pickle.dump((epoch, train_loss, test_loss), f)