In [29]:
import json
from pathlib import Path

import pandas as pd

# Configurations
strategies = {
    "cluster": {
        "name": "Cluster resampling",
        "dirname": "comparisons-cluster",
        "data": []
    },
    "cluster-nopos": {
        "name": "Cluster resampling<br>(w/o using token position)",
        "dirname": "comparisons-cluster-nopos",
        "data": []
    },
    "zero": {
        "name": "Zero ablation",
        "dirname": "comparisons-zero",
        "data": []
    },
    "classic": {
        "name": "Random resampling",
        "dirname": "comparisons-classic",
        "data": []
    },
}

# For each strategy, list all sample directories
for strategy, profile in strategies.items():
    dirpath = Path(f"../../app/public/samples/{profile['dirname']}/samples")
    print(f"\nStrategy: {strategy}")
    for sample_dir in dirpath.iterdir():
        if sample_dir.is_dir():
            # List sample versions
            for version_dir in sample_dir.iterdir():
                if version_dir.is_dir():
                    print(f"Sample: {sample_dir.name} ({version_dir.name})")

                    # Load metadata
                    data_path = version_dir / "data.json"
                    with open(data_path, "r") as f:
                        data = json.load(f)

                    # Compute metrics
                    num_features = len(data["activations"])
                    mean_kld = sum(data["klds"].values()) / len(data["klds"])
                    print(f"    Features: {num_features}, Mean KLD: {mean_kld:.4f}")

                    # Append data
                    profile["data"].append({
                        "strategy": profile["name"],
                        "sample": sample_dir.name,
                        "num_features": num_features,
                        "mean_kld": mean_kld,
                    })

    # Convert lists to DataFrames
    profile["data"] = pd.DataFrame(profile["data"]) # type: ignore


Strategy: cluster
Sample: val.0.0.2 (0.25)
    Features: 8, Mean KLD: 0.0994
Sample: val.0.1024.2 (0.25)
    Features: 16, Mean KLD: 0.0526
Sample: val.0.2048.2 (0.25)
    Features: 18, Mean KLD: 0.1324

Strategy: cluster-nopos
Sample: val.0.0.2 (0.25)
    Features: 15, Mean KLD: 0.1220
Sample: val.0.1024.2 (0.25)
    Features: 16, Mean KLD: 0.1112
Sample: val.0.2048.2 (0.25)
    Features: 19, Mean KLD: 0.0854

Strategy: zero
Sample: val.0.0.2 (0.25)
    Features: 20, Mean KLD: 0.1204
Sample: val.0.1024.2 (0.25)
    Features: 37, Mean KLD: 0.1370
Sample: val.0.2048.2 (0.25)
    Features: 35, Mean KLD: 0.1900

Strategy: classic
Sample: val.0.0.2 (0.25)
    Features: 149, Mean KLD: 0.4160
Sample: val.0.1024.2 (0.25)
    Features: 147, Mean KLD: 1.7832
Sample: val.0.2048.2 (0.25)
    Features: 175, Mean KLD: 1.3398


In [None]:
# Plot mean KLD

import plotly.graph_objects as go
import pandas as pd

# Create a DataFrame by concatenating the data from all strategies
df = pd.concat([profile["data"] for profile in strategies.values()])

# Create a violin plot for each strategy
fig = go.Figure()

for strategy in df["strategy"].unique():
    strategy_data = df[df["strategy"] == strategy]
    mean_value = strategy_data["mean_kld"].mean()
    fig.add_trace(go.Violin(
        x=strategy_data["strategy"],
        y=strategy_data["mean_kld"],
        name=strategy,
        meanline_visible=True
    ))
    # Add a text annotation for the mean value
    fig.add_trace(go.Scatter(
        x=[strategy],
        y=[mean_value],
        mode="text",
        text=[f"{round(mean_value, 2)}"],
        textposition="top center",
        showlegend=False
    ))

# Update layout
fig.update_layout(
    xaxis=dict(title="", tickangle=0),
    yaxis=dict(title=""),
    showlegend=False,
)

fig.show()

# Print means
for strategy in strategies.values():
    mean_kld = strategy["data"]["mean_kld"].mean()
    print(f"{strategy['name']}: {mean_kld:.4f}")

Cluster resampling: 0.0948
Cluster resampling<br>(w/o using token position): 0.1062
Zero ablation: 0.1491
Random resampling: 1.1797


In [32]:
# Plot mean number of features

import plotly.graph_objects as go
import pandas as pd

# Create a DataFrame by concatenating the data from all strategies
df = pd.concat([profile["data"] for profile in strategies.values()])

# Create a violin plot for each strategy
fig = go.Figure()

for strategy in df["strategy"].unique():
    strategy_data = df[df["strategy"] == strategy]
    mean_value = strategy_data["num_features"].mean()
    fig.add_trace(go.Violin(
        x=strategy_data["strategy"],
        y=strategy_data["num_features"],
        name=strategy,
        meanline_visible=True
    ))
    # Add a text annotation for the mean value
    fig.add_trace(go.Scatter(
        x=[strategy],
        y=[mean_value],
        mode="text",
        text=[f"{round(mean_value, 2)}"],
        textposition="top center",
        showlegend=False
    ))

# Update layout
fig.update_layout(
    xaxis=dict(title="", tickangle=0),
    yaxis=dict(title=""),
    showlegend=False,
)

fig.show()

# Print means
for strategy in strategies.values():
    mean_features = strategy["data"]["num_features"].mean()
    print(f"{strategy['name']}: {mean_features:.4f}")

Cluster resampling: 14.0000
Cluster resampling<br>(w/o using token position): 16.6667
Zero ablation: 30.6667
Random resampling: 157.0000
