In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs = {"bbox_inches": None}

In [None]:
figure_destination = "paper"

if figure_destination == "paper":
    figsize = (3.03209, 0.22 * 9.72632)
    fontsize_major = 9
    fontsize_minor = 7
    markersize_minor = 5
    markersize_major = 6

elif figure_destination == "slide":
    figsize = (6.10, 4.87)
    fontsize_major = 16
    fontsize_minor = 11
    markersize_minor = 4
    markersize_major = 8


# Mono Dataset

In [None]:
import os
from pathlib import Path
import pickle
from collections import defaultdict
import math

import numpy as np
import scipy.stats as stats

from utils.io import load_json_file
from utils.logging import get_config_from_results_dir
from utils.metrics import ci_multiplier


def without(d, key):
    new_d = d.copy()
    new_d.pop(key)
    return new_d


def aggregate_dicts(results_dict, split: str, remove_nan: bool = True):
    key_cut_off = len(split) + 1
    
    aggregated_dict = defaultdict(list)
    for subdict in results_dict.values():
        for k, v in subdict.items():
            aggregated_dict[k[key_cut_off:]] += [v]

    aggregated_dict = dict(aggregated_dict)

    drop_keys = set()
    for k, v in aggregated_dict.items():
        if np.isnan(v[0]) and remove_nan:
            drop_keys.add(k)
            continue

        N = len(v)

        mean = np.mean(v)
        var = np.var(v)
        se = math.sqrt(var / N)
        ub = mean + ci_multiplier(N, alpha=0.10) * se
        lb = mean - ci_multiplier(N, alpha=0.10) * se

        agg_results = (mean, var, se, lb, ub, N)

        aggregated_dict[k] = agg_results

    return {k: v for k, v in aggregated_dict.items() if k not in drop_keys}


def aggregate_weighted_dicts(results_dict, split: str, remove_nan: bool = True):
    key_cut_off = len(split) + 1
    
    aggregated_dict = defaultdict(list)
    for subdict in results_dict.values():
        for k, v in subdict.items():
            aggregated_dict[k[key_cut_off:]] += [v]

    aggregated_dict = dict(aggregated_dict)

    drop_keys = set()
    for k, v in aggregated_dict.items():
        if np.isnan(v[0][0]) and remove_nan:
            drop_keys.add(k)
            continue


        v_arr = np.stack(v)
        weights = 1/v_arr[:, 1]
        weights_sum = np.sum(weights)
        N = np.sum(v_arr[:, 2])

        mean = np.sum(weights * v_arr[:, 0]) / weights_sum
        var = 1 / weights_sum
        se = math.sqrt(var)
        ub = mean + ci_multiplier(N, alpha=0.10) * se
        lb = mean - ci_multiplier(N, alpha=0.10) * se

        agg_results = (mean, var, se, lb, ub, N)

        aggregated_dict[k] = agg_results

    return {k: v for k, v in aggregated_dict.items() if k not in drop_keys}


split = "test"
filters = {"dataset": "gossipcop", "top_users_excluded": 1, "checkpoint": "230426_roberta_proto", "version": None}
metrics = ["f1_0", "aupr_0", "f1_1", "aupr_1", "f1_2", "aupr_2", "mcc", "macro_aupr"]

checkpoint_results_weighted = defaultdict(dict)
checkpoint_results = defaultdict(dict)
checkpoint_hparams = dict()
for i, (top_dir, sub_dirs, sub_files) in enumerate(os.walk("../results")):
    top_dir_path = Path(top_dir)

    if len(sub_files) > 0 and top_dir_path.parts[-1] != "summary":
        base_config = get_config_from_results_dir(top_dir)

        continue_flag = False
        for fk, fv in filters.items():
            if base_config[fk] != fv:
                continue_flag = True

        if "sweep" in base_config["checkpoint"]:
            continue_flag = True

        if continue_flag:
            continue

        for sub_file in sub_files:
            if "hparams" in sub_file:
                with open(top_dir_path / sub_file, "rb") as f:
                    hparams = pickle.load(f)

            if f"{split}_" in sub_file:
                results = load_json_file(top_dir_path / sub_file)

                print("1")

        base_config_tuple = tuple(
            [
                *sorted(without(base_config, "fold").items(), key=lambda x: x[0]),
                ("meta learner", hparams['learning_hparams']['meta_learner']),
                ("reset head", hparams['learning_hparams']['reset_classifier']),
                ("inner updates", hparams['learning_hparams']['n_inner_updates']),
            ]
        )

        checkpoint_results[base_config_tuple][base_config["fold"]] = results
        checkpoint_hparams[base_config_tuple] = hparams

        if "episodic_khop" in base_config["structure"]:
            results_with_weights = {}
            for k, v in results.items():
                if "std" not in k and "se" not in k and k + "_se" in results:
                    mean_effect = v
                    # In fixed-effects meta-analysis, the standard error
                    # is assumed to be an estimator for the variance
                    var = math.pow(results[k + "_se"], 2)
                    N = results[f"{split}/eval_iterations"]

                    results_with_weights[k] = (mean_effect, var, N)

            checkpoint_results_weighted[base_config_tuple][
                base_config["fold"]
            ] = results_with_weights

checkpoint_results = dict(checkpoint_results)
checkpoint_results_weighted = dict(checkpoint_results_weighted)

for k, v in checkpoint_results.items():
    aggregate_stats = aggregate_dicts(v, split=split, remove_nan=True)
    checkpoint_results[k] = [
        (k, aggregate_stats[k]) for k in metrics if k in aggregate_stats
    ]

    if k in checkpoint_results_weighted:
        aggregate_weighted_stats = aggregate_weighted_dicts(checkpoint_results_weighted[k], split=split, remove_nan=True)
        checkpoint_results_weighted[k] = [
            (k, aggregate_weighted_stats[k])
            for k in metrics
            if k in aggregate_stats
        ]

#checkpoint_results = [(k, v) for k, v in checkpoint_results.items()]
#checkpoint_results_weighted = [(k, v) for k, v in checkpoint_results_weighted.items()]

import pandas as pd

records = []
for entry_record, entries in checkpoint_results.items():
    
    entry_record_ = dict(entry_record)
    
    if "episodic" in entry_record_["structure"]:
        entries = checkpoint_results_weighted[entry_record]
    
    for metric, (mean_val, _, _, lb, ub, _) in entries:
        entry_record_.update({
            metric: mean_val,
            metric + "_lb": lb,
            metric + "_ub": ub,
        })

    records.append(entry_record_)

checkpoint_results = pd.DataFrame.from_records(
    records
)

#checkpoint_results = checkpoint_results.loc[:, (checkpoint_results != checkpoint_results.iloc[0]).any()] 

checkpoint_results.to_clipboard(excel=True,)

checkpoint_results



# Transfer

In [None]:
import os
from pathlib import Path
import pickle
from collections import defaultdict
import math
import re

import numpy as np
import pandas as pd

from utils.io import load_json_file
from utils.logging import get_config_from_results_dir
from utils.metrics import ci_multiplier

def without(d, key):
    new_d = d.copy()
    new_d.pop(key)
    return new_d

def aggregate_dicts(results_dict, alpha, split, remove_nan: bool = True):

    aggregated_dict = defaultdict(list)
    for subdict in results_dict.values():
        for k, v in subdict.items():
            aggregated_dict[k[len(split) + 1:]] += [v]

    aggregated_dict = dict(aggregated_dict)

    drop_keys = set()
    for k, v in aggregated_dict.items():
        if np.isnan(v[0]) and remove_nan:
            drop_keys.add(k)
            continue

        N = len(v)

        mean = np.mean(v)
        var = np.var(v)
        se = math.sqrt(var / N)
        ub = mean + ci_multiplier(alpha=alpha, N=N) * se
        lb = mean - ci_multiplier(alpha=alpha, N=N) * se

        agg_results = (mean, var, se, lb, ub, N)

        aggregated_dict[k] = agg_results

    return {k: v for k, v in aggregated_dict.items() if k not in drop_keys}

def aggregate_weighted_dicts(results_dict, alpha, split, remove_nan: bool = True):
    aggregated_dict = defaultdict(list)
    for subdict in results_dict.values():
        for k, v in subdict.items():
            aggregated_dict[k[len(split) + 1:]] += [v]

    aggregated_dict = dict(aggregated_dict)

    drop_keys = set()
    for k, v in aggregated_dict.items():
        if np.isnan(v[0][0]) and remove_nan:
            drop_keys.add(k)
            continue

        v_arr = np.stack(v)
        weights = 1 / v_arr[:, 1]

        weights[np.isinf(weights)] = 0
        weights[np.isneginf(weights)] = 0
        weights[np.isnan(weights)] = 0

        weights_sum = np.sum(weights)

        if weights_sum == 0:
            #agg_results = (np.NaN, np.NaN, np.NaN, np.NaN, np.NaN, 0)

            mean = np.mean(v_arr[:, 0])
            var = np.array(0.0)
            se = np.array(0.0)

            N = np.sum(v_arr[:, 2])

            ub = mean + ci_multiplier(N=N, alpha=alpha) * se
            lb = mean - ci_multiplier(N=N, alpha=alpha) * se

            agg_results = (mean, var, se, lb, ub, N)

        else:
            mean = np.sum(weights * v_arr[:, 0]) / weights_sum
            var = 1 / weights_sum
            se = math.sqrt(var)

            N = np.sum(v_arr[:, 2])

            ub = mean + ci_multiplier(N=N, alpha=alpha) * se
            lb = mean - ci_multiplier(N=N, alpha=alpha) * se

            agg_results = (mean, var, se, lb, ub, N)

        aggregated_dict[k] = agg_results

    return {k: v for k, v in aggregated_dict.items() if k not in drop_keys}

def get_weighted_results_table(filters: dict, metrics: list, split: str, alpha: float = 0.10):
    raw_values = []
    checkpoint_results_weighted = defaultdict(dict)
    checkpoint_results = defaultdict(dict)
    for i, (top_dir, sub_dirs, sub_files) in enumerate(os.walk("../results")):
        top_dir_path = Path(top_dir)

        if len(sub_files) > 0 and top_dir_path.parts[-1] != "summary":
            base_config = get_config_from_results_dir(top_dir)

            continue_flag = False
            for fk, fv in filters.items():
                if base_config[fk] != fv:
                    continue_flag = True

            if "sweep" in base_config["checkpoint"]:
                continue_flag = True

            if continue_flag:
                continue

            for sub_file in sub_files:

                if f"{split}_eval" in sub_file:
                    results = load_json_file(top_dir_path / sub_file)

                    bracktets_pattern = r"\[(.*?)\]"

                    eval_str_parts = re.split(r"\_(?![^[]*])", sub_file)

                    k = int(re.search(bracktets_pattern, eval_str_parts[2]).group(1))
                    n_updates = int(
                        re.search(bracktets_pattern, eval_str_parts[3]).group(1)
                    )
                    inner_lr = float(
                        re.search(bracktets_pattern, eval_str_parts[4]).group(1)
                    )
                    inner_head_lr = float(
                        re.search(bracktets_pattern, eval_str_parts[5]).group(1)
                    )

                    class_weights = tuple(
                        map(
                            float,
                            re.search(bracktets_pattern, eval_str_parts[6])
                            .group(1)[8:]
                            .split(", "),
                        )
                    )

                    if len(eval_str_parts) == 8:
                        budget = int(
                            re.search(bracktets_pattern, eval_str_parts[7]).group(1)
                        )
                    else:
                        budget = 2048     

                    eval_tuple = (
                        ("k", k),
                        ("n_updates", n_updates),
                        ("inner_lr", inner_lr),
                        ("inner_head_lr", inner_head_lr),
                        ("class_weights", class_weights),
                        ("budget", budget),
                        ("reset", "reset" in base_config["version"]),
                        ("user_init", "user_init" in base_config["version"]),
                    )

                base_config_tuple = tuple(eval_tuple)

                checkpoint_results[base_config_tuple][base_config["fold"]] = results

                raw_values.append(
                    base_config | dict(eval_tuple) | results
                )

                if "episodic_khop" in base_config["structure"]:
                    results_with_weights = {}

                    for k, v in results.items():
                        if "iterations" in k:
                            N = v
                            break

                    for k, v in results.items():
                        if "std" not in k and "se" not in k and k + "_se" in results:
                            mean_effect = v
                            # In fixed-effects meta-analysis, the standard error
                            # is assumed to be an estimator for the variance
                            var = math.pow(results[k + "_se"], 2)

                            results_with_weights[k] = (mean_effect, var, N)

                    checkpoint_results_weighted[base_config_tuple][
                        base_config["fold"]
                    ] = results_with_weights

    checkpoint_results = dict(checkpoint_results)
    checkpoint_results_weighted = dict(checkpoint_results_weighted)
    
    for k, v in checkpoint_results.items():
        aggregate_stats = aggregate_dicts(v, alpha, split=split, remove_nan=False)
        checkpoint_results[k] = [
            (kk, aggregate_stats[kk]) for kk in metrics if kk in aggregate_stats
        ]

        if k in checkpoint_results_weighted:
            aggregate_weighted_stats = aggregate_weighted_dicts(
                checkpoint_results_weighted[k], alpha, split=split, remove_nan=False
            )
            checkpoint_results_weighted[k] = [
                (kk, aggregate_weighted_stats[kk]) for kk in metrics if kk in aggregate_weighted_stats
            ]

    checkpoint_results = [(k, v) for k, v in checkpoint_results.items()]

    records = []
    for i in range(len(checkpoint_results)):
        entry_record = dict(checkpoint_results[i][0])
        for entry in checkpoint_results[i][1]:
            entry_record.update({
                entry[0]: entry[1][0],
                entry[0] + "_var": entry[1][1],
                entry[0] + "_se":  entry[1][2],
                entry[0] + "_lb":  entry[1][3],
                entry[0] + "_ub":  entry[1][4],
                entry[0] + "_N":   entry[1][5],
            })

        records.append(entry_record)


    checkpoint_results = pd.DataFrame.from_records(
        checkpoint_results
    )

    checkpoint_results_weighted = [(k, v) for k, v in checkpoint_results_weighted.items()]

    records = []
    for i in range(len(checkpoint_results_weighted)):
        entry_record = dict(checkpoint_results_weighted[i][0])
        for entry in checkpoint_results_weighted[i][1]:
            entry_record.update({
                entry[0]: entry[1][0],
                entry[0] + "_var": entry[1][1],
                entry[0] + "_se":  entry[1][2],
                entry[0] + "_lb":  entry[1][3],
                entry[0] + "_ub":  entry[1][4],
                entry[0] + "_N":   entry[1][5],
            })

        records.append(entry_record)

    checkpoint_results_weighted = pd.DataFrame.from_records(
        records
    )
    #).sort_values(
    #    by=["k", "n_updates", "inner_head_lr", "inner_lr", "class_weights"]
    #)

    return raw_values, checkpoint_results, checkpoint_results_weighted

def normalize_by_group(df, by, metric, remove: set = set()):
    
    sub_df = df[[*by, f"{metric}", f"{metric}_se", f"{metric}_N"]].copy(deep=True)
    sub_df[f"{metric}_w"] = np.power(sub_df[f"{metric}_se"], -2)
    
    grouped_sub_df = sub_df.groupby(by=by, as_index=True, group_keys=False)
    
    agg_variance = (grouped_sub_df        
        .apply(lambda x: 1 / x[f"{metric}_w"].sum())
        .reset_index(level=[], name=f"{metric}_var")
    )

    agg_mean = (grouped_sub_df        
        .apply(lambda x: (x[f"{metric}"] * x[f"{metric}_w"]).sum())
        .reset_index(level=[], name=f"{metric}")
    )

    agg_N = (grouped_sub_df        
        .apply(lambda x: x[f"{metric}_N"].sum())
        .reset_index(level=[], name=f"{metric}_N")
    )

    agg_values = agg_mean.join(
        other=[agg_variance, agg_N]
    )

    agg_values[f"{metric}"] = agg_values[f"{metric}"] * agg_values[f"{metric}_var"]

    agg_values[f"{metric}_se"] = np.sqrt(agg_values[f"{metric}_var"])

    agg_values = agg_values.join(
        (agg_values
            .apply(
                lambda x: [
                    x[f"{metric}"] - ci_multiplier(N=x[f"{metric}_N"], alpha=0.1) * x[f"{metric}_se"],
                    x[f"{metric}"] + ci_multiplier(N=x[f"{metric}_N"], alpha=0.1) * x[f"{metric}_se"]
                    ],
                axis=1,
                result_type='expand',
                )
            .rename(
                columns={0: f"{metric}_lb", 1: f"{metric}_ub"}
                )
        ),
        
    )

    agg_values = agg_values[
        agg_values.columns
        .drop(
            [
                col_name
                for condition in remove
                for col_name in agg_values.filter(regex=condition).columns.tolist()
                ]
        )
    ]

    return agg_values


In [None]:
from collections import defaultdict

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

cmap = plt.get_cmap("tab10")

checkpoint_cmap = {
    "zero_shot_transfer": cmap(0),
    "subgraphs": cmap(0),
    "maml_lh": cmap(1),
    "maml_rh": cmap(2),
    "prototypical": cmap(3),
    "protomaml": cmap(4),
}


## Twitter Hate Speech

In [None]:
checkpoints = {
    "subgraphs": ["vb49pmtr", "8kixi0s8", "vzswebea", "a4xe91oq", "6syunm30"],
    "maml_lh": ["zqhx6x3b", "11pt2nis", "ruy4hp9o", "nlfyh80j", "06pfaw4f"],
    "maml_rh": ["rpu3r1rm", "35neqj40", "xycl2xbl", "9hdrq4an", "905yf717"],
    "prototypical": ["yjnx3e9w", "5e0vvr04", "br7qcerq", "6c8tvtwn", "ahp19u65"],
    "protomaml": ["9o4wp36l", "ouxd7twt", "aimkj4sa", "euh2mnqo", "p33ybhsn"],
}

metrics = [
    "supp_improvement",
    "f1_0",
    "aupr_0",
    "f1_1",
    "aupr_1",
    "f1_2",
    "aupr_2",
    "mcc",
    "macro_aupr",
]

all_checkpoint_results = []
for learning_algorithm, ckpts in checkpoints.items():
    print(learning_algorithm)
    for checkpoint in ckpts:

        filters = {
            "dataset": "twitterHateSpeech",
            "top_users_excluded": 0,
            "version": f"transfer_{checkpoint}",
            "structure_mode": "transductive",
        }

        raw_results, checkpoint_results, checkpoint_results_weighted = get_weighted_results_table(
            filters=filters,
            metrics=metrics,
            split="test",
            alpha=0.10
        )

        checkpoint_results_weighted["checkpoint"] = checkpoint
        checkpoint_results_weighted["learning_algorithm"] = learning_algorithm

        all_checkpoint_results.append(checkpoint_results_weighted)

#all_checkpoint_results = all_checkpoint_results[0]
all_checkpoint_results = pd.concat(all_checkpoint_results).reset_index(drop=True)


In [None]:
# Filter out some early attempts with too high learning rate
# These are noisy data points
all_checkpoint_results[(all_checkpoint_results["n_updates"] == 25)] = (
    all_checkpoint_results[(all_checkpoint_results["n_updates"] == 25)]
    .sort_values(by=["learning_algorithm", "checkpoint", "k", "inner_lr", "inner_head_lr"])
    .drop_duplicates(subset=["learning_algorithm", "checkpoint", "k"])
)

all_checkpoint_results

In [None]:
metric = "mcc"

lower_metric = metric.lower()

aggregated_df = normalize_by_group(
    df=all_checkpoint_results,
    by=["k", "learning_algorithm"],
    metric=metric,
)

fig, ax = plt.subplots(1, 1, figsize=figsize)

all_k_shots = [4, 8, 12, 16]

x_ticks = []
x_tick_labels = []

agg_values = defaultdict(list)
custom_lines = []

prev_k_loc = 0
cur_x = 0
for k in all_k_shots:
    matching_k = all_checkpoint_results["k"] == k

    for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
        matching_learning_alg = matching_k & (
            all_checkpoint_results["learning_algorithm"] == learning_algorithm
        )

        color = checkpoint_cmap[learning_algorithm]

        for ckpt in ckpts:
            matching_ckpt = matching_learning_alg & (
                all_checkpoint_results["checkpoint"] == ckpt
            )

            if all_checkpoint_results[matching_ckpt][lower_metric].shape[0] > 1:
                raise KeyboardInterrupt()
            elif all_checkpoint_results[matching_ckpt][lower_metric].shape[0] == 0:
                cur_x += 1
                continue

            value = all_checkpoint_results[matching_ckpt][lower_metric].item()
            value_error = (
                value
                - all_checkpoint_results[matching_ckpt][f"{lower_metric}_lb"].item()
            )

            ax.errorbar(
                cur_x,
                value,
                yerr=value_error,
                fmt="o",
                color=color,
                alpha=0.20,
                zorder=0,
                markersize=markersize_minor,
            )

            cur_x += 1

        agg_row = aggregated_df.xs(key=(k, learning_algorithm))
        agg_value = agg_row[metric].item()
        agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

        agg_loc = cur_x - len(ckpts) / 2

        ax.errorbar(
            agg_loc,
            agg_value,
            yerr=agg_error,
            fmt="D",
            color=color,
            alpha=1.0,
            label=learning_algorithm,
            zorder=2,
            markersize=markersize_major,
            fillstyle="none",
        )

        agg_values[learning_algorithm] += [agg_value]

        custom_lines += [
            Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
        ]

        cur_x += 5

    x_ticks += [(prev_k_loc + cur_x) / 2]
    x_tick_labels += [k]

    cur_x += 25
    prev_k_loc = cur_x

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
all_k_shots_x = np.stack(
    [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
)

for i, (learning_algorithm, values) in enumerate(agg_values.items()):
    color = checkpoint_cmap[learning_algorithm]

    all_k_shots = np.array(x_ticks, dtype=float)

    ax.plot(all_k_shots, values, c=color, alpha=0.75, zorder=1)

ax.set_title(" ", fontsize=11)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot\nTwitterHateSpeech", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_major)

ax.set_ylim(0, 0.25)
ax.set_yticks([0.0, 0.05, 0.10, 0.15, 0.20, 0.25])

fig.tight_layout()

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/twitter_gat_transfer.pdf"
)

plt.show()


In [None]:
twitter_transfer = all_checkpoint_results.copy(deep=True)

In [None]:
all_checkpoint_results[(all_checkpoint_results["learning_algorithm"] == "subgraphs") & (all_checkpoint_results["k"] == 4)]

In [None]:
all_checkpoint_results[(all_checkpoint_results["learning_algorithm"] == "subgraphs") & (all_checkpoint_results["k"] == 8)]

In [None]:
all_agg_dfs = []
for metric in ["f1_0", "aupr_0", "f1_1", "aupr_1", "f1_2", "aupr_2", "mcc"]:
    all_agg_dfs += [normalize_by_group(
        df=all_checkpoint_results[all_checkpoint_results["k"] <= 16],
        by=["k", "learning_algorithm"],
        metric=metric,
        remove={"var", "se", "N"}
    )]

all_agg_dfs = pd.concat(
    all_agg_dfs,
    axis=1,
    join="outer"
    ).reset_index()

all_agg_dfs["learning_algorithm"] = pd.Series(pd.Categorical(
    values=all_agg_dfs["learning_algorithm"],
    categories=list(checkpoints.keys()),
    ordered=True
    ))

all_agg_dfs = all_agg_dfs.sort_values(by=["k", "learning_algorithm"]).set_index(keys=["k", "learning_algorithm"])

all_agg_dfs.to_clipboard(excel=True,)

all_agg_dfs

In [None]:
gat_all_checkpoint_results = all_checkpoint_results.copy(deep=True)

### MLP Baseline on Twitter Hate Speech

In [None]:
checkpoints = {
    "subgraphs": ["l339inkn", "vyeuzmyc", "20asxsz3", "6046x2gc", "pxllyec4"],
    "maml_lh": ["cjvoiuqn", "y5zaa74e", "k2l9yy52", "7rk7bfgn", "f9xtwlxp"],
    "protomaml": ["4kd4uk24", "ev8f5mch", "cj4jv0pl", "8qfxh531", "m4rg2i2i"],
}

all_checkpoint_results = []
for learning_algorithm, ckpts in checkpoints.items():
    print(learning_algorithm)
    for checkpoint in ckpts:

        filters = {
            "dataset": "twitterHateSpeech",
            "top_users_excluded": 0,
            "version": f"transfer_{checkpoint}",
            "structure_mode": "transductive",
        }

        raw_results, checkpoint_results, checkpoint_results_weighted = get_weighted_results_table(
            filters=filters,
            metrics=metrics,
            split="test",
            alpha=0.10
        )

        checkpoint_results_weighted["checkpoint"] = checkpoint
        checkpoint_results_weighted["learning_algorithm"] = learning_algorithm

        all_checkpoint_results.append(checkpoint_results_weighted)

#all_checkpoint_results = all_checkpoint_results[0]
all_checkpoint_results = pd.concat(all_checkpoint_results).reset_index(drop=True)

In [None]:
# Filter out some early attempts with too high learning rate
# These are noisy data points
all_checkpoint_results[(all_checkpoint_results["n_updates"] == 25)] = (
    all_checkpoint_results[(all_checkpoint_results["n_updates"] == 25)]
    .sort_values(by=["learning_algorithm", "checkpoint", "k", "inner_lr", "inner_head_lr"])
    .drop_duplicates(subset=["learning_algorithm", "checkpoint", "k"])
)

all_checkpoint_results;

In [None]:
mlp_all_checkpoint_results = all_checkpoint_results.copy(deep=True)

In [None]:
metric = "mcc"

lower_metric = metric.lower()

aggregated_df = normalize_by_group(
    df=all_checkpoint_results,
    by=["k", "learning_algorithm"],
    metric=metric,
)

fig, ax = plt.subplots(1, 1, figsize=figsize)

all_k_shots = [4, 8, 12, 16]

x_ticks = []
x_tick_labels = []

agg_values = defaultdict(list)
custom_lines = []

prev_k_loc = 0
cur_x = 0
for k in all_k_shots:
    matching_k = all_checkpoint_results["k"] == k

    for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
        matching_learning_alg = matching_k & (
            all_checkpoint_results["learning_algorithm"] == learning_algorithm
        )

        color = checkpoint_cmap[learning_algorithm]

        for ckpt in ckpts:
            matching_ckpt = matching_learning_alg & (
                all_checkpoint_results["checkpoint"] == ckpt
            )

            if all_checkpoint_results[matching_ckpt][lower_metric].shape[0] > 1:
                raise KeyboardInterrupt()
            elif all_checkpoint_results[matching_ckpt][lower_metric].shape[0] == 0:
                cur_x += 1
                continue

            value = all_checkpoint_results[matching_ckpt][lower_metric].item()
            value_error = (
                value
                - all_checkpoint_results[matching_ckpt][f"{lower_metric}_lb"].item()
            )

            ax.errorbar(
                cur_x,
                value,
                yerr=value_error,
                fmt="o",
                color=color,
                alpha=0.20,
                zorder=0,
                markersize=markersize_minor,
            )

            cur_x += 1

        agg_row = aggregated_df.xs(key=(k, learning_algorithm))
        agg_value = agg_row[metric].item()
        agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

        agg_loc = cur_x - len(ckpts) / 2

        ax.errorbar(
            agg_loc,
            agg_value,
            yerr=agg_error,
            fmt="D",
            color=color,
            alpha=1.0,
            label=learning_algorithm,
            zorder=2,
            markersize=markersize_major,
            fillstyle="none",
        )

        agg_values[learning_algorithm] += [agg_value]

        custom_lines += [
            Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
        ]

        cur_x += 5

    x_ticks += [(prev_k_loc + cur_x) / 2]
    x_tick_labels += [k]

    cur_x += 25
    prev_k_loc = cur_x

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=len(all_k_shots) * len(checkpoints), dtype=float)
all_k_shots_x = np.stack(
    [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
)

for i, (learning_algorithm, values) in enumerate(agg_values.items()):
    color = checkpoint_cmap[learning_algorithm]

    all_k_shots = np.array(x_ticks, dtype=float)
    x = np.stack(
        [np.ones_like(x_ticks, dtype=float), all_k_shots], axis=1
    )
    w_prefix = np.linalg.inv(x.T @ x) @ x.T

    y = np.array(values)
    w_ml = w_prefix @ y

    pred_y = all_k_shots_x @ w_ml

    ax.plot(all_k_shots_range, pred_y, c=color, alpha=0.75, zorder=1)

#ax.set_title("Twitter Hate Speech", fontsize=11)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

ax.legend(
    custom_lines,
    list(checkpoints.keys()),
    loc='upper center',
    bbox_to_anchor=(0.5, 1.15),
    fontsize=fontsize_minor,
    ncol=3,
    )

fig.tight_layout()

plt.show()


In [None]:
all_agg_dfs = []
for metric in ["f1_0", "aupr_0", "f1_1", "aupr_1", "f1_2", "aupr_2", "mcc"]:
    all_agg_dfs += [normalize_by_group(
        df=all_checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
        remove={"var", "se", "N"}
    )]
    
all_agg_dfs = pd.concat(
    all_agg_dfs,
    axis=1,
    join="outer"
    ).reset_index()

all_agg_dfs["learning_algorithm"] = pd.Series(pd.Categorical(
    values=all_agg_dfs["learning_algorithm"],
    categories=list(checkpoints.keys()),
    ordered=True
    ))

all_agg_dfs = all_agg_dfs.sort_values(by=["k", "learning_algorithm"]).set_index(keys=["k", "learning_algorithm"])

all_agg_dfs.to_clipboard(excel=True,)

all_agg_dfs

#### Comparison

In [None]:
metric = "mcc"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(2 * figsize[0], figsize[1]))

all_k_shots = [4, 8, 12, 16]

custom_lines = []

for checkpoint_type, checkpoint_results in [("mlp", mlp_all_checkpoint_results), ("gat", gat_all_checkpoint_results)]:
    
    aggregated_df = normalize_by_group(
        df=checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )

    prev_k_loc = 0
    cur_x = 0

    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    for k in all_k_shots:
        matching_k = checkpoint_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            matching_learning_alg = matching_k & (
                checkpoint_results["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]

            for ckpt in ckpts:
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.errorbar(
                agg_loc,
                agg_value,
                yerr=agg_error,
                fmt="^" if checkpoint_type == "gat" else "v",
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                markersize=markersize_major,
                fillstyle="none",
            )

            agg_values[learning_algorithm] += [agg_value]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 25
        prev_k_loc = cur_x

    all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
    all_k_shots_x = np.stack(
        [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
    )

    for i, (learning_algorithm, values) in enumerate(agg_values.items()):
        color = checkpoint_cmap[learning_algorithm]

        all_k_shots_ = np.array(x_ticks, dtype=float)

        ax.plot(
            all_k_shots_,
            values,
            c=color,
            alpha=0.75,
            zorder=1,
            ls="-" if checkpoint_type == "gat" else "--"
            )

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

for learning_algorithm, ckpts in checkpoints.items():
    
    color = checkpoint_cmap[learning_algorithm]
    
    custom_lines += [
        Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
    ]

custom_lines += [
    Line2D([0], [0], marker='v', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
    Line2D([0], [0], marker='^', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
]

#ax.set_title("Twitter Hate Speech", fontsize=11)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.set_title(" ", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

fig.legend(
    custom_lines,
    ["Subgraphs", "MAML", "ProtoMAML"] + ["MLP", "GAT"],
    loc='upper center',
    #bbox_to_anchor=(0.2, 1.1),
    fontsize=fontsize_minor,
    ncol=5,
    )

fig.tight_layout()

plt.show()

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/twitter_mlp_ubs_transfer_comparison_mcc.pdf"
)


In [None]:
metric = "macro_aupr"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(2 * figsize[0], figsize[1]))

all_k_shots = [4, 8, 12, 16]

custom_lines = []

for checkpoint_type, checkpoint_results in [("mlp", mlp_all_checkpoint_results), ("gat", gat_all_checkpoint_results)]:
    
    aggregated_df = normalize_by_group(
        df=checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )

    prev_k_loc = 0
    cur_x = 0

    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    for k in all_k_shots:
        matching_k = checkpoint_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            matching_learning_alg = matching_k & (
                checkpoint_results["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]

            for ckpt in ckpts:
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.errorbar(
                agg_loc,
                agg_value,
                yerr=agg_error,
                fmt="^" if checkpoint_type == "gat" else "v",
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                markersize=markersize_major,
                fillstyle="none",
            )

            agg_values[learning_algorithm] += [agg_value]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 25
        prev_k_loc = cur_x

    all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
    all_k_shots_x = np.stack(
        [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
    )

    for i, (learning_algorithm, values) in enumerate(agg_values.items()):
        color = checkpoint_cmap[learning_algorithm]

        all_k_shots_ = np.array(x_ticks, dtype=float)

        ax.plot(
            all_k_shots_,
            values,
            c=color,
            alpha=0.75,
            zorder=1,
            ls="-" if checkpoint_type == "gat" else "--"
            )

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

for learning_algorithm, ckpts in checkpoints.items():
    
    color = checkpoint_cmap[learning_algorithm]
    
    custom_lines += [
        Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
    ]

custom_lines += [
    Line2D([0], [0], marker='v', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
    Line2D([0], [0], marker='^', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
]

ax.set_title(" ", fontsize=fontsize_major)
ax.set_ylabel("Macro-AUPR", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

fig.tight_layout()

plt.show()

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/twitter_mlp_ubs_transfer_comparison_aupr_macro.pdf"
)


### IID MLPs

In [None]:
checkpoints = {
    "maml_lh": ["nggrixup", "e506xerw", "iiu8msnu", "x502cvg8", "xs3e6fgn"],
}

metrics = [
    "supp_improvement",
    "f1_0",
    "f1_gain_0",
    "aupr_0",
    "f1_1",
    "f1_gain_1",
    "aupr_1",
    "f1_2",
    "f1_gain_2",
    "aupr_2",
    "mcc",
    "f1_gain_macro",
    "macro_aupr"
]

iid_mlp_checkpoint_results = []
for learning_algorithm, ckpts in checkpoints.items():
    print(learning_algorithm)
    for checkpoint in ckpts:

        filters = {
            "dataset": "twitterHateSpeech",
            "top_users_excluded": 0,
            "version": f"transfer_{checkpoint}",
            "structure_mode": "transductive",
        }

        raw_results, checkpoint_results, checkpoint_results_weighted = get_weighted_results_table(
            filters=filters,
            metrics=metrics,
            split="test",
            alpha=0.10
        )

        checkpoint_results_weighted["checkpoint"] = checkpoint
        checkpoint_results_weighted["learning_algorithm"] = learning_algorithm

        iid_mlp_checkpoint_results.append(checkpoint_results_weighted)

#iid_mlp_checkpoint_results = iid_mlp_checkpoint_results[0]
iid_mlp_checkpoint_results = pd.concat(iid_mlp_checkpoint_results).reset_index(drop=True)

In [None]:
# Filter out some bad attempts with too high k-shot
iid_mlp_checkpoint_results = iid_mlp_checkpoint_results[(iid_mlp_checkpoint_results["k"] <= 16)].reset_index(drop=True)

In [None]:
metric = "mcc"

lower_metric = metric.lower()

aggregated_df = normalize_by_group(
    df=iid_mlp_checkpoint_results,
    by=["k", "learning_algorithm"],
    metric=metric,
)

fig, ax = plt.subplots(1, 1, figsize=figsize)

all_k_shots = [4, 8, 12, 16]

x_ticks = []
x_tick_labels = []

agg_values = defaultdict(list)
custom_lines = []

prev_k_loc = 0
cur_x = 0
for k in all_k_shots:
    matching_k = iid_mlp_checkpoint_results["k"] == k

    for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
        matching_learning_alg = matching_k & (
            iid_mlp_checkpoint_results["learning_algorithm"] == learning_algorithm
        )

        color = checkpoint_cmap[learning_algorithm]

        for ckpt in ckpts:
            matching_ckpt = matching_learning_alg & (
                iid_mlp_checkpoint_results["checkpoint"] == ckpt
            )

            if iid_mlp_checkpoint_results[matching_ckpt][lower_metric].shape[0] > 1:
                raise KeyboardInterrupt()
            elif iid_mlp_checkpoint_results[matching_ckpt][lower_metric].shape[0] == 0:
                cur_x += 1
                continue

            value = iid_mlp_checkpoint_results[matching_ckpt][lower_metric].item()
            value_error = (
                value
                - iid_mlp_checkpoint_results[matching_ckpt][f"{lower_metric}_lb"].item()
            )

            ax.errorbar(
                cur_x,
                value,
                yerr=value_error,
                fmt="o",
                color=color,
                alpha=0.20,
                zorder=0,
                markersize=markersize_minor,
            )

            cur_x += 1

        agg_row = aggregated_df.xs(key=(k, learning_algorithm))
        agg_value = agg_row[metric].item()
        agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

        agg_loc = cur_x - len(ckpts) / 2

        ax.errorbar(
            agg_loc,
            agg_value,
            yerr=agg_error,
            fmt="D",
            color=color,
            alpha=1.0,
            label=learning_algorithm,
            zorder=2,
            markersize=markersize_major,
            fillstyle="none",
        )

        agg_values[learning_algorithm] += [agg_value]

        custom_lines += [
            Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
        ]

        cur_x += 5

    x_ticks += [(prev_k_loc + cur_x) / 2]
    x_tick_labels += [k]

    cur_x += 25
    prev_k_loc = cur_x

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=len(all_k_shots) * len(checkpoints), dtype=float)
all_k_shots_x = np.stack(
    [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
)

for i, (learning_algorithm, values) in enumerate(agg_values.items()):
    color = checkpoint_cmap[learning_algorithm]

    all_k_shots = np.array(x_ticks, dtype=float)
    x = np.stack(
        [np.ones_like(x_ticks, dtype=float), all_k_shots], axis=1
    )
    w_prefix = np.linalg.inv(x.T @ x) @ x.T

    y = np.array(values)
    w_ml = w_prefix @ y

    pred_y = all_k_shots_x @ w_ml

    ax.plot(all_k_shots_range, pred_y, c=color, alpha=0.75, zorder=1)

#ax.set_title("Twitter Hate Speech", fontsize=11)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

ax.legend(
    custom_lines,
    list(checkpoints.keys()),
    loc='upper center',
    bbox_to_anchor=(0.5, 1.15),
    fontsize=fontsize_minor,
    ncol=3,
    )

fig.tight_layout()

plt.show()


In [None]:
all_agg_dfs = []
for metric in ["f1_0", "aupr_0", "f1_1", "aupr_1", "f1_2", "aupr_2", "mcc", "f1_gain_0", "f1_gain_1", "f1_gain_2", "macro_aupr"]:
    all_agg_dfs += [normalize_by_group(
        df=iid_mlp_checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
        remove={"var", "se", "N"}
    )]
    
all_agg_dfs = pd.concat(
    all_agg_dfs,
    axis=1,
    join="outer"
    ).reset_index()

all_agg_dfs["learning_algorithm"] = pd.Series(pd.Categorical(
    values=all_agg_dfs["learning_algorithm"],
    categories=list(checkpoints.keys()),
    ordered=True
    ))

all_agg_dfs = all_agg_dfs.sort_values(by=["k", "learning_algorithm"]).set_index(keys=["k", "learning_algorithm"])

all_agg_dfs.to_clipboard(excel=True,)

all_agg_dfs

### Reset GATs

In [None]:
checkpoints = {
    # These were not trained as MAML_RH
    # As these are reset models
    # But the fairest comparsion is to MAML_RH
    "maml_rh": ["vb49pmtr", "8kixi0s8", "vzswebea", "a4xe91oq", "6syunm30"],
    "protomaml": ["9o4wp36l", "ouxd7twt", "aimkj4sa", "euh2mnqo", "p33ybhsn"],
}

reset_checkpoint_results = []
for learning_algorithm, ckpts in checkpoints.items():
    print(learning_algorithm)
    for checkpoint_seed, checkpoint in enumerate(ckpts):

        filters = {
            "dataset": "twitterHateSpeech",
            "top_users_excluded": 0,
            "version": f"reset_{checkpoint}_{checkpoint_seed}",
            "structure_mode": "transductive",
        }

        raw_results, checkpoint_results, checkpoint_results_weighted = get_weighted_results_table(
            filters=filters,
            metrics=metrics,
            split="test",
            alpha=0.10
        )

        checkpoint_results_weighted["checkpoint"] = checkpoint
        checkpoint_results_weighted["learning_algorithm"] = learning_algorithm

        reset_checkpoint_results.append(checkpoint_results_weighted)

#reset_checkpoint_results = reset_checkpoint_results[0]
reset_checkpoint_results = pd.concat(reset_checkpoint_results).reset_index(drop=True)


In [None]:
metric = "mcc"

lower_metric = metric.lower()

aggregated_df = normalize_by_group(
    df=reset_checkpoint_results,
    by=["k", "learning_algorithm"],
    metric=metric,
)

fig, ax = plt.subplots(1, 1, figsize=figsize)

all_k_shots = [4, 8, 12, 16]

x_ticks = []
x_tick_labels = []

agg_values = defaultdict(list)
custom_lines = []

prev_k_loc = 0
cur_x = 0
for k in all_k_shots:
    matching_k = reset_checkpoint_results["k"] == k

    for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
        matching_learning_alg = matching_k & (
            reset_checkpoint_results["learning_algorithm"] == learning_algorithm
        )

        color = checkpoint_cmap[learning_algorithm]

        for ckpt in ckpts:
            matching_ckpt = matching_learning_alg & (
                reset_checkpoint_results["checkpoint"] == ckpt
            )

            if reset_checkpoint_results[matching_ckpt][lower_metric].shape[0] > 1:
                raise KeyboardInterrupt()
            elif reset_checkpoint_results[matching_ckpt][lower_metric].shape[0] == 0:
                cur_x += 1
                continue

            value = reset_checkpoint_results[matching_ckpt][lower_metric].item()
            value_error = (
                value
                - reset_checkpoint_results[matching_ckpt][f"{lower_metric}_lb"].item()
            )

            ax.errorbar(
                cur_x,
                value,
                yerr=value_error,
                fmt="o",
                color=color,
                alpha=0.20,
                zorder=0,
                markersize=markersize_minor,
            )

            cur_x += 1

        agg_row = aggregated_df.xs(key=(k, learning_algorithm))
        agg_value = agg_row[metric].item()
        agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

        agg_loc = cur_x - len(ckpts) / 2

        ax.errorbar(
            agg_loc,
            agg_value,
            yerr=agg_error,
            fmt="D",
            color=color,
            alpha=1.0,
            label=learning_algorithm,
            zorder=2,
            markersize=markersize_major,
            fillstyle="none",
        )

        agg_values[learning_algorithm] += [agg_value]

        custom_lines += [
            Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
        ]

        cur_x += 5

    x_ticks += [(prev_k_loc + cur_x) / 2]
    x_tick_labels += [k]

    cur_x += 25
    prev_k_loc = cur_x

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=len(all_k_shots) * len(checkpoints), dtype=float)
all_k_shots_x = np.stack(
    [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
)

for i, (learning_algorithm, values) in enumerate(agg_values.items()):
    color = checkpoint_cmap[learning_algorithm]

    all_k_shots = np.array(x_ticks, dtype=float)

    ax.plot(all_k_shots, values, c=color, alpha=0.75, zorder=1)

#ax.set_title("Twitter Hate Speech", fontsize=11)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

ax.legend(
    custom_lines,
    list(checkpoints.keys()),
    loc='upper center',
    bbox_to_anchor=(0.5, 1.15),
    fontsize=fontsize_minor,
    ncol=3,
    )

fig.tight_layout()

plt.show()


In [None]:
all_agg_dfs = []
for metric in ["f1_0", "aupr_0", "f1_1", "aupr_1", "f1_2", "aupr_2", "mcc"]:
    all_agg_dfs += [normalize_by_group(
        df=reset_checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
        remove={"var", "se", "N"}
    )]
    
all_agg_dfs = pd.concat(
    all_agg_dfs,
    axis=1,
    join="outer"
    ).reset_index()

all_agg_dfs["learning_algorithm"] = pd.Series(pd.Categorical(
    values=all_agg_dfs["learning_algorithm"],
    categories=list(checkpoints.keys()),
    ordered=True
    ))

all_agg_dfs = all_agg_dfs.sort_values(by=["k", "learning_algorithm"]).set_index(keys=["k", "learning_algorithm"])

all_agg_dfs.to_clipboard(excel=True,)

all_agg_dfs

#### Comparison

In [None]:
metric = "mcc"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(2 * figsize[0], figsize[1]))

all_k_shots = [4, 8, 12, 16]

custom_lines = []

for checkpoint_type, checkpoint_results in [("gat_reset", reset_checkpoint_results), ("gat", gat_all_checkpoint_results)]:
    
    aggregated_df = normalize_by_group(
        df=checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )

    prev_k_loc = 0
    cur_x = 0

    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    for k in all_k_shots:
        matching_k = checkpoint_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            matching_learning_alg = matching_k & (
                checkpoint_results["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]

            for ckpt in ckpts:
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.errorbar(
                agg_loc,
                agg_value,
                yerr=agg_error,
                fmt="<" if checkpoint_type == "gat_reset" else "v",
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                markersize=markersize_major,
                fillstyle="none",
            )

            agg_values[learning_algorithm] += [agg_value]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 25
        prev_k_loc = cur_x

    all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
    all_k_shots_x = np.stack(
        [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
    )

    for i, (learning_algorithm, values) in enumerate(agg_values.items()):
        color = checkpoint_cmap[learning_algorithm]

        all_k_shots_ = np.array(x_ticks, dtype=float)

        ax.plot(
            all_k_shots_,
            values,
            c=color,
            alpha=0.75,
            zorder=1,
            ls="-" if checkpoint_type == "gat" else "--"
            )

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

for learning_algorithm, ckpts in checkpoints.items():
    
    color = checkpoint_cmap[learning_algorithm]
    
    custom_lines += [
        Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
    ]

custom_lines += [
    Line2D([0], [0], marker='<', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
    Line2D([0], [0], marker='^', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
]

ax.set_title(" ", fontsize=fontsize_major)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

fig.legend(
    custom_lines,
    ["MAML", "ProtoMAML"] + ["Reset", "GAT"],
    loc='upper center',
    #bbox_to_anchor=(0.5, 1.25),
    fontsize=fontsize_minor,
    ncol=4,
    )

fig.tight_layout()

plt.show()

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/twitter_random_transfer_comparison_mcc.pdf"
)


In [None]:
metric = "macro_aupr"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(2 * figsize[0], figsize[1]))

all_k_shots = [4, 8, 12, 16]

custom_lines = []

for checkpoint_type, checkpoint_results in [("gat_reset", reset_checkpoint_results), ("gat", gat_all_checkpoint_results)]:
    
    aggregated_df = normalize_by_group(
        df=checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )

    prev_k_loc = 0
    cur_x = 0

    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    for k in all_k_shots:
        matching_k = checkpoint_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            matching_learning_alg = matching_k & (
                checkpoint_results["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]

            for ckpt in ckpts:
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.errorbar(
                agg_loc,
                agg_value,
                yerr=agg_error,
                fmt="<" if checkpoint_type == "gat_reset" else "v",
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                markersize=markersize_major,
                fillstyle="none",
            )

            agg_values[learning_algorithm] += [agg_value]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 25
        prev_k_loc = cur_x

    all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
    all_k_shots_x = np.stack(
        [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
    )

    for i, (learning_algorithm, values) in enumerate(agg_values.items()):
        color = checkpoint_cmap[learning_algorithm]

        all_k_shots_ = np.array(x_ticks, dtype=float)

        ax.plot(
            all_k_shots_,
            values,
            c=color,
            alpha=0.75,
            zorder=1,
            ls="-" if checkpoint_type == "gat" else "--"
            )

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

for learning_algorithm, ckpts in checkpoints.items():
    
    color = checkpoint_cmap[learning_algorithm]
    
    custom_lines += [
        Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
    ]

custom_lines += [
    Line2D([0], [0], marker='<', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
    Line2D([0], [0], marker='^', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
]

ax.set_title(" ", fontsize=fontsize_major)
ax.set_ylabel("Macro-AUPR", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

fig.tight_layout()

plt.show()

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/twitter_random_transfer_comparison_macro_aupr.pdf"
)


### GATs + User initialization

In [None]:
checkpoints = {
    "subgraphs": ["vb49pmtr", "8kixi0s8", "vzswebea", "a4xe91oq", "6syunm30"],
    "maml_lh": ["zqhx6x3b", "11pt2nis", "ruy4hp9o", "nlfyh80j", "06pfaw4f"],
    "maml_rh": ["rpu3r1rm", "35neqj40", "xycl2xbl", "9hdrq4an", "905yf717"],
    "prototypical": ["yjnx3e9w", "5e0vvr04", "br7qcerq", "6c8tvtwn", "ahp19u65"],
    "protomaml": ["9o4wp36l", "ouxd7twt", "aimkj4sa", "euh2mnqo", "p33ybhsn"],
}

user_init_checkpoint_results = []
for learning_algorithm, ckpts in checkpoints.items():
    print(learning_algorithm)
    for checkpoint in ckpts:

        filters = {
            "dataset": "twitterHateSpeech",
            "top_users_excluded": 0,
            "version": f"transfer_{checkpoint}_avg_pool_user_init",
            "structure_mode": "transductive",
        }

        raw_results, checkpoint_results, checkpoint_results_weighted = get_weighted_results_table(
            filters=filters,
            metrics=metrics,
            split="test",
            alpha=0.10
        )

        checkpoint_results_weighted["checkpoint"] = checkpoint
        checkpoint_results_weighted["learning_algorithm"] = learning_algorithm
        
        print(checkpoint_results_weighted.shape[0])

        user_init_checkpoint_results.append(checkpoint_results_weighted)

#user_init_checkpoint_results = user_init_checkpoint_results[0]
user_init_checkpoint_results = pd.concat(user_init_checkpoint_results).reset_index(drop=True)

In [None]:
user_init_checkpoint_results[(user_init_checkpoint_results["checkpoint"] == "9o4wp36l") & (user_init_checkpoint_results["k"] == 16)]["mcc"]

In [None]:
metric = "mcc"

lower_metric = metric.lower()

aggregated_df = normalize_by_group(
    df=user_init_checkpoint_results,
    by=["k", "learning_algorithm"],
    metric=metric,
)

aggregated_df

In [None]:
metric = "mcc"

lower_metric = metric.lower()

aggregated_df = normalize_by_group(
    df=user_init_checkpoint_results,
    by=["k", "learning_algorithm"],
    metric=metric,
)

fig, ax = plt.subplots(1, 1, figsize=figsize)

all_k_shots = [4, 8, 12, 16]

x_ticks = []
x_tick_labels = []

agg_values = defaultdict(list)
custom_lines = []

prev_k_loc = 0
cur_x = 0
for k in all_k_shots:
    matching_k = user_init_checkpoint_results["k"] == k

    for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
        matching_learning_alg = matching_k & (
            user_init_checkpoint_results["learning_algorithm"] == learning_algorithm
        )

        color = checkpoint_cmap[learning_algorithm]

        for ckpt in ckpts:
            matching_ckpt = matching_learning_alg & (
                user_init_checkpoint_results["checkpoint"] == ckpt
            )

            if user_init_checkpoint_results[matching_ckpt][lower_metric].shape[0] > 1:
                raise KeyboardInterrupt()
            elif user_init_checkpoint_results[matching_ckpt][lower_metric].shape[0] == 0:
                cur_x += 1
                continue

            value = user_init_checkpoint_results[matching_ckpt][lower_metric].item()
            value_error = (
                value
                - user_init_checkpoint_results[matching_ckpt][f"{lower_metric}_lb"].item()
            )

            ax.errorbar(
                cur_x,
                value,
                yerr=value_error,
                fmt="o",
                color=color,
                alpha=0.20,
                zorder=0,
                markersize=markersize_minor,
            )

            cur_x += 1

        agg_row = aggregated_df.xs(key=(k, learning_algorithm))
        agg_value = agg_row[metric].item()
        agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

        agg_loc = cur_x - len(ckpts) / 2

        ax.errorbar(
            agg_loc,
            agg_value,
            yerr=agg_error,
            fmt="D",
            color=color,
            alpha=1.0,
            label=learning_algorithm,
            zorder=2,
            markersize=markersize_major,
            fillstyle="none",
        )

        agg_values[learning_algorithm] += [agg_value]

        custom_lines += [
            Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
        ]

        cur_x += 5

    x_ticks += [(prev_k_loc + cur_x) / 2]
    x_tick_labels += [k]

    cur_x += 25
    prev_k_loc = cur_x

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=len(all_k_shots) * len(checkpoints), dtype=float)
all_k_shots_x = np.stack(
    [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
)

for i, (learning_algorithm, values) in enumerate(agg_values.items()):
    color = checkpoint_cmap[learning_algorithm]

    all_k_shots = np.array(x_ticks, dtype=float)

    ax.plot(all_k_shots, values, c=color, alpha=0.75, zorder=1)

#ax.set_title("Twitter Hate Speech", fontsize=11)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

ax.legend(
    custom_lines,
    list(checkpoints.keys()),
    loc='upper center',
    bbox_to_anchor=(0.5, 1.15),
    fontsize=fontsize_minor,
    ncol=3,
    )

fig.tight_layout()

plt.show()


In [None]:
metric = "mcc"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(2 * figsize[0], figsize[1]))

all_k_shots = [4, 8, 12, 16]

custom_lines = []

for checkpoint_type, checkpoint_results in [("gat_user", user_init_checkpoint_results), ("gat", gat_all_checkpoint_results)]:
    
    aggregated_df = normalize_by_group(
        df=checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )

    prev_k_loc = 0
    cur_x = 0

    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    for k in all_k_shots:
        matching_k = checkpoint_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            matching_learning_alg = matching_k & (
                checkpoint_results["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]

            for ckpt in ckpts:
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.errorbar(
                agg_loc,
                agg_value,
                yerr=agg_error,
                fmt=">" if checkpoint_type == "gat_user" else "v",
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                markersize=markersize_major,
                fillstyle="none",
            )

            agg_values[learning_algorithm] += [agg_value]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 25
        prev_k_loc = cur_x

    all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
    all_k_shots_x = np.stack(
        [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
    )

    for i, (learning_algorithm, values) in enumerate(agg_values.items()):
        color = checkpoint_cmap[learning_algorithm]

        all_k_shots_ = np.array(x_ticks, dtype=float)

        ax.plot(
            all_k_shots_,
            values,
            c=color,
            alpha=0.75,
            zorder=1,
            ls="-" if checkpoint_type == "gat" else "--"
            )

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

for learning_algorithm, ckpts in checkpoints.items():
    
    color = checkpoint_cmap[learning_algorithm]
    
    custom_lines += [
        Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
    ]

custom_lines += [
    Line2D([0], [0], marker='>', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
    Line2D([0], [0], marker='v', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
]

ax.set_title(" ", fontsize=fontsize_major)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

fig.legend(
    custom_lines,
    ["Subgraphs", "MAML-LH", "MAML-RH", "ProtoNet", "ProtoMAML"] + ["User", "GAT"],
    loc='upper center',
    #bbox_to_anchor=(0.5, 1.25),
    fontsize=fontsize_minor,
    ncol=7,
    )

fig.tight_layout()

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/twitter_gat_with_user_init_tranfer_mcc.pdf"
)

plt.show()


In [None]:
metric = "macro_aupr"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(2 * figsize[0], figsize[1]))

all_k_shots = [4, 8, 12, 16]

custom_lines = []

for checkpoint_type, checkpoint_results in [("gat_user", user_init_checkpoint_results), ("gat", gat_all_checkpoint_results)]:
    
    aggregated_df = normalize_by_group(
        df=checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )

    prev_k_loc = 0
    cur_x = 0

    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    for k in all_k_shots:
        matching_k = checkpoint_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            matching_learning_alg = matching_k & (
                checkpoint_results["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]

            for ckpt in ckpts:
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.errorbar(
                agg_loc,
                agg_value,
                yerr=agg_error,
                fmt=">" if checkpoint_type == "gat_user" else "v",
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                markersize=markersize_major,
                fillstyle="none",
            )

            agg_values[learning_algorithm] += [agg_value]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 25
        prev_k_loc = cur_x

    all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
    all_k_shots_x = np.stack(
        [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
    )

    for i, (learning_algorithm, values) in enumerate(agg_values.items()):
        color = checkpoint_cmap[learning_algorithm]

        all_k_shots_ = np.array(x_ticks, dtype=float)

        ax.plot(
            all_k_shots_,
            values,
            c=color,
            alpha=0.75,
            zorder=1,
            ls="-" if checkpoint_type == "gat" else "--"
            )

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

for learning_algorithm, ckpts in checkpoints.items():
    
    color = checkpoint_cmap[learning_algorithm]
    
    custom_lines += [
        Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
    ]

custom_lines += [
    Line2D([0], [0], marker='>', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
    Line2D([0], [0], marker='v', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
]

ax.set_title(" ", fontsize=fontsize_major)
ax.set_ylabel("Macro-AUPR", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

fig.tight_layout()

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/twitter_gat_with_user_init_tranfer_macro_aupr.pdf"
)

plt.show()


In [None]:
gat_user_init_results = user_init_checkpoint_results.copy(deep=True)

### Extreme $k$-shot

In [None]:
metric = "mcc"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(2 * figsize[0], figsize[1]))

tested_learning_algorithms = [
    "prototypical"
]

all_k_shots = [4, 8, 12, 16, 32, 64, 128, 256]

custom_lines = []

for checkpoint_type, checkpoint_results in [("gat_user", user_init_checkpoint_results), ("gat", gat_all_checkpoint_results)]:
    
    aggregated_df = normalize_by_group(
        df=checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )
    
    aggregated_df = aggregated_df.loc[pd.IndexSlice[:, tested_learning_algorithms], :]
    
    prev_k_loc = 0
    cur_x = 0

    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    for k in all_k_shots:
        matching_k = checkpoint_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            
            if learning_algorithm not in tested_learning_algorithms:
                continue
            
            matching_learning_alg = matching_k & (
                checkpoint_results["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]

            for ckpt in ckpts:
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.errorbar(
                agg_loc,
                agg_value,
                yerr=agg_error,
                fmt=">" if checkpoint_type == "gat_user" else "v",
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                markersize=markersize_major,
                fillstyle="none",
            )

            agg_values[learning_algorithm] += [agg_value]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 25
        prev_k_loc = cur_x

    all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
    all_k_shots_x = np.stack(
        [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
    )

    for i, (learning_algorithm, values) in enumerate(agg_values.items()):
        color = checkpoint_cmap[learning_algorithm]

        all_k_shots_ = np.array(x_ticks, dtype=float)

        ax.plot(
            all_k_shots_,
            values,
            c=color,
            alpha=0.75,
            zorder=1,
            ls="-" if checkpoint_type == "gat" else "--"
            )

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

learning_algorithm = "prototypical"
    
color = checkpoint_cmap[learning_algorithm]

custom_lines += [
    Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
]

custom_lines += [
    Line2D([0], [0], marker='>', markerfacecolor='k', color='w', lw=0, markersize=markersize_major),
    Line2D([0], [0], marker='v', markerfacecolor='k', color='w', lw=0, markersize=markersize_major),
]

ax.set_title(" ", fontsize=fontsize_major)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

fig.legend(
    custom_lines,
    ["ProtoNet"] + ["User", "GAT"],
    loc='upper center',
    #bbox_to_anchor=(0.5, 1.25),
    fontsize=fontsize_minor,
    ncol=10,
    )

fig.tight_layout()

plt.show()

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/twitter_extreme_kshot_mcc.pdf"
)


In [None]:
metric = "mcc"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(2 * figsize[0], figsize[1]))

tested_learning_algorithms = [
    "prototypical"
]

all_k_shots = [4, 8, 12, 16, 32, 64, 128, 256]

custom_lines = []

for checkpoint_type, checkpoint_results in [("gat", twitter_transfer)]:
    
    aggregated_df = normalize_by_group(
        df=checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )
    
    aggregated_df = aggregated_df.loc[pd.IndexSlice[:, tested_learning_algorithms], :]
    
    prev_k_loc = 0
    cur_x = 0

    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    for k in all_k_shots:
        matching_k = checkpoint_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            
            if learning_algorithm not in tested_learning_algorithms:
                continue
            
            matching_learning_alg = matching_k & (
                checkpoint_results["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]

            for ckpt in ckpts:
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.errorbar(
                agg_loc,
                agg_value,
                yerr=agg_error,
                fmt=">" if checkpoint_type == "gat_user" else "v",
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                markersize=markersize_major,
                fillstyle="none",
            )

            agg_values[learning_algorithm] += [agg_value]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 25
        prev_k_loc = cur_x

    all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
    all_k_shots_x = np.stack(
        [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
    )

    for i, (learning_algorithm, values) in enumerate(agg_values.items()):
        color = checkpoint_cmap[learning_algorithm]

        all_k_shots_ = np.array(x_ticks, dtype=float)

        ax.plot(
            all_k_shots_,
            values,
            c=color,
            alpha=0.75,
            zorder=1,
            ls="-" if checkpoint_type == "gat" else "--"
            )

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

learning_algorithm = "prototypical"
    
color = checkpoint_cmap[learning_algorithm]

custom_lines += [
    Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
]

ax.set_title(" ", fontsize=fontsize_major)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

fig.tight_layout()

plt.show()

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/twitter_extreme_kshot_macro_aupr.pdf"
)

In [None]:
all_agg_dfs = []
for metric in ["f1_0", "aupr_0", "f1_1", "aupr_1", "f1_2", "aupr_2", "mcc"]:
    all_agg_dfs += [normalize_by_group(
        df=all_checkpoint_results[all_checkpoint_results["learning_algorithm"] == "prototypical"],
        by=["k", "learning_algorithm", "user_init"],
        metric=metric,
        remove={"var", "se", "N"}
    )]
    
all_agg_dfs = pd.concat(
    all_agg_dfs,
    axis=1,
    join="outer"
    ).reset_index()

all_agg_dfs["learning_algorithm"] = pd.Series(pd.Categorical(
    values=all_agg_dfs["learning_algorithm"],
    categories=list(checkpoints.keys()),
    ordered=True
    ))

all_agg_dfs = all_agg_dfs.sort_values(by=["k", "learning_algorithm"]).set_index(keys=["k", "learning_algorithm"])

all_agg_dfs.to_clipboard(excel=True,)

all_agg_dfs

In [None]:
metric = "f1_0"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(2 * figsize[0], figsize[1]))

tested_learning_algorithms = [
    "prototypical"
]

all_k_shots = [4, 8, 12, 16, 32, 64, 128, 256]

custom_lines = []

for checkpoint_type, checkpoint_results in [("gat_user", user_init_checkpoint_results), ("gat", gat_all_checkpoint_results)]:
    
    aggregated_df = normalize_by_group(
        df=checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )
    
    aggregated_df = aggregated_df.loc[pd.IndexSlice[:, tested_learning_algorithms], :]
    
    prev_k_loc = 0
    cur_x = 0

    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    for k in all_k_shots:
        matching_k = checkpoint_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            
            if learning_algorithm not in tested_learning_algorithms:
                continue
            
            matching_learning_alg = matching_k & (
                checkpoint_results["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]

            for ckpt in ckpts:
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.errorbar(
                agg_loc,
                agg_value,
                yerr=agg_error,
                fmt=">" if checkpoint_type == "gat_user" else "v",
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                markersize=markersize_major,
                fillstyle="none",
            )

            agg_values[learning_algorithm] += [agg_value]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 25
        prev_k_loc = cur_x

    all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
    all_k_shots_x = np.stack(
        [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
    )

    for i, (learning_algorithm, values) in enumerate(agg_values.items()):
        color = checkpoint_cmap[learning_algorithm]

        all_k_shots_ = np.array(x_ticks, dtype=float)

        ax.plot(
            all_k_shots_,
            values,
            c=color,
            alpha=0.75,
            zorder=1,
            ls="-" if checkpoint_type == "gat" else "--"
            )

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

for learning_algorithm, ckpts in checkpoints.items():
    
    color = checkpoint_cmap[learning_algorithm]
    
    custom_lines += [
        Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
    ]

custom_lines += [
    Line2D([0], [0], marker='>', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
    Line2D([0], [0], marker='^', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
]

#ax.set_title("Twitter Hate Speech", fontsize=11)
ax.set_ylabel("F1-0", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

ax.legend(
    custom_lines,
    list(checkpoints.keys()) + ["User", "GAT"],
    loc='upper center',
    bbox_to_anchor=(0.5, 1.25),
    fontsize=fontsize_minor,
    ncol=4,
    )

fig.tight_layout()

plt.show()


In [None]:
metric = "f1_1"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(2 * figsize[0], figsize[1]))

tested_learning_algorithms = [
    "prototypical"
]

all_k_shots = [4, 8, 12, 16, 32, 64, 128, 256]

custom_lines = []

for checkpoint_type, checkpoint_results in [("gat_user", user_init_checkpoint_results), ("gat", gat_all_checkpoint_results)]:
    
    aggregated_df = normalize_by_group(
        df=checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )
    
    aggregated_df = aggregated_df.loc[pd.IndexSlice[:, tested_learning_algorithms], :]
    
    prev_k_loc = 0
    cur_x = 0

    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    for k in all_k_shots:
        matching_k = checkpoint_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            
            if learning_algorithm not in tested_learning_algorithms:
                continue
            
            matching_learning_alg = matching_k & (
                checkpoint_results["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]

            for ckpt in ckpts:
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.errorbar(
                agg_loc,
                agg_value,
                yerr=agg_error,
                fmt=">" if checkpoint_type == "gat_user" else "v",
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                markersize=markersize_major,
                fillstyle="none",
            )

            agg_values[learning_algorithm] += [agg_value]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 25
        prev_k_loc = cur_x

    all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
    all_k_shots_x = np.stack(
        [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
    )

    for i, (learning_algorithm, values) in enumerate(agg_values.items()):
        color = checkpoint_cmap[learning_algorithm]

        all_k_shots_ = np.array(x_ticks, dtype=float)

        ax.plot(
            all_k_shots_,
            values,
            c=color,
            alpha=0.75,
            zorder=1,
            ls="-" if checkpoint_type == "gat" else "--"
            )

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

for learning_algorithm, ckpts in checkpoints.items():
    
    color = checkpoint_cmap[learning_algorithm]
    
    custom_lines += [
        Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
    ]

custom_lines += [
    Line2D([0], [0], marker='>', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
    Line2D([0], [0], marker='^', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
]

#ax.set_title("Twitter Hate Speech", fontsize=11)
ax.set_ylabel("F1-1", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

ax.legend(
    custom_lines,
    list(checkpoints.keys()) + ["User", "GAT"],
    loc='upper center',
    bbox_to_anchor=(0.5, 1.25),
    fontsize=fontsize_minor,
    ncol=4,
    )

fig.tight_layout()

plt.show()


In [None]:
metric = "f1_2"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(2 * figsize[0], figsize[1]))

tested_learning_algorithms = [
    "prototypical"
]

all_k_shots = [4, 8, 12, 16, 32, 64, 128, 256]

custom_lines = []

for checkpoint_type, checkpoint_results in [("gat_user", user_init_checkpoint_results), ("gat", gat_all_checkpoint_results)]:
    
    aggregated_df = normalize_by_group(
        df=checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )
    
    aggregated_df = aggregated_df.loc[pd.IndexSlice[:, tested_learning_algorithms], :]
    
    prev_k_loc = 0
    cur_x = 0

    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    for k in all_k_shots:
        matching_k = checkpoint_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            
            if learning_algorithm not in tested_learning_algorithms:
                continue
            
            matching_learning_alg = matching_k & (
                checkpoint_results["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]

            for ckpt in ckpts:
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.errorbar(
                agg_loc,
                agg_value,
                yerr=agg_error,
                fmt=">" if checkpoint_type == "gat_user" else "v",
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                markersize=markersize_major,
                fillstyle="none",
            )

            agg_values[learning_algorithm] += [agg_value]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 25
        prev_k_loc = cur_x

    all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
    all_k_shots_x = np.stack(
        [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
    )

    for i, (learning_algorithm, values) in enumerate(agg_values.items()):
        color = checkpoint_cmap[learning_algorithm]

        all_k_shots_ = np.array(x_ticks, dtype=float)

        ax.plot(
            all_k_shots_,
            values,
            c=color,
            alpha=0.75,
            zorder=1,
            ls="-" if checkpoint_type == "gat" else "--"
            )

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

for learning_algorithm, ckpts in checkpoints.items():
    
    color = checkpoint_cmap[learning_algorithm]
    
    custom_lines += [
        Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
    ]

custom_lines += [
    Line2D([0], [0], marker='>', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
    Line2D([0], [0], marker='^', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
]

#ax.set_title("Twitter Hate Speech", fontsize=11)
ax.set_ylabel("F1-2", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

ax.legend(
    custom_lines,
    list(checkpoints.keys()) + ["User", "GAT"],
    loc='upper center',
    bbox_to_anchor=(0.5, 1.25),
    fontsize=fontsize_minor,
    ncol=10,
    )

fig.tight_layout()

plt.show()


In [None]:
gat_all_checkpoint_results[(gat_all_checkpoint_results["user_init"] == False) & (gat_all_checkpoint_results["k"] == 256)]

In [None]:
all_agg_dfs = []
for metric in ["f1_0", "aupr_0", "f1_1", "aupr_1", "f1_2", "aupr_2", "mcc"]:
    all_agg_dfs += [normalize_by_group(
        df=gat_all_checkpoint_results,
        by=["k", "learning_algorithm", "user_init"],
        metric=metric,
        remove={"var", "se", "N"}
    )]

    all_agg_dfs[-1] = all_agg_dfs[-1].loc[pd.IndexSlice[:, tested_learning_algorithms], :]

all_agg_dfs = pd.concat(
    all_agg_dfs,
    axis=1,
    join="outer",
    ).reset_index()

user_init_agg_dfs = []
for metric in ["f1_0", "aupr_0", "f1_1", "aupr_1", "f1_2", "aupr_2", "mcc"]:
    user_init_agg_dfs += [normalize_by_group(
        df=user_init_checkpoint_results,
        by=["k", "learning_algorithm", "user_init"],
        metric=metric,
        remove={"var", "se", "N"}
    )]

    user_init_agg_dfs[-1] = user_init_agg_dfs[-1].loc[pd.IndexSlice[:, tested_learning_algorithms], :]

user_init_agg_dfs = pd.concat(
    user_init_agg_dfs,
    axis=1,
    join="outer",
    ).reset_index()

all_agg_dfs = pd.merge(all_agg_dfs, user_init_agg_dfs, how='outer')

all_agg_dfs["learning_algorithm"] = pd.Series(pd.Categorical(
    values=all_agg_dfs["learning_algorithm"],
    categories=tested_learning_algorithms,
    ordered=True,
    ))

all_agg_dfs = all_agg_dfs.sort_values(by=["k", "learning_algorithm", "user_init"]).set_index(keys=["k", "learning_algorithm", "user_init"])

all_agg_dfs.to_clipboard(excel=True,)

all_agg_dfs

In [None]:
all_agg_dfs.columns

## HealthStory

In [None]:
checkpoints = {
    "subgraphs": ["vb49pmtr", "8kixi0s8", "vzswebea", "a4xe91oq", "6syunm30"],
    "maml_lh": ["zqhx6x3b", "11pt2nis", "ruy4hp9o", "nlfyh80j", "06pfaw4f"],
    "maml_rh": ["rpu3r1rm", "35neqj40", "xycl2xbl", "9hdrq4an", "905yf717"],
    "prototypical": ["2crszon2", "wdhcjrmp", "zwwyk83h", "lku2gxb8", "4pj3ewdu"],
    "protomaml": ["9o4wp36l", "ouxd7twt", "aimkj4sa", "euh2mnqo", "p33ybhsn"],
}

metrics = [
    "supp_improvement",
    "f1_0",
    "f1_1",
    "f1_2",
    "f1_gain_0",
    "f1_gain_1",
    "f1_gain_2",
    "f1_gain_macro",
    "macro_auprg",
    "mcc",
]

all_checkpoint_results = []
for learning_algorithm, ckpts in checkpoints.items():
    print(learning_algorithm)
    for checkpoint in ckpts:

        filters = {
            "dataset": "HealthStory",
            "top_users_excluded": 0,
            "version": f"transfer_{checkpoint}",
            "structure_mode": "transductive",
        }

        raw_results, checkpoint_results, checkpoint_results_weighted = get_weighted_results_table(
            filters=filters,
            metrics=metrics,
            split="test",
            alpha=0.10
        )

        checkpoint_results_weighted["checkpoint"] = checkpoint
        checkpoint_results_weighted["learning_algorithm"] = learning_algorithm

        all_checkpoint_results.append(checkpoint_results_weighted)

#all_checkpoint_results = all_checkpoint_results[0]
all_checkpoint_results = pd.concat(all_checkpoint_results).reset_index(drop=True)


In [None]:
#all_checkpoint_results = all_checkpoint_results[all_checkpoint_results["n_updates"] == 25]
all_checkpoint_results = all_checkpoint_results.sort_values(by=["k", "learning_algorithm", "inner_lr", "n_updates",])

all_checkpoint_results = all_checkpoint_results.drop_duplicates(subset=["k", "learning_algorithm", "checkpoint", "n_updates",], keep="first")

In [None]:
all_checkpoint_results

### Low Adaptation

In [None]:
zero_shot_transfer_baseline = all_checkpoint_results[(all_checkpoint_results["learning_algorithm"] == "subgraphs") & (all_checkpoint_results["inner_lr"] == 0.0)]
zero_shot_transfer_baseline = zero_shot_transfer_baseline[zero_shot_transfer_baseline["k"] == 4]

In [None]:
low_adaptation_checkpoints = all_checkpoint_results[~((all_checkpoint_results["learning_algorithm"] == "subgraphs") & (all_checkpoint_results["inner_lr"] == 0.0))]

low_adaptation_checkpoints = low_adaptation_checkpoints.sort_values(by=["k", "learning_algorithm", "n_updates", "inner_lr",])

low_adaptation_checkpoints = low_adaptation_checkpoints.drop_duplicates(subset=["k", "learning_algorithm", "checkpoint",], keep="first")

low_adaptation_checkpoints

In [None]:
from collections import defaultdict

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

metric = "mcc"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(3.03209, 9.72632/4))

all_k_shots = [4, 8, 12, 16]

x_ticks = []
x_tick_labels = []

agg_locs = defaultdict(list)
agg_values = defaultdict(list)

prev_k_loc = 0
cur_x = 0

zero_shot_vals = zero_shot_transfer_baseline[metric].to_numpy()
zero_shot_error = zero_shot_vals - zero_shot_transfer_baseline[f"{lower_metric}_lb"].to_numpy()

ax.errorbar(
    x=np.arange(cur_x, cur_x + zero_shot_vals.shape[0]),
    y=zero_shot_vals,
    yerr=zero_shot_error,
    fmt='o',
    color=checkpoint_cmap["zero_shot_transfer"],
    alpha=0.20,
    zorder=1
    )

ax.errorbar(
    np.arange(cur_x, cur_x + zero_shot_vals.shape[0]).mean(),
    zero_shot_vals.mean(),
    fmt='D',
    color=checkpoint_cmap["zero_shot_transfer"],
    alpha=1.0,
    label=learning_algorithm,
    zorder=2
    )

ax.hlines(
    y=zero_shot_vals.mean(), 
    xmin=0,
    xmax=1000,
    colors=["gray"],
    zorder=1,
    linestyles="--",
    alpha=0.75,
    )

x_ticks += [np.arange(cur_x, cur_x + zero_shot_vals.shape[0]).mean()]
x_tick_labels += ["0"]

cur_x = zero_shot_vals.shape[0] + 35

for k in all_k_shots:
    matching_k = low_adaptation_checkpoints["k"] == k
    
    for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
        matching_learning_alg = matching_k & (low_adaptation_checkpoints["learning_algorithm"] == learning_algorithm)
        
        for ckpt in ckpts:
            matching_ckpt = matching_learning_alg & (low_adaptation_checkpoints["checkpoint"] == ckpt)
            
            if low_adaptation_checkpoints[matching_ckpt][lower_metric].shape[0] > 1:
                raise KeyboardInterrupt()
            elif low_adaptation_checkpoints[matching_ckpt][lower_metric].shape[0] == 0:
                cur_x += 1
                continue
                        
            value = low_adaptation_checkpoints[matching_ckpt][lower_metric].item()
            value_error = value - low_adaptation_checkpoints[matching_ckpt][f"{lower_metric}_lb"].item()
            
            ax.errorbar(cur_x, value, yerr=value_error, fmt='o', color=cmap(i), alpha=0.20, zorder=1)
            
            cur_x += 1

        agg_value = low_adaptation_checkpoints[matching_learning_alg][lower_metric].mean()#.item()
        agg_loc = cur_x - len(ckpts) / 2

        agg_locs[learning_algorithm] += [agg_loc]
        agg_values[learning_algorithm] += [agg_value]
        
        ax.errorbar(agg_loc, agg_value, fmt='D', color=cmap(i), alpha=1.0, label=learning_algorithm, zorder=2)
        
        cur_x += 5

    x_ticks += [(prev_k_loc + cur_x) / 2]
    x_tick_labels += [k]
    
    cur_x += 25
    prev_k_loc = cur_x

ax.set_xlim(-25, cur_x)

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

custom_lines = [
    Line2D([0], [0], color=cmap(0), lw=4),
    Line2D([0], [0], color=cmap(1), lw=4),
    Line2D([0], [0], color=cmap(2), lw=4),
    Line2D([0], [0], color=cmap(3), lw=4),
    Line2D([0], [0], color=cmap(4), lw=4),
    ]

#ax.legend(custom_lines, list(checkpoints.keys()))

all_k_shots_range = np.arange(-10, cur_x+10, step=cur_x//10, dtype=float)
all_k_shots_x = np.stack([np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1)

for i, (learning_algorithm, values) in enumerate(agg_values.items()):

    color = checkpoint_cmap[learning_algorithm]

    all_k_shots = np.array(agg_locs[learning_algorithm], dtype=float)
    x = np.stack([np.ones_like(agg_locs[learning_algorithm], dtype=float), all_k_shots], axis=1)
    w_prefix = np.linalg.inv(x.T @ x) @ x.T
    
    y = np.array(values)
    w_ml = w_prefix @ y
    
    pred_y = all_k_shots_x @ w_ml
    
    ax.plot(all_k_shots_range, pred_y, c=color, alpha=0.75, zorder=1)

ax.set_title("HealthStory\n(low adaptation)", fontsize=11)
ax.set_ylabel("MCC", fontsize=9)
ax.set_xlabel("$k$-shot", fontsize=9)

fig.tight_layout()

plt.show()


### High Adaptation

In [None]:
zero_shot_transfer_baseline = all_checkpoint_results[(all_checkpoint_results["learning_algorithm"] == "subgraphs") & (all_checkpoint_results["inner_lr"] == 0.0)]
zero_shot_transfer_baseline = zero_shot_transfer_baseline[zero_shot_transfer_baseline["k"] == 4]

In [None]:
high_adaptation_checkpoints = all_checkpoint_results[~((all_checkpoint_results["learning_algorithm"] == "subgraphs") & (all_checkpoint_results["inner_lr"] == 0.0))]

high_adaptation_checkpoints = high_adaptation_checkpoints.sort_values(by=["k", "learning_algorithm", "n_updates", "inner_lr",])

high_adaptation_checkpoints = high_adaptation_checkpoints.drop_duplicates(subset=["k", "learning_algorithm", "checkpoint",], keep="last")

high_adaptation_checkpoints

In [None]:
from collections import defaultdict

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

metric = "mcc"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(3.03209, 9.72632/4))

all_k_shots = [4, 8, 12, 16]

x_ticks = []
x_tick_labels = []

agg_locs = defaultdict(list)
agg_values = defaultdict(list)

prev_k_loc = 0
cur_x = 0

zero_shot_vals = zero_shot_transfer_baseline[metric].to_numpy()
zero_shot_error = zero_shot_vals - zero_shot_transfer_baseline[f"{lower_metric}_lb"].to_numpy()

ax.errorbar(
    x=np.arange(cur_x, cur_x + zero_shot_vals.shape[0]),
    y=zero_shot_vals,
    yerr=zero_shot_error,
    fmt='o',
    color=checkpoint_cmap["zero_shot_transfer"],
    alpha=0.20,
    zorder=1
    )

ax.errorbar(
    np.arange(cur_x, cur_x + zero_shot_vals.shape[0]).mean(),
    zero_shot_vals.mean(),
    fmt='D',
    color=checkpoint_cmap["zero_shot_transfer"],
    alpha=1.0,
    label=learning_algorithm,
    zorder=2
    )

ax.hlines(
    y=zero_shot_vals.mean(), 
    xmin=0,
    xmax=1000,
    colors=["gray"],
    zorder=1,
    linestyles="--",
    alpha=0.75,
    )

x_ticks += [np.arange(cur_x, cur_x + zero_shot_vals.shape[0]).mean()]
x_tick_labels += ["0"]

cur_x = zero_shot_vals.shape[0] + 35

for k in all_k_shots:
    matching_k = high_adaptation_checkpoints["k"] == k
    
    for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
        matching_learning_alg = matching_k & (high_adaptation_checkpoints["learning_algorithm"] == learning_algorithm)
        
        for ckpt in ckpts:
            matching_ckpt = matching_learning_alg & (high_adaptation_checkpoints["checkpoint"] == ckpt)
            
            if high_adaptation_checkpoints[matching_ckpt][lower_metric].shape[0] > 1:
                raise KeyboardInterrupt()
            elif high_adaptation_checkpoints[matching_ckpt][lower_metric].shape[0] == 0:
                cur_x += 1
                continue
                        
            value = high_adaptation_checkpoints[matching_ckpt][lower_metric].item()
            value_error = value - high_adaptation_checkpoints[matching_ckpt][f"{lower_metric}_lb"].item()
            
            ax.errorbar(cur_x, value, yerr=value_error, fmt='o', color=cmap(i), alpha=0.20, zorder=1)
            
            cur_x += 1

        agg_value = high_adaptation_checkpoints[matching_learning_alg][lower_metric].mean()#.item()
        agg_loc = cur_x - len(ckpts) / 2

        agg_locs[learning_algorithm] += [agg_loc]
        agg_values[learning_algorithm] += [agg_value]
        
        ax.errorbar(agg_loc, agg_value, fmt='D', color=cmap(i), alpha=1.0, label=learning_algorithm, zorder=2)
        
        cur_x += 5

    x_ticks += [(prev_k_loc + cur_x) / 2]
    x_tick_labels += [k]
    
    cur_x += 25
    prev_k_loc = cur_x

ax.set_xlim(-25, cur_x)

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

custom_lines = [
    Line2D([0], [0], color=cmap(0), lw=4),
    Line2D([0], [0], color=cmap(1), lw=4),
    Line2D([0], [0], color=cmap(2), lw=4),
    Line2D([0], [0], color=cmap(3), lw=4),
    Line2D([0], [0], color=cmap(4), lw=4),
    ]

#ax.legend(custom_lines, list(checkpoints.keys()))

all_k_shots_range = np.arange(-10, cur_x+10, step=cur_x//10, dtype=float)
all_k_shots_x = np.stack([np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1)

for i, (learning_algorithm, values) in enumerate(agg_values.items()):

    color = checkpoint_cmap[learning_algorithm]

    all_k_shots = np.array(agg_locs[learning_algorithm], dtype=float)
    x = np.stack([np.ones_like(agg_locs[learning_algorithm], dtype=float), all_k_shots], axis=1)
    w_prefix = np.linalg.inv(x.T @ x) @ x.T
    
    y = np.array(values)
    w_ml = w_prefix @ y
    
    pred_y = all_k_shots_x @ w_ml
    
    ax.plot(all_k_shots_range, pred_y, c=color, alpha=0.75, zorder=1)

ax.set_title("HealthStory\n(high adaptation)", fontsize=11)
ax.set_ylabel("MCC", fontsize=9)
ax.set_xlabel("$k$-shot", fontsize=9)

fig.tight_layout()

plt.show()


## CoAID

In [None]:
checkpoints = {
    "subgraphs": ["vb49pmtr", "8kixi0s8", "vzswebea", "a4xe91oq", "6syunm30"],
    "maml_lh": ["zqhx6x3b", "11pt2nis", "ruy4hp9o", "nlfyh80j", "06pfaw4f"],
    "maml_rh": ["rpu3r1rm", "35neqj40", "xycl2xbl", "9hdrq4an", "905yf717"],
    "prototypical": ["2crszon2", "wdhcjrmp", "zwwyk83h", "lku2gxb8", "4pj3ewdu"],
    "prototypical": ["yjnx3e9w", "5e0vvr04", "br7qcerq", "6c8tvtwn", "ahp19u65"],
    "protomaml": ["9o4wp36l", "ouxd7twt", "aimkj4sa", "euh2mnqo", "p33ybhsn"],
}

metrics = [
    "supp_improvement",
    "f1_0",
    "aupr_0",
    "f1_1",
    "aupr_1",
    "mcc",
]

all_checkpoint_results = []
for learning_algorithm, ckpts in checkpoints.items():
    print(learning_algorithm)
    for checkpoint in ckpts:

        filters = {
            "dataset": "CoAID",
            "top_users_excluded": 0,
            "version": f"transfer_{checkpoint}",
            "structure_mode": "transductive",
        }

        raw_results, checkpoint_results, checkpoint_results_weighted = get_weighted_results_table(
            filters=filters,
            metrics=metrics,
            split="test",
            alpha=0.10
        )

        checkpoint_results_weighted["checkpoint"] = checkpoint
        checkpoint_results_weighted["learning_algorithm"] = learning_algorithm

        all_checkpoint_results.append(checkpoint_results_weighted)

#all_checkpoint_results = all_checkpoint_results[0]
all_checkpoint_results = pd.concat(all_checkpoint_results).reset_index(drop=True)


In [None]:
all_checkpoint_results[all_checkpoint_results["learning_algorithm"] == "protomaml"]

In [None]:
# Filter out some early attempts with too high learning rate
# These are noisy data points
all_checkpoint_results = all_checkpoint_results[~((all_checkpoint_results.class_weights == (1.0, 20.0)) | (all_checkpoint_results.class_weights == (1.0, 3.0)))]

all_checkpoint_results[(all_checkpoint_results["n_updates"] == 25)] = (
    all_checkpoint_results[(all_checkpoint_results["n_updates"] == 25)]
    .sort_values(by=["learning_algorithm", "checkpoint", "k", "inner_lr", "inner_head_lr"])
    .drop_duplicates(subset=["learning_algorithm", "checkpoint", "k"])
)

all_checkpoint_results.dropna(inplace=True)

all_checkpoint_results

### Low Adaptation

In [None]:
zero_shot_transfer_baseline = all_checkpoint_results[(all_checkpoint_results["learning_algorithm"] == "subgraphs") & (all_checkpoint_results["inner_lr"] == 0.0)]
zero_shot_transfer_baseline = zero_shot_transfer_baseline[zero_shot_transfer_baseline["k"] == 4]

low_adaptation_checkpoints = pd.concat([zero_shot_transfer_baseline, all_checkpoint_results]).drop_duplicates(subset=["k", "n_updates", "inner_lr", "inner_head_lr", "checkpoint"], keep=False)
low_adaptation_checkpoints = low_adaptation_checkpoints.sort_values(by=["k", "learning_algorithm", "n_updates", "inner_lr",])
low_adaptation_checkpoints = low_adaptation_checkpoints.drop_duplicates(subset=["k", "learning_algorithm", "checkpoint",], keep="first")

zero_shot_transfer_baseline["k"] = 0
low_adaptation_checkpoints = pd.concat([zero_shot_transfer_baseline, low_adaptation_checkpoints])

In [None]:
from collections import defaultdict

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

metric = "aupr_1"

lower_metric = metric.lower()

aggregated_df = normalize_by_group(
    df=low_adaptation_checkpoints,
    by=["k", "learning_algorithm"],
    metric=metric,
)

learning_algorithms_with_values = set(aggregated_df.index.get_level_values(level=1).unique().to_list())

fig, ax = plt.subplots(1, 1, figsize=figsize)

all_k_shots = [4, 8, 12, 16]

x_ticks = []
x_tick_labels = []

agg_locs = defaultdict(list)
agg_values = defaultdict(list)

prev_k_loc = 0
cur_x = 0

zero_shot_vals = zero_shot_transfer_baseline[metric].to_numpy()
zero_shot_error = zero_shot_vals - zero_shot_transfer_baseline[f"{lower_metric}_lb"].to_numpy()

ax.errorbar(
    x=np.arange(cur_x, cur_x + zero_shot_vals.shape[0]),
    y=zero_shot_vals,
    yerr=zero_shot_error,
    fmt='o',
    color=checkpoint_cmap["zero_shot_transfer"],
    alpha=0.20,
    zorder=1,
    markersize=markersize_minor,
    )

ax.errorbar(
    np.arange(cur_x, cur_x + zero_shot_vals.shape[0]).mean(),
    aggregated_df.xs((0.0, "subgraphs"))[f"{lower_metric}"],
    yerr=aggregated_df.xs((0.0, "subgraphs"))[f"{metric}_ub"]-aggregated_df.xs((0.0, "subgraphs"))[f"{lower_metric}"],
    fmt='D',
    color=checkpoint_cmap["zero_shot_transfer"],
    alpha=1.0,
    label=learning_algorithm,
    zorder=2,
    markersize=markersize_major,
    fillstyle="none",
    )

ax.hlines(
    y=zero_shot_vals.mean(), 
    xmin=0,
    xmax=1000,
    colors=["gray"],
    zorder=1,
    linestyles="--",
    alpha=0.75,
    )

x_ticks += [np.arange(cur_x, cur_x + zero_shot_vals.shape[0]).mean()]
x_tick_labels += ["0"]

cur_x = zero_shot_vals.shape[0] + 35

for k in all_k_shots:
    matching_k = low_adaptation_checkpoints["k"] == k
    
    for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
        
        if learning_algorithm not in learning_algorithms_with_values:
            continue
        
        matching_learning_alg = matching_k & (low_adaptation_checkpoints["learning_algorithm"] == learning_algorithm)
        
        for ckpt in ckpts:
            matching_ckpt = matching_learning_alg & (low_adaptation_checkpoints["checkpoint"] == ckpt)
            
            if low_adaptation_checkpoints[matching_ckpt][lower_metric].shape[0] > 1:
                raise KeyboardInterrupt()
            elif low_adaptation_checkpoints[matching_ckpt][lower_metric].shape[0] == 0:
                cur_x += 1
                continue
                        
            value = low_adaptation_checkpoints[matching_ckpt][lower_metric].item()
            value_error = value - low_adaptation_checkpoints[matching_ckpt][f"{lower_metric}_lb"].item()
            
            ax.errorbar(cur_x, value, yerr=value_error, fmt='o', color=cmap(i), alpha=0.20, zorder=1, markersize=markersize_minor)
            
            cur_x += 1

        agg_loc = cur_x - len(ckpts) / 2
        agg_value = aggregated_df.xs((k, learning_algorithm))[metric]
        agg_error = aggregated_df.xs((k, learning_algorithm))[f"{lower_metric}_ub"] - aggregated_df.xs((k, learning_algorithm))[metric]

        agg_locs[learning_algorithm] += [agg_loc]
        agg_values[learning_algorithm] += [agg_value]
        
        ax.errorbar(agg_loc, agg_value, yerr=agg_error, fmt='D', color=cmap(i), alpha=1.0, label=learning_algorithm, zorder=2, markersize=markersize_major, fillstyle="none",)
        
        cur_x += 5

    x_ticks += [(prev_k_loc + cur_x) / 2]
    x_tick_labels += [k]
    
    cur_x += 25
    prev_k_loc = cur_x

ax.set_xlim(-25, cur_x)

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

custom_lines = [
    Line2D([0], [0], color=cmap(0), lw=4),
    Line2D([0], [0], color=cmap(1), lw=4),
    Line2D([0], [0], color=cmap(2), lw=4),
    Line2D([0], [0], color=cmap(3), lw=4),
    Line2D([0], [0], color=cmap(4), lw=4),
    ]

#ax.legend(custom_lines, list(checkpoints.keys()))

all_k_shots_range = np.arange(min(x_ticks[1:]), max(x_ticks[1:]), step=1, dtype=float)
all_k_shots_x = np.stack([np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1)

for i, (learning_algorithm, values) in enumerate(agg_values.items()):

    color = checkpoint_cmap[learning_algorithm]

    all_k_shots = np.array(x_ticks[1:], dtype=float)
    
    ax.plot(all_k_shots, values, c=color, alpha=0.75, zorder=1)

#ax.set_title("HealthStory\n(low adaptation)", fontsize=11)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot\nLow Adaptation", fontsize=fontsize_major)

ax.legend(
    custom_lines,
    list(checkpoints.keys()),
    loc='upper center',
    bbox_to_anchor=(0.5, 1.25),
    fontsize=fontsize_minor,
    ncol=3,
    )

fig.tight_layout()

plt.show()


In [None]:
all_agg_dfs = []
for metric in ["f1_0", "aupr_0", "f1_1", "aupr_1", "mcc"]:
    all_agg_dfs += [normalize_by_group(
        df=low_adaptation_checkpoints,
        by=["k", "learning_algorithm"],
        metric=metric,
        remove={"var", "se", "N"}
    )]
    
all_agg_dfs = pd.concat(
    all_agg_dfs,
    axis=1,
    join="outer"
    ).reset_index()

all_agg_dfs["learning_algorithm"] = pd.Series(pd.Categorical(
    values=all_agg_dfs["learning_algorithm"],
    categories=list(checkpoints.keys()),
    ordered=True
    ))

all_agg_dfs = all_agg_dfs.sort_values(by=["k", "learning_algorithm"]).set_index(keys=["k", "learning_algorithm"])

all_agg_dfs.to_clipboard(excel=True,)

all_agg_dfs

### High Adaptation

In [None]:
zero_shot_transfer_baseline = all_checkpoint_results[(all_checkpoint_results["learning_algorithm"] == "subgraphs") & (all_checkpoint_results["inner_lr"] == 0.0)]
zero_shot_transfer_baseline = zero_shot_transfer_baseline[zero_shot_transfer_baseline["k"] == 4]

high_adaptation_checkpoints = pd.concat([zero_shot_transfer_baseline, all_checkpoint_results]).drop_duplicates(subset=["k", "n_updates", "inner_lr", "inner_head_lr", "checkpoint"], keep=False)
high_adaptation_checkpoints = high_adaptation_checkpoints.sort_values(by=["k", "learning_algorithm", "n_updates", "inner_lr",])
high_adaptation_checkpoints = high_adaptation_checkpoints.drop_duplicates(subset=["k", "learning_algorithm", "checkpoint",], keep="last")

zero_shot_transfer_baseline["k"] = 0
high_adaptation_checkpoints = pd.concat([zero_shot_transfer_baseline, high_adaptation_checkpoints])

In [None]:
all_checkpoint_results[(all_checkpoint_results["learning_algorithm"] == "protomaml")]

In [None]:
high_adaptation_checkpoints.learning_algorithm.unique()

In [None]:
from collections import defaultdict

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

metric = "mcc"

lower_metric = metric.lower()

aggregated_df = normalize_by_group(
    df=high_adaptation_checkpoints,
    by=["k", "learning_algorithm"],
    metric=metric,
)

learning_algorithms_with_values = set(aggregated_df.index.get_level_values(level=1).unique().to_list())

fig, ax = plt.subplots(1, 1, figsize=figsize)

all_k_shots = [4, 8, 12, 16]

x_ticks = []
x_tick_labels = []

agg_locs = defaultdict(list)
agg_values = defaultdict(list)

prev_k_loc = 0
cur_x = 0

zero_shot_vals = zero_shot_transfer_baseline[metric].to_numpy()
zero_shot_error = zero_shot_vals - zero_shot_transfer_baseline[f"{lower_metric}_lb"].to_numpy()

ax.errorbar(
    x=np.arange(cur_x, cur_x + zero_shot_vals.shape[0]),
    y=zero_shot_vals,
    yerr=zero_shot_error,
    fmt='o',
    color=checkpoint_cmap["zero_shot_transfer"],
    alpha=0.20,
    zorder=1,
    markersize=markersize_minor,
    )

ax.errorbar(
    np.arange(cur_x, cur_x + zero_shot_vals.shape[0]).mean(),
    aggregated_df.xs((0.0, "subgraphs"))[f"{lower_metric}"],
    yerr=aggregated_df.xs((0.0, "subgraphs"))[f"{lower_metric}_ub"]-aggregated_df.xs((0.0, "subgraphs"))[f"{lower_metric}"],
    fmt='D',
    color=checkpoint_cmap["zero_shot_transfer"],
    alpha=1.0,
    label=learning_algorithm,
    zorder=2,
    markersize=markersize_major,
    fillstyle="none",
    )

ax.hlines(
    y=zero_shot_vals.mean(), 
    xmin=0,
    xmax=1000,
    colors=["gray"],
    zorder=1,
    linestyles="--",
    alpha=0.75,
    )

x_ticks += [np.arange(cur_x, cur_x + zero_shot_vals.shape[0]).mean()]
x_tick_labels += ["0"]

cur_x = zero_shot_vals.shape[0] + 35

for k in all_k_shots:
    matching_k = high_adaptation_checkpoints["k"] == k
    
    for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
        
        if learning_algorithm not in learning_algorithms_with_values:
            continue
        
        matching_learning_alg = matching_k & (high_adaptation_checkpoints["learning_algorithm"] == learning_algorithm)
        
        for ckpt in ckpts:
            matching_ckpt = matching_learning_alg & (high_adaptation_checkpoints["checkpoint"] == ckpt)
            
            if high_adaptation_checkpoints[matching_ckpt][lower_metric].shape[0] > 1:
                raise KeyboardInterrupt()
            elif high_adaptation_checkpoints[matching_ckpt][lower_metric].shape[0] == 0:
                cur_x += 1
                continue
                        
            value = high_adaptation_checkpoints[matching_ckpt][lower_metric].item()
            value_error = value - high_adaptation_checkpoints[matching_ckpt][f"{lower_metric}_lb"].item()
            
            ax.errorbar(cur_x, value, yerr=value_error, fmt='o', color=cmap(i), alpha=0.20, zorder=1, markersize=markersize_minor)
            
            cur_x += 1

        agg_loc = cur_x - len(ckpts) / 2
        agg_value = aggregated_df.xs((k, learning_algorithm))[metric]
        agg_error = aggregated_df.xs((k, learning_algorithm))[f"{metric}_ub"] - aggregated_df.xs((k, learning_algorithm))[metric]

        agg_locs[learning_algorithm] += [agg_loc]
        agg_values[learning_algorithm] += [agg_value]
        
        ax.errorbar(agg_loc, agg_value, yerr=agg_error, fmt='D', color=cmap(i), alpha=1.0, label=learning_algorithm, zorder=2, markersize=markersize_major, fillstyle="none",)
        
        cur_x += 5

    x_ticks += [(prev_k_loc + cur_x) / 2]
    x_tick_labels += [k]
    
    cur_x += 25
    prev_k_loc = cur_x

ax.set_xlim(-25, cur_x)

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

custom_lines = [
    Line2D([0], [0], color=cmap(0), lw=0, marker="D", markersize=2),
    Line2D([0], [0], color=cmap(1), lw=0, marker="D", markersize=2),
    Line2D([0], [0], color=cmap(2), lw=0, marker="D", markersize=2),
    Line2D([0], [0], color=cmap(3), lw=0, marker="D", markersize=2),
    Line2D([0], [0], color=cmap(4), lw=0, marker="D", markersize=2),
    ]

#ax.legend(custom_lines, list(checkpoints.keys()))

all_k_shots_range = np.arange(min(x_ticks[1:]), max(x_ticks[1:]), step=1, dtype=float)
all_k_shots_x = np.stack([np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1)

for i, (learning_algorithm, values) in enumerate(agg_values.items()):

    color = checkpoint_cmap[learning_algorithm]

    all_k_shots = np.array(x_ticks[1:], dtype=float)
    
    ax.plot(all_k_shots, values, c=color, alpha=0.75, zorder=1)

ax.set_title(" ", fontsize=fontsize_major)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot\nCoAID", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_major)

ax.set_ylim(0, 0.25)
ax.set_yticks([0.0, 0.05, 0.10, 0.15, 0.20, 0.25])

fig.legend(
    custom_lines,
    ["Subgraphs", "MAML-LH", "MAML-RH", "ProtoNet", "ProtoMAML"],
    loc='upper center',
    #bbox_to_anchor=(0.5, 1.25),
    fontsize=fontsize_minor,
    ncol=3,
    )

fig.tight_layout()

plt.show()

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/coaid_gat_transfer_high_adapt.pdf"
)


In [None]:
coaid_transfer = high_adaptation_checkpoints.copy(deep=True)

In [None]:
all_agg_dfs = []
for metric in ["f1_0", "aupr_0", "f1_1", "aupr_1", "mcc"]:
    all_agg_dfs += [normalize_by_group(
        df=high_adaptation_checkpoints,
        by=["k", "learning_algorithm"],
        metric=metric,
        remove={"var", "se", "N"}
    )]
    
all_agg_dfs = pd.concat(
    all_agg_dfs,
    axis=1,
    join="outer"
    ).reset_index()

all_agg_dfs["learning_algorithm"] = pd.Series(pd.Categorical(
    values=all_agg_dfs["learning_algorithm"],
    categories=list(checkpoints.keys()),
    ordered=True
    ))

all_agg_dfs = all_agg_dfs.sort_values(by=["k", "learning_algorithm"]).set_index(keys=["k", "learning_algorithm"])

all_agg_dfs.to_clipboard(excel=True,)

all_agg_dfs

### MLP Baseline

In [None]:
gat_all_checkpoint_results = high_adaptation_checkpoints.copy(deep=True)

In [None]:
checkpoints = {
    "subgraphs": ["l339inkn", "vyeuzmyc", "20asxsz3", "6046x2gc", "pxllyec4"],
    "maml_lh": ["cjvoiuqn", "y5zaa74e", "k2l9yy52", "7rk7bfgn", "f9xtwlxp"],
    "protomaml": ["4kd4uk24", "ev8f5mch", "cj4jv0pl", "8qfxh531", "m4rg2i2i"],
}

all_checkpoint_results = []
for learning_algorithm, ckpts in checkpoints.items():
    print(learning_algorithm)
    for checkpoint in ckpts:

        filters = {
            "dataset": "CoAID",
            "top_users_excluded": 0,
            "version": f"transfer_{checkpoint}",
            "structure_mode": "transductive",
        }

        raw_results, checkpoint_results, checkpoint_results_weighted = get_weighted_results_table(
            filters=filters,
            metrics=metrics,
            split="test",
            alpha=0.10
        )

        checkpoint_results_weighted["checkpoint"] = checkpoint
        checkpoint_results_weighted["learning_algorithm"] = learning_algorithm

        all_checkpoint_results.append(checkpoint_results_weighted)

#all_checkpoint_results = all_checkpoint_results[0]
all_checkpoint_results = pd.concat(all_checkpoint_results).reset_index(drop=True)


In [None]:
all_checkpoint_results = all_checkpoint_results[all_checkpoint_results["class_weights"] == (1.0, 1.0)]

In [None]:
from collections import defaultdict

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

metric = "mcc"

lower_metric = metric.lower()

aggregated_df = normalize_by_group(
    df=all_checkpoint_results,
    by=["k", "learning_algorithm"],
    metric=metric,
)

learning_algorithms_with_values = set(aggregated_df.index.get_level_values(level=1).unique().to_list())

fig, ax = plt.subplots(1, 1, figsize=figsize)

all_k_shots = [4, 8, 12, 16]

x_ticks = []
x_tick_labels = []

agg_locs = defaultdict(list)
agg_values = defaultdict(list)

prev_k_loc = 0
cur_x = 0

for k in all_k_shots:
    matching_k = all_checkpoint_results["k"] == k
    
    for learning_algorithm, ckpts in checkpoints.items():

        color = checkpoint_cmap[learning_algorithm]

        if learning_algorithm not in learning_algorithms_with_values:
            continue
        
        matching_learning_alg = matching_k & (all_checkpoint_results["learning_algorithm"] == learning_algorithm)
        
        for ckpt in ckpts:
            matching_ckpt = matching_learning_alg & (all_checkpoint_results["checkpoint"] == ckpt)
            
            if all_checkpoint_results[matching_ckpt][lower_metric].shape[0] > 1:
                raise KeyboardInterrupt()
            elif all_checkpoint_results[matching_ckpt][lower_metric].shape[0] == 0:
                cur_x += 1
                continue
                        
            value = all_checkpoint_results[matching_ckpt][lower_metric].item()
            value_error = value - all_checkpoint_results[matching_ckpt][f"{lower_metric}_lb"].item()
            
            ax.errorbar(cur_x, value, yerr=value_error, fmt='o', color=color, alpha=0.20, zorder=1, markersize=markersize_minor)
            
            cur_x += 1

        agg_loc = cur_x - len(ckpts) / 2
        agg_value = aggregated_df.xs((k, learning_algorithm))[metric]
        agg_error = aggregated_df.xs((k, learning_algorithm))[f"{metric}_ub"] - aggregated_df.xs((k, learning_algorithm))[metric]

        agg_locs[learning_algorithm] += [agg_loc]
        agg_values[learning_algorithm] += [agg_value]
        
        ax.errorbar(agg_loc, agg_value, yerr=agg_error, fmt='D', color=color, alpha=1.0, label=learning_algorithm, zorder=2, markersize=markersize_major, fillstyle="none",)
        
        cur_x += 5

    x_ticks += [(prev_k_loc + cur_x) / 2]
    x_tick_labels += [k]
    
    cur_x += 25
    prev_k_loc = cur_x

ax.set_xlim(-25, cur_x)

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

custom_lines = [
    Line2D([0], [0], color=cmap(0), lw=4),
    Line2D([0], [0], color=cmap(1), lw=4),
    Line2D([0], [0], color=cmap(2), lw=4),
    Line2D([0], [0], color=cmap(3), lw=4),
    Line2D([0], [0], color=cmap(4), lw=4),
    ]

#ax.legend(custom_lines, list(checkpoints.keys()))

all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
all_k_shots_x = np.stack([np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1)

for i, (learning_algorithm, values) in enumerate(agg_values.items()):

    color = checkpoint_cmap[learning_algorithm]

    all_k_shots = np.array(x_ticks, dtype=float)
    x = np.stack([np.ones_like(x_ticks, dtype=float), all_k_shots], axis=1)
    w_prefix = np.linalg.inv(x.T @ x) @ x.T
    
    y = np.array(values)
    w_ml = w_prefix @ y
    
    pred_y = all_k_shots_x @ w_ml
    
    ax.plot(all_k_shots_range, pred_y, c=color, alpha=0.75, zorder=1)

#ax.set_title("HealthStory\n(low adaptation)", fontsize=11)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot\nHigh Adaptation", fontsize=fontsize_major)
ax.set_ylim(-0.03, 0.27)

ax.legend(
    custom_lines,
    list(checkpoints.keys()),
    loc='upper center',
    bbox_to_anchor=(0.5, 1.25),
    fontsize=fontsize_minor,
    ncol=3,
    )

fig.tight_layout()

plt.show()


In [None]:
mlp_all_checkpoint_results = all_checkpoint_results.copy(deep=True)

In [None]:
all_agg_dfs = []
for metric in ["f1_0", "aupr_0", "f1_1", "aupr_1", "mcc"]:
    all_agg_dfs += [normalize_by_group(
        df=all_checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
        remove={"var", "se", "N"}
    )]
    
all_agg_dfs = pd.concat(
    all_agg_dfs,
    axis=1,
    join="outer"
    ).reset_index()

all_agg_dfs["learning_algorithm"] = pd.Series(pd.Categorical(
    values=all_agg_dfs["learning_algorithm"],
    categories=list(checkpoints.keys()),
    ordered=True
    ))

all_agg_dfs = all_agg_dfs.sort_values(by=["k", "learning_algorithm"]).set_index(keys=["k", "learning_algorithm"])

all_agg_dfs.to_clipboard(excel=True,)

all_agg_dfs

#### Comparison

In [None]:
metric = "mcc"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(2 * figsize[0], figsize[1]))

all_k_shots = [4, 8, 12, 16]

custom_lines = []

for checkpoint_type, checkpoint_results in [("MLP UBS", mlp_all_checkpoint_results), ("GAT", high_adaptation_checkpoints)]:
    
    aggregated_df = normalize_by_group(
        df=checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )

    prev_k_loc = 0
    cur_x = 0

    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    for k in all_k_shots:
        matching_k = checkpoint_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            matching_learning_alg = matching_k & (
                checkpoint_results["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]

            for ckpt in ckpts:
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.errorbar(
                agg_loc,
                agg_value,
                yerr=agg_error,
                fmt="^" if checkpoint_type == "MLP UBS" else "v",
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                markersize=markersize_major,
                fillstyle="none",
            )

            agg_values[learning_algorithm] += [agg_value]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 25
        prev_k_loc = cur_x

    all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
    all_k_shots_x = np.stack(
        [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
    )

    for i, (learning_algorithm, values) in enumerate(agg_values.items()):
        color = checkpoint_cmap[learning_algorithm]

        all_k_shots_ = np.array(x_ticks, dtype=float)

        ax.plot(
            all_k_shots_,
            values,
            c=color,
            alpha=0.75,
            zorder=1,
            ls="-" if checkpoint_type == "GAT" else "--"
            )

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

for learning_algorithm, ckpts in checkpoints.items():
    
    color = checkpoint_cmap[learning_algorithm]
    
    custom_lines += [
        Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
    ]

custom_lines += [
    Line2D([0], [0], marker='v', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
    Line2D([0], [0], marker='^', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
]

ax.set_title(" ", fontsize=11)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

fig.legend(
    custom_lines,
    ["Subgraphs", "MAML", "ProtoMAML"] + ["MLP", "GAT"],
    loc='upper center',
    #bbox_to_anchor=(0.5, 1.35),
    fontsize=fontsize_minor,
    ncol=5,
    )

fig.tight_layout()

plt.show()

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/coaid_mlp_ubs_transfer_comparison_mcc.pdf",
)


In [None]:
metric = "aupr_1"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(2 * figsize[0], figsize[1]))

all_k_shots = [4, 8, 12, 16]

custom_lines = []

for checkpoint_type, checkpoint_results in [("MLP UBS", mlp_all_checkpoint_results), ("GAT", high_adaptation_checkpoints)]:
    
    aggregated_df = normalize_by_group(
        df=checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )

    prev_k_loc = 0
    cur_x = 0

    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    for k in all_k_shots:
        matching_k = checkpoint_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            matching_learning_alg = matching_k & (
                checkpoint_results["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]

            for ckpt in ckpts:
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.errorbar(
                agg_loc,
                agg_value,
                yerr=agg_error,
                fmt="^" if checkpoint_type == "MLP UBS" else "v",
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                markersize=markersize_major,
                fillstyle="none",
            )

            agg_values[learning_algorithm] += [agg_value]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 25
        prev_k_loc = cur_x

    all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
    all_k_shots_x = np.stack(
        [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
    )

    for i, (learning_algorithm, values) in enumerate(agg_values.items()):
        color = checkpoint_cmap[learning_algorithm]

        all_k_shots_ = np.array(x_ticks, dtype=float)

        ax.plot(
            all_k_shots_,
            values,
            c=color,
            alpha=0.75,
            zorder=1,
            ls="-" if checkpoint_type == "GAT" else "--"
            )

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

for learning_algorithm, ckpts in checkpoints.items():
    
    color = checkpoint_cmap[learning_algorithm]
    
    custom_lines += [
        Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
    ]

custom_lines += [
    Line2D([0], [0], marker='v', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
    Line2D([0], [0], marker='^', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
]

ax.set_title(" ", fontsize=11)
ax.set_ylabel("AUPR-Fake", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

fig.tight_layout()

plt.show()

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/coaid_mlp_ubs_transfer_comparison_aupr_1.pdf",
)


In [None]:
zero_shot_transfer_baseline = all_checkpoint_results[(all_checkpoint_results["learning_algorithm"] == "subgraphs") & (all_checkpoint_results["inner_lr"] == 0.0)]
zero_shot_transfer_baseline = zero_shot_transfer_baseline[zero_shot_transfer_baseline["k"] == 4]

low_adaptation_checkpoints = pd.concat([zero_shot_transfer_baseline, all_checkpoint_results]).drop_duplicates(subset=["k", "n_updates", "inner_lr", "inner_head_lr", "checkpoint"], keep=False)
low_adaptation_checkpoints = low_adaptation_checkpoints.sort_values(by=["k", "learning_algorithm", "n_updates", "inner_lr",])
low_adaptation_checkpoints = low_adaptation_checkpoints.drop_duplicates(subset=["k", "learning_algorithm", "checkpoint",], keep="first")

zero_shot_transfer_baseline["k"] = 0
low_adaptation_checkpoints = pd.concat([zero_shot_transfer_baseline, low_adaptation_checkpoints])

In [None]:
from collections import defaultdict

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

metric = "aupr_1"

lower_metric = metric.lower()

aggregated_df = normalize_by_group(
    df=high_adaptation_checkpoints,
    by=["k", "learning_algorithm"],
    metric=metric,
)

learning_algorithms_with_values = set(aggregated_df.index.get_level_values(level=1).unique().to_list())

fig, ax = plt.subplots(1, 1, figsize=figsize)

all_k_shots = [4, 8, 12, 16]

x_ticks = []
x_tick_labels = []

agg_locs = defaultdict(list)
agg_values = defaultdict(list)

prev_k_loc = 0
cur_x = 0

zero_shot_vals = zero_shot_transfer_baseline[metric].to_numpy()
zero_shot_error = zero_shot_vals - zero_shot_transfer_baseline[f"{lower_metric}_lb"].to_numpy()

ax.errorbar(
    x=np.arange(cur_x, cur_x + zero_shot_vals.shape[0]),
    y=zero_shot_vals,
    yerr=zero_shot_error,
    fmt='o',
    color=checkpoint_cmap["zero_shot_transfer"],
    alpha=0.20,
    zorder=1,
    markersize=markersize_minor,
    )

ax.errorbar(
    np.arange(cur_x, cur_x + zero_shot_vals.shape[0]).mean(),
    aggregated_df.xs((0.0, "subgraphs"))[f"{lower_metric}"],
    yerr=aggregated_df.xs((0.0, "subgraphs"))[f"{lower_metric}_ub"]-aggregated_df.xs((0.0, "subgraphs"))[f"{lower_metric}"],
    fmt='D',
    color=checkpoint_cmap["zero_shot_transfer"],
    alpha=1.0,
    label=learning_algorithm,
    zorder=2,
    markersize=markersize_major,
    fillstyle="none",
    )

ax.hlines(
    y=zero_shot_vals.mean(), 
    xmin=0,
    xmax=1000,
    colors=["gray"],
    zorder=1,
    linestyles="--",
    alpha=0.75,
    )

x_ticks += [np.arange(cur_x, cur_x + zero_shot_vals.shape[0]).mean()]
x_tick_labels += ["0"]

cur_x = zero_shot_vals.shape[0] + 35

for k in all_k_shots:
    matching_k = high_adaptation_checkpoints["k"] == k
    
    for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
        
        if learning_algorithm not in learning_algorithms_with_values:
            continue
        
        matching_learning_alg = matching_k & (high_adaptation_checkpoints["learning_algorithm"] == learning_algorithm)
        
        for ckpt in ckpts:
            matching_ckpt = matching_learning_alg & (high_adaptation_checkpoints["checkpoint"] == ckpt)
            
            if high_adaptation_checkpoints[matching_ckpt][lower_metric].shape[0] > 1:
                raise KeyboardInterrupt()
            elif high_adaptation_checkpoints[matching_ckpt][lower_metric].shape[0] == 0:
                cur_x += 1
                continue
                        
            value = high_adaptation_checkpoints[matching_ckpt][lower_metric].item()
            value_error = value - high_adaptation_checkpoints[matching_ckpt][f"{lower_metric}_lb"].item()
            
            ax.errorbar(cur_x, value, yerr=value_error, fmt='o', color=cmap(i), alpha=0.20, zorder=1, markersize=markersize_minor)
            
            cur_x += 1

        agg_loc = cur_x - len(ckpts) / 2
        agg_value = aggregated_df.xs((k, learning_algorithm))[metric]
        agg_error = aggregated_df.xs((k, learning_algorithm))[f"{metric}_ub"] - aggregated_df.xs((k, learning_algorithm))[metric]

        agg_locs[learning_algorithm] += [agg_loc]
        agg_values[learning_algorithm] += [agg_value]
        
        ax.errorbar(agg_loc, agg_value, yerr=agg_error, fmt='D', color=cmap(i), alpha=1.0, label=learning_algorithm, zorder=2, markersize=markersize_major, fillstyle="none",)
        
        cur_x += 5

    x_ticks += [(prev_k_loc + cur_x) / 2]
    x_tick_labels += [k]
    
    cur_x += 25
    prev_k_loc = cur_x

ax.set_xlim(-25, cur_x)

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

custom_lines = [
    Line2D([0], [0], color=cmap(0), lw=4),
    Line2D([0], [0], color=cmap(1), lw=4),
    Line2D([0], [0], color=cmap(2), lw=4),
    Line2D([0], [0], color=cmap(3), lw=4),
    Line2D([0], [0], color=cmap(4), lw=4),
    ]

#ax.legend(custom_lines, list(checkpoints.keys()))

all_k_shots_range = np.arange(min(x_ticks[1:]), max(x_ticks[1:]), step=1, dtype=float)
all_k_shots_x = np.stack([np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1)

for i, (learning_algorithm, values) in enumerate(agg_values.items()):

    color = checkpoint_cmap[learning_algorithm]

    all_k_shots = np.array(x_ticks[1:], dtype=float)
    
    ax.plot(all_k_shots, values, c=color, alpha=0.75, zorder=1)

#ax.set_title("HealthStory\n(low adaptation)", fontsize=11)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot\nHigh Adaptation", fontsize=fontsize_major)

ax.legend(
    custom_lines,
    list(checkpoints.keys()),
    loc='upper center',
    bbox_to_anchor=(0.5, 1.25),
    fontsize=fontsize_minor,
    ncol=3,
    )

fig.tight_layout()

plt.show()


### Reset GATs

In [None]:
checkpoints = {
    # These were not trained as MAML_RH
    # As these are reset models
    # But the fairest comparsion is to MAML_RH
    "maml_rh": ["vb49pmtr", "8kixi0s8", "vzswebea", "a4xe91oq", "6syunm30"],
    "protomaml": ["9o4wp36l", "ouxd7twt", "aimkj4sa", "euh2mnqo", "p33ybhsn"],
}

reset_checkpoint_results = []
for learning_algorithm, ckpts in checkpoints.items():
    print(learning_algorithm)
    for checkpoint_seed, checkpoint in enumerate(ckpts):

        filters = {
            "dataset": "CoAID",
            "top_users_excluded": 0,
            "version": f"reset_{checkpoint}_{checkpoint_seed}",
            "structure_mode": "transductive",
        }

        raw_results, checkpoint_results, checkpoint_results_weighted = get_weighted_results_table(
            filters=filters,
            metrics=metrics,
            split="test",
            alpha=0.10
        )

        checkpoint_results_weighted["checkpoint"] = checkpoint
        checkpoint_results_weighted["learning_algorithm"] = learning_algorithm

        reset_checkpoint_results.append(checkpoint_results_weighted)

#reset_checkpoint_results = reset_checkpoint_results[0]
reset_checkpoint_results = pd.concat(reset_checkpoint_results).reset_index(drop=True)

In [None]:
reset_checkpoint_results = reset_checkpoint_results[reset_checkpoint_results["class_weights"] == (1.0, 1.0)]

In [None]:
metric = "mcc"

lower_metric = metric.lower()

aggregated_df = normalize_by_group(
    df=reset_checkpoint_results,
    by=["k", "learning_algorithm"],
    metric=metric,
)

fig, ax = plt.subplots(1, 1, figsize=figsize)

all_k_shots = [4, 8, 12, 16]

x_ticks = []
x_tick_labels = []

agg_values = defaultdict(list)
custom_lines = []

prev_k_loc = 0
cur_x = 0
for k in all_k_shots:
    matching_k = reset_checkpoint_results["k"] == k

    for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
        matching_learning_alg = matching_k & (
            reset_checkpoint_results["learning_algorithm"] == learning_algorithm
        )

        color = checkpoint_cmap[learning_algorithm]

        for ckpt in ckpts:
            matching_ckpt = matching_learning_alg & (
                reset_checkpoint_results["checkpoint"] == ckpt
            )

            if reset_checkpoint_results[matching_ckpt][lower_metric].shape[0] > 1:
                raise KeyboardInterrupt()
            elif reset_checkpoint_results[matching_ckpt][lower_metric].shape[0] == 0:
                cur_x += 1
                continue

            value = reset_checkpoint_results[matching_ckpt][lower_metric].item()
            value_error = (
                value
                - reset_checkpoint_results[matching_ckpt][f"{lower_metric}_lb"].item()
            )

            ax.errorbar(
                cur_x,
                value,
                yerr=value_error,
                fmt="o",
                color=color,
                alpha=0.20,
                zorder=0,
                markersize=markersize_minor,
            )

            cur_x += 1

        agg_row = aggregated_df.xs(key=(k, learning_algorithm))
        agg_value = agg_row[metric].item()
        agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

        agg_loc = cur_x - len(ckpts) / 2

        ax.errorbar(
            agg_loc,
            agg_value,
            yerr=agg_error,
            fmt="D",
            color=color,
            alpha=1.0,
            label=learning_algorithm,
            zorder=2,
            markersize=markersize_major,
            fillstyle="none",
        )

        agg_values[learning_algorithm] += [agg_value]

        custom_lines += [
            Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
        ]

        cur_x += 5

    x_ticks += [(prev_k_loc + cur_x) / 2]
    x_tick_labels += [k]

    cur_x += 25
    prev_k_loc = cur_x

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=len(all_k_shots) * len(checkpoints), dtype=float)
all_k_shots_x = np.stack(
    [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
)

for i, (learning_algorithm, values) in enumerate(agg_values.items()):
    color = checkpoint_cmap[learning_algorithm]

    all_k_shots = np.array(x_ticks, dtype=float)

    ax.plot(all_k_shots, values, c=color, alpha=0.75, zorder=1)

#ax.set_title("Twitter Hate Speech", fontsize=11)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

ax.legend(
    custom_lines,
    list(checkpoints.keys()),
    loc='upper center',
    bbox_to_anchor=(0.5, 1.15),
    fontsize=fontsize_minor,
    ncol=3,
    )

fig.tight_layout()

plt.show()


In [None]:
all_agg_dfs = []
for metric in ["f1_0", "aupr_0", "f1_1", "aupr_1", "mcc"]:
    all_agg_dfs += [normalize_by_group(
        df=reset_checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
        remove={"var", "se", "N"}
    )]
    
all_agg_dfs = pd.concat(
    all_agg_dfs,
    axis=1,
    join="outer"
    ).reset_index()

all_agg_dfs["learning_algorithm"] = pd.Series(pd.Categorical(
    values=all_agg_dfs["learning_algorithm"],
    categories=list(checkpoints.keys()),
    ordered=True
    ))

all_agg_dfs = all_agg_dfs.sort_values(by=["k", "learning_algorithm"]).set_index(keys=["k", "learning_algorithm"])

all_agg_dfs.to_clipboard(excel=True,)

all_agg_dfs


#### Comparison

In [None]:
metric = "mcc"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(2 * figsize[0], figsize[1]))

all_k_shots = [4, 8, 12, 16]

custom_lines = []

for checkpoint_type, checkpoint_results in [("gat_reset", reset_checkpoint_results), ("gat", high_adaptation_checkpoints)]:
    
    aggregated_df = normalize_by_group(
        df=checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )

    prev_k_loc = 0
    cur_x = 0

    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    for k in all_k_shots:
        matching_k = checkpoint_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            matching_learning_alg = matching_k & (
                checkpoint_results["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]

            for ckpt in ckpts:
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.errorbar(
                agg_loc,
                agg_value,
                yerr=agg_error,
                fmt="<" if checkpoint_type == "gat_reset" else "v",
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                markersize=markersize_major,
                fillstyle="none",
            )

            agg_values[learning_algorithm] += [agg_value]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 25
        prev_k_loc = cur_x

    all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
    all_k_shots_x = np.stack(
        [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
    )

    for i, (learning_algorithm, values) in enumerate(agg_values.items()):
        color = checkpoint_cmap[learning_algorithm]

        all_k_shots_ = np.array(x_ticks, dtype=float)

        ax.plot(
            all_k_shots_,
            values,
            c=color,
            alpha=0.75,
            zorder=1,
            ls="-" if checkpoint_type == "gat" else "--"
            )

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

for learning_algorithm, ckpts in checkpoints.items():
    
    color = checkpoint_cmap[learning_algorithm]
    
    custom_lines += [
        Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
    ]

custom_lines += [
    Line2D([0], [0], marker='<', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
    Line2D([0], [0], marker='^', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
]

ax.set_title(" ", fontsize=11)
ax.set_ylabel("MCC", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

fig.legend(
    custom_lines,
    ["MAML", "ProtoMAML"] + ["Reset", "GAT"],
    loc='upper center',
    #bbox_to_anchor=(0.5, 1.25),
    fontsize=fontsize_minor,
    ncol=4,
    )

fig.tight_layout()

plt.show()

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/coaid_random_transfer_comparison_mcc.pdf",
)


In [None]:
metric = "aupr_1"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(2 * figsize[0], figsize[1]))

all_k_shots = [4, 8, 12, 16]

custom_lines = []

for checkpoint_type, checkpoint_results in [("gat_reset", reset_checkpoint_results), ("gat", high_adaptation_checkpoints)]:
    
    aggregated_df = normalize_by_group(
        df=checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )

    prev_k_loc = 0
    cur_x = 0

    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    for k in all_k_shots:
        matching_k = checkpoint_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            matching_learning_alg = matching_k & (
                checkpoint_results["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]

            for ckpt in ckpts:
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.errorbar(
                agg_loc,
                agg_value,
                yerr=agg_error,
                fmt="<" if checkpoint_type == "gat_reset" else "v",
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                markersize=markersize_major,
                fillstyle="none",
            )

            agg_values[learning_algorithm] += [agg_value]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 25
        prev_k_loc = cur_x

    all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
    all_k_shots_x = np.stack(
        [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
    )

    for i, (learning_algorithm, values) in enumerate(agg_values.items()):
        color = checkpoint_cmap[learning_algorithm]

        all_k_shots_ = np.array(x_ticks, dtype=float)

        ax.plot(
            all_k_shots_,
            values,
            c=color,
            alpha=0.75,
            zorder=1,
            ls="-" if checkpoint_type == "gat" else "--"
            )

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

for learning_algorithm, ckpts in checkpoints.items():
    
    color = checkpoint_cmap[learning_algorithm]
    
    custom_lines += [
        Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
    ]

custom_lines += [
    Line2D([0], [0], marker='<', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
    Line2D([0], [0], marker='^', markerfacecolor="k", color='w', lw=0, markersize=markersize_major+2),
]

ax.set_title(" ", fontsize=11)
ax.set_ylabel("AUPR-Fake", fontsize=fontsize_major)
ax.set_xlabel("$k$-shot", fontsize=fontsize_major)
ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

#ax.legend(
#    custom_lines,
#    list(checkpoints.keys()) + ["Reset", "GAT"],
#    loc='upper center',
#    bbox_to_anchor=(0.5, 1.25),
#    fontsize=fontsize_minor,
#    ncol=4,
#    )

fig.tight_layout()

plt.show()

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/coaid_random_transfer_comparison_aupr_1.pdf",
)


# CoAID & Twitter Transfer Figure

In [None]:
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import matplotlib.style as style

#style.use('tableau-colorblind10')

cmap = plt.get_cmap("tab10")

checkpoint_cmap = {
    "zero_shot_transfer": cmap(0),
    "subgraphs": cmap(0),
    "maml_lh": cmap(1),
    "maml_rh": cmap(2),
    "prototypical": cmap(3),
    "protomaml": cmap(4),
}

checkpoint_fmt = {
    "zero_shot_transfer": "o",
    "subgraphs": "o",
    "maml_lh": "v",
    "maml_rh": "^",
    "prototypical": "d",
    "protomaml": "D",
}

In [None]:
metric = "mcc"
lower_metric = metric.lower()

markersize_minor = 2

all_k_shots = [4, 8, 12, 16]

fig, axes = plt.subplots(1, 2, figsize=(2 * figsize[0], figsize[1]), layout="constrained")

custom_lines = []
for ax_num, (dataset_name, transfer_results, ax) in enumerate(zip(["CoAID", "TwitterHateSpeech"], [coaid_transfer, twitter_transfer], axes)):
    aggregated_df = normalize_by_group(
        df=transfer_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )
    
    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    prev_k_loc = 0
    cur_x = 0

    if dataset_name in {"CoAID"}:
        zero_shot_vals = zero_shot_transfer_baseline[metric].to_numpy()

        ax.hlines(
            y=zero_shot_vals.mean(), 
            xmin=0,
            xmax=1000,
            colors=["gray"],
            zorder=1,
            linestyles="--",
            alpha=0.75,
            )

    for k in all_k_shots:
        matching_k = transfer_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            matching_learning_alg = matching_k & (
                transfer_results["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]
            fmt = checkpoint_fmt[learning_algorithm]

            for ckpt in ckpts:
                matching_ckpt = matching_learning_alg & (
                    transfer_results["checkpoint"] == ckpt
                )

                if transfer_results[matching_ckpt][lower_metric].shape[0] > 1:
                    raise KeyboardInterrupt()
                elif transfer_results[matching_ckpt][lower_metric].shape[0] == 0:
                    cur_x += 1
                    continue

                value = transfer_results[matching_ckpt][lower_metric].item()
                value_error = (
                    value
                    - transfer_results[matching_ckpt][f"{lower_metric}_lb"].item()
                )

                ax.errorbar(
                    cur_x,
                    value,
                    yerr=value_error,
                    fmt=fmt,
                    color=color,
                    alpha=0.25,
                    zorder=1,
                    markersize=markersize_minor,
                )
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.scatter(
                agg_loc,
                agg_value,
                marker=fmt,
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                s=5 * markersize_major,
            )

            agg_values[learning_algorithm] += [agg_value]

            if ax_num == 0:
                custom_lines += [
                    Line2D([0], [0], marker=fmt, markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
                ]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 10
        prev_k_loc = cur_x

    all_k_shots_range = np.arange(min(x_ticks), max(x_ticks), step=1, dtype=float)
    all_k_shots_x = np.stack(
        [np.ones_like(all_k_shots_range, dtype=float), all_k_shots_range], axis=1
    )

    for i, (learning_algorithm, values) in enumerate(agg_values.items()):
        color = checkpoint_cmap[learning_algorithm]

        #ax.plot(x_ticks, values, c=color, alpha=1, zorder=0, linewidth=1)

    ax.set_title(" ", fontsize=16)
    
    ax.set_xlim(-10, cur_x)
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(x_tick_labels, fontsize=fontsize_minor)

    ax.set_ylim(0, 0.25)
    ax.set_yticks([0.0, 0.05, 0.10, 0.15, 0.20, 0.25])
    ax.set_yticklabels(["0.00", "0.05", "0.10", "0.15", "0.20", "0.25"], fontsize=fontsize_minor)
    
    ax.set_xlabel(dataset_name, fontsize=fontsize_major)
    if ax_num == 0:
        ax.set_ylabel("MCC", fontsize=fontsize_major)
    else:
        ax.set_yticklabels([])

    #ax.tick_params(axis='both', which='major', labelsize=fontsize_minor)

fig.legend(
    custom_lines,
    ["Subgraphs", "MAML-LH", "MAML-RH", "ProtoNet", "ProtoMAML"],
    loc='upper center',
    #bbox_to_anchor=(0.5, 1.25),
    fontsize=fontsize_minor,
    ncol=5,
    )

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/coaid_and_twitter_transfer_line.pdf"
)


## Extreme $k$-shot

In [None]:
metric = "mcc"

lower_metric = metric.lower()

fig, ax = plt.subplots(1, 1, figsize=(2 * figsize[0], figsize[1]), layout="constrained")

tested_learning_algorithms = [
    "prototypical"
]

all_k_shots = [4, 8, 12, 16, 32, 64, 128, 256]

custom_lines = []

for checkpoint_type, checkpoint_results in [("gat", twitter_transfer)]:
    
    aggregated_df = normalize_by_group(
        df=checkpoint_results,
        by=["k", "learning_algorithm"],
        metric=metric,
    )
    
    aggregated_df = aggregated_df.loc[pd.IndexSlice[:, tested_learning_algorithms], :]
    
    prev_k_loc = 0
    cur_x = 0

    x_ticks = []
    x_tick_labels = []

    agg_values = defaultdict(list)

    for k in all_k_shots:
        matching_k = checkpoint_results["k"] == k

        for i, (learning_algorithm, ckpts) in enumerate(checkpoints.items()):
            if learning_algorithm not in {"prototypical"}:
                continue

            matching_learning_alg = matching_k & (
                twitter_transfer["learning_algorithm"] == learning_algorithm
            )

            color = checkpoint_cmap[learning_algorithm]
            fmt = checkpoint_fmt[learning_algorithm]

            for ckpt in ckpts:
                matching_ckpt = matching_learning_alg & (
                    twitter_transfer["checkpoint"] == ckpt
                )

                if twitter_transfer[matching_ckpt][lower_metric].shape[0] > 1:
                    raise KeyboardInterrupt()
                elif twitter_transfer[matching_ckpt][lower_metric].shape[0] == 0:
                    cur_x += 1
                    continue

                value = twitter_transfer[matching_ckpt][lower_metric].item()
                value_error = (
                    value
                    - twitter_transfer[matching_ckpt][f"{lower_metric}_lb"].item()
                )

                ax.errorbar(
                    cur_x,
                    value,
                    yerr=value_error,
                    fmt=fmt,
                    color=color,
                    alpha=0.25,
                    zorder=1,
                    markersize=markersize_minor,
                )
                cur_x += 1

            agg_row = aggregated_df.xs(key=(k, learning_algorithm))
            agg_value = agg_row[metric].item()
            agg_error = (agg_row[f"{metric}_ub"] - agg_row[metric]).item()

            agg_loc = cur_x - len(ckpts) / 2

            ax.scatter(
                agg_loc,
                agg_value,
                marker=fmt,
                color=color,
                alpha=1.0,
                label=learning_algorithm,
                zorder=2,
                s=5 * markersize_major,
            )

            agg_values[learning_algorithm] += [agg_value]

            cur_x += 5

        x_ticks += [(prev_k_loc + cur_x) / 2]
        x_tick_labels += [k]

        cur_x += 25
        prev_k_loc = cur_x

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)

learning_algorithm = "prototypical"
    
color = checkpoint_cmap[learning_algorithm]

custom_lines += [
    Line2D([0], [0], marker='D', markerfacecolor=color, color='w', lw=0, markersize=markersize_major),
]

ax.set_title(" ", fontsize=16)

ax.set_xlim(-10, cur_x-20)
ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels, fontsize=fontsize_minor)

ax.set_ylim(0, 0.25)
ax.set_yticks([0.0, 0.05, 0.10, 0.15, 0.20, 0.25])
ax.set_yticklabels(["0.00", "0.05", "0.10", "0.15", "0.20", "0.25"], fontsize=fontsize_minor)

ax.set_xlabel(dataset_name, fontsize=fontsize_major)
ax.set_ylabel("MCC", fontsize=fontsize_major)

fig.legend(
    [
        Line2D([0], [0], marker=fmt, markerfacecolor=color, color='w', lw=0, markersize=markersize_major)
        ],
    ["ProtoNet"],
    loc='upper center',
    #bbox_to_anchor=(0.5, 1.25),
    fontsize=fontsize_minor,
    ncol=5,
    )

plt.show()

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/twitter_extreme_kshot_mcc.pdf"
)