In [1]:
%load_ext autoreload
%autoreload 2
import sys
if "../src" not in sys.path:
    sys.path.append("../src")

In [2]:
import re
import pickle
import json
from pathlib import Path
from typing import List, Tuple, Dict

import numpy as np
import torch
import pandas as pd
from tqdm.auto import tqdm

import datasets
from transformers import MT5TokenizerFast

from CwnGraph import CwnImage
import vec4gloss
from vec4gloss import check_hashes
from vec4gloss import Vec4GlossModel

from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.functional import bleu_score
from nltk.translate.meteor_score import single_meteor_score

## Data dependencies
Note: the `defgen_dataset_cwn\train` is only used in checking data hash, the 'test' split is used in the notebook itself.
```
10.11 => ../data/defgen_dataset_cwn/train/dataset.arrow 65a56d
20.21../data/models\vec4gloss-defgen-220629-1250\pytorch_model.bin 9f894f
```

In [3]:
vec4gloss_model_dir = "../data/models/vec4gloss-defgen-220629-1250"

_ = check_hashes([    
    "../data/defgen_dataset_cwn/train/dataset.arrow",    
    vec4gloss_model_dir + "/pytorch_model.bin",
])

../data/defgen_dataset_cwn/train/dataset.arrow 65a56d
../data/models/vec4gloss-defgen-220629-1250/pytorch_model.bin 9f894f


## Loading resources

In [4]:
cwn = CwnImage.load("v.2022.06.21")

In [5]:
use_cuda = torch.cuda.is_available() and "GeForce" not in torch.cuda.get_device_name()
device = "cuda" if use_cuda else "cpu"    
print("Using", device)

ds_defgen = datasets.load_from_disk("../data/defgen_dataset_cwn")
model = Vec4GlossModel.from_pretrained(vec4gloss_model_dir).to(device)
tokenizer = MT5TokenizerFast.from_pretrained(vec4gloss_model_dir)
gen = vec4gloss.gen_func(tokenizer, model)

Using cuda


## Load evaldataset

In [6]:
ds_defgen["test"][0]

{'cwnid': '05170501',
 'src': '在妳出國前哭著要我等妳回來，我也<答應>過妳一定會等妳。',
 'tgt': 'VE。同意他人的要求。'}

In [7]:
def simplify_pos(pos):    
    pos = ",".join(x 
                   for x in pos.split(",")
                   if x!="nom")
    if pos and pos == "Nb":
        return "Nb" # ignore
    elif pos and pos[0] in "DVN":
        return pos[0]   
    else:
        return "O"    

## Check POS distributions

In [8]:
pos_results = []
for item_x in tqdm(ds_defgen["test"]):       
    sense_x = cwn.from_sense_id(item_x["cwnid"])    
    entry = {"cwnid": sense_x.id, 
             "word": sense_x.head_word,
             "pos": simplify_pos(sense_x.pos)}
    pos_results.append(entry)    

  0%|          | 0/8553 [00:00<?, ?it/s]

In [9]:
from collections import Counter
Counter(x["pos"] for x in pos_results)

Counter({'V': 4376, 'D': 432, 'Nb': 414, 'O': 530, 'N': 2801})

## Calculate scores

In [10]:
eval_results = []
for item_x in tqdm(ds_defgen["test"]):
    try:
        gendef = " ".join(list(gen(item_x["src"])))
    except Exception as ex:
        print(item_x["cwnid"], str(ex))
    refdef = " ".join(item_x["tgt"])    
    score_bleu = bleu_score(gendef, [refdef]).item()
    score_rouge = rouge_score(gendef, refdef, rouge_keys=("rougeL"))
    score_meteor = single_meteor_score(gendef.split(), refdef.split())    
    
    sense_x = cwn.from_sense_id(item_x["cwnid"])    
    entry = {"cwnid": sense_x.id, 
             "word": sense_x.head_word,
             "pos": simplify_pos(sense_x.pos)}
    scores = {k: v.item() for k,v in score_rouge.items()}
    scores.update({"bleu": score_bleu, "meteor": score_meteor})
    entry.update(**scores)
    eval_results.append(entry)    

  0%|          | 0/8553 [00:00<?, ?it/s]

In [18]:
se = lambda x: np.std(x, ddof=1)/np.sqrt(len(x))
eval_results_df = pd.DataFrame.from_records(eval_results)
eval_results_df.groupby("pos")\
    .agg(
        n_sample = ("cwnid", len),
        bleu_mean=("bleu", "mean"),
        meteor_mean=("meteor", "mean"),
        rouge_mean=("rougeL_fmeasure", "mean"),
        bleu_se=("bleu", se),
        meteor_se=("meteor", se),
        rouge_se=("rougeL_fmeasure", se)
    )

Unnamed: 0_level_0,n_sample,bleu_mean,meteor_mean,rouge_mean,bleu_se,meteor_se,rouge_se
pos,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
D,432,0.408252,0.616829,0.815719,0.021168,0.01584,0.018018
N,2801,0.35104,0.586241,0.914749,0.007475,0.005587,0.004216
Nb,414,0.632856,0.742545,0.888524,0.021533,0.016441,0.012069
O,530,0.410067,0.625508,0.757303,0.016867,0.013446,0.017195
V,4376,0.434996,0.625141,0.874232,0.00613,0.004625,0.003787


In [19]:
eval_results_df.shape

(8553, 8)

In [20]:
eval_out_path = "../data/auto_metrics.csv"
eval_results_df.to_csv(eval_out_path, index=False)

## Output Hashes

```
../data/auto_metrics.csv 7930a0
```

In [21]:
_ = check_hashes([eval_out_path])

../data/auto_metrics.csv 7930a0
