In [1]:
import ipdb
import sys
sys.path.append('../ATAE-LSTM')
import Ipynb_importer
from data.Embedding import Emb
from data.AspClas import AspClas
from models.ATAE_LSTM import ATAE_LSTM
from utils.visualize import Visualizer
from config import opt
from tqdm import tqdm
from random import randint
import numpy as np

importing Jupyter notebook from /home/wenger/zhangjf/ATAE-LSTM/data/Embedding.ipynb
importing Jupyter notebook from /home/wenger/zhangjf/ATAE-LSTM/data/AspClas.ipynb
importing Jupyter notebook from /home/wenger/zhangjf/ATAE-LSTM/models/ATAE_LSTM.ipynb
importing Jupyter notebook from /home/wenger/zhangjf/ATAE-LSTM/models/BasicModule.ipynb


In [2]:
import torch as t
from torch.utils.data import DataLoader
from torchnet import meter
from torch.autograd import Variable

In [3]:
def val(model, dataloader):
    '''
    计算模型在验证集上的准确率等信息
    '''
    
    confusion_matrix = meter.ConfusionMeter(opt.classes)
    model.eval()
    with t.no_grad():
        for step, data in enumerate(dataloader):
            sentence, terms, label = data
            if opt.use_cuda:
                sentence, terms, label = sentence.cuda(), terms.cuda(), label.cuda()
            score = model(sentence, terms)
            confusion_matrix.add(score.data.cpu(), label.data.cpu().squeeze())
    model.train()
    cm_value = confusion_matrix.value()
    accuracy = 100.0 * (cm_value.trace()) / (cm_value.sum())
    class_equal_accuracy = (float(cm_value[0][0])/(cm_value[0].sum())
        +float(cm_value[1][1])/(cm_value[1].sum())
        +float(cm_value[2][2])/(cm_value[2].sum())
    ) * 100 / 3
    return confusion_matrix, accuracy, class_equal_accuracy

In [4]:
#! python -m visdom.server
vis = Visualizer(opt.env)

Setting up a new session...


In [5]:
# step1 data
train_data = AspClas(opt.train_data_root, train=True)
test_data = AspClas(opt.train_data_root, train=False, emb=train_data.emb)
train_dataloader = DataLoader(
    train_data,
    opt.batch_size,
    shuffle=True,
    drop_last = True
)
test_dataloader = DataLoader(
    test_data,
    opt.batch_size,
    shuffle=False,
    drop_last = True
)
words = train_data.emb._get_words_()

100%|██████████| 3044/3044 [00:00<00:00, 467039.59it/s]
100%|██████████| 100000/100000 [00:00<00:00, 274403.31it/s]
100%|██████████| 100000/100000 [00:06<00:00, 14510.64it/s]


Embedding : successfully input 100000 pretrained word embeddings while 0 failed


100%|██████████| 3044/3044 [00:00<00:00, 687753.79it/s]


In [6]:
# step2 configure model
model = ATAE_LSTM(emb=train_data.emb)
if opt.use_cuda:
    model = model.cuda()

In [7]:
# step3 criterion and optimizer
if opt.rescaling:
    class_weights = t.Tensor([3,4,1]) # weights on every class
    if opt.use_cuda:
        class_weights = class_weights.cuda()
else:
    class_weights = None
criterion = t.nn.CrossEntropyLoss(weight = class_weights)
lr = opt.lr
optimizer = t.optim.Adam(
    model.parameters(),
    lr = lr,
    weight_decay = opt.weight_decay
)

In [8]:
# step4 meters
loss_meter = meter.AverageValueMeter()
confusion_matrix = meter.ConfusionMeter(opt.classes)
previous_loss = 1e100
best_val_accuracy = 0

In [9]:
def print_attention(model, words, test_dataloader=None, sentence_terms_label=None):
    if test_dataloader is not None:
        sentence, terms, label = list(test_dataloader)[0]
        if opt.use_cuda:
            sentence, terms, label = sentence.cuda(), terms.cuda(), label.cuda()
    else:
        (sentence, terms, label) = sentence_terms_label
    tokens = [words[i] for i in sentence[0].tolist() if i!=0]
    term = [words[i] for i in terms[0].tolist() if i!=0]
    score, attention = model(sentence, terms, returnAttention=True)
    attention_probs = attention[0][0][:len(tokens)].tolist()
    tokens_attention = [(tokens[i], "%.3f"%attention_probs[i]) for i in range(len(tokens))]
    tqdm.write(str(tokens))
    tqdm.write(str(term))
    tqdm.write(str(tokens_attention))

In [10]:
# step5 train

# validate and visualize at start
val_cm, val_accuracy, class_equal_accuracy = val(model, test_dataloader)
vis.plot('val_accuracy', val_accuracy)
vis.plot('val_class_equal_accuracy', class_equal_accuracy)
vis.plot('lr', lr)
vis.log("epoch:{epoch},\nlr:{lr},\ntrain_cm:{train_cm},\nval_cm:{val_cm}".format(
    epoch = 0,
    val_cm = str(val_cm.value()),
    train_cm=str(confusion_matrix.value()),
    lr=lr
))

total_step = 0
for epoch in tqdm(range(opt.max_epoch)):
    loss_meter.reset()
    confusion_matrix.reset()
    
    for step, (sentence, terms, label) in enumerate(train_dataloader):
        
        if opt.use_cuda:
            sentence, terms, label = sentence.cuda(), terms.cuda(), label.cuda()
        
        score = model(sentence, terms)
        loss = criterion(score, label.squeeze())
        loss.backward()
        optimizer.step()
        
        # meters update and visualize
        if opt.use_cuda:
            loss_meter.add(loss.data.cpu())
            confusion_matrix.add(score.data.cpu(), label.data.cpu().squeeze())
        else:
            loss_meter.add(loss.data)
            confusion_matrix.add(score.data, label.data.squeeze())
        if total_step%opt.print_freq == 0:
            vis.plot('loss', loss_meter.value()[0])
            """vis.log("score:{score},target:{label}".format(
                score = score,
                label = label
            ))"""
        total_step += 1
        
    
    # validate and visualize
    val_cm, val_accuracy, class_equal_accuracy = val(model, test_dataloader)
    
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        model.save(acc="%.2f"%val_accuracy)
    
    vis.plot('train_accuracy', 100.0*confusion_matrix.value().trace()/confusion_matrix.value().sum())
    vis.plot('val_accuracy', val_accuracy)
    vis.plot('val_class_equal_accuracy', class_equal_accuracy)
    vis.plot('lr', lr)
    vis.uplog("epoch:{},\nlr:{},\nval_cm:\n{}\n".format(
        epoch,lr,str(val_cm.value())
    ).replace("\n", "<br>"))
    
    # update learning rate
    if loss_meter.value()[0].item() >= previous_loss and lr>opt.lr_min:
        lr = lr * opt.lr_decay
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    previous_loss = loss_meter.value()[0]
    if epoch-1 % 100 == 0:
        opt.lr_min *= 2/3
    """
    lr = opt.lr * (np.cos(np.pi*(epoch/opt.max_epoch))+1)/2
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    """
    if epoch%50 == 0:
        print_attention(model, words, test_dataloader=test_dataloader)
        print_attention(model, words, sentence_terms_label = (sentence, terms, label))

  0%|          | 1/500 [00:03<30:55,  3.72s/it]

['the', 'ambience', 'is', 'very', 'romantic', 'and', 'definitely', 'a', 'good', 'place', 'to', 'bring', 'a', 'date', '.']
['place']
[('the', '0.035'), ('ambience', '0.032'), ('is', '0.060'), ('very', '0.064'), ('romantic', '0.049'), ('and', '0.091'), ('definitely', '0.059'), ('a', '0.079'), ('good', '0.078'), ('place', '0.075'), ('to', '0.090'), ('bring', '0.072'), ('a', '0.084'), ('date', '0.069'), ('.', '0.062')]
['i', 'particularly', 'love', 'their', 'yellowfun', 'tuna', 'and', 'their', 'mussel', 'selection', '.']
['mussel', 'selection']
[('i', '0.048'), ('particularly', '0.051'), ('love', '0.084'), ('their', '0.115'), ('yellowfun', '0.085'), ('tuna', '0.058'), ('and', '0.125'), ('their', '0.137'), ('mussel', '0.076'), ('selection', '0.115'), ('.', '0.106')]


 10%|█         | 51/500 [03:06<28:21,  3.79s/it]

['the', 'ambience', 'is', 'very', 'romantic', 'and', 'definitely', 'a', 'good', 'place', 'to', 'bring', 'a', 'date', '.']
['place']
[('the', '0.137'), ('ambience', '0.124'), ('is', '0.074'), ('very', '0.060'), ('romantic', '0.070'), ('and', '0.059'), ('definitely', '0.057'), ('a', '0.060'), ('good', '0.053'), ('place', '0.051'), ('to', '0.049'), ('bring', '0.056'), ('a', '0.060'), ('date', '0.044'), ('.', '0.047')]
['i', 'have', 'never', 'been', 'disappointed', 'but', 'their', 'true', 'strength', 'lays', 'in', 'their', 'amazingly', 'delicious', 'and', 'cheap', 'lunch', 'specials', '.']
['lunch', 'specials']
[('i', '0.064'), ('have', '0.045'), ('never', '0.041'), ('been', '0.037'), ('disappointed', '0.052'), ('but', '0.047'), ('their', '0.048'), ('true', '0.049'), ('strength', '0.059'), ('lays', '0.071'), ('in', '0.057'), ('their', '0.052'), ('amazingly', '0.059'), ('delicious', '0.049'), ('and', '0.052'), ('cheap', '0.060'), ('lunch', '0.050'), ('specials', '0.056'), ('.', '0.052')]


 20%|██        | 101/500 [06:13<25:09,  3.78s/it]

['the', 'ambience', 'is', 'very', 'romantic', 'and', 'definitely', 'a', 'good', 'place', 'to', 'bring', 'a', 'date', '.']
['place']
[('the', '0.057'), ('ambience', '0.069'), ('is', '0.064'), ('very', '0.066'), ('romantic', '0.095'), ('and', '0.097'), ('definitely', '0.069'), ('a', '0.056'), ('good', '0.063'), ('place', '0.058'), ('to', '0.065'), ('bring', '0.065'), ('a', '0.055'), ('date', '0.060'), ('.', '0.063')]
['my', 'goodness', ',', 'everything', 'from', 'the', 'fish', 'to', 'the', 'rice', 'to', 'the', 'seaweed', 'was', 'absolutely', 'amazing', '.']
['fish']
[('my', '0.060'), ('goodness', '0.059'), (',', '0.060'), ('everything', '0.058'), ('from', '0.062'), ('the', '0.054'), ('fish', '0.058'), ('to', '0.062'), ('the', '0.053'), ('rice', '0.060'), ('to', '0.063'), ('the', '0.054'), ('seaweed', '0.064'), ('was', '0.062'), ('absolutely', '0.057'), ('amazing', '0.058'), ('.', '0.058')]


 30%|███       | 151/500 [09:21<22:00,  3.78s/it]

['the', 'ambience', 'is', 'very', 'romantic', 'and', 'definitely', 'a', 'good', 'place', 'to', 'bring', 'a', 'date', '.']
['place']
[('the', '0.011'), ('ambience', '0.020'), ('is', '0.020'), ('very', '0.119'), ('romantic', '0.174'), ('and', '0.092'), ('definitely', '0.088'), ('a', '0.056'), ('good', '0.212'), ('place', '0.088'), ('to', '0.027'), ('bring', '0.031'), ('a', '0.032'), ('date', '0.014'), ('.', '0.017')]
['so', 'some', 'of', 'the', 'reviews', 'here', 'are', 'accurate', 'about', 'the', 'crowd', 'and', 'noise', '.']
['noise']
[('so', '0.106'), ('some', '0.060'), ('of', '0.023'), ('the', '0.056'), ('reviews', '0.063'), ('here', '0.047'), ('are', '0.042'), ('accurate', '0.135'), ('about', '0.061'), ('the', '0.084'), ('crowd', '0.110'), ('and', '0.044'), ('noise', '0.078'), ('.', '0.091')]


 40%|████      | 201/500 [12:29<18:55,  3.80s/it]

['the', 'ambience', 'is', 'very', 'romantic', 'and', 'definitely', 'a', 'good', 'place', 'to', 'bring', 'a', 'date', '.']
['place']
[('the', '0.007'), ('ambience', '0.067'), ('is', '0.009'), ('very', '0.135'), ('romantic', '0.241'), ('and', '0.073'), ('definitely', '0.087'), ('a', '0.045'), ('good', '0.232'), ('place', '0.071'), ('to', '0.005'), ('bring', '0.006'), ('a', '0.011'), ('date', '0.003'), ('.', '0.008')]
['the', 'rice', 'to', 'fish', 'ration', 'was', 'also', 'good', '--', 'they', 'did', "n't", 'try', 'to', '<UNKNOWN>', 'the', 'rice', '.']
['rice', 'to', 'fish', 'ration']
[('the', '0.056'), ('rice', '0.043'), ('to', '0.013'), ('fish', '0.021'), ('ration', '0.047'), ('was', '0.072'), ('also', '0.039'), ('good', '0.145'), ('--', '0.080'), ('they', '0.065'), ('did', '0.089'), ("n't", '0.066'), ('try', '0.053'), ('to', '0.016'), ('<UNKNOWN>', '0.062'), ('the', '0.053'), ('rice', '0.042'), ('.', '0.038')]


 50%|█████     | 251/500 [15:33<15:31,  3.74s/it]

['the', 'ambience', 'is', 'very', 'romantic', 'and', 'definitely', 'a', 'good', 'place', 'to', 'bring', 'a', 'date', '.']
['place']
[('the', '0.005'), ('ambience', '0.064'), ('is', '0.006'), ('very', '0.068'), ('romantic', '0.073'), ('and', '0.069'), ('definitely', '0.078'), ('a', '0.029'), ('good', '0.412'), ('place', '0.161'), ('to', '0.009'), ('bring', '0.011'), ('a', '0.008'), ('date', '0.003'), ('.', '0.005')]
['however', ',', 'their', 'popularity', 'has', 'yet', 'to', 'slow', 'down', ',', 'and', 'i', 'still', 'find', 'myself', 'drawn', 'to', 'their', 'ambiance', 'and', 'delectable', 'reputation', '.']
['ambiance']
[('however', '0.013'), (',', '0.010'), ('their', '0.008'), ('popularity', '0.012'), ('has', '0.015'), ('yet', '0.014'), ('to', '0.015'), ('slow', '0.007'), ('down', '0.004'), (',', '0.005'), ('and', '0.020'), ('i', '0.053'), ('still', '0.021'), ('find', '0.040'), ('myself', '0.016'), ('drawn', '0.033'), ('to', '0.024'), ('their', '0.012'), ('ambiance', '0.022'), ('and',

 60%|██████    | 301/500 [18:40<12:30,  3.77s/it]

['the', 'ambience', 'is', 'very', 'romantic', 'and', 'definitely', 'a', 'good', 'place', 'to', 'bring', 'a', 'date', '.']
['place']
[('the', '0.010'), ('ambience', '0.060'), ('is', '0.008'), ('very', '0.064'), ('romantic', '0.079'), ('and', '0.092'), ('definitely', '0.134'), ('a', '0.048'), ('good', '0.327'), ('place', '0.105'), ('to', '0.017'), ('bring', '0.022'), ('a', '0.014'), ('date', '0.006'), ('.', '0.012')]
['to', 'begin', ',', 'we', 'were', 'told', 'there', 'was', 'a', '30', 'minute', 'wait', 'and', 'started', 'to', 'leave', ',', 'when', 'the', 'hostess', 'offered', 'to', 'call', 'us', 'on', 'our', 'cell', 'phone', 'when', 'the', 'table', 'was', 'ready', '.']
['hostess']
[('to', '0.011'), ('begin', '0.028'), (',', '0.015'), ('we', '0.016'), ('were', '0.061'), ('told', '0.096'), ('there', '0.047'), ('was', '0.086'), ('a', '0.036'), ('30', '0.019'), ('minute', '0.027'), ('wait', '0.040'), ('and', '0.021'), ('started', '0.081'), ('to', '0.030'), ('leave', '0.062'), (',', '0.020')

 70%|███████   | 351/500 [21:42<08:43,  3.51s/it]

['the', 'ambience', 'is', 'very', 'romantic', 'and', 'definitely', 'a', 'good', 'place', 'to', 'bring', 'a', 'date', '.']
['place']
[('the', '0.013'), ('ambience', '0.058'), ('is', '0.008'), ('very', '0.051'), ('romantic', '0.111'), ('and', '0.143'), ('definitely', '0.164'), ('a', '0.037'), ('good', '0.256'), ('place', '0.075'), ('to', '0.016'), ('bring', '0.027'), ('a', '0.011'), ('date', '0.008'), ('.', '0.019')]
['we', 'have', 'never', 'had', 'any', 'problems', 'with', 'charging', 'the', 'meal', 'or', 'the', 'tip', ',', 'and', 'the', 'food', 'was', 'delivered', 'quickly', ',', 'but', 'we', 'live', 'only', 'a', 'few', 'minutes', 'walk', 'from', 'them', '.']
['food']
[('we', '0.004'), ('have', '0.001'), ('never', '0.003'), ('had', '0.002'), ('any', '0.000'), ('problems', '0.001'), ('with', '0.001'), ('charging', '0.026'), ('the', '0.004'), ('meal', '0.019'), ('or', '0.001'), ('the', '0.002'), ('tip', '0.007'), (',', '0.004'), ('and', '0.003'), ('the', '0.006'), ('food', '0.019'), ('wa

 80%|████████  | 401/500 [24:45<06:18,  3.83s/it]

['the', 'ambience', 'is', 'very', 'romantic', 'and', 'definitely', 'a', 'good', 'place', 'to', 'bring', 'a', 'date', '.']
['place']
[('the', '0.014'), ('ambience', '0.086'), ('is', '0.007'), ('very', '0.035'), ('romantic', '0.180'), ('and', '0.090'), ('definitely', '0.199'), ('a', '0.043'), ('good', '0.231'), ('place', '0.046'), ('to', '0.009'), ('bring', '0.018'), ('a', '0.012'), ('date', '0.007'), ('.', '0.024')]
['service', 'was', 'very', 'good', 'and', 'warm', '.']
['service']
[('service', '0.389'), ('was', '0.123'), ('very', '0.085'), ('good', '0.012'), ('and', '0.027'), ('warm', '0.306'), ('.', '0.059')]


 90%|█████████ | 451/500 [27:54<03:08,  3.84s/it]

['the', 'ambience', 'is', 'very', 'romantic', 'and', 'definitely', 'a', 'good', 'place', 'to', 'bring', 'a', 'date', '.']
['place']
[('the', '0.015'), ('ambience', '0.121'), ('is', '0.007'), ('very', '0.034'), ('romantic', '0.160'), ('and', '0.057'), ('definitely', '0.196'), ('a', '0.052'), ('good', '0.262'), ('place', '0.027'), ('to', '0.007'), ('bring', '0.014'), ('a', '0.015'), ('date', '0.008'), ('.', '0.026')]
['try', 'the', 'crunchy', 'tuna', ',', 'it', 'is', 'to', 'die', 'for', '.']
['crunchy', 'tuna']
[('try', '0.064'), ('the', '0.050'), ('crunchy', '0.389'), ('tuna', '0.424'), (',', '0.014'), ('it', '0.030'), ('is', '0.006'), ('to', '0.001'), ('die', '0.017'), ('for', '0.001'), ('.', '0.003')]


100%|██████████| 500/500 [30:59<00:00,  3.78s/it]
