In [94]:
import sys
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import csv

import random
from typing import List, Any, Dict, Tuple
import copy
from datetime import datetime
from collections import Counter, defaultdict
from scipy import stats
from tabulate import tabulate

sys.path.append("../")

from src.scripts import run_fake_data_test
from src.helpers.visualisation import barplot_distribution, plot_confusion_matrix, tabulate_annotation_pair_summary, analyze_pair_annotations
from src.helpers.io import read_json


%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [98]:
prompt_fields_prev = [
    "multi_turn_relationship",
    "media_format",
    "topic",
    "function_purpose",
    "anthropomorphization",
    "restricted_flags",
]
response_fields_prev = [
    "answer_form",
    "self_disclosure",
    "topic_response",
    "media_format_response",
    "restricted_flags_response",
]

prompt_fields_new = [
    "prompt_multi_turn_relationship",
    "prompt_media_format",
    "prompt_interaction_features",
    "prompt_function_purpose",
    "turn_topic",
    "turn_sensitive_use_flags",
]
response_fields_new = [
    "response_answer_form",
    "response_media_format",
    "response_interaction_features",
]

In [99]:
dset = run_fake_data_test.run_automatic_analysis_v0("../data/")

Loading conversations from ../data/sample120.json
Loaded 120 conversations.
Updated file: combined.json
Added conversation IDs to 1 files
Split records into two folders:
  - ../data/labelstudio_outputs_split1_v2/: Contains 0 records with unique conversation IDs
  - ../data/labelstudio_outputs_split2_v2/: Contains 83 records with duplicate conversation IDs
starttinginsdgljsdfjksdfjl
prompt_media_format
prompt_function_purpose
prompt_multi_turn_relationship
prompt_interaction_features
turn_sensitive_use_flags
turn_topic
response_media_format
response_answer_form
prompt_media_format
prompt_function_purpose
prompt_multi_turn_relationship
prompt_interaction_features
turn_sensitive_use_flags
turn_topic
response_media_format
response_answer_form

gpt4o-json

prompt-multi_turn_relationship: 2 / 597 failed due to invalid annotations.
prompt_multi_turn_relationship
prompt-interaction_features: 1 / 597 failed due to invalid annotations.
prompt_interaction_features
turn-sensitive_use_flags: 0 / 59

In [100]:
def run_interrater_comparison(
    dataset, 
    task_name,
    annotation_source_1,
    annotation_source_2,
):
    info_to_plot1 = dset.get_annotation_distribution(name=task_name, level="message", annotation_source=annotation_source_1)
    info_to_plot2 = dset.get_annotation_distribution(name=task_name, level="message", annotation_source=annotation_source_2)
    info_to_plot1b = dset.get_annotation_distribution(name=task_name, level="message", annotation_source=annotation_source_1, annotation_as_list_type=True)
    info_to_plot2b = dset.get_annotation_distribution(name=task_name, level="message", annotation_source=annotation_source_2, annotation_as_list_type=True)

    outdir = f"../data/annotation_analysis_v0/{annotation_source_1}--{annotation_source_2}/{task_name}"
    os.makedirs(outdir, exist_ok=True)
    fig = barplot_distribution(
        {"Split1": info_to_plot1, "Split2": info_to_plot2}, normalize=True, 
        xlabel=task_name, ylabel="Proportion", title="",
        output_path=f"{outdir}/barchart.png", order="descending")
    
    fig_b = barplot_distribution(
        {"Split1": info_to_plot1b, "Split2": info_to_plot2b}, normalize=True, 
        xlabel=task_name, ylabel="Proportion", title="",
        output_path=f"{outdir}/multilabel_barchart.png", order="descending")

    info_to_plot_cm, agreement_metrics, paired_values = dataset.get_joint_distribution(
        annotations1=(task_name, annotation_source_1), 
        annotations2=(task_name, annotation_source_2), 
        level="message",
        compute_disagreement=True,
        verbose=True
    )
    # print(info_to_plot_cm)

    fig2 = plot_confusion_matrix(
        info_to_plot_cm, normalize=True, xlabel="", ylabel="", title="Confusion Matrix",
        output_path=f"{outdir}/confusion_matrix.png")

    # print(paired_values[0:3])
    df = analyze_pair_annotations(paired_values)
    df.to_csv(f"{outdir}/pair_frequencies.csv", index=False, quoting=csv.QUOTE_NONNUMERIC)

    print()
    print(f"-----------------{task_name}-----------------")
    print(agreement_metrics)
    print(tabulate_annotation_pair_summary(df, 20))
    print(len(df))
    print()
    return paired_values


In [101]:
# task_annotations = {}
# for feature in prompt_fields_new:
#     task_annotations[feature] = run_interrater_comparison(dset, feature, "gpt4o_json_full", "gpt4o_free_full")
#     # break
# for feature in response_fields_new:
#     task_annotations[feature] = run_interrater_comparison(dset, feature, "gpt4o_json_full", "gpt4o_free_full")
#     # break

In [102]:
ex_ids = []
orig_sample = read_json("../data/sample120.json")
for datum in orig_sample["data"]:
    for turn in datum["conversation"]:
        ex_ids.append(datum["conversation_id"] + "-" + str(turn["turn"]))

In [103]:
# automatic_variants = [
#     "gpt4o_json_full",
#     # "gpt4o_free_full",
#     "gpto3mini_json_full",
#     # "gpto3mini_free_full",
# ]
# focus_keys = [(model_key, field_name) for model_key in automatic_variants for field_name in prompt_fields_new]
# focus_keys.extend([(model_key, field_name) for model_key in automatic_variants for field_name in response_fields_new])
# focus_keys.extend([(split_key, field_name) for split_key in ["split1", "split2"] for field_name in prompt_fields_new])
# focus_keys.extend([(split_key, field_name) for split_key in ["split1", "split2"] for field_name in response_fields_new])
# focus_metadatas = dset.extract_conversation_metadata_by_ids(
#     ex_ids,
#     annotation_keys=focus_keys,
#     level="message",
# )

In [104]:
def display_info_for_turn(
    ex_idx_turn,
):

    ex_idx, turn = ex_idx_turn.split("-")
    turn = int(turn)
    message = dset.id_lookup(ex_idx_turn, level="message")[ex_idx_turn].to_dict()
    role = message['role']
    # relevant_keys = prompt_fields_new if role == "user" else response_fields_new
    relevant_keys = prompt_fields_new + response_fields_new
    task_to_source_to_vals = defaultdict(dict)
    for key in message["metadata"].keys():
        source, task = key.split("-")
        if task in relevant_keys:
            task_to_source_to_vals[task][source] = message["metadata"][key]

    print(f"IDX: {ex_idx} | Turn: {turn} | Role: {role}")
    print(f"-------------------------------------------")
    for task, source_vals in task_to_source_to_vals.items():
        print()
        print(f"TASK: {task}")
        for source, val in source_vals.items():
            src_info = val["annotator"] if "split" in source else source
            print(f"{src_info}:   {val['value']}")

    print("\n****** Message Content:******")
    print(message["content"])
    print()

    if turn > 0:
        print("\n****** Previous Turn Message Content:******")
        prev_message = dset.id_lookup(ex_idx + "-" + str(turn-1), level="message")[ex_idx + "-" + str(turn-1)].to_dict()
        print(prev_message["content"])


In [117]:
ANNOTATION_TURN = 14
display_info_for_turn(ex_ids[ANNOTATION_TURN])

IDX: wildchat_1d31bdda8c40f114afa0aad43e02a3c9 | Turn: 0 | Role: user
-------------------------------------------

TASK: prompt_media_format
megan:   ['Formatted enumeration / itemization', 'Likely retrieved / pasted content', 'Natural language']
niloofar:   ['Formatted enumeration / itemization', 'Likely retrieved / pasted content', 'Math / symbols', 'Natural language']
gpt4o_json:   ['Natural language', 'Formatted enumeration/itemization']
gpt4o_free:   ['Natural language', 'Formatted enumeration/itemization (bullets/lists)']
gpto3mini_json:   ['Natural language', 'Formatted enumeration/itemization']
gpto3mini_free:   ['Natural language', 'Formatted enumeration/itemization (bullets/lists)']

TASK: prompt_function_purpose
megan:   ['Content generation: Prompts for another AI system']
niloofar:   ['Advice, guidance, & recommendations: Professional advice', 'Editorial & formatting: Natural language style or re-formatting', 'Editorial & formatting: Information processing & re-formatting'

In [41]:
ANNOTATION_TURN = 8

In [42]:
message = dset.id_lookup(ex_ids[ANNOTATION_TURN], level="message")[ex_ids[ANNOTATION_TURN]].to_dict()

In [49]:
# message

In [91]:
EG_ID = "wildchat_20847df802a3268754fe7d7a6ada334b-6"

In [92]:
message = dset.id_lookup(EG_ID, level="message")[EG_ID].to_dict()

In [108]:
# message

In [109]:
# display_info_for_turn("wildchat_f1675170ab5361f56211e19bacbe1945-1")

In [11]:
# dset.data[3].conversation[2].metadata