In [None]:
import collections
import itertools
import json

import numpy as np
from matplotlib import lines, patches, pyplot as plt

from evaluation.models import SuggestionLog
from evaluation.utils import get_all_suggestions, get_experiment_infos, REPO_ROOT
from experiments.models import ALL_MODELS

MODEL_SORT_ORDER = ALL_MODELS + ["gold"]

In [None]:
def fmt_model_id(model_id):
    if not "." in model_id:
        return model_id
    a = model_id.replace("-fewshot", "").replace("-zeroshot", "").replace("openai", "gpt-4o")
    model, modality = a.split(".")
    return f"{model} ({modality})"


def multimodel_suggestion_graph(suggests, fraction_visible=1, fraction_startingPoint=0):
    # pin each suggestion type and model to an int index
    colors = {v: f"C{str(i)}" for i, v in enumerate(np.unique([s.suggestion["suggest_type"] for s in suggests]))}
    offsets = {
        v: (i + 1) * 2
        for i, v in enumerate(
            sorted(
                np.unique([s.model_key for s in suggests]),
                key=lambda model_key: ("text" not in model_key, MODEL_SORT_ORDER.index(model_key)),
                reverse=True,
            )
        )
    }

    _, graph = plt.subplots()
    for s in suggests:
        graph.eventplot(
            [s.end], lineoffsets=offsets[s.model_key], linewidth=0.7, colors=colors[s.suggestion["suggest_type"]]
        )
    # setup graph
    if fraction_visible < 0 or fraction_visible > 1:
        fraction_visible = 1 / 10
    if fraction_startingPoint < 0 or fraction_startingPoint > 1 - fraction_visible:
        fraction_startingPoint = 0
    x_max = max([s.end for s in suggests])
    graph.set(
        xlabel="Time (s)",
        xlim=(x_max * fraction_startingPoint, x_max * (fraction_startingPoint + fraction_visible)),
        ylabel="",
        yticks=list(offsets.values()),
        yticklabels=[fmt_model_id(k) for k in offsets.keys()],
    )
    # graph.legend(
    #     handles=[patches.Patch(alpha=0) for _ in range(2)]
    #     + [lines.Line2D(color=c, xdata=[], ydata=[]) for c in colors.values()],
    #     labels=["Task", ""] + list(colors.keys()),
    #     loc="lower left",
    #     bbox_to_anchor=(1.37, 0.51),
    # )
    return graph

In [None]:
# import test data
experiment_ids = ["starless-lands-s18p1"]
all_suggestions = list(itertools.chain.from_iterable(get_all_suggestions(eid) for eid in experiment_ids))
suggestions_by_model = {model_id: [s for s in all_suggestions if s.model_key == model_id] for model_id in ALL_MODELS}

with open(REPO_ROOT / "evaluation/gold/gold-stopwatch.json") as file:
    gs = list(json.load(file).values())[0]
gold_suggests_17 = [
    SuggestionLog(
        start=g["time"],
        end=g["time"],
        suggestion={"suggest_type": g["matches"][-1]},
        model_key="gold",
        experiment_info={"id": "", "name": "", "log_dir": "", "pcm_fp": "", "transcript_fp": ""},
    )
    for g in gs
]

In [None]:
multimodel_suggestion_graph(
    suggestions_by_model["qwen-25.audio-fewshot"]
    + suggestions_by_model["phi-4.audio-fewshot"]
    + suggestions_by_model["openai-mini.audio-zeroshot"]
    + suggestions_by_model["openai-mini.text-zeroshot"]
    + gold_suggests_17
)

In [None]:
from dataclasses import dataclass
from evaluation.models import LabelledSuggestion


@dataclass
class StopwatchSuggestion:
    """Gold-labelled suggestions from stopwatch annotation"""

    time: float  # the time the annotation happened
    matches: list[str]  # the list of strings the suggestion must match to satisfy this label
    antimatches: list[str] | None = None  # a list of strings that cannot match


def get_gold_labels(experiment_id: str) -> list[LabelledSuggestion | StopwatchSuggestion]:
    out = []
    with open(f"gold/{experiment_id}.jsonl") as f:
        # with open(f"annotations/to-dedup-{experiment_id}.jsonl") as f:  # todo
        for line in f:
            if not line.strip():
                continue
            data = json.loads(line)
            if "matches" in data:
                out.append(StopwatchSuggestion(**data))
            else:
                out.append(LabelledSuggestion.model_validate(data))
    return out

In [None]:
def graph_all(models, gold=True, fraction_visible=1, fraction_startingPoint=0):
    suggests = collections.defaultdict(list)
    adjust = 0
    # compile suggestions across sessions (cumulative time)
    for session in [experiment.id for experiment in get_experiment_infos()]:
        for model in models:
            if model not in ALL_MODELS:
                continue
            suggests[model] += [
                s.model_copy(update={"end": s.end + adjust})
                for s in get_all_suggestions(session)
                if s.model_key == model
            ]
        if gold:
            golds = [
                (
                    g.entry.model_copy(update={"model_key": "gold", "end": g.entry.end + adjust})
                    if isinstance(g, LabelledSuggestion)
                    else SuggestionLog(
                        start=g.time + adjust,
                        end=g.time + adjust,
                        suggestion={"suggest_type": g.matches[-1]},
                        model_key="gold",
                        experiment_info={"id": "", "name": "", "log_dir": "", "pcm_fp": "", "transcript_fp": ""},
                    )
                )
                for g in get_gold_labels(session)
            ]
            # todo: not hard code, validate gold labels existing
            valid_suggest_types = ["foundry", "gamedata", "improvise_npc"]
            suggests["gold"] += [g for g in golds if g.suggestion["suggest_type"] in valid_suggest_types]
        adjust += max([s.end - adjust for model in suggests.keys() for s in suggests[model]])
    return multimodel_suggestion_graph(
        [s for model in suggests.keys() for s in suggests[model]],
        fraction_visible=fraction_visible,
        fraction_startingPoint=fraction_startingPoint,
    )

In [None]:
font = {"family": "serif", "size": 16}
plt.rc("font", **font)
ax = graph_all(
    [
        "openai.audio-zeroshot",
        "openai.text-zeroshot",
        "openai-mini.audio-zeroshot",
        "openai-mini.text-zeroshot",
        "ultravox.audio-fewshot",
        "ultravox.text-fewshot",
        "qwen-25.audio-fewshot",
        "qwen-25.text-fewshot",
        "phi-4.audio-fewshot",
        "phi-4.text-fewshot",
    ],
    fraction_visible=0.33,
)
plt.savefig("figs/timeline.pdf", bbox_inches="tight")

In [None]:
# deprecated, use multimodel_suggestion_graph
def graph_suggestions(suggestions, fraction_visible=1 / 10, fraction_startingPoint=0, golds=None):
    suggest_types = np.unique([s.suggestion["suggest_type"] for s in suggestions + (golds if golds else [])])
    suggest_colors = {value: index + 1 for index, value in enumerate(suggest_types)}

    suggestion_data = [[s.start, s.end, suggest_colors[s.suggestion["suggest_type"]]] for s in suggestions]
    gold_data = [[g.start, g.end, suggest_colors[g.suggestion["suggest_type"]], -1] for g in golds] if golds else []

    _, graph = plt.subplots()
    for s in suggestion_data + gold_data:
        graph.eventplot(
            [s[0], s[1]],
            lineoffsets=((s[2] * 2) + 0.2 if s[-1] == -1 else s[2] * 2),
            linewidth=0.7,
            colors=("y" if s[-1] == -1 else "C" + str(s[2])),
        )
        graph.plot([s[0], s[1]], [s[2] * 2, s[2] * 2], color=("y" if s[-1] == -1 else "C" + str(s[2])))

    if fraction_visible < 0 or fraction_visible > 1:
        fraction_visible = 1 / 10
    if fraction_startingPoint < 0 or fraction_startingPoint > 1 - fraction_visible:
        fraction_startingPoint = 0

    x_min = suggestions[-1].end * fraction_startingPoint
    x_max = x_min + (suggestions[-1].end * fraction_visible)

    y_ticks = [index * 2 for index in suggest_colors.values()]
    graph.set(xlabel="Time (s)", xlim=(x_min, x_max), ylabel="Suggest Type", yticks=y_ticks, yticklabels=suggest_types)

In [None]:
suggests_all = [s for experiment in get_experiment_infos() for s in get_all_suggestions(experiment.id)]
graph_suggestions(suggests_all)

In [None]:
graph_suggestions(phi4t_suggests_17, fraction_visible=1, golds=gold_suggests_17)