In [22]:
import sys
import os
import json
import pandas as pd
import numpy as np
import ast
import itertools
import random
import copy
from datetime import datetime
from collections import Counter, defaultdict

import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display

sys.path.append("../")

from src.helpers import io
from src.classes.dataset import Dataset
from src.classes.annotation_set import AnnotationSet

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
# FILL IN:
PATH_TO_DATASET = "../data/static/wildchat4k-raw.json"
DATASET_ID = "wildchat_1m"
PATH_TO_ANNOTATIONS_DIR = "../res/gpto3mini-json-wildchat"
OUTDIR = "data/annotation_analysis_v0/data-slice-comparison"
os.makedirs(OUTDIR, exist_ok=True)

# Load dataset (w/o annotations)
dataset = Dataset.load(PATH_TO_DATASET)

# Load annotations into dataset
for fpath in io.listdir_nohidden(PATH_TO_ANNOTATIONS_DIR):
    annotation_set = AnnotationSet.load_automatic(path=fpath, source="automatic_v0")
    dataset.add_annotations(annotation_set)

prompt-multi_turn_relationship: 0 / 10127 failed due to invalid annotations.
prompt-interaction_features: 0 / 10127 failed due to invalid annotations.
turn-sensitive_use_flags: 0 / 10127 failed due to invalid annotations.
turn-topic: 1 / 10127 failed due to invalid annotations.
response-interaction_features: 0 / 10127 failed due to invalid annotations.
prompt-function_purpose: 6 / 10127 failed due to invalid annotations.
prompt-media_format: 0 / 10127 failed due to invalid annotations.
response-media_format: 0 / 10127 failed due to invalid annotations.
response-answer_form: 0 / 10127 failed due to invalid annotations.


# General-purpose data slice comparison

In [37]:
# Split dataset by group function
def split_dataset_by(dataset, group_fn):
    groups = defaultdict(list)
    for conv in dataset.data:
        group = group_fn(conv)
        groups[group].append(conv)
    return {g: Dataset(dataset_id=f"{dataset.dataset_id}_{g}", data=convs) for g, convs in groups.items()}

# Compare each group to baseline
def compare_annotations_to_baseline(group_datasets, baseline_dataset, annotation_tasks, annotation_source="automatic_v0"):
    def get_label_counts(dset, task_name, level):
        counter = Counter()
        for conv in dset.data:
            for msg in conv.conversation:
                if level == "conversation":
                    label = getattr(conv, task_name, None)
                else:
                    if msg.role == "user" and level == "prompt":
                        key = f"{annotation_source}-prompt_{task_name}"
                    elif msg.role == "assistant" and level == "response":
                        key = f"{annotation_source}-response_{task_name}"
                    elif level == "turn":
                        key = f"{annotation_source}-turn_{task_name}"
                    else:
                        continue
                    if key in msg.metadata:
                        label = msg.metadata[key].value
                    else:
                        label = None
                if label:
                    if isinstance(label, list):
                        counter.update(label)
                    else:
                        counter[label] += 1
        return counter

    baseline_distributions = {
        (task, level): get_label_counts(baseline_dataset, task, level)
        for (task, level) in annotation_tasks
    }

    comparison_results = {}
    for group_name, group_dataset in group_datasets.items():
        group_results = {}
        for (task, level) in annotation_tasks:
            baseline = baseline_distributions[(task, level)]
            group = get_label_counts(group_dataset, task, level)
            total_base = sum(baseline.values())
            total_group = sum(group.values())

            all_labels = set(baseline.keys()) | set(group.keys())
            differences = []
            for label in all_labels:
                base_pct = (baseline[label] / total_base) * 100 if total_base else 0
                group_pct = (group[label] / total_group) * 100 if total_group else 0
                diff = round(group_pct - base_pct, 2)
                differences.append((label, diff, group_pct, base_pct))

            differences.sort(key=lambda x: abs(x[1]), reverse=True)
            group_results[(task, level)] = differences

        comparison_results[group_name] = group_results

    return comparison_results

# Visualization: Diverging bar chart for one group's difference from baseline
def plot_differences_for_group(group_name, group_diff_data, baseline_label, comparison_label, outdir=OUTDIR):
    os.makedirs(outdir, exist_ok=True)
    for (task, level), diffs in group_diff_data.items():
        if not diffs:
            print(f"[Skipped] No data available to plot for {task} @ {level}")
            continue

        labels, differences, group_pcts, base_pcts = zip(*diffs)
        colors = ["green" if diff > 0 else "red" for diff in differences]

        plt.figure(figsize=(10, max(6, len(labels) * 0.5)))
        y_pos = np.arange(len(labels))
        bars = plt.barh(y_pos, differences, color=colors)
        plt.yticks(y_pos, labels)
        plt.axvline(0, color="black", linewidth=0.8)
        plt.title(f"{comparison_label} vs {baseline_label}\nTop Differences for {task} @ {level}")
        plt.xlabel("Percentage Difference from Baseline")
        plt.gca().invert_yaxis()

        for i, bar in enumerate(bars):
            plt.text(
                bar.get_width() + (0.5 if bar.get_width() > 0 else -0.5),
                bar.get_y() + bar.get_height() / 2,
                f"{differences[i]:+.1f}%",
                va='center', ha='left' if bar.get_width() > 0 else 'right', fontsize=9
            )

        safe_group = ''.join([w[0].upper() for w in comparison_label.split()])
        safe_base = ''.join([w[0].upper() for w in baseline_label.split()])
        fname = f"diff_{safe_group}{safe_base}_{level}_{task}.png"
        plot_path = os.path.join(outdir, fname)
        plt.tight_layout()
        plt.savefig(plot_path)
        plt.close()
        print(f"Saved plot: {plot_path}")

### Use geography as an example

In [None]:
# === Custom Group Definitions ===
# Split by country (user types, different datasets, language?, )
grouped_by_country = split_dataset_by(
    dataset,
    lambda conv: conv.geography.split(";")[0].strip() if conv.geography else "Unknown"
)

# Define baseline and comparison groups
baseline_names = ["United States"]
comparison_names = ["China"]

baseline_dataset = Dataset(
    dataset_id="baseline_group",
    data=list(itertools.chain.from_iterable([
        grouped_by_country[c].data for c in baseline_names if c in grouped_by_country
    ]))
)
comparison_dataset = Dataset(
    dataset_id="comparison_group",
    data=list(itertools.chain.from_iterable([
        grouped_by_country[c].data for c in comparison_names if c in grouped_by_country
    ]))
)

annotation_tasks = [
    ("function_purpose", "prompt"),
    ("interaction_features", "prompt"),
    ("media_format", "prompt"),
    ("answer_form", "response"),
    ("interaction_features", "response"),
    ("media_format", "response"),
    ("topic", "turn"),
    ("sensitive_use_flags", "turn"),
]

comparison_results = compare_annotations_to_baseline(
    group_datasets={"ComparisonGroup_vs_Baseline": comparison_dataset},
    baseline_dataset=baseline_dataset,
    annotation_tasks=annotation_tasks,
    annotation_source="automatic_v0"
)

plot_differences_for_group(
    group_name="ComparisonGroup_vs_Baseline",
    group_diff_data=comparison_results["ComparisonGroup_vs_Baseline"],
    baseline_label=", ".join(baseline_names),
    comparison_label=", ".join(comparison_names)
)


Saved plot: data/annotation_analysis_v0/data-slice-comparison/diff_CUS_prompt_function_purpose.png
Saved plot: data/annotation_analysis_v0/data-slice-comparison/diff_CUS_prompt_interaction_features.png
Saved plot: data/annotation_analysis_v0/data-slice-comparison/diff_CUS_prompt_media_format.png
[Skipped] No data available to plot for answer_form @ response
[Skipped] No data available to plot for interaction_features @ response
[Skipped] No data available to plot for media_format @ response
Saved plot: data/annotation_analysis_v0/data-slice-comparison/diff_CUS_turn_topic.png
Saved plot: data/annotation_analysis_v0/data-slice-comparison/diff_CUS_turn_sensitive_use_flags.png
