In [2]:
from typing import Any, Dict

import numpy as np
import scipy.sparse as sp
import torch

from src.graph_models.csbm import CSBM

In [8]:
def configure_hardware(device, seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Hardware
    #torch.backends.cuda.matmul.allow_tf32 = other_params["allow_tf32"]
    #torch.backends.cudnn.allow_tf32 = other_params["allow_tf32"]
    if not torch.cuda.is_available():
        assert device == "cpu", "CUDA is not availble, set device to 'cpu'"
    elif device == "cpu":
        device = torch.device("cpu")
    else:
        device = torch.device(f"cuda:{device}")

    return device

def kappa_0(u, dtype, device):
    z = torch.zeros((u.shape), dtype=dtype).to(device)
    pi = torch.acos(z)*2
    r = (pi - torch.acos(u)) / pi
    r[r!=r] = 1.0
    return r

def kappa_1(u, dtype, device):
    z = torch.zeros((u.shape), dtype=dtype).to(device)
    pi = torch.acos(z) * 2
    r = (u*(pi - torch.acos(u)) + torch.sqrt(1-u*u))/pi
    r[r!=r] = 1.0
    return r

In [66]:
# Data
classes = 2
n = 10
#n_per_class_trn = 400
K = 1.5
sigma = 1
avg_within_class_degree = 1.58 * 2
avg_between_class_degree = 0.37 * 2
seed = 42
# Model
depth = 1
# other
device = "cpu"
dtype = torch.float64


In [67]:
# Sample
csbm = CSBM(n, avg_within_class_degree, avg_between_class_degree, K, sigma, classes=classes)
X, A, y = csbm.sample(n, seed)
device = configure_hardware(device, seed)
X = torch.tensor(X, dtype=dtype, device=device)
A = torch.tensor(A, dtype=dtype, device=device)
y = torch.tensor(y, device=device)
# Row normalize
S = torch.triu(A, diagonal=1) + torch.triu(A, diagonal=1).T
S.data[torch.arange(S.shape[0]), torch.arange(S.shape[0])] = 1
Deg_inv = torch.diag(torch.pow(S.sum(axis=1), - 1))
S = Deg_inv @ S

# Computing NTK
csigma = 1 # ReLU
S_norm = torch.norm(S)
XXT = X.matmul(X.T)
Sig = S.matmul(XXT.matmul(S.T))
kernel = torch.zeros((S.shape), dtype=dtype).to(device)
# ReLu GCN
kernel_sub = torch.zeros((depth, S.shape[0], S.shape[1]), dtype=dtype).to(device)
for i in range(depth):
    p = torch.zeros((S.shape), dtype=dtype).to(device)
    Diag_Sig = torch.diagonal(Sig) 
    Sig_i = p + Diag_Sig.reshape(1, -1)
    Sig_j = p + Diag_Sig.reshape(-1, 1)
    q = torch.sqrt(Sig_i * Sig_j)
    u = Sig/q # why normalization?
    E = (q * kappa_1(u, dtype, device)) * csigma
    E_der = (kappa_0(u, dtype, device)) * csigma
    kernel_der = (S.matmul(S.T)) * E_der
    kernel_sub[i] += Sig * kernel_der
    E = E.double()
    Sig = S.matmul(E.matmul(S.T))
    for j in range(i):
        kernel_sub[j] *= kernel_der
kernel += torch.sum(kernel_sub, dim=0)
kernel += Sig

# Sort kernel
_, idx = y.sort()
K = kernel[idx, :]
K = K[:, idx]

In [68]:
print(K)

tensor([[2.3422, 1.8070, 1.9502, 1.8146, 2.1270, 1.8291, 2.1270, 0.6418, 0.3346,
         0.6739],
        [1.8070, 1.7244, 1.8478, 1.6246, 1.7998, 1.6060, 1.7998, 0.5820, 0.3339,
         0.6269],
        [1.9502, 1.8478, 2.7280, 1.9128, 2.1892, 1.8874, 2.1892, 0.6521, 0.3487,
         0.6832],
        [1.8146, 1.6246, 1.9128, 1.7926, 1.8434, 1.6244, 1.8434, 0.5751, 0.3150,
         0.6105],
        [2.1270, 1.7998, 2.1892, 1.8434, 2.2557, 1.8700, 2.2557, 0.6555, 0.3498,
         0.6869],
        [1.8291, 1.6060, 1.8874, 1.6244, 1.8700, 1.6727, 1.8700, 0.6209, 0.3490,
         0.6514],
        [2.1270, 1.7998, 2.1892, 1.8434, 2.2557, 1.8700, 2.2557, 0.6555, 0.3498,
         0.6869],
        [0.6418, 0.5820, 0.6521, 0.5751, 0.6555, 0.6209, 0.6555, 0.6231, 0.4102,
         0.4319],
        [0.3346, 0.3339, 0.3487, 0.3150, 0.3498, 0.3490, 0.3498, 0.4102, 0.3943,
         0.3224],
        [0.6739, 0.6269, 0.6832, 0.6105, 0.6869, 0.6514, 0.6869, 0.4319, 0.3224,
         0.4693]], dtype=tor

In [57]:
x = torch.tensor([0, 1] * 2)
x, idx = x.sort()
A = torch.randn(4, 4)
print(idx)
print(x[idx])
print(A)
B = A[idx, :]
B = B[:, idx]
print(B)

tensor([0, 2, 1, 3])
tensor([0, 1, 0, 1])
tensor([[ 1.0311, -0.7048,  1.0131, -0.3308],
        [ 0.5177,  0.3878, -0.5797, -0.1691],
        [-0.5733,  0.5069, -0.4752, -0.4920],
        [ 0.2704, -0.5628,  0.6793,  0.4405]])
tensor([[ 1.0311,  1.0131, -0.7048, -0.3308],
        [-0.5733, -0.4752,  0.5069, -0.4920],
        [ 0.5177, -0.5797,  0.3878, -0.1691],
        [ 0.2704,  0.6793, -0.5628,  0.4405]])


In [34]:
kernel

tensor([[2.3422, 0.6418, 1.8070, 1.9502, 1.8146, 0.3346, 2.1270, 1.8291, 2.1270,
         0.6739],
        [0.6418, 0.6231, 0.5820, 0.6521, 0.5751, 0.4102, 0.6555, 0.6209, 0.6555,
         0.4319],
        [1.8070, 0.5820, 1.7244, 1.8478, 1.6246, 0.3339, 1.7998, 1.6060, 1.7998,
         0.6269],
        [1.9502, 0.6521, 1.8478, 2.7280, 1.9128, 0.3487, 2.1892, 1.8874, 2.1892,
         0.6832],
        [1.8146, 0.5751, 1.6246, 1.9128, 1.7926, 0.3150, 1.8434, 1.6244, 1.8434,
         0.6105],
        [0.3346, 0.4102, 0.3339, 0.3487, 0.3150, 0.3943, 0.3498, 0.3490, 0.3498,
         0.3224],
        [2.1270, 0.6555, 1.7998, 2.1892, 1.8434, 0.3498, 2.2557, 1.8700, 2.2557,
         0.6869],
        [1.8291, 0.6209, 1.6060, 1.8874, 1.6244, 0.3490, 1.8700, 1.6727, 1.8700,
         0.6514],
        [2.1270, 0.6555, 1.7998, 2.1892, 1.8434, 0.3498, 2.2557, 1.8700, 2.2557,
         0.6869],
        [0.6739, 0.4319, 0.6269, 0.6832, 0.6105, 0.3224, 0.6869, 0.6514, 0.6869,
         0.4693]], dtype=tor

In [40]:
kernel

tensor([[2.3422, 0.6418, 1.8070, 1.9502, 1.8146, 0.3346, 2.1270, 1.8291, 2.1270,
         0.6739],
        [0.6418, 0.6231, 0.5820, 0.6521, 0.5751, 0.4102, 0.6555, 0.6209, 0.6555,
         0.4319],
        [1.8070, 0.5820, 1.7244, 1.8478, 1.6246, 0.3339, 1.7998, 1.6060, 1.7998,
         0.6269],
        [1.9502, 0.6521, 1.8478, 2.7280, 1.9128, 0.3487, 2.1892, 1.8874, 2.1892,
         0.6832],
        [1.8146, 0.5751, 1.6246, 1.9128, 1.7926, 0.3150, 1.8434, 1.6244, 1.8434,
         0.6105],
        [0.3346, 0.4102, 0.3339, 0.3487, 0.3150, 0.3943, 0.3498, 0.3490, 0.3498,
         0.3224],
        [2.1270, 0.6555, 1.7998, 2.1892, 1.8434, 0.3498, 2.2557, 1.8700, 2.2557,
         0.6869],
        [1.8291, 0.6209, 1.6060, 1.8874, 1.6244, 0.3490, 1.8700, 1.6727, 1.8700,
         0.6514],
        [2.1270, 0.6555, 1.7998, 2.1892, 1.8434, 0.3498, 2.2557, 1.8700, 2.2557,
         0.6869],
        [0.6739, 0.4319, 0.6269, 0.6832, 0.6105, 0.3224, 0.6869, 0.6514, 0.6869,
         0.4693]], dtype=tor