In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
import torch
import numpy as np
from src.graph_models.csbm import CSBM
from src.data import get_graph
from common import configure_hardware

seed = 0
data_dict = dict(
    seed = 0,
    classes = 2,
    n_trn_labeled = 6,
    n_trn_unlabeled = 0,
    n_val = 2,
    n_test = 2,
    sigma = 1,
    avg_within_class_degree = 1.58 * 2,
    avg_between_class_degree = 0.37 * 2,
    K = 1.5,
)
device = "cpu"
dtype = torch.float64
device_ = configure_hardware(device, seed)
# Sample
X, A, y = get_graph(data_dict, sort=True)
X = torch.tensor(X, dtype=dtype, device=device_)
A = torch.tensor(A, dtype=dtype, device=device_)
y = torch.tensor(y, device=device_)

In [15]:
A

tensor([[0., 0., 0., 1., 0., 1., 1., 0., 0., 0.],
        [0., 0., 1., 1., 0., 1., 0., 0., 0., 0.],
        [0., 1., 0., 1., 1., 1., 0., 1., 0., 0.],
        [1., 1., 1., 0., 1., 1., 0., 1., 1., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 1., 0., 0., 1., 0., 1., 1.],
        [0., 0., 0., 1., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]], dtype=torch.float64)

In [16]:
X

tensor([[-0.1687,  0.7737],
        [ 0.4168, -1.2341],
        [-1.7958, -1.1536],
        [-0.4890, -2.8554],
        [-0.7491, -1.7762],
        [-1.2626, -1.0746],
        [ 0.2140,  0.9420],
        [ 1.5728,  0.4018],
        [ 1.8968, -0.1349],
        [ 0.8818,  1.4338]], dtype=torch.float64)

In [17]:
x = X.clone().cpu()
N, D = x.shape
x_sorted, index_x = torch.sort(x, dim=0)
matrix_index_for_each_node = torch.arange(N, dtype=torch.long)[:, None, None].expand(N, N, D)
A_cpu_dense = A.cpu()
if A.is_sparse:
    A_cpu_dense = A_cpu_dense.to_dense()
cum_sorted_weights = A_cpu_dense[matrix_index_for_each_node, index_x].cumsum(1)
weight_sum_per_node = cum_sorted_weights.max(1)[0] # degree of each node (copied for each dimension in d)
median_element = (cum_sorted_weights < (weight_sum_per_node / 2)[:, None].expand(N, N, D)).sum(1).to(A.device) #<- dimension wise median element

matrix_reverse_index = torch.arange(D, dtype=torch.long)[None, :].expand(N, D).to(A.device)
x_selected = x[ # <-- dimension wise median vector for each node
    index_x[median_element, matrix_reverse_index],
    matrix_reverse_index
]
result = weight_sum_per_node.to(A.device) * x_selected # multiply degree of a node with dimensionwise median result


In [28]:
result

tensor([[-1.4670, -3.2238],
        [-3.7878, -3.4608],
        [-2.4450, -6.1703],
        [-1.1811, -7.5221],
        [-3.5915, -5.7107],
        [-1.9560, -4.9363],
        [-0.3375,  0.8036],
        [ 1.0701, -0.6743],
        [ 2.6455,  1.2054],
        [ 3.1457, -0.2697]], dtype=torch.float64)

In [19]:
torch.sort(X, dim=0)

torch.return_types.sort(
values=tensor([[-1.7958, -2.8554],
        [-1.2626, -1.7762],
        [-0.7491, -1.2341],
        [-0.4890, -1.1536],
        [-0.1687, -1.0746],
        [ 0.2140, -0.1349],
        [ 0.4168,  0.4018],
        [ 0.8818,  0.7737],
        [ 1.5728,  0.9420],
        [ 1.8968,  1.4338]], dtype=torch.float64),
indices=tensor([[2, 3],
        [5, 4],
        [4, 1],
        [3, 2],
        [0, 5],
        [6, 8],
        [1, 7],
        [9, 0],
        [7, 6],
        [8, 9]]))

In [27]:
print(weight_sum_per_node)

tensor([[3., 3.],
        [3., 3.],
        [5., 5.],
        [7., 7.],
        [2., 2.],
        [4., 4.],
        [2., 2.],
        [5., 5.],
        [3., 3.],
        [2., 2.]], dtype=torch.float64)


In [24]:
print(median_element)

tensor([[3, 4],
        [1, 3],
        [3, 2],
        [4, 4],
        [0, 0],
        [3, 2],
        [4, 6],
        [5, 5],
        [7, 6],
        [8, 5]])


In [25]:
print(x_selected)

tensor([[-0.4890, -1.0746],
        [-1.2626, -1.1536],
        [-0.4890, -1.2341],
        [-0.1687, -1.0746],
        [-1.7958, -2.8554],
        [-0.4890, -1.2341],
        [-0.1687,  0.4018],
        [ 0.2140, -0.1349],
        [ 0.8818,  0.4018],
        [ 1.5728, -0.1349]], dtype=torch.float64)
