In [1]:
import pandas as pd
from IPython.display import display
import argparse
import os
import json
import warnings

In [2]:
data_dir = '/scratch/dzhang5/LLM/TWEET-FID/'
output_dir = '/scratch/dzhang5/LLM/TWEET-FID/unlabeled-results-autolabel-mv/0.1'
model_name = "refuel-llm"
few_shot_selection = "semantic_similarity"
text_column = "context"
verified = False
last_result_dir = '/scratch/dzhang5/LLM/TWEET-FID/unlabeled-results-autolabel-ner-qa-mv/0.1'
use_current_explanation = True
use_ground_explanation = False
raw_data_path = '/scratch/dzhang5/usda_project/tweet-fid-application/tweet-fid-unlabeled.p'

In [3]:
unlabeled_data = pd.read_pickle(raw_data_path)

In [4]:
unlabeled_data.rename(columns={'tweet_token':"tweet_tokens"}, inplace=True)
unlabeled_data.rename(columns={'tweet_text':"tweet"},inplace=True)
unlabeled_data.loc[unlabeled_data['tweet'].str.contains('\r'), 'tweet'] = unlabeled_data['tweet'].loc[unlabeled_data['tweet'].str.contains('\r')].str.replace('\r', '')
unlabeled_data['id'] = unlabeled_data['id'].astype(int)

In [5]:
def process_merge_results():
    word_suffix_name = f'unlabeled-{ix}.csv'
    word_data_path = os.path.join(data_dir, word_suffix_name)
    ori_df = pd.read_csv(word_data_path)
    predictions_list = []
    for label_type in ['Food', 'Location', 'Symptom', 'Keyword']:
        if verified:
            output_name = os.path.split(model_name)[-1] + '_' + few_shot_selection + '_' + label_type + '_final_' + os.path.split(word_data_path)[-1]
        else:
            output_name = os.path.split(model_name)[-1] + '_' + few_shot_selection + '_' + label_type + '_' + os.path.split(word_data_path)[-1]
        output_path = os.path.join(last_result_dir, output_name)
        predictions = pd.read_csv(output_path)
        predictions_list.append(predictions)
    merged_predictions = predictions_list[0]
    for i in range(1, len(predictions_list)):
        next_prediction = predictions_list[i]
        assert (merged_predictions[text_column] == next_prediction[text_column]).all()
        cols_to_merge = next_prediction.columns.difference(merged_predictions.columns)
        merged_predictions = merged_predictions.join(next_prediction[cols_to_merge], validate='1:1')
    assert (ori_df['context'] == merged_predictions['context']).all()
    merged_predictions['All_answer_successfully_labeled'] = (merged_predictions['Food_answer_successfully_labeled'] & merged_predictions['Location_answer_successfully_labeled'] & merged_predictions['Symptom_answer_successfully_labeled'] & merged_predictions['Keyword_answer_successfully_labeled'])
    labeled_df = ori_df.loc[merged_predictions['All_answer_successfully_labeled']].copy()
    labeled_df.reset_index(inplace=True)
    labeled_df.rename(columns={'index':'ori_index'},inplace=True)
    keep_predictions = merged_predictions.loc[merged_predictions['All_answer_successfully_labeled']].copy()
    keep_predictions.reset_index(inplace=True)
    keep_predictions.rename(columns={'index':'ori_index'},inplace=True)
    cols_to_join = [f'{label_type}_answer_label' for label_type in ['Food', 'Location', 'Symptom', 'Keyword']]
    labeled_df = labeled_df.join(keep_predictions[cols_to_join], validate='1:1')
    labeled_df['batch_idx'] = ix
    return labeled_df

In [6]:
def extract_data():
    suffix_name = f'unlabeled-first-{ix}.csv'
    if verified:
        agg_name = os.path.split(model_name)[-1] + '_' + few_shot_selection + '_aggregated_final_' + suffix_name
    else:
        agg_name = os.path.split(model_name)[-1] + '_' + few_shot_selection + '_aggregated_' + suffix_name
    agg_path = os.path.join(last_result_dir, agg_name)
    ori_df = pd.read_csv(agg_path)
    output_name = os.path.split(model_name)[-1] + '_' + few_shot_selection + '_COT_' + '_cur_' + str(use_current_explanation) + '_ground_' + str(use_ground_explanation) + '_' + os.path.split(agg_path)[-1]
    output_path = os.path.join(output_dir, output_name)
    predictions = pd.read_csv(output_path)
    assert (ori_df['tweet'] == predictions['tweet']).all()
    display(predictions['sentence_class_successfully_labeled'].value_counts())
    labeled_df = ori_df.loc[predictions['sentence_class_successfully_labeled']].copy()
    unlabeled_df = ori_df.loc[~predictions['sentence_class_successfully_labeled']].copy()
    keep_predictions = predictions.loc[predictions['sentence_class_successfully_labeled']].copy()
    labeled_df = labeled_df.join(keep_predictions[['sentence_class_label']], validate='1:1')
    labeled_df['batch_idx'] = ix
    unlabeled_df['batch_idx'] = ix
    return labeled_df, unlabeled_df

In [7]:
labeled_df_list, unlabeled_df_list, word_labeled_df_list = [], [], []
for ix in range(1, 21):
    labeled_df, unlabeled_df = extract_data()
    labeled_df_list.append(labeled_df)
    unlabeled_df_list.append(unlabeled_df)
    word_labeled_df = process_merge_results()
    word_labeled_df_list.append(word_labeled_df)

sentence_class_successfully_labeled
True     237
False     72
Name: count, dtype: int64

sentence_class_successfully_labeled
True     424
False     12
Name: count, dtype: int64

sentence_class_successfully_labeled
True     188
False    136
Name: count, dtype: int64

sentence_class_successfully_labeled
True    401
Name: count, dtype: int64

sentence_class_successfully_labeled
True     367
False     19
Name: count, dtype: int64

sentence_class_successfully_labeled
True     487
False     13
Name: count, dtype: int64

sentence_class_successfully_labeled
True    558
Name: count, dtype: int64

sentence_class_successfully_labeled
True    346
Name: count, dtype: int64

sentence_class_successfully_labeled
True     319
False     14
Name: count, dtype: int64

sentence_class_successfully_labeled
True    346
Name: count, dtype: int64

sentence_class_successfully_labeled
True    316
Name: count, dtype: int64

sentence_class_successfully_labeled
True     298
False     46
Name: count, dtype: int64

sentence_class_successfully_labeled
True    334
Name: count, dtype: int64

sentence_class_successfully_labeled
True     238
False     83
Name: count, dtype: int64

sentence_class_successfully_labeled
True     381
False     45
Name: count, dtype: int64

sentence_class_successfully_labeled
True     733
False    256
Name: count, dtype: int64

sentence_class_successfully_labeled
True     645
False    184
Name: count, dtype: int64

sentence_class_successfully_labeled
True     411
False     69
Name: count, dtype: int64

sentence_class_successfully_labeled
True     411
False     72
Name: count, dtype: int64

sentence_class_successfully_labeled
True     418
False     76
Name: count, dtype: int64

In [8]:
labeled_df = pd.concat(labeled_df_list)
unlabeled_df = pd.concat(unlabeled_df_list)
word_labeled_df = pd.concat(word_labeled_df_list)

In [9]:
print(word_labeled_df.shape, labeled_df.shape)

(8955, 36) (7858, 35)


In [10]:
word_keep_columns = ['Food_answer_label', 'Location_answer_label','Symptom_answer_label', 'Keyword_answer_label']
full_labeled_df = labeled_df.join(word_labeled_df.set_index('id')[word_keep_columns], on='id', validate='1:1')

In [11]:
full_labeled_df.to_pickle(f'/scratch/dzhang5/LLM/TWEET-FID/full_labeled_{model_name}_{few_shot_selection}.p')

In [83]:
labeled_df = pd.merge(labeled_df, unlabeled_data[['id', 'tweet_tokens']], how='left', on='id', validate='1:1', suffixes=['_drop', ''])

In [86]:
labeled_df.drop(columns=['tweet_tokens_drop'], inplace=True)

In [96]:
token_label_map = {'Food':'food', 'Location':'loc', 'Symptom':'symptom', 'Keyword':'other'}

In [130]:
def convert_token_labels(tweet_tokens, entity_dict):
    def get_index(wp, tp):
        if word_list[wp].startswith(tweet_tokens[tp]):
            if wp == wlen-1:
                return [tp]
            elif tp == tlen - 1:
                return []
            else:
                rest_ans = get_index(wp+1, tp+1)
                if rest_ans:
                    return [tp] + rest_ans
                else:
                    return []
        return []
    entity_dict = json.loads(entity_dict)
    tlen = len(tweet_tokens)
    ans_list = ['O']*tlen
    for k, v in entity_dict.items():
        for word in v:
            word_list = word.split()
            wlen = len(word_list)
            all_index_list = [get_index(0, i) for i in range(tlen)]
            all_index_list = [_ for _ in all_index_list if _]
            if len(all_index_list) > 1:
                warnings.warn(f"multiple match of {word} in {tweet_tokens}")
            elif len(all_index_list) == 0:
                warnings.warn(f"no match of {word} in {tweet_tokens}")
            else:
                index_list = all_index_list[0]
                for e, ix in enumerate(index_list):
                    if ans_list[ix] != 'O':
                        warnings.warn(f"multiple types of entity match the same word: {word}")
                    if e == 0:
                        label = f'B-{token_label_map[k]}'
                    else:
                        label = f'I-{token_label_map[k]}'
                    ans_list[ix] = label
    return ans_list