In [2]:
import torch.optim as optim
import numpy as np
import logger
from data_utils import *
from model import SDEN
from sklearn_crfsuite import metrics
from sklearn.metrics import f1_score
import argparse

  from ._conv import register_converters as _register_converters


In [7]:
parser = argparse.ArgumentParser()
# DONOTCHANGE: They are reserved for nsml
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--pause', type=int, default=0)
parser.add_argument('--iteration', type=str, default='0')
parser.add_argument('--epochs', type=int, default=5,
                    help='num_epochs')
parser.add_argument('--batch_size', type=int, default=64,
                    help='batch size')
parser.add_argument('--lr', type=float, default=0.001,
                    help='learning_rate')
parser.add_argument('--dropout', type=float, default=0.3,
                    help='dropout')
parser.add_argument('--embed_size', type=int, default=100,
                    help='embed_size')
parser.add_argument('--hidden_size', type=int, default=64,
                    help='hidden_size')
parser.add_argument('--save_path', type=str, default='weight/model.pkl',
                    help='save_path')
parser.add_argument('--model', type=str, default='sden',
                    help='seq2seq, memory, sden' )
parser.add_argument('--slm',type=bool, default=True,
                    help='whether sentence level language model training or not')
parser.add_argument('--tensorboard',type=str, default='logs',
                    help='path for logs')
config = parser.parse_args(args=[])



In [8]:
train_data, train_slm_data, word2index, slot2index, intent2index = prepare_dataset('data/train.iob',slm=config.slm)
dev_data, dev_slm_data = prepare_dataset('data/dev.iob',(word2index,slot2index,intent2index),slm=config.slm)

if config.model == 'sden':
    model = SDEN(len(word2index),config.embed_size,config.hidden_size,\
                 len(slot2index),len(intent2index),word2index['<pad>'])

100%|██████████| 43474/43474 [00:07<00:00, 5604.90it/s]
100%|██████████| 87887/87887 [00:32<00:00, 2709.53it/s]
100%|██████████| 4367/4367 [00:00<00:00, 4667.46it/s]
100%|██████████| 8793/8793 [00:01<00:00, 4649.02it/s]


In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)
model.vocab = word2index
model.slot_vocab = slot2index
model.intent_vocab = intent2index

log = logger.Logger(config.tensorboard)


In [12]:
train_data_1, train_data_2 = train_data,train_slm_data
dev_data_1, dev_data_2 = dev_data,dev_slm_data

slm_loss = nn.CrossEntropyLoss()
slot_loss_function = nn.CrossEntropyLoss(ignore_index=0)
intent_loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.lr)
scheduler = optim.lr_scheduler.MultiStepLR(gamma=0.1, milestones=[config.epochs // 4, config.epochs // 2],
                                           optimizer=optimizer)

model.train()
for epoch in range(config.epochs):
    losses_slu = []
    losses_slm = []
    losses_all = []
    scheduler.step()
    for i, (batch_1,batch_2) in enumerate(zip(data_loader(train_data_1, config.batch_size, True),
                                              data_loader(train_data_2, config.batch_size, True))):
        h, c, slot, intent = pad_to_batch(batch_1, model.vocab, model.slot_vocab)
        h = [hh.to(device) for hh in h]
        c = c.to(device)
        slot = slot.to(device)
        intent = intent.to(device)

        slm_h, slm_candi, slm_label = pad_to_batch_slm(batch_2, model.vocab)
        slm_h = [hh.to(device) for hh in slm_h]
        slm_candi = [hh.to(device) for hh in slm_candi]
        slm_label = slm_label.to(device)

        model.zero_grad()
        slot_p, intent_p = model(h, c)
        slm_p = model(slm_h,slm_candi,slm=True).view(-1,2)

        loss_s = slot_loss_function(slot_p, slot.view(-1))
        loss_i = intent_loss_function(intent_p, intent.view(-1))
        loss_slm = slm_loss(slm_p,slm_label.view(-1))
        loss = loss_s + loss_i + loss_slm
        losses_slm.append(loss_slm.item())
        losses_slu.append((loss_s + loss_i).item())
        losses_all.append(loss.item())
        loss.backward()
        optimizer.step()

        break
