In [33]:
import json
from pathlib import Path

import pandas as pd

# Configurations
strategies = {
    "cluster": {
        "label": "Cluster resampling",
        "label_alt": "Cluster resampling<br>using positional dimension",
        "color": "royalblue",
        "data": []
    },
    "cluster-nopos": {
        "label": "Cluster resampling",
        "label_alt": "Cluster resampling<br>w/o positional dimension",
        "color": "royalblue",
        "data": []
    },
    "zero": {
        "label": "Zero ablation",
        "color": "lightseagreen",
        "data": []
    },
    "random-pos": {
        "label": "Conventional resampling<br>using token position",
        "color": "indianred",
        "data": []
    },
    "random": {
        "label": "Conventional resampling",
        "color": "indianred",
        "data": []
    },
}
model_name = "e2e.jumprelu.shakespeare_64x4"

# Populate data
dirpath = Path(f"../../checkpoints/{model_name}/ablation-comparisons")
for sample_dir in dirpath.iterdir():
    if sample_dir.is_dir():
        # For each strategy, list all sample directories
        for strategy, profile in strategies.items():
            # Load data
            data_path = sample_dir / f"{strategy}.json"
            if data_path.exists():
                with open(data_path, "r") as f:
                    data = json.load(f)

                # Compute metrics
                num_features = data["num_nodes"]
                mean_kld = sum(data["klds"].values()) / len(data["klds"])
                print(f"Strategy: {strategy}, Sample: {sample_dir.name} Features: {num_features}, Mean KLD: {mean_kld:.4f}")

                # Append data
                profile["data"].append({
                    "strategy": strategy,
                    "label": profile["label"],
                    "label_alt": profile.get("label_alt", ""),
                    "sample": sample_dir.name,
                    "num_features": num_features,
                    "mean_kld": mean_kld,
                })

# Convert lists to DataFrames
for strategy, profile in strategies.items():
    profile["data"] = pd.DataFrame(profile["data"]) # type: ignore

Strategy: cluster, Sample: val.0.15360.2 Features: 8, Mean KLD: 0.0740
Strategy: cluster-nopos, Sample: val.0.15360.2 Features: 16, Mean KLD: 0.1083
Strategy: zero, Sample: val.0.15360.2 Features: 26, Mean KLD: 0.1982
Strategy: random-pos, Sample: val.0.15360.2 Features: 40, Mean KLD: 0.5101
Strategy: random, Sample: val.0.15360.2 Features: 60, Mean KLD: 0.6237
Strategy: cluster, Sample: val.0.90112.2 Features: 5, Mean KLD: 0.0517
Strategy: cluster-nopos, Sample: val.0.90112.2 Features: 9, Mean KLD: 0.0656
Strategy: zero, Sample: val.0.90112.2 Features: 8, Mean KLD: 0.1656
Strategy: random-pos, Sample: val.0.90112.2 Features: 23, Mean KLD: 0.2930
Strategy: random, Sample: val.0.90112.2 Features: 26, Mean KLD: 0.2716
Strategy: cluster, Sample: val.0.3072.2 Features: 14, Mean KLD: 0.1635
Strategy: cluster-nopos, Sample: val.0.3072.2 Features: 15, Mean KLD: 0.1533
Strategy: zero, Sample: val.0.3072.2 Features: 21, Mean KLD: 0.1498
Strategy: random-pos, Sample: val.0.3072.2 Features: 63, M

In [34]:
# Plot mean number of features for main strategies

import plotly.graph_objects as go
import pandas as pd

# Create a DataFrame by concatenating the data from all strategies
df = pd.concat([profile["data"] for profile in strategies.values()])

def show_figure(names: list[str], label_key: str = "label"):
    """
    Create a box plot to compare the number of features across different strategies.
    """
    fig = go.Figure()

    for strategy in names:
        strategy_data = df[df["strategy"] == strategy]
        median_value = strategy_data["num_features"].median()
        fig.add_trace(go.Box(
            x=strategy_data[label_key],
            y=strategy_data["num_features"],
            boxpoints="all",  # Show all points
            marker=dict(size=3),
            marker_color=strategies[strategy]["color"],
        ))
        # Add a text annotation
        fig.add_trace(go.Scatter(
            x=[strategy_data[label_key].iloc[0]],
            y=[median_value],
            mode="text",
            text=[f"{str(median_value).strip('0').strip('.')}"],
            textposition="top center",
            showlegend=False
        ))

    # Update layout
    fig.update_layout(
        xaxis=dict(title="", tickangle=0),
        yaxis=dict(title=""),
        showlegend=False,
    )

    fig.show()

# Show figure for different sets of strategies
show_figure(["cluster", "zero", "random"])
show_figure(["cluster", "cluster-nopos"], label_key="label_alt")

# Print average num_features for each strategy
for strategy, profile in strategies.items():
    mean_features = profile["data"]["num_features"].mean()
    print(f"Strategy: {strategy}, Average num_features: {mean_features:.2f}")

Strategy: cluster, Average num_features: 18.09
Strategy: cluster-nopos, Average num_features: 21.92
Strategy: zero, Average num_features: 31.40
Strategy: random-pos, Average num_features: 45.35
Strategy: random, Average num_features: 48.61


In [35]:
# Plot mean KLD

import pandas as pd

# Create a DataFrame by concatenating the data from all strategies
featured_strategies = ["cluster", "zero", "random"]
df = pd.concat([profile["data"] for name, profile in strategies.items() if name in featured_strategies])

# Create a violin plot for each strategy
fig = go.Figure()

for strategy in df["strategy"].unique():
    strategy_data = df[df["strategy"] == strategy]
    median_value = strategy_data["mean_kld"].median()
    fig.add_trace(go.Box(
        x=strategy_data["label"],
        y=strategy_data["mean_kld"],
        name=strategy,
        boxpoints="all",  # Show all points
        marker=dict(size=3),
        marker_color=strategies[strategy]["color"],
    ))
    # Add a text annotation
    fig.add_trace(go.Scatter(
        x=[strategy_data["label"].iloc[0]],
        y=[median_value],
        mode="text",
        text=[f"{round(median_value, 2)}"],
        textposition="top center",
        showlegend=False,
    ))

# Update layout
fig.update_layout(
    xaxis=dict(title="", tickangle=0),
    yaxis=dict(title=""),
    showlegend=False,
)
fig.update_yaxes(type="log", tickmode = 'array', tickvals = [0.05, 0.1, 0.5, 1, 2],)

fig.show()