In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import Any, Dict

import numpy as np
import scipy.sparse as sp
import torch
import torch.nn.functional as F
from torch_sparse import SparseTensor
from jaxtyping import Float, Integer
import matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt
from cycler import cycler
import os

from src.graph_models.csbm import CSBM
from src.models.common import get_diffusion
from src.models.ntk import NTK
from src.attacks import create_attack
from common import configure_hardware, get_graph, \
    count_edges_for_idx, calc_kernel_means, plot_ntk_model_diff, plot_ntk_model_acc
from src import utils

In [3]:
# Data
data_dict = dict(
    classes = 2,
    n = 1000,
    n_per_class_trn = 300,
    n_per_class_labeled = 300,
    sigma = 1,
    avg_within_class_degree = 1.58 * 2,
    avg_between_class_degree = 0.37 * 2,
)
K = 1
seed_l = [seed for seed in range(30)]
# Model
model_dict_gcn = {
    "label": "GCN",
    "model": "GCN",
    "normalization": "row_normalization",
    "depth": 1,
}
model_dict_softmedoid = {
    "label": "SoftM_T1",
    "model": "SoftMedoid",
    "normalization": "row_normalization",
    "depth": 1,
    "T": 1
}
model_dict_softmedoid5 = {
    "label": "SoftM_T5",
    "model": "SoftMedoid",
    "normalization": "row_normalization",
    "depth": 1,
    "T": 5
}
model_dict_l = [model_dict_gcn, model_dict_softmedoid, model_dict_softmedoid5]
# other
device = "0"
dtype = torch.float64

In [17]:
result_dict = dict()
for seed in seed_l: 
    rng = np.random.Generator(np.random.PCG64(seed))
    ntk_dict = dict()
    device_ = configure_hardware(device, seed)
    data_dict["K"] = K
    # Sample
    X, A, y = get_graph(data_dict, seed=seed, sort=True)
    X = torch.tensor(X, dtype=dtype, device=device_)
    A = torch.tensor(A, dtype=dtype, device=device_)
    y = torch.tensor(y, device=device_)
    # Trn / Test Split
    n_cls0 = sum(y == 0).cpu().item()
    n = len(y)
    idx_cls0 = rng.permutation(np.arange(n_cls0))
    idx_cls1 = rng.permutation(np.arange(n_cls0, n))
    n_trn = data_dict["n_per_class_trn"]
    n_labeled = data_dict["n_per_class_labeled"]
    idx_labeled = np.concatenate((idx_cls0[:n_labeled], idx_cls1[:n_labeled]))
    idx_unlabeled = np.concatenate((idx_cls0[n_labeled:], idx_cls1[n_labeled:]))
    idx_target = np.concatenate((idx_cls0[n_trn:], idx_cls1[n_trn:]))
    for model_dict in model_dict_l:
        # Computing NTK
        ntk = NTK(model_dict, X, A)
        y_pred = ntk(idx_labeled, idx_unlabeled, y)
        acc = utils.accuracy(y_pred, y[idx_target]).cpu().item()
        if not model_dict["label"] in result_dict:
            result_dict[model_dict["label"]] = []
            result_dict[model_dict["label"]+"_acc"] = []
        result_dict[model_dict["label"]].append(ntk)
        result_dict[model_dict["label"]+"_acc"].append(acc)

In [15]:
# With Stabilising Inverse
for model_label, result_l in result_dict.items():
    if not model_label.endswith("_acc"):
        continue
    res = np.array(result_l) * 100
    print(f"{model_label}: {np.mean(res):.2f} +- {np.std(res):.2f}")

GCN_acc: 85.56 +- 4.28
SoftM_T1_acc: 77.72 +- 2.76
SoftM_T5_acc: 83.37 +- 7.01


In [18]:
# Without Stabilising Inverse
for model_label, result_l in result_dict.items():
    if not model_label.endswith("_acc"):
        continue
    res = np.array(result_l) * 100
    print(f"{model_label}: {np.mean(res):.2f} +- {np.std(res):.2f}")

GCN_acc: 83.59 +- 10.03
SoftM_T1_acc: 75.18 +- 6.05
SoftM_T5_acc: 81.98 +- 8.78
