In [None]:
import os
from datetime import datetime
from tqdm import trange
from src.utils.data_util import DataHandlerCLSBERT, DataHandlerCLS
from src.train_valid_test_step import *
from config import Config as config
from torch.multiprocessing import set_start_method
from src.classifiers.literal_idiom_classifier import LiteralIdiomaticClassifier

In [None]:
# Load model
data_handler = DataHandlerCLS()
model = LiteralIdiomaticClassifier(data_handler.config)
save_path = config.PATH_TO_CHECKPOINT_CLF.format('best')
save_path += 'projection_layer.mdl'
checkpoint = torch.load(save_path)
print(checkpoint['epoch'])
model.load_state_dict(checkpoint['state_dict'], strict=False)
model.to(config.DEVICE)

In [4]:
# print out current test model information
print('Adapter Name: {}'.format(config.ADAPTER_NAME))
print('Adapter Split: {}'.format(config.SPLIT))
print('Task Split: {}'.format(config.CLS_TYPE))

Adapter Name: fusion
Adapter Split: random
Task Split: random


In [5]:
# Run prediction on test set

In [None]:
model.eval()
bbar = tqdm(enumerate(data_handler.validset_generator),
                ncols=100, leave=False, total=data_handler.config.num_batch_valid)

labels, preds = [], []
inputs = []
mc_preds = []
idioms = []
for idx, data in bbar:

    with torch.no_grad():
        # model forward pass to compute loss
        loss, logits = model(data)
    ys = data['labels'].cpu().detach().numpy().tolist()
    labels += ys
    ys_ = torch.argmax(logits, dim=-1).cpu().detach().numpy().tolist()
    preds += ys_
    mc_preds += [1 for _ in ys_]
    xs = data['inputs']['input_ids'].cpu().detach().numpy()  # batch_size, max_xs_seq_len
    inputs += list(xs)
    idioms += data['idioms']

 12%|███████▏                                                      | 15/129 [00:03<00:23,  4.89it/s]

In [None]:
from collections import defaultdict
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score


In [None]:
idiom2pred = {}

for i in range(len(idioms)): 
    idiom, pred, truth = idioms[i], preds[i], labels[i]
    if idiom not in idiom2pred: 
        idiom2pred[idiom] = {'pred': [], 'truth': []}
    idiom2pred[idiom]['pred'].append(pred)
    idiom2pred[idiom]['truth'].append(truth)

In [None]:
# idiom2perf = {}

# for idiom in idiom2pred: 
#     preds, truths = idiom2pred[idiom]['pred'], idiom2pred[idiom]['truth']
#     acc = accuracy_score(truths, preds)
#     if idiom not in idiom2perf: 
#         idiom2perf[idiom] = acc
    

In [None]:
# Compute evaluation metrics
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
def compute_performance(y_true, y_pred):
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}

In [None]:
# EVALUATION RESULT HERE!!!
eval_res = compute_performance(labels, preds)
eval_res