# Ablation Figures

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns

metrics = ["STS-BE", "ROUGE-L"]


def draw_graphs(model_data, models, metrics, fig_name, shots, legend_ncol):
    sns.set_theme(style="darkgrid")
    plt.rcParams["font.family"] = "Caladea"
    fig, axes = plt.subplots(
        len(metrics), 1, figsize=(24, 8 * len(metrics)), sharex=True
    )

    # Iterate through each metric to create a subplot
    for i, metric in enumerate(metrics):
        if len(metrics) > 1:
            ax = axes[i]
        else:
            ax = axes
        ax.tick_params(axis="both", which="major", labelsize=48)
        for model_name in models:
            data = model_data[model_name]
            filtered_shot_idx = [
                i for i, shot in enumerate(data["meta"]["shots"]) if shot in shots
            ]
            kwargs = {}
            if "zorder" in data["meta"]:
                kwargs["zorder"] = data["meta"]["zorder"]
            label = f"{data['meta']['name'] if 'name' in data['meta'] else model_name}"
            sns.lineplot(
                x=[str(data["meta"]["shots"][i]) for i in filtered_shot_idx],
                y=[data[metric][i] for i in filtered_shot_idx],
                ax=ax,
                label=label,
                linestyle=data["meta"]["linestyle"],
                linewidth=12,
                marker="D",
                markersize=20,
                **kwargs,
            )
        if i == len(metrics) - 1:
            ax.set_xlabel("Shot", fontsize=54, fontweight="bold")
        ax.set_ylabel(metric, fontsize=54, fontweight="bold")
        ax.set_xticks(ticks=range(len(shots)), labels=[str(shot) for shot in shots])
        ax.grid(True)
        ax.get_legend().remove()

    handles, labels = ax.get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        loc="lower center",
        bbox_to_anchor=(0.5, 1),
        fontsize=48,
        ncol=legend_ncol,
        columnspacing=0.8,
    )

    # Adjust layout to prevent overlap
    plt.tight_layout()
    # bbox_inches="tight" ensures that all the visible content
    # is saved into the pdf file.
    plt.savefig(fig_name, bbox_inches="tight")
    plt.show()

## Bursty Distribution Ablation Figure

In [None]:
shots = [0, 1, 2, 4, 8, 12, 16]
model_data = {
    "EILEV BLIP-2 OPT-2.7B": {
        "STS-CE": [0.2098, 0.4754, 0.4897, 0.5569, 0.612, 0.6312, 0.6363],
        "STS-BE": [0.3278, 0.5495, 0.571, 0.6284, 0.6735, 0.6898, 0.6936],
        "BERTScore-F1": [0.5234, 0.6305, 0.6399, 0.6463, 0.6543, 0.6539, 0.6529],
        "ROUGE-L": [0.2315, 0.5013, 0.5396, 0.5785, 0.6102, 0.6249, 0.6296],
        "BLEU": [0.008795, 0.1376, 0.2015, 0.2443, 0.2741, 0.2968, 0.3049],
        "meta": {"shots": shots, "linestyle": "-"},
    },
    "EILEV BLIP-2 Flan-T5-xl": {
        "STS-CE": [0.3552, 0.5039, 0.5176, 0.5539, 0.6089, 0.6276, 0.6349],
        "STS-BE": [0.426, 0.5697, 0.5812, 0.613, 0.6689, 0.6886, 0.6948],
        "BERTScore-F1": [-1.84, 0.6291, 0.6394, 0.6477, 0.6527, 0.6561, 0.6572],
        "ROUGE-L": [0.3129, 0.5032, 0.5322, 0.5648, 0.607, 0.6203, 0.623],
        "BLEU": [0.06718, 0.1507, 0.1992, 0.2373, 0.2834, 0.2931, 0.2913],
        "meta": {"shots": shots, "linestyle": "-"},
    },
    "Ablation BLIP-2 OPT-2.7B": {
        "STS-CE": [0.4019, 0.4934, 0.4871, 0.4791, 0.4932, 0.4794, 0.4726],
        "STS-BE": [0.4904, 0.5633, 0.5629, 0.5564, 0.5625, 0.5437, 0.5328],
        "BERTScore-F1": [0.5606, 0.6344, 0.6423, 0.6452, 0.6541, 0.658, 0.6562],
        "ROUGE-L": [0.3912, 0.5042, 0.5238, 0.5294, 0.5365, 0.5373, 0.5385],
        "BLEU": [0.07154, 0.1479, 0.1779, 0.1772, 0.1677, 0.1606, 0.1643],
        "meta": {"shots": shots, "linestyle": "--"},
    },
    "Ablation BLIP-2 Flan-T5-xl": {
        "STS-CE": [0.308, 0.4858, 0.4472, 0.4449, 0.4394, 0.4358, 0.4272],
        "STS-BE": [0.4182, 0.5578, 0.5125, 0.5073, 0.498, 0.4922, 0.4811],
        "BERTScore-F1": [-0.8935, 0.3246, -1.842, -1.935, -2.121, -2.084, -2.362],
        "ROUGE-L": [0.2729, 0.5152, 0.4741, 0.4764, 0.4751, 0.4765, 0.4689],
        "BLEU": [0.03092, 0.164, 0.1636, 0.1624, 0.1576, 0.1538, 0.146],
        "meta": {"shots": shots, "linestyle": "--"},
    },
}
models = [
    "EILEV BLIP-2 OPT-2.7B",
    "EILEV BLIP-2 Flan-T5-xl",
    "Ablation BLIP-2 OPT-2.7B",
    "Ablation BLIP-2 Flan-T5-xl",
]

draw_graphs(model_data, models, metrics, "bursty-ablation.pdf", [0, 4, 8, 12, 16], 2)