In [1]:
from pprint import pprint

import numpy as np
import yaml
from easydict import EasyDict
from tqdm import tqdm

import wandb

In [2]:
api = wandb.Api()

with open("configs/ColourInteract.yaml", "r") as file:
    config = EasyDict(yaml.safe_load(file))

c2_over_c1s = [0.01, 0.1, 1.0, 10.0, 100.0]

In [3]:
runs = api.runs(f"{config.wandb.entity}/{config.wandb.project}")
print(f"{config.wandb.entity}/{config.wandb.project}")

all_run_names = []

for approach in tqdm(config.model.approaches):
    if approach == "only_original":
        rewirers = [None]
    else:
        rewirers = config.data.rewirers
    for c2_over_c1 in c2_over_c1s:
        for rewirer in rewirers:
            statistics = []
            for seed in config.model.seeds:
                
                run_name = f"{config.data.dataset}-{approach}-rewired-with-{rewirer}-c2/c1-{c2_over_c1}-seed-{seed}"

                all_run_names.append(run_name)

commute_opt_gnn/ColourInteract Sweeps Wednesday Night


100%|██████████| 2/2 [00:00<00:00, 8674.88it/s]


In [4]:
all_run_statistics = {}

for run in tqdm(runs):  
    if run.name in all_run_names:
        run_stats = run.history(keys=["eval/loss"], pandas=False)
        run_stats = [stats["eval/loss"] for stats in run_stats]
        all_run_statistics[run.name] = run_stats

100%|██████████| 75/75 [00:25<00:00,  2.91it/s]


In [5]:
results = {run: {} for run in ["only_original_None", "interleave_cayley", "interleave_unconnected_cayley_clusters", "interleave_fully_connected_clusters", "interleave_fully_connected"]}

all_run_names = []

for approach in tqdm(config.model.approaches):
    if approach == "only_original":
        rewirers = [None]
    else:
        rewirers = config.data.rewirers
    for c2_over_c1 in c2_over_c1s:
        for rewirer in rewirers:
            statistics = []
            for seed in config.model.seeds:
                
                run_name = f"{config.data.dataset}-{approach}-rewired-with-{rewirer}-c2/c1-{c2_over_c1}-seed-{seed}"

                statistics.append(all_run_statistics[run_name])

            statistics = np.array(statistics)
            
            mins = statistics[:, -1]

            results[f"{approach}_{rewirer}"][c2_over_c1] = mins

100%|██████████| 2/2 [00:00<00:00, 2117.27it/s]


In [6]:
import copy

processed_results = copy.deepcopy(results)

for c2_over_c1 in c2_over_c1s:

    for rewirer in results:
        normalised_results = []
        
        for seed_idx in range(len(config.model.seeds)):
            plain_gin = results["only_original_None"][c2_over_c1][seed_idx]
            
            normalised_results.append(processed_results[rewirer][c2_over_c1][seed_idx] / plain_gin)
            
        processed_results[rewirer][c2_over_c1] = normalised_results

In [7]:
pprint(processed_results)

{'interleave_cayley': {0.01: [7.149752095098683,
                              28.37402707570636,
                              24.347761152332637],
                       0.1: [0.3511203542133236,
                             0.7233843486570287,
                             1.9634477793537741],
                       1.0: [0.4669103442659743,
                             0.7369839532972746,
                             0.4194543512382267],
                       10.0: [0.4392665961742164,
                              0.3631101331498,
                              0.359544500484152],
                       100.0: [1.0036156918011256,
                               0.5788429688285017,
                               0.3896137190529125]},
 'interleave_fully_connected': {0.01: [582.2179469146237,
                                       83.3878696474904,
                                       2540.2867620305524],
                                0.1: [55.716485786153626,
                    

In [8]:
for rewirer in results:
    for c2_over_c1 in c2_over_c1s:
        processed_results[rewirer][c2_over_c1] = (np.median(processed_results[rewirer][c2_over_c1]), np.mean(processed_results[rewirer][c2_over_c1]), np.std(processed_results[rewirer][c2_over_c1]))

In [9]:
pprint(processed_results)

{'interleave_cayley': {0.01: (24.347761152332637,
                              19.95718010771256,
                              9.204178889509738),
                       0.1: (0.7233843486570287,
                             1.0126508274080421,
                             0.689278005481952),
                       1.0: (0.4669103442659743,
                             0.5411162162671586,
                             0.13984788313386903),
                       10.0: (0.3631101331498,
                              0.38730707660272284,
                              0.03676975378080833),
                       100.0: (0.5788429688285017,
                               0.6573574598941799,
                               0.2567398159163621)},
 'interleave_fully_connected': {0.01: (582.2179469146237,
                                       1068.630859530889,
                                       1060.3572300623214),
                                0.1: (55.716485786153626,
                

In [10]:
num_strings = {rewirer: {} for rewirer in results}

for rewirer in results:
    print(f"Rewirer: {rewirer}")
    middle_string = ""
    below_string = ""
    above_string = ""

    for c2_over_c1 in c2_over_c1s:
        # if c2_over_c1 == 0.5:
        #     print(rewirer)
        #     print(processed_results[rewirer][c2_over_c1])
        median, mean, std = processed_results[rewirer][c2_over_c1]
        below_string += f"({c2_over_c1:.4f},{(mean-std):.4f})"
        above_string += f"({c2_over_c1:.4f},{(mean+std):.4f})"
        middle_string += f"({c2_over_c1:.4f},{mean:.4f})"

    num_strings[rewirer]["middle"] = middle_string
    num_strings[rewirer]["below"] = below_string
    num_strings[rewirer]["above"] = above_string

Rewirer: only_original_None
Rewirer: interleave_cayley
Rewirer: interleave_unconnected_cayley_clusters
Rewirer: interleave_fully_connected_clusters
Rewirer: interleave_fully_connected


In [12]:
plot = r"""\begin{tikzpicture}
\begin{axis}[
    xmin=0.01, xmax=100,
    ymin=0, ymax=2,
    ymajorgrids=true,
    xlabel=$c_2/c_1$,
    ylabel=$\textrm{loss}/\textrm{loss}_{\textrm{Plain GIN}}$,
    grid style=dashed,
    xmode=log,
    xtick={0.01, 0.1,1,10, 100}, % Specify the positions of the ticks
    xticklabels={$0.01$, $0.1$, $1$, $10$, $100$}, % Specify the labels for the ticks
    ytick={0.0, 0.5, 1.0,1.5,2.0}, % Specify the positions of the ticks
    yticklabels={-, $0.5$, $1.0$, $1.5$, $2.0$}, % Specify the labels for the ticks
    legend pos=north east,
    width=0.9\textwidth,
]
"""

colours = ["black", "blue", "red", "green", "orange", "purple", "brown"]
shading = ["gray", "blue", "red", "green", "orange", "purple", "brown"]

for idx, rewirer in enumerate(results):
    plot += r"\addplot[color=" + colours[idx] + ", ultra thick, mark=ball] coordinates{" + num_strings[rewirer]["middle"] + "};\n\n"

for idx, rewirer in enumerate(results):
    print(rewirer)
    plot += r"\addplot[name path=" + rewirer + "_top,color=" + shading[idx] + r"!70,dashed] coordinates {" + num_strings[rewirer]["above"] + "};\n\n"
    plot += r"\addplot[name path=" + rewirer + "_down,color=" + shading[idx] + r"!70,dashed] coordinates {" + num_strings[rewirer]["below"] + "};\n\n"
    # \addplot[gray!50,fill opacity=0.1] fill between[of=gin_top and gin_down];
    plot += r"\addplot[" + shading[idx] + "!50,fill opacity=0.1] fill between[of=" + rewirer + r"_top and " + rewirer + "_down];\n\n"

plot += r"""
\legend{Plain GIN, GIN+Cayley, GIN+Unconnected Cayley Clusters, GIN+Fully Connected}
\end{axis}
\end{tikzpicture}
"""

only_original_None
interleave_cayley
interleave_unconnected_cayley_clusters
interleave_fully_connected_clusters
interleave_fully_connected


In [13]:
print(plot)

\begin{tikzpicture}
\begin{axis}[
    xmin=0.01, xmax=100,
    ymin=0, ymax=2,
    ymajorgrids=true,
    xlabel=$c_2/c_1$,
    ylabel=$\textrm{loss}/\textrm{loss}_{\textrm{Plain GIN}}$,
    grid style=dashed,
    xmode=log,
    xtick={0.01, 0.1,1,10, 100}, % Specify the positions of the ticks
    xticklabels={$0.01$, $0.1$, $1$, $10$, $100$}, % Specify the labels for the ticks
    ytick={0.0, 0.5, 1.0,1.5,2.0}, % Specify the positions of the ticks
    yticklabels={-, $0.5$, $1.0$, $1.5$, $2.0$}, % Specify the labels for the ticks
    legend pos=north east,
    width=0.9\textwidth,
]
\addplot[color=black, ultra thick, mark=ball] coordinates{(0.0100,1.0000)(0.1000,1.0000)(1.0000,1.0000)(10.0000,1.0000)(100.0000,1.0000)};

\addplot[color=blue, ultra thick, mark=ball] coordinates{(0.0100,19.9572)(0.1000,1.0127)(1.0000,0.5411)(10.0000,0.3873)(100.0000,0.6574)};

\addplot[color=red, ultra thick, mark=ball] coordinates{(0.0100,1.6050)(0.1000,0.8354)(1.0000,0.6644)(10.0000,0.6486)(100.0000,0.5