In [4]:
import json
import spacy
import numpy as np
from tqdm import tqdm
from collections import Counter
from dataset_preparation import automatic_stage_tagging_sentence_level


In [25]:
def exact_match(predict_seq, reference_seq):
    match_cnt = 0
    total_cnt = 0
    for predicted_plan, reference_plan in zip(predict_seq, reference_seq):
        for p1, p2 in zip(predicted_plan, reference_plan):
            if p1 and p2 and p1==p2:
                match_cnt += 1

        total_cnt += len(predicted_plan)
    return match_cnt/total_cnt

def plan_to_unigram(plan):
    return [(stage) for stage in plan]
    
def plan_to_bigram(plan):
    result = []
    for i in range(len(plan)-1):
        result.append(tuple(plan[i:i+2]))
    return result

def plan_to_trigram(plan):
    result = []
    for i in range(len(plan)-2):
        result.append(tuple(plan[i:i+3]))
    return result


def n_gram_match_rate(predict_seq, reference_seq, ngram=1):
    if ngram==1:
        reference_ngram = [plan_to_unigram(plan) for plan in reference_seq]
        prediction_ngram = [plan_to_unigram(plan) for plan in predict_seq]

    elif ngram==2:
        reference_ngram = [plan_to_bigram(plan) for plan in reference_seq]
        prediction_ngram = [plan_to_bigram(plan) for plan in predict_seq]

    elif ngram==3:
        reference_ngram = [plan_to_trigram(plan) for plan in reference_seq]
        prediction_ngram = [plan_to_trigram(plan) for plan in predict_seq]
    else:
        print('Wrong n-gram number. ')

    average_match_rate = []
    for ngram1, ngram2 in zip(reference_ngram, prediction_ngram):
        ngram1_cnt = Counter(ngram1)
        ngram2_cnt = Counter(ngram2)
        match_cnt = 0
        for ngram in ngram2_cnt.keys():
            if ngram in ngram1_cnt:
                # print(bigram1_cnt[bigram],bigram2_cnt[bigram])
                # print(min(bigram1_cnt[bigram], bigram2_cnt[bigram]))
                match_cnt += min(ngram1_cnt[ngram], ngram2_cnt[ngram])
        if sum(ngram2_cnt.values()) != 0:
            match_rate = match_cnt / sum(ngram2_cnt.values())
            average_match_rate.append(match_rate)
    print('Unigram match rates', np.mean(average_match_rate))



def compute_stage_matching(generation_doc_list, stage_reference_data):
    '''
    generation_doc_list and test_stage_data have format of list of list
    '''
    spacy_tokenizer = spacy.load("en_core_web_sm", disable=['parser', 'senter', 'ner'])
    scores = []
    for generated_text_list, teat_stage in tqdm(zip(generation_doc_list, stage_reference_data), 
                                                total=len(generation_doc_list)):
        labels = []
        for sent in generated_text_list:
            # words = spacy_tokenizer(sent)
            label = automatic_stage_tagging_sentence_level(sent, spacy_tokenizer)
            labels.append(label)
        
        match_cnt = 0.0
        for generated_label, reference_label in zip(labels, teat_stage):
            if generated_label == reference_label:
                match_cnt += 1
        scores.append(match_cnt/len(teat_stage))
    return np.average(scores)

In [5]:
planner_result_path='/home/yinhong/Documents/source/RecipeWithPlans/model-checkpoint/planner_results/'
with open(planner_result_path+'planner_prediction_test.json') as f:
    test_predicted_plan = json.load(f)['planner_prediction']

test_data_path='/home/yinhong/Documents/datasets/recipe1m+/preprocessed_data/test_dataset.json'
with open(test_data_path) as f:
    test_plan = json.load(f)['stage_label']

In [28]:
print(n_gram_match_rate(test_predicted_plan, test_plan, ngram=1))
print(n_gram_match_rate(test_predicted_plan, test_plan, ngram=2))
print(n_gram_match_rate(test_predicted_plan, test_plan, ngram=3))


Unigram match rates 0.676552433825484
None
Unigram match rates 0.3292366278179846
None
Unigram match rates 0.13029280196214318
None
