# Ablation Figures

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns

metrics = ["STS-BE", "ROUGE-L"]


def draw_graphs(model_data, models, metrics, fig_name, shots, legend_ncol):
    sns.set_theme(style="darkgrid")
    plt.rcParams["font.family"] = "Caladea"
    fig, axes = plt.subplots(
        len(metrics), 1, figsize=(24, 12 * len(metrics)), sharex=True
    )

    # Iterate through each metric to create a subplot
    for i, metric in enumerate(metrics):
        if len(metrics) > 1:
            ax = axes[i]
        else:
            ax = axes
        ax.tick_params(axis="both", which="major", labelsize=58)
        for model_name in models:
            data = model_data[model_name]
            filtered_shot_idx = [
                i for i, shot in enumerate(data["meta"]["shots"]) if shot in shots
            ]
            kwargs = {}
            if "zorder" in data["meta"]:
                kwargs["zorder"] = data["meta"]["zorder"]
            label = f"{data['meta']['name'] if 'name' in data['meta'] else model_name}"
            sns.lineplot(
                x=[str(data["meta"]["shots"][i]) for i in filtered_shot_idx],
                y=[data[metric][i] for i in filtered_shot_idx],
                ax=ax,
                label=label,
                linestyle=data["meta"]["linestyle"],
                linewidth=12,
                marker="D",
                markersize=20,
                alpha=0.7,
                **kwargs,
            )
        if i == len(metrics) - 1:
            ax.set_xlabel("Shot", fontsize=70, fontweight="bold")
        ax.set_ylabel(metric, fontsize=70, fontweight="bold")
        ax.set_xticks(ticks=range(len(shots)), labels=[str(shot) for shot in shots])
        ax.grid(True)
        ax.get_legend().remove()

    handles, labels = ax.get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        loc="lower center",
        bbox_to_anchor=(0.5, 1),
        fontsize=58,
        ncol=legend_ncol,
        columnspacing=0.8,
    )

    # Adjust layout to prevent overlap
    plt.tight_layout()
    # bbox_inches="tight" ensures that all the visible content
    # is saved into the pdf file.
    plt.savefig(fig_name, bbox_inches="tight")
    plt.show()

In [None]:
shots = [0, 1, 2, 4, 8, 12, 16]
eilev_data = {
    "EILEV BLIP-2 OPT-2.7B": {
        "STS-CE": [0.2098, 0.4754, 0.4897, 0.5569, 0.612, 0.6312, 0.6363],
        "STS-BE": [0.3278, 0.5495, 0.571, 0.6284, 0.6735, 0.6898, 0.6936],
        "BERTScore-F1": [0.5234, 0.6305, 0.6399, 0.6463, 0.6543, 0.6539, 0.6529],
        "ROUGE-L": [0.2315, 0.5013, 0.5396, 0.5785, 0.6102, 0.6249, 0.6296],
        "BLEU": [0.008795, 0.1376, 0.2015, 0.2443, 0.2741, 0.2968, 0.3049],
        "meta": {"shots": shots, "linestyle": "-"},
    },
    "EILEV BLIP-2 Flan-T5-xl": {
        "STS-CE": [0.3552, 0.5039, 0.5176, 0.5539, 0.6089, 0.6276, 0.6349],
        "STS-BE": [0.426, 0.5697, 0.5812, 0.613, 0.6689, 0.6886, 0.6948],
        "BERTScore-F1": [-1.84, 0.6291, 0.6394, 0.6477, 0.6527, 0.6561, 0.6572],
        "ROUGE-L": [0.3129, 0.5032, 0.5322, 0.5648, 0.607, 0.6203, 0.623],
        "BLEU": [0.06718, 0.1507, 0.1992, 0.2373, 0.2834, 0.2931, 0.2913],
        "meta": {"shots": shots, "linestyle": "-"},
    },
}

## Bursty Distribution Ablation Figure

In [None]:
bursty_ablation_data = {
    "Ablation BLIP-2 OPT-2.7B": {
        "STS-CE": [0.4019, 0.4934, 0.4871, 0.4791, 0.4932, 0.4794, 0.4726],
        "STS-BE": [0.4904, 0.5633, 0.5629, 0.5564, 0.5625, 0.5437, 0.5328],
        "BERTScore-F1": [0.5606, 0.6344, 0.6423, 0.6452, 0.6541, 0.658, 0.6562],
        "ROUGE-L": [0.3912, 0.5042, 0.5238, 0.5294, 0.5365, 0.5373, 0.5385],
        "BLEU": [0.07154, 0.1479, 0.1779, 0.1772, 0.1677, 0.1606, 0.1643],
        "meta": {"shots": shots, "linestyle": "--"},
    },
    "Ablation BLIP-2 Flan-T5-xl": {
        "STS-CE": [0.308, 0.4858, 0.4472, 0.4449, 0.4394, 0.4358, 0.4272],
        "STS-BE": [0.4182, 0.5578, 0.5125, 0.5073, 0.498, 0.4922, 0.4811],
        "BERTScore-F1": [-0.8935, 0.3246, -1.842, -1.935, -2.121, -2.084, -2.362],
        "ROUGE-L": [0.2729, 0.5152, 0.4741, 0.4764, 0.4751, 0.4765, 0.4689],
        "BLEU": [0.03092, 0.164, 0.1636, 0.1624, 0.1576, 0.1538, 0.146],
        "meta": {"shots": shots, "linestyle": "--"},
    },
}
models = [
    "EILEV BLIP-2 OPT-2.7B",
    "EILEV BLIP-2 Flan-T5-xl",
    "Ablation BLIP-2 OPT-2.7B",
    "Ablation BLIP-2 Flan-T5-xl",
]

draw_graphs(
    {**eilev_data, **bursty_ablation_data},
    models,
    metrics,
    "bursty-ablation.pdf",
    [0, 4, 8, 12, 16],
    2,
)

## Skewed Marginal Distributions (Top 100 Common Actions)

In [None]:
skewed_eilev_data = {
    "EILEV BLIP-2 OPT-2.7B": {
        "STS-CE": [0.222, 0.5117, 0.51, 0.5939, 0.6602, 0.6775, 0.6814],
        "STS-BE": [0.3672, 0.5864, 0.5978, 0.6727, 0.7311, 0.7466, 0.7512],
        "BERTScore-F1": [0.5451, 0.6229, 0.6323, 0.639, 0.6444, 0.6449, 0.6448],
        "ROUGE-L": [0.1916, 0.5245, 0.5601, 0.6069, 0.6424, 0.6547, 0.6606],
        "BLEU": [0.01493, 0.1183, 0.2028, 0.2526, 0.2888, 0.3037, 0.3145],
        "meta": {"shots": shots, "linestyle": "-"},
    },
    "EILEV BLIP-2 Flan-T5-xl": {
        "STS-CE": [0.3368, 0.5243, 0.5319, 0.5983, 0.6553, 0.6794, 0.6889],
        "STS-BE": [0.4282, 0.595, 0.6037, 0.6606, 0.726, 0.7484, 0.7569],
        "BERTScore-F1": [0.5189, 0.6147, 0.6184, 0.6258, 0.6347, 0.6396, 0.642],
        "ROUGE-L": [0.3103, 0.5236, 0.5448, 0.592, 0.644, 0.6605, 0.6653],
        "BLEU": [0.05684, 0.1503, 0.1947, 0.258, 0.3148, 0.3278, 0.3266],
        "meta": {"shots": shots, "linestyle": "-"},
    },
}
skewed_ablation_data = {
    "T100 BLIP-2 OPT-2.7B": {
        "STS-CE": [0.2118, 0.4958, 0.4682, 0.5153, 0.5623, 0.5765, 0.5643],
        "STS-BE": [0.3817, 0.5728, 0.5564, 0.5961, 0.6348, 0.6451, 0.6344],
        "BERTScore-F1": [0.5512, 0.6255, 0.6386, 0.6444, 0.6516, 0.6535, 0.6549],
        "ROUGE-L": [0.18, 0.5134, 0.5329, 0.5574, 0.5807, 0.5899, 0.588],
        "BLEU": [0.007519, 0.1243, 0.1654, 0.1842, 0.2059, 0.2184, 0.2175],
        "meta": {"shots": shots, "linestyle": "--"},
    },
    "T100 BLIP-2 Flan-T5-xl": {
        "STS-CE": [0.2485, 0.4929, 0.4881, 0.517, 0.5625, 0.5845, 0.5917],
        "STS-BE": [0.3826, 0.5686, 0.5672, 0.5945, 0.6401, 0.6603, 0.6666],
        "BERTScore-F1": [0.4717, 0.622, 0.6329, 0.6319, 0.5809, 0.491, 0.3668],
        "ROUGE-L": [0.2245, 0.5222, 0.5396, 0.5663, 0.6004, 0.6126, 0.6151],
        "BLEU": [0.0309, 0.1564, 0.185, 0.2212, 0.2615, 0.2736, 0.2757],
        "meta": {"shots": shots, "linestyle": "--"},
    },
}
models = [
    "EILEV BLIP-2 OPT-2.7B",
    "EILEV BLIP-2 Flan-T5-xl",
    "T100 BLIP-2 OPT-2.7B",
    "T100 BLIP-2 Flan-T5-xl",
]

draw_graphs(
    {**skewed_eilev_data, **skewed_ablation_data},
    models,
    metrics,
    "skewed-t100-ablation.pdf",
    [0, 4, 8, 12, 16],
    2,
)

## Skewed Marginal Distributions (Top 500 Common Actions)

In [None]:
skewed_ablation_data = {
    "T500 BLIP-2 OPT-2.7B": {
        "STS-CE": [0.3136, 0.5058, 0.5086, 0.584, 0.6448, 0.6634, 0.6635],
        "STS-BE": [0.4181, 0.5807, 0.5952, 0.6633, 0.7177, 0.7346, 0.7352],
        "BERTScore-F1": [0.3434, 0.6046, 0.6316, 0.6384, 0.644, 0.6445, 0.6431],
        "ROUGE-L": [0.2508, 0.5095, 0.5556, 0.6027, 0.6387, 0.6515, 0.6543],
        "BLEU": [0.02951, 0.1266, 0.2012, 0.2551, 0.2948, 0.3161, 0.3247],
        "meta": {"shots": shots, "linestyle": "--"},
    },
    "T500 BLIP-2 Flan-T5-xl": {
        "STS-CE": [0.3934, 0.5232, 0.5302, 0.5815, 0.6482, 0.6694, 0.6761],
        "STS-BE": [0.4617, 0.5957, 0.6051, 0.6554, 0.7207, 0.7402, 0.746],
        "BERTScore-F1": [0.2386, 0.6129, 0.6165, 0.6273, 0.6378, 0.639, 0.6263],
        "ROUGE-L": [0.3651, 0.5258, 0.5493, 0.5925, 0.6432, 0.6562, 0.6589],
        "BLEU": [0.07617, 0.1538, 0.2025, 0.2591, 0.312, 0.3184, 0.3147],
        "meta": {"shots": shots, "linestyle": "--"},
    },
}
models = [
    "EILEV BLIP-2 OPT-2.7B",
    "EILEV BLIP-2 Flan-T5-xl",
    "T500 BLIP-2 OPT-2.7B",
    "T500 BLIP-2 Flan-T5-xl",
]

draw_graphs(
    {**skewed_eilev_data, **skewed_ablation_data},
    models,
    metrics,
    "skewed-t500-ablation.pdf",
    [0, 4, 8, 12, 16],
    2,
)

## Dynamic Meaning Ablation Figure

In [None]:
dynamic_ablation_data = {
    "Ablation BLIP-2 OPT-2.7B": {
        "STS-CE": [0.3197, 0.4552, 0.4243, 0.4929, 0.5562, 0.5785, 0.5855],
        "STS-BE": [0.4007, 0.5234, 0.5013, 0.5545, 0.6157, 0.6386, 0.6465],
        "BERTScore-F1": [0.4286, 0.6368, 0.6455, 0.6516, 0.6588, 0.6645, 0.667],
        "ROUGE-L": [0.3123, 0.4956, 0.5285, 0.5678, 0.5972, 0.6078, 0.6114],
        "BLEU": [0, 0.1335, 0.1905, 0.22, 0.246, 0.2544, 0.258],
        "meta": {"shots": shots, "linestyle": "--"},
    },
    "Ablation BLIP-2 Flan-T5-xl": {
        "STS-CE": [0.0157, 0.4336, 0.3944, 0.4957, 0.5772, 0.5973, 0.6035],
        "STS-BE": [0.05909, 0.4901, 0.4432, 0.5243, 0.6253, 0.6532, 0.6631],
        "BERTScore-F1": [-12.057, 0.5963, 0.6317, 0.6395, 0.6552, 0.6617, 0.6649],
        "ROUGE-L": [0.001129, 0.4749, 0.5072, 0.5513, 0.596, 0.608, 0.6105],
        "BLEU": [0, 0.1292, 0.1435, 0.1928, 0.2585, 0.2785, 0.2852],
        "meta": {"shots": shots, "linestyle": "--"},
    },
}

models = [
    "EILEV BLIP-2 OPT-2.7B",
    "EILEV BLIP-2 Flan-T5-xl",
    "Ablation BLIP-2 OPT-2.7B",
    "Ablation BLIP-2 Flan-T5-xl",
]
draw_graphs(
    {**eilev_data, **dynamic_ablation_data},
    models,
    metrics,
    "dynamic-ablation.pdf",
    [0, 4, 8, 12, 16],
    2,
)