# Analyze structured outputs across LLMs

In [None]:
import polars as pl
import numpy as np
import plotly.io as pio
import plotly.express as px
from pathlib import Path
from itertools import combinations
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage, leaves_list
import plot_theme as pt

In [None]:
# set default theme
pio.templates.default = "plotly_white+cc"

In [None]:
folder = Path('../')
files = list(folder.glob('structured*/*.parquet'))

In [None]:
keep_columns = ['primary_category', 'subject_suggestion', 'alternative_category', 'inference_duration', 'model', 'from', 'body', 'date', 'subject']

df = pl.concat(pl.scan_parquet(f).select(keep_columns) for f in files).collect()

In [None]:
models = df['model'].unique().to_list()
models

In [None]:
cat_col = "primary_category"
aligned_categories = df.pivot(
    index=["date", "body", "from"], on="model", values=cat_col, aggregate_function="first"
).with_columns(
    all_equal=pl.concat_list(models).list.n_unique() == 1
)
aligned_categories.drop_nulls(subset=models)['all_equal'].mean()

In [None]:
model_cols = aligned_categories.select(models)
value_counts = (
    model_cols.unpivot(variable_name="model", value_name="class")
    .fill_null("N/A").filter(pl.col("class") != "N/A")
    .group_by("model")
    .agg(pl.col("class").value_counts(sort=True))
    .explode("class")
    .unnest("class")
)
value_counts_pivot = value_counts.pivot(
    on="model", index="class", values="count"
).fill_null(0)
ordering = (
    value_counts.fill_null("N/A")
    .group_by("class")
    .agg(pl.sum("count"))
    .sort("count", descending=True)["class"]
)

In [None]:
dist = pdist(value_counts_pivot[:, 1:].to_numpy().T, metric="cosine")
leaves = leaves_list(linkage(dist, method="average", optimal_ordering=True))

model_cols = value_counts_pivot.columns[1:]
re_ordered = [model_cols[i] for i in leaves]

mapping = dict(enumerate(re_ordered))
mapping = {f"column_{k}": v for k, v in mapping.items()}

In [None]:
_tmp = 1 - squareform(dist)[leaves][:, leaves]

hm = pl.DataFrame(_tmp)
hm = hm.rename(mapping)
hm_pd = hm.to_pandas()
hm_pd.index = hm.columns
fig = px.imshow(
    hm_pd,
    color_continuous_scale="ice",
    labels=dict(color="Cosine similarity", x="LLM", y="LLM"),
    zmin=0.3,
    height=650,
)
pt.save(fig, "cosine_similarity_mtx")

## Compute "overlap distance"

Defined as `1 - accuracy`

In [None]:
acc_mtx = np.zeros((len(model_cols), ) * 2)

for (i, mdl), (j, mdl2) in combinations(enumerate(model_cols), 2):
    acc = (aligned_categories[mdl] == aligned_categories[mdl2]).mean()
    acc_mtx[i, j] = acc
    acc_mtx[j, i] = acc

acc_mtx[np.diag_indices_from(acc_mtx)] = 1

In [None]:
linkage_matrix = linkage(squareform(1 - acc_mtx), method="average", optimal_ordering=True)
leaves = leaves_list(linkage_matrix)
acc_order = [model_cols[i] for i in leaves]

mapping = dict(enumerate(acc_order))
mapping = {f"column_{k}": v for k, v in mapping.items()}

In [None]:
_tmp = acc_mtx[leaves][:, leaves]

hm = pl.DataFrame(_tmp * 100)
hm = hm.rename(mapping)
hm_pd = hm.to_pandas()
hm_pd.index = hm.columns
fig = px.imshow(
    hm_pd,
    color_continuous_scale="ice",
    # zmin=50,
    labels=dict(color="Label overlap (%)", x="LLM", y="LLM"),
)
fig.update_layout(
    width=600,
    height=500,
    xaxis=dict(showgrid=False),
    yaxis=dict(showgrid=False),
    plot_bgcolor="rgba(0, 0, 0, 0.1)",
)

In [None]:
value_counts.head()

In [None]:
sorted_classes = value_counts.group_by("class").agg(pl.sum("count")).sort("count", descending=True)

In [None]:
vc_filt = (
    value_counts.fill_null("N/A")
    .filter(pl.col("class") != "N/A")
    .with_columns(
        (pl.col("count") / pl.col("count").sum())
        .over("model")
        .alias("normalized_counts")
    )
)

In [None]:
plt_vc = vc_filt.sort(
    pl.col("normalized_counts").sum().over("class"),
    descending=True,
    maintain_order=True,
)

In [None]:
fig = px.line_polar(
    plt_vc,
    r="normalized_counts",
    theta="class",
    color="model",
    line_close=True,
    log_r=True,
    height=550,
).update_traces(
    visible="legendonly", selector=lambda x: x["name"] not in ["Llama-3.1-8b", "Falcon-3-7b"]
)
fig = fig.update_layout(
    legend=dict(orientation="h", yanchor="top", y=-0.2, xanchor="left", x=0),
    polar=dict(
        radialaxis=dict(range=[np.log10(plt_vc['normalized_counts'].min()), np.log10(plt_vc["normalized_counts"].max() * 1.1)])
    ),
)
pt.save(fig, "polar_plot_value_counts")

In [None]:
value_counts_all_eq = aligned_categories.filter(pl.col("all_equal"))[
    models[0]
].value_counts(sort=True)
value_counts_all_eq['count'].sum()

In [None]:
fig = px.bar(
    value_counts_all_eq,
    x=models[0],
    y="count",
    labels={models[0]: "Category", "count": "Emails"},
    width=450,
    height=500,
    log_y=True,
)
fig.update_traces(
    marker_color="limegreen",
)
fig.update_layout(
    updatemenus=[
        dict(
            type="buttons",
            xanchor="left",
            yanchor="bottom",
            showactive=True,
            direction="right",
            x=0,
            y=1,
            buttons=[
                dict(
                    label="Log scale",
                    method="relayout",
                    args=[{"yaxis.type": "log"}],
                ),
                dict(
                    label="Linear scale",
                    method="relayout",
                    args=[{"yaxis.type": "linear"}],
                ),
            ],
        )
    ],
)