In [1]:
import json
from tqdm import tqdm
import argparse
import re
import os
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, confusion_matrix
import pandas as pd
import numpy as np

In [2]:
def write_file(labels, dir_, last_name):
    print("writing ...")
    file_name = os.path.join(dir_, last_name)
    file = open(file_name, "w")
    for line in labels:
        file.write(line.strip() + '\n')
    file.close()

def read_openai_file(file_name):
    print(f"read ... {file_name}")

    file = open(file_name, "r")
    results = []
    for line in tqdm(file):
        results.append(line.strip())
    file.close()
    return results

def read_dialog_file(file_name):
    print(f"read ... {file_name}")

    with open(file_name, "r") as file:
        dialogs = json.load(file)
    results = [i[-1]['generation']['content'] for i in dialogs]
    return results

def read_full_dialog_file(file_name):
    print(f"read ... {file_name}")

    with open(file_name, "r") as file:
        dialogs = json.load(file)
    return dialogs

def read_mrc_file(file_name):
    print(f"read ... {file_name}")

    return json.load(open(file_name))

def split_and_keep_second_part(text):
    # Regular expression pattern for any non-space character followed by a colon
    pattern = r'[^ ]+:'
    parts = re.split(pattern, text, maxsplit=1)

    # Return the second part after removing any leading/trailing whitespace, if it exists
    return parts[1].strip() if len(parts) > 1 else text


def process_predictions(ori_results):
    results = []
    for line in ori_results:
        line = split_and_keep_second_part(line)
        results.append(line)
    return results


def compute_f1(mrc_data, openai_data):
    print("computting f1 ...")

    true_positive = 0
    false_positive = 0
    false_negitative = 0
    category_data = {'food':{'tp':0, 'fp':0, 'fn':0}, 'symptom':{'tp':0, 'fp':0, 'fn':0}, 
                     'loc':{'tp':0, 'fp':0, 'fn':0}, 'other':{'tp':0, 'fp':0, 'fn':0},}
    for idx_ in range(len(mrc_data)):
        reference = []
        candidate = []
        entity_list = []
        item_ = mrc_data[idx_]
        context_list = item_["context"].strip().split()
        entity_type = item_["entity_label"]
        for sub_idx in range(len(item_["start_position"])):
            start_ = item_["start_position"][sub_idx]
            end_ = item_["end_position"][sub_idx]
            reference.append((" ".join(context_list[start_:end_ + 1]), start_, end_))

        flag = False
        candidate_sentence = openai_data[idx_]
        candidate_sentence_list = candidate_sentence.strip().split()
        start_ = 0
        for word_idx, word in enumerate(candidate_sentence_list):
            if len(word) > 2 and word[0] == '@' and word[1] == '@':
                flag = True
                for end_ in range(word_idx, len(candidate_sentence_list)):
                    end_word = candidate_sentence_list[end_]
                    if len(end_word) > 2 and end_word[-1] == '#' and end_word[-2] == '#':
                        entity_ = " ".join(candidate_sentence_list[word_idx:end_ + 1])[2:-2]
                        len_ = end_ - word_idx + 1
                        while start_ < len(context_list):
                            if start_ + len_ - 1 < len(context_list) and " ".join(
                                    context_list[start_:start_ + len_]) == entity_:
                                candidate.append(
                                    (" ".join(context_list[start_:start_ + len_]), start_, start_ + len_ - 1))
                                break
                            start_ += 1
                        break
            if len(word) > 2 and word[-1] == '#' and word[-2] == '#':
                flag = False
                continue
            if not flag:
                start_ += 1

        # item_ = openai_data[idx_]
        # context_list = item_.strip().split()

        # flag = False
        # start_ = 0
        # for word_idx, word in enumerate(context_list):
        #     if len(word) > 2 and word[0] == '@' and word[1] == '@':
        #         flag = True
        #         start_ = word_idx
        #     if flag and len(word) > 2 and word[-1] == '#' and word[-2] == '#':
        #         flag = False
        #         candidate.append((" ".join(context_list[start_:word_idx+1])[2:-2], start_, word_idx))

        # print(f"ref: {reference}")
        # print(f"can: {candidate}")
        for span_item in candidate:
            if span_item in reference:
                reference.remove(span_item)
                true_positive += 1
                category_data[entity_type]['tp'] += 1
            else:
                false_positive += 1
                category_data[entity_type]['fp'] += 1
        false_negitative += len(reference)
        category_data[entity_type]['fn'] += len(reference)

    span_recall = true_positive / (true_positive + false_negitative)
    span_precision = true_positive / (true_positive + false_positive)
    span_f1 = span_precision * span_recall * 2 / (span_recall + span_precision)

    return span_recall, span_precision, span_f1, true_positive, false_positive, false_negitative, category_data

In [3]:
def convert_label(text):
    text = text.lower()
    if 'yes' in text:
        return 1
    if 'no' in text:
        return 0
    return 0

def compute_sentence_f1(mrc_data, openai_data):
    y_true = [i['sentence_class'] for i in mrc_data]
    y_pred = [convert_label(ans) for ans in openai_data]
    recall = recall_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    acc = accuracy_score(y_true, y_pred)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    return (recall, precision, f1, acc), (tn, fp, fn, tp)


def construct_results(gpt_results, verify_results):
    def justify(string_):
        if len(string_) >= 3 and string_[:3].lower() == "yes":
            return "yes"
        if len(string_) >= 2 and string_[:2].lower() == "no":
            return "no"
        return ""

    def reverse_ans(string_):
        if len(string_) >= 3 and "yes" in string_.lower():
            return "no"
        if len(string_) >= 2 and "no" in string_.lower():
            return "yes"
        return string_

    results = []
    for idx_, item in enumerate(gpt_results):
        ans = " ".join(item.strip().split())
        if justify(verify_results[idx_].strip()) == "no":
            ans = reverse_ans(ans)
        results.append(ans)
    return results

In [4]:
def compute_simple_f1(mrc_data, openai_data, label_symbol="@@##"):
    print("computting f1 ...")
    sym_len = len(label_symbol)
    label_prefix, label_suffix = label_symbol[:sym_len // 2], label_symbol[sym_len // 2:]

    mention_true_positive, mention_false_positive, mention_false_negative = 0, 0, 0
    mention_category_data = {'food':{'tp':0, 'fp':0, 'fn':0}, 'symptom':{'tp':0, 'fp':0, 'fn':0},
                             'loc':{'tp':0, 'fp':0, 'fn':0}, 'other':{'tp':0, 'fp':0, 'fn':0},}

    for idx_ in range(len(mrc_data)):
        reference_word = []
        candidate_word = []
        item_ = mrc_data[idx_]
        context_list = item_["context"].strip().split()
        entity_type = item_["entity_label"]
        for sub_idx in range(len(item_["start_position"])):
            start_ = item_["start_position"][sub_idx]
            end_ = item_["end_position"][sub_idx]
            reference_word.append(" ".join(context_list[start_:end_ + 1]))

        flag = False
        candidate_sentence = openai_data[idx_]
        candidate_sentence = candidate_sentence.strip().split(":")[-1]
        candidate_sentence_list = candidate_sentence.strip().split(',')
        candidate_word.extend(candidate_sentence_list)

        for span_item in candidate_word:
            if span_item in reference_word:
                reference_word.remove(span_item)
                mention_true_positive += 1
                mention_category_data[entity_type]['tp'] += 1
            else:
                mention_false_positive += 1
                mention_category_data[entity_type]['fp'] += 1
        mention_false_negative += len(reference_word)
        mention_category_data[entity_type]['fn'] += len(reference_word)

    mention_recall = mention_true_positive / (mention_true_positive + mention_false_negative)
    mention_precision = mention_true_positive / (mention_true_positive + mention_false_positive)
    mention_f1 = mention_precision * mention_recall * 2 / (mention_recall + mention_precision)
    mention_all_count = {'tp': mention_true_positive, 'fp': mention_false_positive, 'fn': mention_false_negative}
    return mention_recall, mention_precision, mention_f1, mention_all_count, mention_category_data

In [13]:
def construct_tweet_results(gpt_results, verify_results, pos_word, neg_word):
    # def justify(string_):
    #     if len(string_) >= 3 and string_[:3].lower() == "yes":
    #         return "yes"
    #     if len(string_) >= 2 and string_[:2].lower() == "no":
    #         return "no"
    #     return ""
    def justify(string_):
        if "yes" in string_.lower():
            return "yes"
        if "no" in string_.lower():
            return "no"
        return ""

    def reverse_ans(string_, pos_word, neg_word):
        if len(string_) >= 3 and pos_word in string_.lower():
            return neg_word
        if len(string_) >= 2 and neg_word in string_.lower():
            return pos_word
        return string_

    results = []
    for idx_, item in enumerate(gpt_results):
        ans = " ".join(item.strip().split())
        if justify(verify_results[idx_].strip()) == "no":
            ans = reverse_ans(ans, pos_word, neg_word)
        results.append(ans)
    return results

In [14]:
def construct_tweet_results_old(gpt_results, verify_results, pos_word, neg_word):
    # def justify(string_):
    #     if len(string_) >= 3 and string_[:3].lower() == "yes":
    #         return "yes"
    #     if len(string_) >= 2 and string_[:2].lower() == "no":
    #         return "no"
    #     return ""
    def justify(string_):
        if "yes" in string_.lower():
            return "yes"
        if "no" in string_.lower():
            return "no"
        return ""

    def reverse_ans(string_, pos_word, neg_word):
        if len(string_) >= 3 and pos_word in string_:
            return string_.replace(pos_word, neg_word)
        if len(string_) >= 2 and neg_word in string_:
            return string_.replace(neg_word, pos_word)
        return string_

    results = []
    for idx_, item in enumerate(gpt_results):
        ans = " ".join(item.strip().split())
        if justify(verify_results[idx_].strip()) == "no":
            ans = reverse_ans(ans, pos_word, neg_word)
        results.append(ans)
    return results

In [34]:
candidate_file = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.tmp.second.test.short.unbalanced.4.28.json"
reference_file = "/scratch/dzhang5/LLM/TWEET-FID/mrc-ner.expert.test"

In [35]:
predictions = read_dialog_file(candidate_file)
mrc_data = read_mrc_file(reference_file)

read ... /scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.tmp.second.test.short.unbalanced.4.28.json
read ... /scratch/dzhang5/LLM/TWEET-FID/mrc-ner.expert.test


In [36]:
full_dialogs = read_full_dialog_file(candidate_file)

read ... /scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.tmp.second.test.short.unbalanced.4.28.json


In [38]:
full_dialogs[3][-1]['generation']['content']

"Sure! Here's the labeled sentence with other entities associated with foodborne illnesses:\n\n@USER As much fun as I can . Woke up with <<food poisoning>> or stomach flu . Been bugging me all day #tmi Almost done driving for the day\n\nOther entities labeled in the sentence include:\n\n* food poisoning\n* stomach flu"

In [56]:
def merge_results(mrc_data, first_dialogs, label_symbol, type_count):
    sym_len = len(label_symbol)
    label_prefix, label_suffix = label_symbol[:sym_len // 2], label_symbol[sym_len // 2:]

    def get_words(labeled_sentence, label_prefix, label_suffix):
        word_list = []
        words = labeled_sentence.strip().split()
        flag = False
        last_ = ""
        for idx_, word in enumerate(words):
            if len(word) > 2 and word[0] == label_prefix[0] and word[1] == label_prefix[1]:
                last_ = idx_
                flag = True
            if flag and len(word) > 2 and word[-1] == label_suffix[-1] and word[-2] == label_suffix[-2]:
                word_list.append((" ".join(words[last_:idx_ + 1])[2:-2], last_, idx_))
                flag = False
        return word_list

    merge_list = []
    for item_idx in tqdm(range(len(mrc_data))):
        item_ = mrc_data[item_idx]
        sen_id, entity_id = [int(i) for i in item_["qas_id"].split(".")]
        if 0 == entity_id:
            context = item_["context"]
            entity_label_dict = {}
        origin_label = item_["entity_label"]
        entity_list = get_words(first_dialogs[sen_id - 1][-1]['generation']['content'].strip(), label_prefix,
                                label_suffix)
        entity_label_dict[origin_label] = entity_list
        if type_count - 1 == entity_id:
            previous_system = first_dialogs[sen_id - 1][0]
            previous_assistant_response = first_dialogs[sen_id - 1][2]
            merge_list.append({'context': context, "entity_predictions": entity_label_dict,
                               'previous_system': previous_system,
                               'previous_assistant_response': previous_assistant_response})
    return merge_list

In [57]:
merge_dialogs = merge_results(mrc_data, full_dialogs, "<<>>", 4)

100%|██████████| 1648/1648 [00:00<00:00, 79482.70it/s]


In [60]:
merge_dialogs[0]

{'context': '@USER As much fun as I can . Woke up with food poisoning or stomach flu . Been bugging me all day #tmi Almost done driving for the day',
 'entity_predictions': {'food': [], 'symptom': [], 'loc': [], 'other': []},
 'previous_system': {'role': 'system',
  'content': "Always follow the user's instruction and only provide answer for the user's last given sentence"},
 'previous_assistant_response': {'role': 'assistant',
  'content': "Understood! I will only provide answer with formats like the user's instruction for the user's last given sentence."}}

In [58]:
str(merge_dialogs[0]['entity_predictions'])

"{'food': [], 'symptom': [], 'loc': [], 'other': []}"

In [47]:
last_example = [i[1]['content'].split(": ")[-1][:-2] for i in full_dialogs]
print(pd.Series(last_example).describe())
final_answer = [i[4]['generation']['content'].split(": ")[-1] for i in full_dialogs]

count     412
unique      2
top        No
freq      207
dtype: object


In [36]:
match_list = [example.lower() in answer.lower() for example, answer in zip(last_example, final_answer)]
np.mean(match_list)

0.4368932038834951

In [7]:
compute_sentence_f1(mrc_data, predictions)

((0.6642335766423357,
  0.40444444444444444,
  0.5027624309392265,
  0.5631067961165048),
 (141, 134, 46, 91))

In [13]:
verify_file = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.verify.first.test.short.unbalanced.4.28.json"

In [14]:
verify_results = read_dialog_file(verify_file)

read ... /scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.verify.first.test.short.unbalanced.4.28.json


In [38]:
full_verify_dialogs = read_full_dialog_file(verify_file)

read ... /scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.verify.first.test.short.unbalanced.4.28.json


In [48]:
last_example = [i[1]['content'].split(": ")[-1][:-2] for i in full_verify_dialogs]
print(pd.Series(last_example).describe())
final_answer = [i[4]['generation']['content'].split(": ")[-1] for i in full_verify_dialogs]
match_list = [example.lower() in answer.lower() for example, answer in zip(last_example, final_answer)]
np.mean(match_list)

count     412
unique      2
top        No
freq      213
dtype: object


0.5315533980582524

In [8]:
result_file = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.first.short.unbalanced.4.28.32.knn.sequence.fullprompt.verified"
reference_file = "/scratch/dzhang5/LLM/TWEET-FID/mrc-tc.expert.test"

In [9]:
final_results = read_openai_file(result_file)
mrc_data = read_mrc_file(reference_file)

read ... /scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.first.short.unbalanced.4.28.32.knn.sequence.fullprompt.verified


412it [00:00, 1393591.33it/s]

read ... /scratch/dzhang5/LLM/TWEET-FID/mrc-tc.expert.test





In [10]:
compute_sentence_f1(mrc_data, final_results)

((0.7883211678832117,
  0.36860068259385664,
  0.5023255813953488,
  0.48058252427184467),
 (90, 185, 29, 108))

In [49]:
candidate_file = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.tmp.second.test.short.unbalanced.4.28.json"
reference_file = "/scratch/dzhang5/LLM/TWEET-FID/mrc-ner.expert.test"

In [50]:
predictions = read_dialog_file(candidate_file)
mrc_data = read_mrc_file(reference_file)

read ... /scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.tmp.second.test.short.unbalanced.4.28.json
read ... /scratch/dzhang5/LLM/TWEET-FID/mrc-ner.expert.test


In [75]:
def write_file(labels, dir_, last_name):
    print("writing ...")
    file_name = os.path.join(dir_, last_name)
    file = open(file_name, "w")
    for line in labels:
        file.write(line.strip() + '\n')
    file.close()

In [76]:
candidate_file = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.32.knn.sequence.fullprompt.verified"
reference_file = "/scratch/dzhang5/LLM/TWEET-FID/mrc-ner.expert.test"

In [78]:
candidate_file = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.tmp.test.short.unbalanced.4.28.json"
reference_file = "/scratch/dzhang5/LLM/TWEET-FID/mrc-ner.expert.test"

In [79]:
predictions = read_dialog_file(candidate_file)
mrc_data = read_mrc_file(reference_file)

read ... /scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.tmp.test.short.unbalanced.4.28.json
read ... /scratch/dzhang5/LLM/TWEET-FID/mrc-ner.expert.test


In [80]:
compute_simple_f1(mrc_data=mrc_data, openai_data=predictions)

computting f1 ...


(0.28343949044585987,
 0.04777241009125067,
 0.08176389526871843,
 {'tp': 89, 'fp': 1774, 'fn': 225},
 {'food': {'tp': 34, 'fp': 411, 'fn': 27},
  'symptom': {'tp': 16, 'fp': 487, 'fn': 44},
  'loc': {'tp': 27, 'fp': 428, 'fn': 26},
  'other': {'tp': 12, 'fp': 448, 'fn': 128}})

In [45]:
verify_candidate_file = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama-single/llama-2-13b-chat.verify.tc.test.short.unbalanced.4.28.json"
reference_file = "/scratch/dzhang5/LLM/TWEET-FID/mrc-tc.expert.test"
tmp_candidate_file = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama-single/llama-2-13b-chat.tmp.tc.test.short.unbalanced.4.28.json"
write_dir = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama-single/"
write_name = "llama-2-13b-chat.tc.short.unbalanced.4.28.32.knn.sequence.fullprompt.verified"

In [71]:
verify_candidate_file = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.verify.reverse.second.test.short.unbalanced.4.28.json"
reference_file = "/scratch/dzhang5/LLM/TWEET-FID/mrc-tc.expert.test"
tmp_candidate_file = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.tmp.reverse.second.test.short.unbalanced.4.28.json"
write_dir = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/"
write_name = "llama-2-13b-chat.reverse.second.short.unbalanced.4.28.32.knn.sequence.fullprompt.verified"

In [78]:
verify_candidate_file = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.verify.reverse.2.second.test.short.unbalanced.4.28.json"
reference_file = "/scratch/dzhang5/LLM/TWEET-FID/mrc-tc.expert.test"
tmp_candidate_file = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.tmp.first.test.short.unbalanced.4.28.json"
write_dir = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/"
write_name = "llama-2-13b-chat.reverse.2.second.short.unbalanced.4.28.32.knn.sequence.fullprompt.verified"

In [84]:
verify_candidate_file = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.verify.first.test.short.unbalanced.4.28.json"
reference_file = "/scratch/dzhang5/LLM/TWEET-FID/mrc-tc.expert.test"
tmp_candidate_file = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.tmp.first.test.short.unbalanced.4.28.json"
write_dir = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/"
write_name = "llama-2-13b-chat.first.short.unbalanced.4.28.32.knn.sequence.fullprompt.verified"

In [95]:
verify_candidate_file = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.verify.2.first.test.short.unbalanced.4.28.json"
reference_file = "/scratch/dzhang5/LLM/TWEET-FID/mrc-tc.expert.test"
tmp_candidate_file = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.tmp.reverse.second.test.short.unbalanced.4.28.json"
write_dir = "/scratch/dzhang5/LLM/TWEET-FID/test-results-llama/"
write_name = "llama-2-13b-chat.2.first.short.unbalanced.4.28.32.knn.sequence.fullprompt.verified"

In [96]:
gpt_results = read_dialog_file(tmp_candidate_file)
mrc_data = read_mrc_file(reference_file)
verify_results = read_dialog_file(verify_candidate_file)
previous_final_results = read_openai_file(os.path.join(write_dir, write_name))

final_results = construct_tweet_results(gpt_results=gpt_results, verify_results=verify_results,
                                            pos_word='yes', neg_word='no')

final_results_old = construct_tweet_results_old(gpt_results=gpt_results, verify_results=verify_results,
                                                pos_word='yes', neg_word='no')

read ... /scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.tmp.reverse.second.test.short.unbalanced.4.28.json
read ... /scratch/dzhang5/LLM/TWEET-FID/mrc-tc.expert.test
read ... /scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.verify.2.first.test.short.unbalanced.4.28.json
read ... /scratch/dzhang5/LLM/TWEET-FID/test-results-llama/llama-2-13b-chat.2.first.short.unbalanced.4.28.32.knn.sequence.fullprompt.verified


412it [00:00, 295545.28it/s]


In [97]:
all([i==j for i, j in zip(final_results_old, previous_final_results)])

True

In [98]:
compute_sentence_f1(mrc_data=mrc_data, openai_data=gpt_results)

((0.9197080291970803,
  0.45161290322580644,
  0.6057692307692307,
  0.6019417475728155),
 (122, 153, 11, 126))

In [99]:
compute_sentence_f1(mrc_data=mrc_data, openai_data=final_results)

((0.7956204379562044,
  0.39636363636363636,
  0.529126213592233,
  0.529126213592233),
 (109, 166, 28, 109))

In [100]:
compute_sentence_f1(mrc_data=mrc_data, openai_data=final_results_old)

((0.9708029197080292,
  0.3481675392670157,
  0.51252408477842,
  0.3859223300970874),
 (26, 249, 4, 133))

In [101]:
write_file(labels=final_results, dir_=write_dir, last_name=write_name)

writing ...


In [19]:
gpt_predictions = [convert_label(_) for _ in gpt_results]

In [20]:
labels = [_['sentence_class'] for _ in mrc_data]

In [24]:
verify_labels = [int(i == j) for i, j in zip(gpt_predictions, labels)]

In [27]:
verify_predictions = [convert_label(_) for _ in verify_results]

In [32]:
verify_incorrect = [int(i!=j) for i, j in zip(verify_predictions, verify_labels)]

In [38]:
verify_table = pd.DataFrame({'result': verify_results, 'label':verify_labels, 'pred':verify_predictions})

In [67]:
final_results = construct_tweet_results(gpt_results=gpt_results, verify_results=verify_results,
                                        pos_word='yes', neg_word='no')

In [68]:
def justify(string_):
    if "yes" in string_.lower():
        return "yes"
    if "no" in string_.lower():
        return "no"
    return ""

In [66]:
pd.Series([justify(_) for _ in gpt_results]).value_counts()

yes    205
no     198
         9
dtype: int64

In [64]:
pd.Series([justify(_) for _ in verify_results]).value_counts()

no     290
yes    119
         3
dtype: int64

In [65]:
pd.Series([justify(_) for _ in final_results]).value_counts()

yes    296
no     107
         9
dtype: int64

In [47]:
compute_sentence_f1(mrc_data, gpt_results)

((0.8467153284671532,
  0.5658536585365853,
  0.6783625730994152,
  0.7330097087378641),
 (186, 89, 21, 116))

In [69]:
compute_sentence_f1(mrc_data, final_results)

((0.635036496350365,
  0.2939189189189189,
  0.4018475750577367,
  0.3713592233009709),
 (66, 209, 50, 87))