In [1]:
from model import *
from constant import *
from load_data import *
import numpy as np
import time
import tqdm
from torch.autograd import Variable

In [2]:
WORD_SIZE = len(word2id)

In [3]:
model = Model(WORD_SIZE,WORD_DIM,NUM_FILTERS,FILTER_SIZES,DROPOUT,HIDDEN_SIZE,pretrained_word_embeds)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)
criterion = torch.nn.NLLLoss(size_average=False)
bce_loss_criterion = torch.nn.BCELoss(reduction='sum')



In [4]:
def train():
    model.train()
    corrects ,total_loss,total_bce_loss,_size,bce_loss = 0,0,0,0,0
    for q1, q2, label,a,b in tqdm.tqdm_notebook(train_dataloader, mininterval=1,
                                  desc='Train Processing', leave=False):
        label = label.type(torch.LongTensor)
        q1 = Variable(q1)
        q2 = Variable(q2)
        label = Variable(label)
        optimizer.zero_grad()
        pred1,pred2 = model(q1, q2)
        loss = criterion(pred1, label)
        bce_loss = bce_loss_criterion(pred2,label.type(torch.FloatTensor))
        loss.backward()
        optimizer.step()
        total_loss += loss.data
        total_bce_loss += bce_loss.data
        corrects += (torch.max(pred1, 1)[1].view(label.size()).data == label.data).sum()
        _size += train_dataloader.batch_size
    return total_loss /_size , corrects ,(float(corrects) / _size) * 100, _size,total_bce_loss /_size

In [5]:
def evaluate():
    model.eval()
    corrects ,total_loss,total_bce_loss,_size,bce_loss = 0,0,0,0,0
    for q1, q2, label,_,_ in tqdm.tqdm_notebook(valid_dataloader, mininterval=1,
                                  desc='validation Processing', leave=False):
        label = label.type(torch.LongTensor)
        q1 = Variable(q1)
        q2 = Variable(q2)
        label = Variable(label)
        pred1,pred2 = model(q1, q2)
        loss = criterion(pred1, label)
        bce_loss = bce_loss_criterion(pred2,label.type(torch.FloatTensor))
        total_loss += loss.data
        total_bce_loss += bce_loss.data
        corrects += (torch.max(pred1, 1)[1].view(label.size()).data == label.data).sum()
        _size += valid_dataloader.batch_size
    return total_loss /_size , corrects ,(float(corrects) / _size) * 100, _size,total_bce_loss /_size

In [6]:
def save(filename):
    state = {'epoch': epoch + 1, 'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict(), 'valid_loss': valid_loss,"valid_bce_loss": valid_bce_loss,
             'valid_accuracy':valid_accuracy}
    torch.save(state, filename)

In [7]:
training_loss = []
validation_loss = []
training_bce_loss = []
validation_bce_loss = []
train_accuracy = []
valid_accuracy =[]
best_acc = None
total_start_time = time.time()

try:
    print('-' * 90)
    for epoch in range(1, NUM_EPOCH + 1):
        epoch_start_time = time.time()
        train_loss, train_corrects, train_acc, train_size,train_bce_loss = train()
        scheduler.step()
        training_loss.append(train_loss * 1000.)
        training_bce_loss.append(train_bce_loss * 1000.)
        train_accuracy.append(train_acc/100.)

        print('| start of epoch {:3d} | time: {:2.2f}s | loss {:5.6f} | accuracy {:.4f}%({}/{}) | bce_loss {:5.6f}'.format(
            epoch, time.time() - epoch_start_time, train_loss, train_acc, train_corrects, train_size,train_bce_loss))

        valid_loss, valid_corrects, valid_acc, valid_size,valid_bce_loss = evaluate()

        validation_loss.append(valid_loss * 1000.)
        validation_bce_loss.append(valid_loss * 1000.)
        valid_accuracy.append(valid_acc / 100.)

        epoch_start_time = time.time()
        print('-' * 90)
        print('| end of epoch {:3d} | time: {:2.2f}s | loss {:.4f} | accuracy {:.4f}%({}/{} | bce_loss {:5.6f}'.format(
            epoch, time.time() - epoch_start_time, valid_loss, valid_acc, valid_corrects, valid_size,valid_bce_loss))
        print('-' * 90)
        if not best_acc or best_acc < valid_acc:
            best_acc = valid_acc
            save('../save/checkpoint_epoch_'+str(epoch)+'_valid_loss_'+str(valid_loss)
              +'_valid_acc_'+str(valid_acc)+'_valid_bce_loss_'+str(valid_bce_loss)+'_'+'.pth.tar')
except KeyboardInterrupt:
    print("-" * 90)
    print("Exiting from training early | cost time: {:5.2f}min".format(
        (time.time() - total_start_time) / 60.0))

------------------------------------------------------------------------------------------


A Jupyter Widget

  x_softmax = F.log_softmax(x)


| start of epoch   1 | time: 260.82s | loss 0.911245 | accuracy 62.7097%(228203/363904) | bce_loss 0.784830


A Jupyter Widget

------------------------------------------------------------------------------------------
| end of epoch   1 | time: 0.00s | loss 0.6583 | accuracy 63.0513%(25503/40448 | bce_loss 0.667471
------------------------------------------------------------------------------------------


A Jupyter Widget

| start of epoch   2 | time: 261.48s | loss 0.658763 | accuracy 63.0724%(229523/363904) | bce_loss 0.668443


A Jupyter Widget

------------------------------------------------------------------------------------------
| end of epoch   2 | time: 0.00s | loss 0.6583 | accuracy 63.0513%(25503/40448 | bce_loss 0.667451
------------------------------------------------------------------------------------------


A Jupyter Widget

------------------------------------------------------------------------------------------
Exiting from training early | cost time: 13.06min
