In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import Any, Dict

import numpy as np
import scipy.sparse as sp
import torch
from jaxtyping import Float, Integer

from src.graph_models.csbm import CSBM
from src.models.ntk import NTK

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def configure_hardware(device, seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Hardware
    #torch.backends.cuda.matmul.allow_tf32 = other_params["allow_tf32"]
    #torch.backends.cudnn.allow_tf32 = other_params["allow_tf32"]
    if not torch.cuda.is_available():
        assert device == "cpu", "CUDA is not availble, set device to 'cpu'"
    elif device == "cpu":
        device = torch.device("cpu")
    else:
        device = torch.device(f"cuda:{device}")

    return device

In [4]:
# Data
classes = 2
n = 1000
#n_per_class_trn = 400
K = 1.5
sigma = 1
avg_within_class_degree = 1.58 * 2
avg_between_class_degree = 0.37 * 2
seed = 42
# Model
model_dict = {
    "label": "GCN",
    "model": "GCN",
    "normalization": "row_normalization",
    "depth": 1,
}
# other
device = "cpu"
dtype = torch.float64


In [5]:
def get_graph(sort: bool=True):
    """Return graph sampled from a CSBM.

    If sort is true, X, A and y are sorted for class.
    
    Returns X, A, y."""
    csbm = CSBM(n, avg_within_class_degree, avg_between_class_degree, K, sigma, classes=classes)
    X, A, y = csbm.sample(n, seed)
    device_ = configure_hardware(device, seed)
    X = torch.tensor(X, dtype=dtype, device=device_)
    A = torch.tensor(A, dtype=dtype, device=device_)
    y = torch.tensor(y, device=device_)
    if sort:
        _, idx = y.sort()
        y = y[idx]
        X = X[idx, :]
        A = A[idx, :]
        A = A[:, idx]
    return X, A, y


def row_normalize(A):
    # Row normalize
    S = torch.triu(A, diagonal=1) + torch.triu(A, diagonal=1).T
    S.data[torch.arange(S.shape[0]), torch.arange(S.shape[0])] = 1
    Deg_inv = torch.diag(torch.pow(S.sum(axis=1), - 1))
    return Deg_inv @ S


def get_diffusion(X, A, model_dict):
    if model_dict["model"] == "GCN":
        if model_dict["normalization"] == "row_normalization":
            return row_normalize(A)
        else:
            raise NotImplementedError("Only row normalization for GCN implemented")
    else:
        raise NotImplementedError("Only GCN architecture implemented")

In [12]:
X, A, y = get_graph()
S = get_diffusion(X, A, model_dict)
# Computing NTK
ntk = NTK(X, S, model_dict)
print(ntk.get_ntk())

tensor([[ 2.7267,  1.9644,  1.0438,  ...,  1.3358,  3.7985,  1.2604],
        [ 1.9644,  4.2664,  1.6084,  ...,  1.2780,  3.2808,  1.4220],
        [ 1.0438,  1.6084,  2.5592,  ...,  1.3128,  3.0359,  1.2975],
        ...,
        [ 1.3358,  1.2780,  1.3128,  ...,  3.0045,  4.3059,  1.4304],
        [ 3.7985,  3.2808,  3.0359,  ...,  4.3059, 55.5725,  4.0014],
        [ 1.2604,  1.4220,  1.2975,  ...,  1.4304,  4.0014,  2.8405]],
       dtype=torch.float64)


In [40]:
ntk_dict = dict()
for K in K_l:
    # Sample
    X, A, y = get_graph()
    S = get_diffusion(X, A, model_dict)
    # Computing NTK
    ntk_l = list()
    ntk_l.append(NTK(X, S, y))
    for budget in budget_l:
        A_pert = attack(ntk_l[0], budget, A)
        S = get_diffusion(X, A_pert, model_dict)
        ntk_l.append(get_ntk(X, S, y))
    ntk_dict[K] = ntk_l

NameError: name 'K_l' is not defined

In [31]:
X, A, y = get_graph()
S = get_diffusion(X, A, model_dict)
# Computing NTK
ntk = get_ntk(X, S, y)
ntk = torch.tensor(ntk, requires_grad=True)
print(ntk)
print(torch.linalg.solve(ntk, ntk))

tensor([[ 2.7267,  1.9644,  1.0438,  ...,  1.3358,  3.7985,  1.2604],
        [ 1.9644,  4.2664,  1.6084,  ...,  1.2780,  3.2808,  1.4220],
        [ 1.0438,  1.6084,  2.5592,  ...,  1.3128,  3.0359,  1.2975],
        ...,
        [ 1.3358,  1.2780,  1.3128,  ...,  3.0045,  4.3059,  1.4304],
        [ 3.7985,  3.2808,  3.0359,  ...,  4.3059, 55.5725,  4.0014],
        [ 1.2604,  1.4220,  1.2975,  ...,  1.4304,  4.0014,  2.8405]],
       dtype=torch.float64, requires_grad=True)
tensor([[ 1.1366e+00,  8.8244e-16, -1.1094e-01,  ..., -5.1569e-02,
          5.6345e+01, -1.3413e-01],
        [ 4.4654e-02,  1.0000e+00, -3.3360e-03,  ..., -7.9349e-03,
         -2.0054e+01, -5.6974e-02],
        [-7.3329e-02,  1.6405e-17,  9.3669e-01,  ..., -3.9913e-02,
          1.5123e+01,  6.6175e-02],
        ...,
        [ 4.6833e-03,  2.6712e-17, -7.0250e-03,  ...,  9.8531e-01,
         -4.9789e+00, -4.4864e-03],
        [-1.7053e-03, -9.5403e-18,  2.5580e-03,  ...,  4.4890e-03,
          2.3734e+00,  1.6

  ntk = torch.tensor(ntk, requires_grad=True)


In [25]:
ntk

tensor([[2.3422, 2.1270, 1.8070, 1.9502, 1.8146, 1.8291, 2.1270, 0.3346, 0.6418,
         0.6739],
        [2.1270, 2.2557, 1.7998, 2.1892, 1.8434, 1.8700, 2.2557, 0.3498, 0.6555,
         0.6869],
        [1.8070, 1.7998, 1.7244, 1.8478, 1.6246, 1.6060, 1.7998, 0.3339, 0.5820,
         0.6269],
        [1.9502, 2.1892, 1.8478, 2.7280, 1.9128, 1.8874, 2.1892, 0.3487, 0.6521,
         0.6832],
        [1.8146, 1.8434, 1.6246, 1.9128, 1.7926, 1.6244, 1.8434, 0.3150, 0.5751,
         0.6105],
        [1.8291, 1.8700, 1.6060, 1.8874, 1.6244, 1.6727, 1.8700, 0.3490, 0.6209,
         0.6514],
        [2.1270, 2.2557, 1.7998, 2.1892, 1.8434, 1.8700, 2.2557, 0.3498, 0.6555,
         0.6869],
        [0.3346, 0.3498, 0.3339, 0.3487, 0.3150, 0.3490, 0.3498, 0.3943, 0.4102,
         0.3224],
        [0.6418, 0.6555, 0.5820, 0.6521, 0.5751, 0.6209, 0.6555, 0.4102, 0.6231,
         0.4319],
        [0.6739, 0.6869, 0.6269, 0.6832, 0.6105, 0.6514, 0.6869, 0.3224, 0.4319,
         0.4693]], dtype=tor

In [40]:
kernel

tensor([[2.3422, 0.6418, 1.8070, 1.9502, 1.8146, 0.3346, 2.1270, 1.8291, 2.1270,
         0.6739],
        [0.6418, 0.6231, 0.5820, 0.6521, 0.5751, 0.4102, 0.6555, 0.6209, 0.6555,
         0.4319],
        [1.8070, 0.5820, 1.7244, 1.8478, 1.6246, 0.3339, 1.7998, 1.6060, 1.7998,
         0.6269],
        [1.9502, 0.6521, 1.8478, 2.7280, 1.9128, 0.3487, 2.1892, 1.8874, 2.1892,
         0.6832],
        [1.8146, 0.5751, 1.6246, 1.9128, 1.7926, 0.3150, 1.8434, 1.6244, 1.8434,
         0.6105],
        [0.3346, 0.4102, 0.3339, 0.3487, 0.3150, 0.3943, 0.3498, 0.3490, 0.3498,
         0.3224],
        [2.1270, 0.6555, 1.7998, 2.1892, 1.8434, 0.3498, 2.2557, 1.8700, 2.2557,
         0.6869],
        [1.8291, 0.6209, 1.6060, 1.8874, 1.6244, 0.3490, 1.8700, 1.6727, 1.8700,
         0.6514],
        [2.1270, 0.6555, 1.7998, 2.1892, 1.8434, 0.3498, 2.2557, 1.8700, 2.2557,
         0.6869],
        [0.6739, 0.4319, 0.6269, 0.6832, 0.6105, 0.3224, 0.6869, 0.6514, 0.6869,
         0.4693]], dtype=tor

In [19]:
kernel

NameError: name 'kernel' is not defined