In [None]:
import os
from datetime import datetime
from tqdm import trange
from src_idiom_detect.utils.data_util import *
from src.train_valid_test_step import *
from config import Config as config
from torch.multiprocessing import set_start_method
from src_idiom_detect.model.bilstm import Seq2SeqBiLSTMLite as Seq2SeqMdl
from src_idiom_detect.utils.model_util import *
from src_idiom_detect.utils.eval_util import *


In [2]:
def read_json_lines(path_to_file): 
    with open(path_to_file) as f:
        content = f.readlines()
    f.close()
    raw_data  = [json.loads(x) for x in content] 
    return raw_data

def read_json_file(path):
    with open(path, 'r') as f:
        return json.load(f)
    
def write_json_file(path, data):
    with open(path, 'w') as f:
        json.dump(data, f)
    return

In [None]:
# Load model
data_handler = DataHandler()
save_path = config.PATH_TO_CHECKPOINT_DET
model, optimizer, epoch_start = load_init_det_model(Seq2SeqMdl, data_handler.config)

In [4]:
# print out current test model information
print('Adapter Name: {}'.format(config.ADAPTER_NAME))
print('Adapter Split: {}'.format(config.SPLIT))
print('Task Split: {}'.format(config.DET_TYPE))

Adapter Name: fusion
Adapter Split: random
Task Split: random


In [5]:
idioms = read_json_file('./fusion_analysis/idioms.json')
len(idioms)

4102

In [6]:
# Run prediction on test set

In [7]:
model.eval()
labels, preds = [], []
inputs = []
labels_text, preds_text = [], []
bbar = tqdm(enumerate(data_handler.validset_generator), ncols=100, leave=False, total=data_handler.config.num_batch_valid)
for idx, data in bbar:
    torch.cuda.empty_cache()
    batch_size = data['xs']['input_ids'].shape[0]

    # model forward pass
    with torch.no_grad():
        # model forward pass to compute loss
        ys_, _ = model(data['xs'], data['x_lens'], data['ys'], training=False)
        if data_handler.config.DETECT_MODEL_TYPE == 'bilstm':
            data['ys'] = data['ys'][:, 1:]

    # eval results
    xs = list(data['xs']['input_ids'].cpu().detach().numpy())  # batch_size, max_xs_seq_len
    ys = list(data['ys'].cpu().detach().numpy())  # batch_size, max_ys_seq_len
    ys_ = list(torch.argmax(ys_, dim=2).cpu().detach().numpy())  # batch_size, max_ys_seq_len
    xs, ys, ys_ = post_process_eval(xs, ys, ys_, data_handler.config)
    
    for bi in range(len(xs)): 
        preds_text.append(data_handler.tokenizer.decode([xs[bi][ti] for ti, t in enumerate(ys_[bi]) if t == 4]))
        labels_text.append(data_handler.tokenizer.decode([xs[bi][ti] for ti, t in enumerate(ys[bi]) if t == 4]))
                                                
    preds += ys_
    labels += ys
    inputs += xs


                                                                                                    

In [8]:
# results post-processing
# 1. convert into binary 
for i in range(len(labels)): 
    labels[i] = [1 if t == 4 else 0 for t in labels[i]]
    preds[i] = [1 if t == 4 else 0 for t in preds[i]]
#     preds[i] = [0 for t in preds[i]]

In [9]:
# Compute evaluation metrics
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
def compute_performance(y_true, y_pred):
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, labels=np.unique(y_pred))
    recall = recall_score(y_true, y_pred, labels=np.unique(y_pred))
    f1 = f1_score(y_true, y_pred, labels=np.unique(y_pred))
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}

 

In [10]:
bin_labels = [1 if 1 in d else 0 for d in labels]

In [11]:
# 1. compute sequence accuracy 
seq_acc = [1 if labels[i] == preds[i] else 0 for i in range(len(labels))]
idiom2acc = [[idioms[i], seq_acc[i]] for i in range(len(seq_acc))]
seq_acc = sum(seq_acc)/len(seq_acc)
seq_acc

0.6006825938566553

In [None]:
# 2. compute sequence level token accuracy, precision, recall and f1
token_seq_level = {'accuracy': [], 'precision': [], 'recall': [], 'f1': []}
for i in range(len(labels)): 
#     if sum(preds[i]) > 0: 
    cur_pref = compute_performance(labels[i], preds[i])
    token_seq_level['accuracy'].append(cur_pref['accuracy'])
    token_seq_level['precision'].append(cur_pref['precision'])
    token_seq_level['recall'].append(cur_pref['recall'])
    token_seq_level['f1'].append(cur_pref['f1'])

    

In [22]:
token_seq_level = {k: np.mean(v) for k, v in token_seq_level.items()}
token_seq_level

{'accuracy': 0.9642993981739959,
 'precision': 0.6708277224427882,
 'recall': 0.6450402893601332,
 'f1': 0.6398106439992675}

In [23]:
# 3. compute overall token accuracy, precision, recall and f1
token_overall_level = compute_performance([t for sublist in labels for t in sublist], [t for sublist in preds for t in sublist])
token_overall_level

{'accuracy': 0.9711082658862621,
 'precision': 0.8510427010923535,
 'recall': 0.8149204702627939,
 'f1': 0.8325899757120777}

In [24]:
eval_res = {
    'sequence_accuracy': seq_acc, 
    'token_perf_seq_level': token_seq_level,
    'token_pref_flatten': token_overall_level
}
eval_res

{'sequence_accuracy': 0.5405405405405406,
 'token_perf_seq_level': {'accuracy': 0.9642993981739959,
  'precision': 0.6708277224427882,
  'recall': 0.6450402893601332,
  'f1': 0.6398106439992675},
 'token_pref_flatten': {'accuracy': 0.9711082658862621,
  'precision': 0.8510427010923535,
  'recall': 0.8149204702627939,
  'f1': 0.8325899757120777}}