In [None]:
# !unzip relations.zip

In [None]:
import os
import json
from sklearn.metrics import multilabel_confusion_matrix
import numpy as np
import pandas as pd

In [None]:
def read_json(path):
    with open(path, 'r', encoding="utf-8") as f:
        data = json.load(f)
    return data

def write_json(data, path):
    if not os.path.exists(os.path.dirname(path)):
        os.makedirs(os.path.dirname(path))
    with open(path, 'w', encoding="utf-8") as f:
        json.dump(data, f, indent=4, ensure_ascii=False)

In [None]:
def calculate_false(test_folder, prediction_file_path, run_id, task_id):

  test_relations = []
  for t in range(1, task_id+1):
    input_path = f"{test_folder}run_{0}/task{1}/test_1.json".format(run_id, t)
    # print(input_path)
    task_data = read_json(input_path)
    test_data = [item['relation'] for item in task_data]
    test_relations.extend(test_data)
  predictions = read_json(prediction_file_path)

  prediction_relations = [item['predict'] for item in predictions]
  print(prediction_relations)
  print(test_relations)
  false_count = 0
  for i in range(len(test_relations)):
    if prediction_relations[i] != test_relations[i]:
      false_count += 1
  return false_count

In [None]:
def calculate_non_defined(relations, prediction_file_path, run_id, task_id):
  predictions = read_json(prediction_file_path)
  prediction_relations = [item['predict'] for item in predictions]
  relation_types_file = f"{relations}/run_{run_id}/task{task_id}.json"
  relations = read_json(relation_types_file)

  non_defined_count = 0
  for relation in prediction_relations:
    if relation not in relations:
      non_defined_count += 1
  return non_defined_count


In [None]:
def main(input_folder, test_folder, relations_folder):
  results = []
  for run_id in range(1, 6):
    for task_id in range(1, 11):
      print(f"run_id: {run_id}, task_id: {task_id}")
      if task_id == 1:
        prediction_file_path = f"{input_folder}_{run_id}_extracted/task_task{task_id}_current_task_pred.json"
      else:
        prediction_file_path = f"{input_folder}_{run_id}_extracted/task_{task_id}_seen_task.json"
      non_defined = calculate_non_defined(relations_folder, prediction_file_path, run_id, task_id)
      sum_false = calculate_false(test_folder, prediction_file_path, run_id, task_id)
      error = {'run_id': run_id, 'task_id': task_id, 'non_defined': non_defined, 'sum_false':int(sum_false)}
      results.append(error)

  return results



In [None]:
if __name__ == '__main__':
  input_folder = "./m_10/KMmeans_CRE_tacred"
  test_folder = "./llama_format_data/test/"
  relations_folder = "./relations/"
  results = main(input_folder, test_folder, relations_folder)

In [None]:
write_json(results, "./content/t5_fewrel_false_analysis.json")

In [None]:
# Convert to DataFrame
df = pd.DataFrame(results)

# Group by task_id and compute the mean of non_defined and sum of sum_false
llama_result_df = df.groupby('task_id').agg(mean_non_defined=('non_defined', 'mean'),
                                   sum_sum_false=('sum_false', 'mean')).reset_index()

In [None]:
# Llama2-7B
print(llama_result_df.T.to_latex())

In [None]:
# t5
# print(t5_result_df.to_latex())

In [None]:
# mistral
# print(mistral_result_df.to_latex())

In [None]:
!zip -r tacred_error.zip ./content/

In [None]:
data_stats = []
for run_id in range(1,6):
  for task_id in range(1,11):
    test_relations = []
    for t in range(1, task_id+1):
        input_path = f"{test_folder}/run_{0}/task{1}/test_1.json".format(run_id, t)
        # print(input_path)
        task_data = read_json(input_path)
        test_data = [item['relation'] for item in task_data]
        test_relations.extend(test_data)
    print(f"Count: {len(test_relations)}")
    stat = {'run_id': run_id, 'task_id': task_id, 'count': len(test_relations)}
    data_stats.append(stat)

In [None]:
data_stats_df = pd.DataFrame(data_stats)

# Group by task_id and compute the mean of non_defined and sum of sum_false
mean_data_stats_df = data_stats_df.groupby('task_id').agg(mean_non_defined=('count', 'mean')).reset_index()

In [None]:
print(mean_data_stats_df.T.to_latex())