In [11]:
import json
from pathlib import Path

import pandas as pd

# Configurations
strategies = {
    "cluster": {
        "name": "Cluster resampling",
        "data": []
    },
    "cluster-nopos": {
        "name": "Cluster resampling<br>(w/o using token position)",
        "data": []
    },
    "zero": {
        "name": "Zero ablation",
        "data": []
    },
    "random-pos": {
        "name": "Random resampling (using token position)",
        "data": []
    },
    "random": {
        "name": "Random resampling",
        "data": []
    },
}
model_name = "e2e.jumprelu.shakespeare_64x4"

# Populate data
dirpath = Path(f"../../checkpoints/{model_name}/ablation-comparisons")
for sample_dir in dirpath.iterdir():
    if sample_dir.is_dir():
        # For each strategy, list all sample directories
        for strategy, profile in strategies.items():
            # Load data
            data_path = sample_dir / f"{strategy}.json"
            with open(data_path, "r") as f:
                data = json.load(f)

            # Compute metrics
            num_features = data["num_nodes"]
            mean_kld = sum(data["klds"].values()) / len(data["klds"])
            print(f"Strategy: {strategy}, Sample: {sample_dir.name} Features: {num_features}, Mean KLD: {mean_kld:.4f}")

            # Append data
            profile["data"].append({
                "strategy": profile["name"],
                "sample": sample_dir.name,
                "num_features": num_features,
                "mean_kld": mean_kld,
            })

# Convert lists to DataFrames
for strategy, profile in strategies.items():
    profile["data"] = pd.DataFrame(profile["data"]) # type: ignore

Strategy: cluster, Sample: val.0.0.2 Features: 9, Mean KLD: 0.0674
Strategy: cluster-nopos, Sample: val.0.0.2 Features: 11, Mean KLD: 0.1024
Strategy: zero, Sample: val.0.0.2 Features: 20, Mean KLD: 0.1204
Strategy: random-pos, Sample: val.0.0.2 Features: 50, Mean KLD: 0.4336
Strategy: random, Sample: val.0.0.2 Features: 49, Mean KLD: 0.4461
Strategy: cluster, Sample: val.0.1024.2 Features: 15, Mean KLD: 0.0636
Strategy: cluster-nopos, Sample: val.0.1024.2 Features: 17, Mean KLD: 0.0723
Strategy: zero, Sample: val.0.1024.2 Features: 34, Mean KLD: 0.1970
Strategy: random-pos, Sample: val.0.1024.2 Features: 47, Mean KLD: 1.5478
Strategy: random, Sample: val.0.1024.2 Features: 53, Mean KLD: 1.3817
Strategy: cluster, Sample: val.0.2048.2 Features: 17, Mean KLD: 0.1103
Strategy: cluster-nopos, Sample: val.0.2048.2 Features: 18, Mean KLD: 0.1007
Strategy: zero, Sample: val.0.2048.2 Features: 35, Mean KLD: 0.1899
Strategy: random-pos, Sample: val.0.2048.2 Features: 52, Mean KLD: 1.3539
Strate

In [12]:
# Plot mean KLD

import plotly.graph_objects as go
import pandas as pd

# Create a DataFrame by concatenating the data from all strategies
df = pd.concat([profile["data"] for profile in strategies.values()])

# Create a violin plot for each strategy
fig = go.Figure()

for strategy in df["strategy"].unique():
    strategy_data = df[df["strategy"] == strategy]
    mean_value = strategy_data["mean_kld"].mean()
    fig.add_trace(go.Violin(
        x=strategy_data["strategy"],
        y=strategy_data["mean_kld"],
        name=strategy,
        meanline_visible=True,
        points="all"  # Show all points
    ))
    # Add a text annotation for the mean value
    fig.add_trace(go.Scatter(
        x=[strategy],
        y=[mean_value],
        mode="text",
        text=[f"{round(mean_value, 2)}"],
        textposition="top center",
        showlegend=False
    ))

# Update layout
fig.update_layout(
    xaxis=dict(title="", tickangle=0),
    yaxis=dict(title=""),
    showlegend=False,
)

fig.show()

# Print means
for strategy in strategies.values():
    mean_kld = strategy["data"]["mean_kld"].mean()
    print(f"{strategy['name']}: {mean_kld:.4f}")

Cluster resampling: 0.0804
Cluster resampling<br>(w/o using token position): 0.0918
Zero ablation: 0.1691
Random resampling (using token position): 1.1118
Random resampling: 1.0291


In [13]:
# Plot mean number of features

import plotly.graph_objects as go
import pandas as pd

# Create a DataFrame by concatenating the data from all strategies
df = pd.concat([profile["data"] for profile in strategies.values()])

# Create a violin plot for each strategy
fig = go.Figure()

for strategy in df["strategy"].unique():
    strategy_data = df[df["strategy"] == strategy]
    mean_value = strategy_data["num_features"].mean()
    fig.add_trace(go.Violin(
        x=strategy_data["strategy"],
        y=strategy_data["num_features"],
        name=strategy,
        meanline_visible=True,
        points="all"  # Show all points
    ))
    # Add a text annotation for the mean value
    fig.add_trace(go.Scatter(
        x=[strategy],
        y=[mean_value],
        mode="text",
        text=[f"{round(mean_value, 2)}"],
        textposition="top center",
        showlegend=False
    ))

# Update layout
fig.update_layout(
    xaxis=dict(title="", tickangle=0),
    yaxis=dict(title=""),
    showlegend=False,
)

fig.show()

# Print means
for strategy in strategies.values():
    mean_features = strategy["data"]["num_features"].mean()
    print(f"{strategy['name']}: {mean_features:.4f}")

Cluster resampling: 13.6667
Cluster resampling<br>(w/o using token position): 15.3333
Zero ablation: 29.6667
Random resampling (using token position): 49.6667
Random resampling: 50.3333
