In [116]:
import sys
import os
import json
import pandas as pd
import numpy as np
import ast
import itertools
import random
import copy
from datetime import datetime
from collections import Counter, defaultdict

import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display

sys.path.append("../")

from src.helpers import io
from src.helpers.dataset_comparison import split_dataset_by, compare_annotations_to_baseline
from src.helpers.visualisation import plot_differences_for_group
from src.classes.dataset import Dataset
from src.classes.annotation_set import AnnotationSet

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [117]:
# FILL IN:
PATH_TO_DATASET = "../data/static/wildchat4k-raw.json"
DATASET_ID = "wildchat_1m"
PATH_TO_ANNOTATIONS_DIR = "../res/gpto3mini-json-wildchat"
OUTDIR = "data/annotation_analysis_v0/data-slice-comparison"
os.makedirs(OUTDIR, exist_ok=True)

# Load dataset (w/o annotations)
dataset = Dataset.load(PATH_TO_DATASET)

# Load annotations into dataset
for fpath in io.listdir_nohidden(PATH_TO_ANNOTATIONS_DIR):
    annotation_set = AnnotationSet.load_automatic(path=fpath, source="automatic_v0")
    dataset.add_annotations(annotation_set)

prompt-multi_turn_relationship: 0 / 10127 failed due to invalid annotations.
prompt-interaction_features: 0 / 10127 failed due to invalid annotations.
turn-sensitive_use_flags: 0 / 10127 failed due to invalid annotations.
turn-topic: 1 / 10127 failed due to invalid annotations.
response-interaction_features: 0 / 10127 failed due to invalid annotations.
prompt-function_purpose: 6 / 10127 failed due to invalid annotations.
prompt-media_format: 0 / 10127 failed due to invalid annotations.
response-media_format: 0 / 10127 failed due to invalid annotations.
response-answer_form: 0 / 10127 failed due to invalid annotations.


# General-purpose data slice comparison

### Use geography as an example

In [118]:
# === Custom Group Definitions ===
# Split by country 
grouped_by_country = split_dataset_by(
    dataset,
    lambda conv: conv.geography.split(";")[0].strip() if conv.geography else "Unknown"
)

baseline_names = ["United States"]
comparison_names = ["China"]

baseline_dataset = Dataset(
    dataset_id="baseline_group",
    data=list(itertools.chain.from_iterable([
        grouped_by_country[c].data for c in baseline_names if c in grouped_by_country
    ]))
)

comparison_dataset = Dataset(
    dataset_id="comparison_group",
    data=list(itertools.chain.from_iterable([
        grouped_by_country[c].data for c in comparison_names if c in grouped_by_country
    ]))
)

annotation_source_tasks = [
    ("prompt_function_purpose", "automatic_v0"),
    ("prompt_interaction_features", "automatic_v0"),
    ("prompt_media_format", "automatic_v0"),
    ("response_answer_form", "automatic_v0"),
    ("response_interaction_features", "automatic_v0"),
    ("response_media_format", "automatic_v0"),
    ("turn_topic", "automatic_v0"),
    ("turn_sensitive_use_flags", "automatic_v0")
]

comparison_results = compare_annotations_to_baseline(
    group_datasets={"ComparisonGroup_vs_Baseline": comparison_dataset},
    baseline_dataset=baseline_dataset,
    annotation_source_tasks=annotation_source_tasks
)




In [119]:
plot_differences_for_group(
    group_name="ComparisonGroup_vs_Baseline",
    group_diff_data=comparison_results["ComparisonGroup_vs_Baseline"],
    baseline_label=", ".join(baseline_names),
    comparison_label=", ".join(comparison_names),
    outdir=OUTDIR
)

✅ Saved plot: data/annotation_analysis_v0/data-slice-comparison/diff_CUS_prompt_function_purpose_automatic_v0.png
📊 Metrics for prompt_function_purpose (automatic_v0):
   chi2: 531.6428001554893
   p_value: 1.1347551254297907e-90
   jsd: 0.35087313985097246
   wasserstein: 1.1916982400381915
--------------------------------------------------
✅ Saved plot: data/annotation_analysis_v0/data-slice-comparison/diff_CUS_prompt_interaction_features_automatic_v0.png
📊 Metrics for prompt_interaction_features (automatic_v0):
   chi2: 52.2524792284799
   p_value: 4.788342448097323e-10
   jsd: 0.10536087360243997
   wasserstein: 2.9170689081019123
--------------------------------------------------
✅ Saved plot: data/annotation_analysis_v0/data-slice-comparison/diff_CUS_prompt_media_format_automatic_v0.png
📊 Metrics for prompt_media_format (automatic_v0):
   chi2: 119.67056196203019
   p_value: 2.1118341584101363e-20
   jsd: 0.12545914699964428
   wasserstein: 2.0749050097334423
--------------------