In [54]:
import json
from pathlib import Path

import pandas as pd

# Configurations
strategies = {
    "cluster": {
        "name": "Cluster resampling",
        "data": []
    },
    "cluster-nopos": {
        "name": "Cluster resampling<br>(w/o using token position)",
        "data": []
    },
    "zero": {
        "name": "Zero ablation",
        "data": []
    },
    "random-pos": {
        "name": "Random resampling (using token position)",
        "data": []
    },
    "random": {
        "name": "Random resampling",
        "data": []
    },
}
model_name = "e2e.jumprelu.shakespeare_64x4"

# Populate data
dirpath = Path(f"../../checkpoints/{model_name}/ablation-comparisons")
for sample_dir in dirpath.iterdir():
    if sample_dir.is_dir():
        # For each strategy, list all sample directories
        for strategy, profile in strategies.items():
            # Load data
            data_path = sample_dir / f"{strategy}.json"
            if data_path.exists():
                with open(data_path, "r") as f:
                    data = json.load(f)

                # Compute metrics
                num_features = data["num_nodes"]
                mean_kld = sum(data["klds"].values()) / len(data["klds"])
                print(f"Strategy: {strategy}, Sample: {sample_dir.name} Features: {num_features}, Mean KLD: {mean_kld:.4f}")

                # Append data
                profile["data"].append({
                    "strategy": profile["name"],
                    "sample": sample_dir.name,
                    "num_features": num_features,
                    "mean_kld": mean_kld,
                })

# Convert lists to DataFrames
for strategy, profile in strategies.items():
    profile["data"] = pd.DataFrame(profile["data"]) # type: ignore

Strategy: cluster, Sample: val.0.3072.2 Features: 14, Mean KLD: 0.1635
Strategy: cluster-nopos, Sample: val.0.3072.2 Features: 15, Mean KLD: 0.1533
Strategy: zero, Sample: val.0.3072.2 Features: 21, Mean KLD: 0.1498
Strategy: random-pos, Sample: val.0.3072.2 Features: 63, Mean KLD: 1.0505
Strategy: random, Sample: val.0.3072.2 Features: 53, Mean KLD: 1.2512
Strategy: cluster, Sample: val.0.8192.2 Features: 12, Mean KLD: 0.0945
Strategy: cluster-nopos, Sample: val.0.8192.2 Features: 14, Mean KLD: 0.1150
Strategy: zero, Sample: val.0.8192.2 Features: 39, Mean KLD: 0.1401
Strategy: random-pos, Sample: val.0.8192.2 Features: 35, Mean KLD: 1.4327
Strategy: random, Sample: val.0.8192.2 Features: 42, Mean KLD: 1.7319
Strategy: cluster, Sample: val.0.4096.2 Features: 43, Mean KLD: 0.2291
Strategy: cluster-nopos, Sample: val.0.4096.2 Features: 52, Mean KLD: 0.2341
Strategy: zero, Sample: val.0.4096.2 Features: 57, Mean KLD: 0.7461
Strategy: random-pos, Sample: val.0.4096.2 Features: 85, Mean KL

In [55]:
# Plot mean number of features

import plotly.graph_objects as go
import pandas as pd

# Create a DataFrame by concatenating the data from all strategies
df = pd.concat([profile["data"] for profile in strategies.values()])

# Create a violin plot for each strategy
fig = go.Figure()

for strategy in df["strategy"].unique():
    strategy_data = df[df["strategy"] == strategy]
    median_value = strategy_data["num_features"].median()
    fig.add_trace(go.Box(
        x=strategy_data["strategy"],
        y=strategy_data["num_features"],
        name=strategy,
        boxpoints="all"  # Show all points
    ))
    # Add a text annotation
    fig.add_trace(go.Scatter(
        x=[strategy],
        y=[median_value],
        mode="text",
        text=[f"{round(median_value, 2)}"],
        textposition="top center",
        showlegend=False
    ))

# Update layout
fig.update_layout(
    xaxis=dict(title="", tickangle=0),
    yaxis=dict(title=""),
    showlegend=False,
)

fig.show()

# Print means
for strategy in strategies.values():
    mean_features = strategy["data"]["num_features"].mean()
    print(f"{strategy['name']}: {mean_features:.4f}")

Cluster resampling: 15.0909
Cluster resampling<br>(w/o using token position): 18.4545
Zero ablation: 28.9091
Random resampling (using token position): 47.9091
Random resampling: 51.6364


In [57]:
# Plot mean KLD

import plotly.graph_objects as go
import pandas as pd

# Create a DataFrame by concatenating the data from all strategies
df = pd.concat([profile["data"] for profile in strategies.values()])

# Create a violin plot for each strategy
fig = go.Figure()

for strategy in df["strategy"].unique():
    strategy_data = df[df["strategy"] == strategy]
    median_value = strategy_data["mean_kld"].median()
    fig.add_trace(go.Box(
        x=strategy_data["strategy"],
        y=strategy_data["mean_kld"],
        name=strategy,
        boxpoints="all"  # Show all points
    ))
    # Add a text annotation
    fig.add_trace(go.Scatter(
        x=[strategy],
        y=[median_value],
        mode="text",
        text=[f"{round(median_value, 2)}"],
        textposition="top center",
        showlegend=False
    ))

# Update layout
fig.update_layout(
    xaxis=dict(title="", tickangle=0),
    yaxis=dict(title=""),
    showlegend=False,
)

fig.show()

# Print means
for strategy in strategies.values():
    mean_kld = strategy["data"]["mean_kld"].mean()
    print(f"{strategy['name']}: {mean_kld:.4f}")

Cluster resampling: 0.1037
Cluster resampling<br>(w/o using token position): 0.1300
Zero ablation: 0.2100
Random resampling (using token position): 0.9792
Random resampling: 1.0942
