In [3]:
from model.bert import bert_ATE
from data.dataset import dataset_ATM,remove_duplicates
from torch.utils.data import DataLoader, ConcatDataset
from transformers import BertTokenizer
import torch
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import time
import numpy as np
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

In [4]:
DEVICE = torch.device("mps" if torch.has_mps else "cpu")
pretrain_model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(pretrain_model_name)
lr = 2e-5
model_ATE = bert_ATE(pretrain_model_name).to(DEVICE)
optimizer_ATE = torch.optim.Adam(model_ATE.parameters(), lr=lr)

In [5]:
def load_model(model, path):
    model.load_state_dict(torch.load(path), strict=False)
    return model

In [6]:
laptops_test = remove_duplicates(pd.read_csv("data/laptops_test.csv"))
restaurants_test = remove_duplicates(pd.read_csv("data/restaurants_test.csv"))
twitter_test = remove_duplicates(pd.read_csv("data/twitter_test.csv"))

In [7]:
laptops_test_ds = dataset_ATM(laptops_test, tokenizer)
restaurants_test_ds = dataset_ATM(restaurants_test, tokenizer)
twitter_test_ds = dataset_ATM(twitter_test, tokenizer)

In [8]:
test_ds = ConcatDataset([laptops_test_ds, restaurants_test_ds, twitter_test_ds])

In [9]:
def create_mini_batch(samples):
    ids_tensors = [s[1] for s in samples]
    ids_tensors = pad_sequence(ids_tensors, batch_first=True)

    tags_tensors = [s[2] for s in samples]
    tags_tensors = pad_sequence(tags_tensors, batch_first=True)

    pols_tensors = [s[3] for s in samples]
    pols_tensors = pad_sequence(pols_tensors, batch_first=True)
    
    masks_tensors = torch.zeros(ids_tensors.shape, dtype=torch.long)
    masks_tensors = masks_tensors.masked_fill(ids_tensors != 0, 1)
    
    return ids_tensors, tags_tensors, pols_tensors, masks_tensors

In [10]:
test_loader = DataLoader(test_ds, batch_size=50, collate_fn=create_mini_batch, shuffle = True)

In [60]:
def test_model_ATE(loader):
    pred = []
    truth = []
    with torch.no_grad():
        for data in loader:

            ids_tensors, tags_tensors, _, masks_tensors = data
            ids_tensors = ids_tensors.to(DEVICE)
            tags_tensors = tags_tensors.to(DEVICE)
            masks_tensors = masks_tensors.to(DEVICE)
            outputs = model_ATE(ids_tensors=ids_tensors, tags_tensors=None, masks_tensors=masks_tensors)
            outputs = torch.softmax(outputs, dim=2)
            _, predictions = torch.max(outputs, dim=2)
            pred += list([int(j) for i in predictions for j in i ])
            truth += list([int(j) for i in tags_tensors for j in i ])
    return truth, pred


In [61]:
model_ATE = load_model(model_ATE, 'bert_ATE.pkl')

In [62]:
truth, pred = test_model_ATE(test_loader)

In [63]:
print(classification_report(truth, pred, target_names=[str(i) for i in range(3)]))

              precision    recall  f1-score   support

           0       0.99      1.00      0.99     84653
           1       0.90      0.85      0.88      3089
           2       0.89      0.78      0.83      1834

    accuracy                           0.99     89576
   macro avg       0.93      0.88      0.90     89576
weighted avg       0.99      0.99      0.99     89576

