In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric as tg
import torch_geometric.nn as tgnn
from torch_geometric.utils import get_laplacian, to_dense_adj

import networkx as nx
import numpy as np
import matplotlib.pyplot as plt

In [14]:
def all_gammas(X, adjlist, a1, a2, b):
    inter_gammas_gat = [[[],[]],[[],[]]]
    intra_gammas_gat  = [[[],[]],[[],[]]]

    pairs_intra = 0
    pairs_inter = 0

    n_intra = 0
    n_inter = 0

    N = int(len(X)/2)
    x_mlp_gat = np.zeros(len(X))
    x_gat = np.zeros(len(X))
    x_gcn = np.zeros(len(X))
    
    for i in range(len(X)):
        
        gamma_gat_head1 = gamma_gat_fn(X[i], X[adjlist[i]], a1, a2, b)
        gamma_gat_head2 = gamma_gat_fn(X[i], X[adjlist[i]], -a1, -a2, -b)
        gamma_gcn = np.ones(len(adjlist[i]))/len(adjlist[i])
        
        x_gat[i] = np.dot(gamma_gat_head1, X[adjlist[i]]) + np.dot(gamma_gat_head2, X[adjlist[i]])
        x_gcn[i] = np.dot(gamma_gcn, X[adjlist[i]])
        
        ct = 0
        for j in adjlist[i]:
            if (j < N and i < N) or (j >= N and i >= N):
                pairs_intra += pairs_mlp_gat[ct] > 0
                
                if (j < N and i < N):
                    intra_gammas_gat[0][0].append(gamma_gat_head1[ct])
                    intra_gammas_gat[1][0].append(gamma_gat_head2[ct])
                else:
                    intra_gammas_gat[0][1].append(gamma_gat_head1[ct])
                    intra_gammas_gat[1][1].append(gamma_gat_head2[ct])
                n_intra += 1
            elif (j < N and i >= N) or (j >= N and i < N):
                pairs_inter += pairs_mlp_gat[ct] <= 0
                if (j < N and i >= N):
                    inter_gammas_gat[0][0].append(gamma_gat_head1[ct])
                    inter_gammas_gat[1][0].append(gamma_gat_head2[ct])
                else:
                    inter_gammas_gat[0][1].append(gamma_gat_head1[ct])
                    inter_gammas_gat[1][1].append(gamma_gat_head2[ct])                   
                n_inter += 1
            ct += 1
            
    class_pair_intra = pairs_intra/n_intra
    class_pair_inter = pairs_inter/n_inter
            
    return x_gat, x_gcn, class_pair_intra, class_pair_inter, intra_gammas_gat, inter_gammas_gat

In [39]:
n = 1000
d = int(np.ceil(n/(np.log(n)**2)))
p = 0.5
q = 0.1
heads = 2 # does not work for another number

sizes = [int(n/2), int(n/2)]
probs = [[p,q], [q,p]]

std_ = 0.1
mu_up = 20*std_*np.sqrt(np.log(n**2))/(2*np.sqrt(d))
mu_lb = 0.01*std_/(2*np.sqrt(d))

mus = np.geomspace(mu_lb, mu_up, 30, endpoint=True)
ground_truth = np.concatenate((np.zeros(int(n/2)), np.ones(int(n/2))))

# for mu in mus:
g = nx.stochastic_block_model(sizes, probs)

mu = mus[0]
adjlist = [[v for v in g.neighbors(i)] for i in range(n)]
adj_matrix = [[0 for i in range(n)] for j in range(n)]
for i in range(n):
    nbors = g.neighbors(i)
    for j in nbors:
        adj_matrix[i][j] = 1
        
edge_idx, _ = tg.utils.dense_to_sparse(torch.from_numpy(np.array(adj_matrix)))

for i in range(len(adjlist)):
    adjlist[i].append(i)

X = np.zeros((n,d))
X[:int(n/2)] = -mu
X[int(n/2):] = mu
noise = std_*np.random.randn(n,d)
X = X + noise

R = 1
mu_ = mu*np.ones(d)
w = (R/np.linalg.norm(mu_))*mu_
Xw = X@w

print (Xw.shape)
print (edge_idx.shape)

(1000,)
torch.Size([2, 299038])


In [40]:
class Eigen(nn.Module):
    def __init__(self, k):
        super().__init__()
        self.k = k
    
    def forward(self, edge_idx):
        lap_idx, lap_wt = get_laplacian(edge_idx, normalization="sym")
        lap_adj = to_dense_adj(lap_idx)
        eigenvals, eigenvecs = torch.linalg.eig(lap_adj)
        top_eig = eigenvecs.squeeze(0)[:, 1:self.k+1]
        top_eig = torch.real(top_eig)
        new_edge_features = torch.Tensor(edge_idx.size(1), 2 * self.k)
        new_edge_idx = edge_idx.T

        for idx, pair in enumerate(new_edge_idx):
            i, j = pair
            x_i_prime = top_eig[i]
            x_j_prime = top_eig[j]
            new_feat = torch.cat([x_i_prime, x_j_prime], dim=0)
            new_edge_features[idx] = new_feat

        return new_edge_features

class GATv3Layer(tgnn.MessagePassing):
    def __init__(self, indim, eigendim, outdim):
        super().__init__(aggr="add")
        self.original_mlp = nn.Sequential(
                nn.Linear(2 * indim, outdim), # account for extra Wx_i || Wx_j from GATv1
                nn.Linear(outdim, outdim),
                nn.LeakyReLU(0.02),
                nn.Linear(outdim, outdim)
            )
        
        self.eigen_mlp = nn.Sequential(
                nn.Linear(eigendim, outdim), # account for the fact that edge attributes are already concatenated
                nn.Linear(outdim, outdim),
                nn.LeakyReLU(0.02),
                nn.Linear(outdim, outdim)
            )
        self.W = nn.Linear(indim, indim)
        self.out = nn.Linear(indim, outdim)

        self.alpha = nn.Parameter(torch.rand(1, 1))
        nn.init.xavier_uniform_(self.alpha.data, gain=1.414)

        self.beta = nn.Parameter(torch.rand(1, 1))
        nn.init.xavier_uniform_(self.beta.data, gain=1.414)
        
        gamma = self._gamma
        pair_pred = self._pair_pred
        self._gamma = None
        self._pair_pred = None
        
    def forward(self, x, edge_attr, edge_idx):
        edge_idx, edge_attr = tg.utils.add_self_loops(edge_idx, edge_attr)
        return self.propagate(edge_idx, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        cat = torch.cat([x_i, x_j], dim=1)
        
        node_attr = self.alpha * self.original_mlp(cat)
        edge_attr = self.beta * self.eigen_mlp(edge_attr)
        
        self._pair_pred = F.leaky_relu(node_attr + edge_attr),
        gamma = torch.softmax(F.leaky_relu(node_attr + edge_attr), 1)
        msg = gamma * self.out(x_j)
        
        return msg

class GATv3(nn.Module):
    def __init__(self, indim, eigendim, hidden, outdim, k):
        super().__init__()

        self.eigen = Eigen(k)
        self.gat1 = GATv3Layer(indim, eigendim, hidden)
        self.gat2 = GATv3Layer(hidden, eigendim, outdim)

    def forward(self, x, edge_idx):
        with torch.no_grad():
            eigen_x = self.eigen(edge_idx)
        x = torch.relu(self.gat1(x, eigen_x, edge_idx))
        out = torch.softmax(self.gat2(x, eigen_x, edge_idx), 1)

        return out

In [42]:
gat = GATv3(1, 1, 1, 1, 1)
Xw_tensor = torch.from_numpy(Xw)
y = gat(Xw_tensor, edge_idx)
print (y.shape)

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2)