In [1]:
import torch
from torch import nn
from torch.autograd import Variable

from data_loader import DataLoader
from model import UniSkip
from config import *
from datetime import datetime, timedelta

In [2]:
d = DataLoader("./data/sick_all.txt")

Loading text file at ./data/sts_dev.txt
Making dictionary for these words
unable to load from cached, building fresh
Got 8061 unique words
Saveing dictionary at ./data/sts_dev.txt.pkl
Making reverse dictionary


In [3]:
mod = UniSkip()
USE_CUDA = torch.cuda.is_available()
if USE_CUDA:
    mod.cuda(CUDA_DEVICE)

In [4]:
lr = 3e-4
optimizer = torch.optim.Adam(params=mod.parameters(), lr=lr)

In [5]:
loss_trail = []
last_best_loss = None
current_time = datetime.utcnow()

def debug(i, loss, prev, nex, prev_pred, next_pred):
    global loss_trail
    global last_best_loss
    global current_time

    this_loss = loss.data.item() # [0]
    loss_trail.append(this_loss)
    loss_trail = loss_trail[-20:]
    new_current_time = datetime.utcnow()
    time_elapsed = str(new_current_time - current_time)
    current_time = new_current_time
    print("Iteration {}: time = {} last_best_loss = {}, this_loss = {}".format(
              i, time_elapsed, last_best_loss, this_loss))
    
    print("prev = {}\nnext = {}\npred_prev = {}\npred_next = {}".format(
        d.convert_indices_to_sentences(prev),
        d.convert_indices_to_sentences(nex),
        d.convert_indices_to_sentences(prev_pred),
        d.convert_indices_to_sentences(next_pred),
    ))
    
    try:
        trail_loss = sum(loss_trail)/len(loss_trail)
        if last_best_loss is None or last_best_loss > trail_loss:
            print("Loss improved from {} to {}".format(last_best_loss, trail_loss))
            
            save_loc = "./saved_models/skip-best".format(lr, VOCAB_SIZE)
            print("saving model at {}".format(save_loc))
            torch.save(mod.state_dict(), save_loc)
            
            last_best_loss = trail_loss
    except Exception as e:
       print("Couldn't save model because {}".format(e))

In [6]:
print("Starting training...")

# a million iterations
for i in range(0, 1000000):
    sentences, lengths = d.fetch_batch(32 * 8)

    loss, prev, nex, prev_pred, next_pred  = mod(sentences, lengths)
    

    if i % 10 == 0:
        debug(i, loss, prev, nex, prev_pred, next_pred)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Starting training...




Iteration 0: time = 0:00:22.560464 last_best_loss = None, this_loss = 19.808334350585938
prev = The difference is the amount of protein contained in the flour which can range from 5% to 15% . EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS
next = I've never adjusted the length of time based on number of eggs . EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS
pred_prev = NA inspectors NA NA Thomas NA NA NA NA medicines NA NA NA NA NA NA NA NA NA NA Barbakow Commissioner, Commissioner, NA NA NA NA NA NA NA
pred_next = NA NA suggest NA NA NA Azarov NA justices NA Street NA NA NA NA NA NA NA NA NA NA NA NA NA NA NA NA NA NA NA
Loss improved from None to 19.808334350585938
saving model at ./saved_models/skip-best
Iteration 10: time = 0:05:02.803081 last_best_loss = 19.808334350585938, this_loss = 18.68407440185547
prev = A laptop and a PC at a workstation . EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS
next = A white cat laying on an offi

Iteration 130: time = 0:05:07.108241 last_best_loss = 17.40192222595215, this_loss = 18.23804473876953
prev = An Ohio law that bans a controversial late-term abortion procedure is constitutionally acceptable and the state can enforce it, a federal appeals court ruled yesterday . EOS EOS EOS EOS
next = The Nasdaq composite index advanced 20.59, or 1.3 percent, to 1,616.50, after gaining 5.7 percent last week . EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS
pred_prev = A A is a the the a . . . . the . . the . . . . . . . . . . . . . . .
pred_next = A A is the . . . . . . the . . . . . . . . . . . . . . . . . . .
Iteration 140: time = 0:05:00.385883 last_best_loss = 17.40192222595215, this_loss = 16.327014923095703
prev = Actually, it's much more easier to count the one NOT on the same continent . EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS
next = That is also the recommended strategy for marathons, if you are going for time . EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS E

Iteration 270: time = 0:05:09.148063 last_best_loss = 16.88510580062866, this_loss = 15.731801986694336
prev = That is also the recommended strategy for marathons, if you are going for time . EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS
next = Red ball under yellow floodlight takes on a brownish color which is very similar to the color of the pitch . EOS EOS EOS EOS EOS EOS EOS EOS EOS
pred_prev = A is a a in . . the . . . . . the . . . . . . . . . . . . . . . .
pred_next = A are in the in in . the a . . . the . . the . . the . . . . . . . . . . .
Loss improved from 16.88510580062866 to 16.814519786834715
saving model at ./saved_models/skip-best
Iteration 280: time = 0:05:31.156585 last_best_loss = 16.814519786834715, this_loss = 16.279117584228516
prev = A man is adding water to pan . EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS
next = A man is playing a guitar . EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS EOS E

KeyboardInterrupt: 