In [13]:
from tqdm import tqdm
import numpy as np

In [10]:
#utils
def read_csv(path):
    result = []

    with open(path) as f:
        for line in f.readlines():
            result.append(line)

    return result

def build_structure(throughs):
    """
    #turn ['1 2 啥'] to {('1', '2'):'啥'}
    #
    #
    """

    result = []

    for through in throughs:

        through = through.split()

        pointer = 0

        catch = {}

        #c++ style code for adapt "2 3 好 的" 
        while pointer + 3 <= len(through):

            key = (through[pointer], through[pointer+1])# key like : (1,2) 
        
            catch[key] = through[pointer+2]

            while pointer + 3 < len(through) and not through[pointer+3].isdigit():

                pointer += 1

                catch[key] += through[pointer+2]
            
            pointer += 3

        result.append(catch)

    return result



In [3]:
### f1 score
###
#《Introduction to SIGHAN 2015 Bake-off for Chinese Spelling Check》
# https://aclanthology.org/W15-3106.pdf
#
##

def get_score(predict, target, debug = False):
    """
    Detection 
    """
    print("*"*5, "Detection", "*"*5)

    tp, fp, fn, tn = 0, 0, 0, 0

    tp_list, fp_list, fn_list = [], [], []

    for i in tqdm(range(len(predict))):
        
        dict_pre, dict_true = predict[i], target[i]
        #print(dict_pre, dict_true)

        if not dict_true and dict_pre:
            """
            Picked but Dont Need Pick
            """
            fp += 1
            fp_list.append(i)
            continue
        
        if not dict_true and not dict_pre:
            continue

        flag = False
        for key in dict_true.keys():
            if not key in dict_pre.keys():
                fn += 1
                flag = True
                fn_list.append(i)
                break
        if flag:
            continue
         
        for key in dict_pre.keys():            
            if key not in dict_true.keys():
                """
                Need Pick but not Picked
                """
                fn += 1
                fn_list.append(i)
                break
        else:
            """
            Picked and Need Pick
            """
            tp += 1
            tp_list.append(i)

        """
        Unpick and Dont Need Pick
        """
        #since f1 score dont need fn, so we wont calculate it
        
    print("TP: ",tp, "FP: ", fp, "FN: ", fn,)
    
    precision = tp / (tp + fp + 1e-10)

    recall = tp / (tp + fn + 1e-10)

    F1_score = 2 * precision * recall / (precision + recall + 1e-10)

    print("Precision: ", precision, "Recall: ", recall)

    print("F1_score: ", F1_score)

    return F1_score, tp_list, fp_list, fn_list

In [4]:
### f1 score
###
#《Introduction to SIGHAN 2015 Bake-off for Chinese Spelling Check》
# https://aclanthology.org/W15-3106.pdf
#
##

def get_score_correction(predict, target):
    """
    Correction Level
    """
    print("*"*5, "Correction", "*"*5)

    tp, fp, fn, tn = 0, 0, 0, 0

    tp_list, fp_list, fn_list = [], [], []

    for i in tqdm(range(len(predict))):

        dict_pre, dict_true = predict[i], target[i]
        #print(dict_pre, dict_true)

        if not dict_true and dict_pre:
            """
            Picked but Dont Need Pick
            """
            fp += 1
            fp_list.append(i)
            continue
        
        if not dict_true and not dict_pre:
            continue

        flag = False
        for key in dict_true.keys():
            if not key in dict_pre.keys():
                fn += 1
                flag = True
                fn_list.append(i)
                break
        if flag:
            continue

        for key in dict_pre.keys():            
            if key not in dict_true.keys():
                """
                Need Pick but Not Picked
                """
                fn += 1
                fn_list.append(i)
                break
            elif dict_pre[key] != dict_true[key]:
                fn += 1
                fn_list.append(i)
                break
        else:
            """
            Picked and Need Pick
            """
            tp += 1
            tp_list.append(i)
            
        #since f1 score dont need fn, so we wont calculate it
        
    print("TP: ",tp, "FP: ", fp, "FN: ", fn,)
    
    precision = tp / (tp + fp + 1e-10)

    recall = tp / (tp + fn + 1e-10)

    F1_score = 2 * precision * recall / (precision + recall + 1e-10)

    print("Precision: ", precision, "Recall: ", recall)

    print("F1_score: ", F1_score)

    return F1_score, tp_list, fp_list, fn_list

In [65]:
# sighan15 test through character level
#predictions_path = "../../tmp/tst-csc-test/generated_predictions.txt"
#target_path = "../../data/rawdata/sighan/valid.through"

#predictions, target = read_csv(predictions_path), read_csv(target_path)

#structured_predictions, structured_target = build_structure(predictions), build_structure(target)

#get_score(structured_predictions, structured_target)


In [113]:
#sighan15 test
predictions_path = "../../tmp/sighan_seq/generated_predictions.through"
target_path = "../../data/rawdata/sighan/valid.through"

predictions, target = read_csv(predictions_path), read_csv(target_path)

structured_predictions, structured_target = build_structure(predictions), build_structure(target)

_ = get_score(structured_predictions, structured_target)

_ = get_score_correction(structured_predictions, structured_target)

***** Detection *****


100%|██████████| 1100/1100 [00:00<00:00, 349525.33it/s]


TP:  379 FP:  133 FN:  163
Precision:  0.7402343749998553 Recall:  0.6992619926197972
F1_score:  0.7191650853388982
***** Correction *****


100%|██████████| 1100/1100 [00:00<00:00, 348706.40it/s]

TP:  361 FP:  133 FN:  181
Precision:  0.7307692307690828 Recall:  0.6660516605164822
F1_score:  0.6969111968611696





In [8]:
#sighan15 test
predictions_path = "../../tmp/sighan_seq/generated_predictions_8884.through"
target_path = "../../data/rawdata/sighan/valid.through"

predictions, target = read_csv(predictions_path), read_csv(target_path)

structured_predictions, structured_target = build_structure(predictions), build_structure(target)

_ = get_score(structured_predictions, structured_target)

_ = get_score_correction(structured_predictions, structured_target)

***** Detection *****


100%|██████████| 1100/1100 [00:00<00:00, 329223.23it/s]


TP:  389 FP:  133 FN:  153
Precision:  0.7452107279692058 Recall:  0.7177121771216387
F1_score:  0.7312030074686772
***** Correction *****


100%|██████████| 1100/1100 [00:00<00:00, 307766.95it/s]

TP:  371 FP:  133 FN:  171
Precision:  0.736111111110965 Recall:  0.6845018450183238
F1_score:  0.7093690248065269





In [7]:
#sighan15 test
predictions_path = "../../tmp/sighan_seq/generated_predictions_.through"
target_path = "../../data/rawdata/sighan/valid.through"

predictions, target = read_csv(predictions_path), read_csv(target_path)

structured_predictions, structured_target = build_structure(predictions), build_structure(target)

_ = get_score(structured_predictions, structured_target)

_ = get_score_correction(structured_predictions, structured_target)

***** Detection *****


100%|██████████| 1100/1100 [00:00<00:00, 626270.45it/s]


TP:  377 FP:  137 FN:  165
Precision:  0.7334630350193125 Recall:  0.6955719557194288
F1_score:  0.7140151514650513
***** Correction *****


100%|██████████| 1100/1100 [00:00<00:00, 587961.56it/s]

TP:  359 FP:  137 FN:  183
Precision:  0.7237903225804992 Recall:  0.662361623616114
F1_score:  0.6917148361734717





In [9]:
#sighan15 test 20 epoch 3090 batch_size 64 
predictions_path = "../../tmp/bart_sighan_seq/generated_predictions.through"
target_path = "../../data/rawdata/sighan/valid.through"

predictions, target = read_csv(predictions_path), read_csv(target_path)

structured_predictions, structured_target = build_structure(predictions), build_structure(target)

_ = get_score(structured_predictions, structured_target)

_ = get_score_correction(structured_predictions, structured_target)

***** Detection *****


100%|██████████| 1099/1099 [00:00<00:00, 316763.34it/s]


TP:  368 FP:  125 FN:  173
Precision:  0.7464503042594834 Recall:  0.6802218114601329
F1_score:  0.711798839408384
***** Correction *****


100%|██████████| 1099/1099 [00:00<00:00, 326454.68it/s]

TP:  342 FP:  125 FN:  199
Precision:  0.7323340471090509 Recall:  0.6321626617374062
F1_score:  0.6785714285215635





In [7]:
#sighan15 test 10 epoch 3090 batch_sizer 64
predictions_path = "../../tmp/bart_sighan_seq_10epoch/generated_predictions.through"
target_path = "../../data/rawdata/sighan/valid.through"

predictions, target = read_csv(predictions_path), read_csv(target_path)

structured_predictions, structured_target = build_structure(predictions), build_structure(target)

_ = get_score(structured_predictions, structured_target)

_ = get_score_correction(structured_predictions, structured_target)

***** Detection *****


100%|██████████| 1099/1099 [00:00<00:00, 321625.74it/s]


TP:  386 FP:  138 FN:  155
Precision:  0.7366412213739052 Recall:  0.7134935304989438
F1_score:  0.7248826290578577
***** Correction *****


100%|██████████| 1099/1099 [00:00<00:00, 303179.43it/s]

TP:  361 FP:  138 FN:  180
Precision:  0.7234468937874302 Recall:  0.6672828096117065
F1_score:  0.6942307691807174





In [10]:
#sighan14 test
predictions_path = "../../tmp/sighan_seq/generated_predictions14.through"
target_path = "../../data/rawdata/sighan/valid14.through"

predictions, target = read_csv(predictions_path), read_csv(target_path)

structured_predictions, structured_target = build_structure(predictions), build_structure(target)

_ = get_score(structured_predictions, structured_target)

_ = get_score_correction(structured_predictions, structured_target)

***** Detection *****


100%|██████████| 1061/1061 [00:00<00:00, 373366.60it/s]


TP:  248 FP:  232 FN:  272
Precision:  0.516666666666559 Recall:  0.4769230769229852
F1_score:  0.4959999999499808
***** Correction *****


100%|██████████| 1061/1061 [00:00<00:00, 375540.64it/s]

TP:  240 FP:  232 FN:  280
Precision:  0.5084745762710787 Recall:  0.46153846153837275
F1_score:  0.483870967691955





In [22]:
def better_get_score(source, predictions, target):
    sources, preds, tgts = [ "".join(src.split()) for src in source ], [ "".join(pre.split()) for pre in predictions ], [ "".join(tgt.split()) for tgt in target ] 

    tp, fp, fn = 0, 0, 0

    for i in range(len(sources)):
        source, pred, label = sources[i], preds[i], tgts[i] 
  

        if source == label:
            if (pred == label):
                pass
            else:
                fp += 1 
        else :
            if pred == label:
                tp += 1
            else:
                fn += 1
                      
    precision = tp / (tp + fp + 1e-10)

    recall = tp / (tp + fn + 1e-10)

    F1_score = 2 * precision * recall / (precision + recall + 1e-10)

    print("Precision: ", precision, "Recall: ", recall)

    print("F1_score: ", F1_score)

    return {"F1_score": float(F1_score)}



In [23]:
#sighan test
sources_path = "../../data/rawdata/sighan/test.src"
predictions_path = "../../tmp/sighan/bart_seq2seq_eval.epoch10.bs32/generated_predictions.txt"
targets_path = "../../data/rawdata/sighan/test.tgt"

sources, predictions, targets = read_csv(sources_path), read_csv(predictions_path), read_csv(targets_path)

print(predictions[0], target[0])

_ = better_get_score(sources, predictions, targets)



你 好! 我 是 张 爱 文 。
 你好!我是张爱文。

Precision:  0.7296747967478192 Recall:  0.662361623616114
F1_score:  0.694390715617294
