## Test with small data

In [1]:
resume = 'data/checkpoints/0906_214036/model_best.pth'

In [2]:
import json
import torch
import argparse
from tqdm import tqdm
from tabulate import tabulate

from utils.unique import unique
import model.loss as module_loss
import model.model as module_arch
import model.metric as module_metric
from parse_config import ConfigParser
import data_loader.data_loaders as module_data
from utils.span2json import span2json
from utils.conll2span import conll2span
from utils.correcting_labels import fix_labels

PAD = '<pad>'

In [3]:
from utils.correcting_labels import fix_labels, remove_incorrect_tag
def get_dict_prediction(tokens, preds, attention_mask, ids2tag):
    temp_preds=[]
    for index in range(len(preds)):    
        if attention_mask[index]==1:
            Ptag = ids2tag.get(preds[index].item())
            temp_preds.append(Ptag)
            
    temp_preds = remove_incorrect_tag(temp_preds, "BIOES")
    temp_preds = fix_labels(temp_preds, "BIOES")    
    temp_preds = conll2span(temp_preds)
    temp_preds = span2json(tokens, temp_preds)   
    return temp_preds

In [4]:
args = argparse.ArgumentParser(description='PyTorch Template')
args.add_argument('-c', '--config', default=None, type=str, help='config file path (default: None)')
args.add_argument('-r', '--resume', default=f"{resume}", type=str, help='path to latest checkpoint (default: None)')
args.add_argument('-d', '--device', default=None, type=str, help='indices of GPUs to enable (default: all)')
args.add_argument('-f', '--file', default=None, type=str, help='Error')

config = ConfigParser.from_args(args)
logger = config.get_logger('test')

data_loader = config.init_obj('data_loader', module_data)
test_data_loader = data_loader.get_test()

# build model architecturea
model = config.init_obj('arch', module_arch)
# logger.info(model)

# get function handles of loss and metrics
criterion = getattr(module_loss, config['loss'])
metric_fns = [getattr(module_metric, met) for met in config['metrics']]

logger.info('Loading checkpoint: {} ...'.format(config.resume))
checkpoint = torch.load(config.resume)
state_dict = checkpoint['state_dict']
if config['n_gpu'] > 1:
    model = torch.nn.DataParallel(model)
model.load_state_dict(state_dict)
layers_train = config._config['trainer']['layers_train']

# prepare model for testing
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

total_loss = 0.0
total_metrics = torch.zeros(len(metric_fns))

Loading checkpoint: data/checkpoints/0906_214036/model_best.pth ...


In [5]:
from tqdm import tqdm
results = []
with torch.no_grad():
    for batch_idx, instance in tqdm(enumerate(test_data_loader)):
        input_ids = torch.tensor(instance['input_ids']).to(device)
        attention_mask = torch.tensor(instance['attention_mask']).to(device)
        batch_size = input_ids.shape[0]
        output = model(input_ids, attention_mask)
        
        loss = 0
        nested_lm_conll_ids = {l:None for l in range(len(layers_train))}
        for index, layer in enumerate(layers_train):
            temp_nested_lm_conll_ids = torch.tensor(instance['nested_lm_conll_ids'][layer])
            temp_nested_lm_conll_ids = temp_nested_lm_conll_ids.to(device)
            nested_lm_conll_ids[index]=temp_nested_lm_conll_ids
            loss+=criterion(output[index], temp_nested_lm_conll_ids)
            
        total_loss += loss.item() * batch_size
        predictions = {x:[] for x in range(batch_size)}
        lm_entities = {x:[] for x in range(batch_size)}
        for sent_ids in range(batch_size):
            for layer in range(len(output)):
                predictions[sent_ids].append(output[layer][sent_ids].argmax(axis=0))
                lm_entities[sent_ids].append(nested_lm_conll_ids[layer][sent_ids])
        for sent_ids in range(batch_size):
            tokens = instance['lm_tokens'][sent_ids]
            tokens = [w for w in tokens if w!=PAD]
            preds = []
            for index in range(len(layers_train)):
                preds+=get_dict_prediction(
                        tokens, 
                        predictions[sent_ids][index], 
                        attention_mask[sent_ids], 
                        data_loader.ids2tag)
            entities_labels = []
            for index in range(len(layers_train)):
                entities_labels+=get_dict_prediction(
                    tokens, 
                    lm_entities[sent_ids][index], 
                    attention_mask[sent_ids], 
                    data_loader.ids2tag)
            results.append({
                'sentence_id': instance['sentence_id'][sent_ids],
                'tokens': tokens,
                'entities': entities_labels,
                'predictions':preds})
            for i, metric in enumerate(metric_fns):
                total_metrics[i] += metric( 
                        output, nested_lm_conll_ids, attention_mask, 
                        data_loader.boundary_type, info=False, ids2tag=data_loader.ids2tag
            ) * batch_size     

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
5it [00:34,  6.96s/it]


## Results 

In [8]:
for ids, item in enumerate(results[:2]):
    print('\n',"|".join(item["tokens"]),'\n')
    sid = item['sentence_id']
    entities=item['entities']
    predictions=item['predictions']
    
    print("Predictions")
    for pred in predictions:
        tag = "/" if pred in entities else "X"
        out = [f"SID:{sid}, Index{pred['span']}", tag, ''.join(pred['text']), pred['entity_type']]
        print(out)
    
    print("Not answer")
    for en in entities:
        if en not in predictions:
            out = [f"SID:{sid}, Index{en['span']}", "X", "".join(en['text']), en['entity_type']]
            print(out)


 <s>||ความคืบหน้|า|หลัง||ศาลปกครอง||กลาง|มี||คําสั่ง|ไม่|รับ||คําฟ้อง|และ|ไม่||คุ้มครอง||ชั่วคราว||ใน||คดี||ที่||ผู้ตรวจการ|แผ่นดิน||ยื่นฟ้อง||ต่อ||ศาลปกครอง||ว่า||_||สํานักงาน||คณะกรรมการ||กิจการ||กระจาย|เสียง||_||กิจการ||โทรทัศน์|และ||กิจการ||โทรคมนาคม||แห่งชาติ||_||กสทช|.||_||จัด||การประมูล||คลื่น|ความถี่||_||ย่าน||_||2.1||<unk>|z||_|พร้อมทั้ง||ออก||ประกาศ||_|เรื่อง||_||หลักเกณฑ์|และ||วิธี|การ|อนุญาต||ให้|ใช้||คลื่น|ความถี่||ดังกล่าว|โดย|ไม่|ชอบ|ด้วยกฎหมาย||_||ด้วย||เหตุผล||ว่า||_||ผู้ตรวจการ|แผ่นดิน||_|ไม่|มี||สิทธิ|และ||หน้าที่||_||เสมือน|หนึ่ง|ผู้||ฟ้องคดี||_|เมื่อ|วันที่||_||3||_||ธ|.|ค|.||55||_|ล่าสุด||_||3||_||ม|.|ค|.||56||_||ผู้ตรวจการ|แผ่นดิน||ได้||ยื่น||อุทธรณ์||คําสั่ง|ไม่|รับ||คําฟ้อง||ของ||ศาลปกครอง||กลาง|แล้ว|</s> 

Predictions
['SID:0, Index[5, 9]', 'X', 'ศาลปกครองกลาง', 'goverment']
['SID:0, Index[35, 37]', '/', 'ศาลปกครอง', 'goverment']
['SID:0, Index[41, 63]', '/', 'สํานักงานคณะกรรมการกิจการกระจายเสียง_กิจการโทรทัศน์และกิจการโทรคมนาคมแห่งชาติ', 'goverment']
['SID:0

In [7]:
## Can input both BIESO and BIO
from model.eval import ClassEvaluator
results_eval, conll_results = ClassEvaluator()(results)

Calculate F1-score based on.
labels_true: 1138
num_labels: 1670
predictions_true: 1154
num_predictions: 1297

<<< Results Evaluations >>>

-------------  ---------  -------  -------  ----------------  ---------------  ----------
group/n.class  precision  recall   f1       predictions_true  num_predictions  num_labels
group 0: 20    90.4675    80.0823  84.9587  987               1091             1215
group 1: 19    83.4356    53.1496  64.9348  136               163              254
group 2: 65    72.093     14.9254  24.7308  31                43               201
-------------  ---------  -------  -------  ----------------  ---------------  ----------
----------------  ---------  -------  --------  ----------  ---------  ----------
tag               precision  recall   f1-score  preds_true  num_preds  num_labels
total             88.9746    68.1437  77.1783   1154        1297       1670
cardinal          88.8199    80.117   84.2442   143         161        171
country           94.6237 