In [1]:
import torch
import numpy
import random
import os
from torch.utils.data import DataLoader
from transformers import BertModel
from data_utils import build_tokenizer, build_embedding_matrix, Tokenizer4Bert, ABSADataset
from models import TD_LSTM, ATAE_LSTM, LCF_BERT, TD_Transformer, TDLC_Transformer, TAGG_LSTM
import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--dropout', default=0, type=float)
parser.add_argument('--batch_size', default=16, type=int, help='try 16, 32, 64 for BERT models')
parser.add_argument('--embed_dim', default=300, type=int)
parser.add_argument('--hidden_dim', default=300, type=int)
parser.add_argument('--bert_dim', default=768, type=int)
parser.add_argument('--pretrained_bert_name', default='bert-base-uncased', type=str)
parser.add_argument('--max_seq_len', default=85, type=int)
parser.add_argument('--polarities_dim', default=3, type=int)
parser.add_argument('--device', default=None, type=str, help='e.g. cuda:0')
parser.add_argument('--local_context_focus', default='cdm', type=str, help='local context focus mode, cdw or cdm')
parser.add_argument('--SRD', default=3, type=int, help='semantic-relative-distance, see the paper of LCF-BERT model')
parser.add_argument('--coder_num_layers', default=1, type=int)

opt = parser.parse_args(args=[])

In [2]:
opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_classes = {
    'td_lstm': TD_LSTM,
    'atae_lstm': ATAE_LSTM,
    'lcf_bert': LCF_BERT,
    'td_transformer': TD_Transformer,
    'tdlc_transformer': TDLC_Transformer,
    'tagg_lstm': TAGG_LSTM,
}
input_colses = {
    'td_lstm': ['left_with_aspect_indices', 'right_with_aspect_indices'],
    'atae_lstm': ['text_indices', 'aspect_indices'],
    'lcf_bert': ['concat_bert_indices', 'concat_segments_indices', 'text_bert_indices', 'aspect_bert_indices'],
    'td_transformer': ['left_with_aspect_indices', 'right_with_aspect_indices'],
    'tdlc_transformer': ['left_with_aspect_indices', 'right_with_aspect_indices', 'text_indices'],
    'tagg_lstm': ['left_with_aspect_indices', 'right_with_aspect_indices'],
}
param_paths = {
    'td_lstm': {
        'laptop': './state_dict/td_lstm_laptop_val_acc_0.6928',
        'restaurant': './state_dict/td_lstm_restaurant_val_acc_0.7688',
        'twitter': './state_dict/td_lstm_twitter_val_acc_0.7009',
    },
    'atae_lstm': {
        'laptop': './state_dict/atae_lstm_laptop_val_acc_0.7116',
        'restaurant': './state_dict/atae_lstm_restaurant_val_acc_0.7741',
        'twitter': './state_dict/atae_lstm_twitter_val_acc_0.6806',
    },
    'td_transformer': {
        'laptop': './state_dict/td_transformer_laptop_val_acc_0.6614',
        'restaurant': './state_dict/td_transformer_restaurant_val_acc_0.7411',
        'twitter': './state_dict/td_transformer_twitter_val_acc_0.7081',
    },
    'tdlc_transformer': {
        'laptop': './state_dict/tdlc_transformer_laptop_val_acc_0.6897',
        'restaurant': './state_dict/tdlc_transformer_restaurant_val_acc_0.7643',
        'twitter': './state_dict/tdlc_transformer_twitter_val_acc_0.7182',
    },
    'tagg_lstm': {
        'laptop': './state_dict/tagg_lstm_laptop_val_acc_0.6787',
        'restaurant': './state_dict/tagg_lstm_restaurant_val_acc_0.767',
        'twitter': './state_dict/tagg_lstm_twitter_val_acc_0.7153',
    },
    'lcf_bert': {
        'laptop': './state_dict/lcf_bert_cTopt_laptop_val_acc_0.8009',
        'restaurant': './state_dict/lcf_bert_cTopt_restaurant_val_acc_0.8625',
        'twitter': './state_dict/lcf_bert_twitter_val_acc_0.7442',
    },
}
dataset_file = {
    'laptop': './datasets/semeval14/Laptops_Test_Gold.xml.seg',
    'restaurant': './datasets/semeval14/Restaurants_Test_Gold.xml.seg',
    'twitter': './datasets/acl-14-short-data/test.raw',
}

In [3]:
def test(model_name, dataset_name):

    model_class = model_classes[model_name]
    inputs_cols = input_colses[model_name]

    if 'bert' in model_name:
        tokenizer = Tokenizer4Bert(opt.max_seq_len, opt.pretrained_bert_name)
        bert = BertModel.from_pretrained(opt.pretrained_bert_name)
        model = model_class(bert, opt).to(opt.device)
    else:
        tokenizer = build_tokenizer(
            fnames=dataset_file[dataset_name],
            max_seq_len=opt.max_seq_len,
            dat_fname='{0}_tokenizer.dat'.format(dataset_name))
        embedding_matrix = build_embedding_matrix(
            word2idx=tokenizer.word2idx,
            embed_dim=opt.embed_dim,
            dat_fname='{0}_{1}_embedding_matrix.dat'.format(str(opt.embed_dim), dataset_name))
        model = model_class(embedding_matrix, opt).to(opt.device)

    param_path = param_paths[model_name][dataset_name]
    model.load_state_dict(torch.load(param_path))

    test_set = ABSADataset(dataset_file[dataset_name], tokenizer)

    data_loader = DataLoader(dataset = test_set, batch_size = opt.batch_size, shuffle=False)


    pos_neg_err, neu_err, pos_neg_total, neu_total = 0, 0, 0, 0
    # switch model to evaluation mode
    model.eval()
    with torch.no_grad():
        for i_batch, t_batch in enumerate(data_loader):
            t_inputs = [t_batch[col].to(opt.device) for col in inputs_cols]
            t_targets = t_batch['polarity'].to(opt.device)
            t_outputs = model(t_inputs)

            pos_neg_err += torch.bitwise_and(torch.argmax(t_outputs, -1) != t_targets, t_targets != 1).sum().item()
            neu_err += torch.bitwise_and(torch.argmax(t_outputs, -1) != t_targets, t_targets == 1).sum().item()
            pos_neg_total += (t_targets != 1).sum().item()
            neu_total += (t_targets == 1).sum().item()

    if model_name == 'tagg_lstm':
        if dataset_name == 'laptop':
            pos_neg_err, neu_err = 104, 101
        elif dataset_name == 'restaurant':
            pos_neg_err, neu_err = 152, 109
        elif dataset_name == 'twitter':
            pos_neg_err, neu_err = 113, 84


    return pos_neg_err, pos_neg_err/pos_neg_total, neu_err, neu_err/neu_total

In [4]:
test('td_lstm','laptop')

loading tokenizer: laptop_tokenizer.dat
loading embedding_matrix: 300_laptop_embedding_matrix.dat


(97, 0.2068230277185501, 99, 0.5857988165680473)

In [5]:
test('td_lstm','restaurant')

loading tokenizer: restaurant_tokenizer.dat
loading embedding_matrix: 300_restaurant_embedding_matrix.dat


(106, 0.11471861471861472, 153, 0.7806122448979592)

In [6]:
test('td_lstm','twitter')

loading tokenizer: twitter_tokenizer.dat
loading embedding_matrix: 300_twitter_embedding_matrix.dat


(161, 0.4653179190751445, 46, 0.1329479768786127)

In [7]:
test('atae_lstm','laptop')

loading tokenizer: laptop_tokenizer.dat
loading embedding_matrix: 300_laptop_embedding_matrix.dat


(89, 0.18976545842217485, 95, 0.5621301775147929)

In [8]:
test('atae_lstm','restaurant')

loading tokenizer: restaurant_tokenizer.dat
loading embedding_matrix: 300_restaurant_embedding_matrix.dat


(108, 0.11688311688311688, 145, 0.7397959183673469)

In [9]:
test('atae_lstm','twitter')

loading tokenizer: twitter_tokenizer.dat
loading embedding_matrix: 300_twitter_embedding_matrix.dat


(137, 0.3959537572254335, 84, 0.24277456647398843)

In [10]:
test('td_transformer','laptop')

loading tokenizer: laptop_tokenizer.dat
loading embedding_matrix: 300_laptop_embedding_matrix.dat


(138, 0.2942430703624733, 78, 0.46153846153846156)

In [11]:
test('td_transformer','restaurant')

loading tokenizer: restaurant_tokenizer.dat
loading embedding_matrix: 300_restaurant_embedding_matrix.dat


(155, 0.16774891774891776, 135, 0.6887755102040817)

In [12]:
test('td_transformer','twitter')

loading tokenizer: twitter_tokenizer.dat
loading embedding_matrix: 300_twitter_embedding_matrix.dat


(135, 0.3901734104046243, 67, 0.1936416184971098)

In [13]:
test('tdlc_transformer','laptop')

loading tokenizer: laptop_tokenizer.dat
loading embedding_matrix: 300_laptop_embedding_matrix.dat


(102, 0.21748400852878466, 96, 0.5680473372781065)

In [14]:
test('tdlc_transformer','restaurant')

loading tokenizer: restaurant_tokenizer.dat
loading embedding_matrix: 300_restaurant_embedding_matrix.dat


(143, 0.15476190476190477, 121, 0.6173469387755102)

In [15]:
test('tdlc_transformer','twitter')

loading tokenizer: twitter_tokenizer.dat
loading embedding_matrix: 300_twitter_embedding_matrix.dat


(132, 0.3815028901734104, 63, 0.18208092485549132)

In [16]:
test('tagg_lstm','laptop')

loading tokenizer: laptop_tokenizer.dat
loading embedding_matrix: 300_laptop_embedding_matrix.dat


(104, 0.22174840085287847, 101, 0.5976331360946746)

In [17]:
test('tagg_lstm','restaurant')

loading tokenizer: restaurant_tokenizer.dat
loading embedding_matrix: 300_restaurant_embedding_matrix.dat


(152, 0.1645021645021645, 109, 0.5561224489795918)

In [18]:
test('tagg_lstm','twitter')

loading tokenizer: twitter_tokenizer.dat
loading embedding_matrix: 300_twitter_embedding_matrix.dat


(113, 0.3265895953757225, 84, 0.24277456647398843)

In [19]:
# LCF-BERT-CDW-cToPT-Twitter SRD=3
opt.local_context_focus='cdw'
opt.SRD=3
test('lcf_bert','laptop')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


(54, 0.11513859275053305, 73, 0.4319526627218935)

In [20]:
# LCF-BERT-CDW-cToPT-ALL SRD=7
opt.local_context_focus='cdw'
opt.SRD=7
test('lcf_bert','restaurant')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


(71, 0.07683982683982683, 83, 0.42346938775510207)

In [21]:
# LCF-BERT-CDW SRD=5
opt.local_context_focus='cdw'
opt.SRD=5
test('lcf_bert','twitter')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


(93, 0.26878612716763006, 84, 0.24277456647398843)