In [90]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric as tg
import torch_geometric.nn as tgnn
from torch_geometric.utils import get_laplacian, to_dense_adj
from torch_geometric.typing import Adj, OptTensor, PairTensor

import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Tuple, Union
import math, wandb

In [91]:
class Eigen(nn.Module):
    def __init__(self, k):
        super().__init__()
        self.k = k
    
    def forward(self, edge_idx):
        lap_idx, lap_wt = get_laplacian(edge_idx, normalization="sym")
        lap_adj = to_dense_adj(lap_idx)
        eigenvals, eigenvecs = torch.linalg.eig(lap_adj)
        top_eig = eigenvecs.squeeze(0)[:, 1:self.k+1]
        top_eig = torch.real(top_eig)
        new_edge_features = torch.Tensor(edge_idx.size(1), 2 * self.k)
        new_edge_idx = edge_idx.T

        for idx, pair in enumerate(new_edge_idx):
            i, j = pair
            x_i_prime = top_eig[i]
            x_j_prime = top_eig[j]
            new_feat = torch.cat([x_i_prime, x_j_prime], dim=0)
            new_edge_features[idx] = new_feat

        return new_edge_features

class GATv3Layer(tgnn.MessagePassing):
    def __init__(self, indim, eigendim, outdim):
        super().__init__(aggr="add")
        self.original_mlp = nn.Sequential(
                nn.Linear(2 * indim, outdim), # account for extra Wx_i || Wx_j from GATv1
                nn.Linear(outdim, outdim),
                nn.LeakyReLU(0.02),
                nn.Linear(outdim, outdim)
            )
        
        self.eigen_mlp = nn.Sequential(
                nn.Linear(eigendim, outdim), # account for the fact that edge attributes are already concatenated
                nn.Linear(outdim, outdim),
                nn.LeakyReLU(0.02),
                nn.Linear(outdim, outdim)
            )
        self.W = nn.Linear(indim, indim)
        self.project = nn.Linear(outdim, 1)
        self.out = nn.Linear(indim, outdim)

        self.alpha = nn.Parameter(torch.rand(1, 1))
        self.glorot(self.alpha)

        self.beta = nn.Parameter(torch.rand(1, 1))
        self.glorot(self.beta)
        
        self.all_gammas = None
        
    def glorot(self, value):
        if isinstance(value, torch.Tensor):
            stdv = math.sqrt(6.0 / (value.size(-2) + value.size(-1)))
            value.data.uniform_(-stdv, stdv)
        else:
            for v in value.parameters() if hasattr(value, 'parameters') else []:
                glorot(v)
            for v in value.buffers() if hasattr(value, 'buffers') else []:
                glorot(v)
        
    def forward(self, x, edge_idx, edge_attr):
        num_nodes = x.size(0)
        edge_idx, edge_attr = tg.utils.remove_self_loops(edge_idx, edge_attr)
        edge_idx, edge_attr = tg.utils.add_self_loops(edge_idx, edge_attr, num_nodes=num_nodes)
        
        return self.propagate(edge_idx, x=x, edge_attr=edge_attr)

    def message(self, x_j: torch.Tensor, x_i: torch.Tensor,
                edge_attr: torch.Tensor,
                index: torch.Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> torch.Tensor:
    
        cat = torch.cat([x_i, x_j], dim=1)
        
        node_attr = self.alpha * self.original_mlp(cat) # [E, d]
        edge_attr = self.beta * self.eigen_mlp(edge_attr) # [E, d]
        
        temp = F.leaky_relu(node_attr + edge_attr) # [E, d]
        project = self.project(temp)
        gamma = tg.utils.softmax(project, index, ptr, size_i) # [E, d]
        msg = gamma * self.out(x_j) # [E, d]
        
        self.all_gammas = gamma
        
        return msg

In [92]:
class GATv3(nn.Module):
    def __init__(self, indim, eigendim, hidden, outdim, k):
        super().__init__()

        self.eigen = Eigen(k)
        self.gat1 = GATv3Layer(indim, eigendim, hidden)
        self.gat2 = GATv3Layer(hidden, eigendim, outdim)

    def forward(self, x, edge_idx):
        with torch.no_grad():
            eigen_x = self.eigen(edge_idx)
        x = torch.relu(self.gat1(x, edge_idx, eigen_x))
        out = self.gat2(x, edge_idx, eigen_x)

        return out

In [93]:
def get_gammas(Xw, ground_truth, gat_layer, edge_idx):
    all_gammas = gat_layer.all_gammas
    gamma_matrix = [[0 for j in range(Xw.size(0))] for i in range(Xw.size(0))]
    for idx, pair in enumerate(edge_idx.T):
        i, j = pair
        gamma = all_gammas[idx]
        
        gamma_matrix[i][j] = gamma.item()
        
    return gamma_matrix
        
def get_intra_inter_avg_gamma(gamma_matrix):
    """
    intra-edges are nodes with class 0
    inter-edges are nodes with class 1
    """
    
    d = len(gamma_matrix) // 2
    all_node_ids = list(range(len(gamma_matrix)))
    intra_edges = all_node_ids[:d]
    inter_edges = all_node_ids[d:]
    
    intra_edge_gammas = []
    inter_edge_gammas = []
    
    for i in range(len(gamma_matrix)):
        for j in range(len(gamma_matrix[i])):
            if j in intra_edges:
                intra_edge_gammas.append(gamma_matrix[i][j])
            elif j in inter_edges:
                inter_edge_gammas.append(gamma_matrix[i][j])
            else:
                pass
                
    return np.array(intra_edge_gammas), np.array(inter_edge_gammas)

In [96]:
n = 400
d = int(np.ceil(n/(np.log(n)**2)))
p = 0.5
q = 0.1

sizes = [int(n/2), int(n/2)]
probs = [[p,q], [q,p]]

std_ = 0.1
mu_up = 20*std_*np.sqrt(np.log(n**2))/(2*np.sqrt(d))
mu_lb = 0.01*std_/(2*np.sqrt(d))

mus = np.geomspace(mu_lb, mu_up, 30, endpoint=True)
ground_truth = np.concatenate((np.zeros(int(n/2)), np.ones(int(n/2))))

for mu in mus:
    g = nx.stochastic_block_model(sizes, probs)
    adjlist = [[v for v in g.neighbors(i)] for i in range(n)]

    adj_matrix = [[0 for i in range(n)] for j in range(n)]
    for i in range(n):
        nbors = g.neighbors(i)
        for j in nbors:
            adj_matrix[i][j] = 1

    edge_idx, _ = tg.utils.dense_to_sparse(torch.from_numpy(np.array(adj_matrix)))

    for i in range(len(adjlist)):
        adjlist[i].append(i) # self-loops

    X = np.zeros((n,d))
    X[:int(n/2)] = -mu
    X[int(n/2):] = mu
    noise = std_*np.random.randn(n,d)
    X = X + noise

    R = 1
    mu_ = mu*np.ones(d)
    w = (R/np.linalg.norm(mu_))*mu_
    Xw = X@w

    wandb.init(project="GATv3", entity="rish-16")

    HIDDEN = 16
    eigenK = 10
    EPOCHS = 1250

    gat = GATv3(
        indim=1, 
        eigendim=20, 
        hidden=HIDDEN, 
        outdim=1, 
        k=eigenK
    ) # take top 10 eigen vector features
    crit = nn.BCEWithLogitsLoss()
    optimiser = torch.optim.Adam(gat.parameters())

    wandb.config = {
        "learning_rate": 0.001,
        "n_nodes": n,
        "eigenK": eigenK,
        "optimiser": "Adam",
        "epochs": EPOCHS,
        "n_layers": 2,
        "std": std_,
        "p": p,
        "q": q,
        "d": d
    }

    wandb.watch(gat)

    Xw_tensor = torch.from_numpy(Xw).unsqueeze(-1).float()
    ground_truth_tensor = torch.from_numpy(ground_truth).unsqueeze(-1).float()
    
    print (f"Training with Mu: {mu}")

    for epoch in range(EPOCHS):
        optimiser.step()
        pred = gat(Xw_tensor, edge_idx)
        loss = crit(pred, ground_truth_tensor)
        loss.backward()
        optimiser.step()

        wandb.log({
            "epoch": epoch,
            "train_bce": loss.item()
        })    

        if epoch % 200 == 0:
            print (f"Epoch: {epoch} | Train BCE: {loss.item()}")
            
    print ("------------------------------------------\n\n")

    gamma_matrix1 = get_gammas(Xw_tensor, ground_truth_tensor, gat.gat1, edge_idx)
    gamma_matrix2 = get_gammas(Xw_tensor, ground_truth_tensor, gat.gat2, edge_idx)

    intra1, inter1 = get_intra_inter_avg_gamma(gamma_matrix1)
    intra2, inter2 = get_intra_inter_avg_gamma(gamma_matrix2)

    avg_intra_gamma_1 = intra1.mean()
    avg_inter_gamma_1 = inter1.mean()

    avg_intra_gamma_2 = intra2.mean()
    avg_inter_gamma_2 = inter2.mean()
    
    std_intra_gamma_1 = intra1.std()
    std_inter_gamma_1 = inter1.std()

    std_intra_gamma_2 = intra2.std()
    std_inter_gamma_2 = inter2.std()

    wandb.log({
        "mu": mu,
        "avg_intra_edge_gamma_1": avg_intra_gamma_1,
        "avg_inter_edge_gamma_1": avg_inter_gamma_1,
        "avg_intra_edge_gamma_2": avg_intra_gamma_2,
        "avg_inter_edge_gamma_2": avg_inter_gamma_2,
        "std_intra_edge_gamma_1": std_intra_gamma_1,
        "std_inter_edge_gamma_1": std_inter_gamma_1,
        "std_intra_edge_gamma_2": std_intra_gamma_2,
        "std_inter_edge_gamma_2": std_inter_gamma_2,
    })

    wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_bce,▂▂▂▂▂▂▂▂▂▂▂▂▃▃▃▃▂▂▂▂▃▃▄▅▆▆▅▄▃▂▁▁▁▂▃▄▅▆██

0,1
epoch,61.0
train_bce,0.69424


Training with Mu: 0.00014433756729740645
Epoch: 0 | Train BCE: 0.7008239030838013
Epoch: 200 | Train BCE: 0.5811325311660767
Epoch: 400 | Train BCE: 0.4728632867336273
Epoch: 600 | Train BCE: 0.3233521580696106
Epoch: 800 | Train BCE: 0.12481539696455002
Epoch: 1000 | Train BCE: 0.04695728421211243
Epoch: 1200 | Train BCE: 0.011814644560217857
------------------------------------------




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
avg_inter_edge_gamma_1,▁
avg_inter_edge_gamma_2,▁
avg_intra_edge_gamma_1,▁
avg_intra_edge_gamma_2,▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
mu,▁
std_inter_edge_gamma_1,▁
std_inter_edge_gamma_2,▁
std_intra_edge_gamma_1,▁
std_intra_edge_gamma_2,▁

0,1
avg_inter_edge_gamma_1,0.0025
avg_inter_edge_gamma_2,0.00234
avg_intra_edge_gamma_1,0.0025
avg_intra_edge_gamma_2,0.00234
epoch,1249.0
mu,0.00014
std_inter_edge_gamma_1,0.04994
std_inter_edge_gamma_2,0.00358
std_intra_edge_gamma_1,0.04994
std_intra_edge_gamma_2,0.00358


Training with Mu: 0.00019579604521621774
Epoch: 0 | Train BCE: 0.693331778049469
Epoch: 200 | Train BCE: 0.6030031442642212
Epoch: 400 | Train BCE: 0.42886388301849365
Epoch: 600 | Train BCE: 0.4737018346786499
Epoch: 800 | Train BCE: 0.11750677227973938
Epoch: 1000 | Train BCE: 0.29766231775283813
Epoch: 1200 | Train BCE: 0.09280756115913391
------------------------------------------




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
avg_inter_edge_gamma_1,▁
avg_inter_edge_gamma_2,▁
avg_intra_edge_gamma_1,▁
avg_intra_edge_gamma_2,▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
mu,▁
std_inter_edge_gamma_1,▁
std_inter_edge_gamma_2,▁
std_intra_edge_gamma_1,▁
std_intra_edge_gamma_2,▁

0,1
avg_inter_edge_gamma_1,0.0025
avg_inter_edge_gamma_2,0.0
avg_intra_edge_gamma_1,0.0025
avg_intra_edge_gamma_2,0.0
epoch,1249.0
mu,0.0002
std_inter_edge_gamma_1,0.04994
std_inter_edge_gamma_2,0.0
std_intra_edge_gamma_1,0.04994
std_intra_edge_gamma_2,0.0


Training with Mu: 0.00026560023173537287
Epoch: 0 | Train BCE: 0.7005781531333923
Epoch: 200 | Train BCE: 0.6709121465682983
Epoch: 400 | Train BCE: 0.39175766706466675
Epoch: 600 | Train BCE: 0.23879766464233398
Epoch: 800 | Train BCE: 0.05040017515420914
Epoch: 1000 | Train BCE: 0.030592476949095726
Epoch: 1200 | Train BCE: 0.02673676237463951
------------------------------------------




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
avg_inter_edge_gamma_1,▁
avg_inter_edge_gamma_2,▁
avg_intra_edge_gamma_1,▁
avg_intra_edge_gamma_2,▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
mu,▁
std_inter_edge_gamma_1,▁
std_inter_edge_gamma_2,▁
std_intra_edge_gamma_1,▁
std_intra_edge_gamma_2,▁

0,1
avg_inter_edge_gamma_1,0.0025
avg_inter_edge_gamma_2,0.00248
avg_intra_edge_gamma_1,0.0025
avg_intra_edge_gamma_2,0.00248
epoch,1249.0
mu,0.00027
std_inter_edge_gamma_1,0.04994
std_inter_edge_gamma_2,0.00381
std_intra_edge_gamma_1,0.04994
std_intra_edge_gamma_2,0.0038


Training with Mu: 0.00036029064335790133
Epoch: 0 | Train BCE: 0.7025562524795532
Epoch: 200 | Train BCE: 0.639792263507843
Epoch: 400 | Train BCE: 0.47997409105300903
Epoch: 600 | Train BCE: 0.11061519384384155
Epoch: 800 | Train BCE: 0.393920361995697
Epoch: 1000 | Train BCE: 0.17294567823410034
Epoch: 1200 | Train BCE: 0.289513498544693
------------------------------------------




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
avg_inter_edge_gamma_1,▁
avg_inter_edge_gamma_2,▁
avg_intra_edge_gamma_1,▁
avg_intra_edge_gamma_2,▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
mu,▁
std_inter_edge_gamma_1,▁
std_inter_edge_gamma_2,▁
std_intra_edge_gamma_1,▁
std_intra_edge_gamma_2,▁

0,1
avg_inter_edge_gamma_1,0.0025
avg_inter_edge_gamma_2,0.0
avg_intra_edge_gamma_1,0.0025
avg_intra_edge_gamma_2,0.0
epoch,1249.0
mu,0.00036
std_inter_edge_gamma_1,0.04887
std_inter_edge_gamma_2,0.0
std_intra_edge_gamma_1,0.04994
std_intra_edge_gamma_2,0.0


Training with Mu: 0.0004887395874736438
Epoch: 0 | Train BCE: 0.6932708024978638
Epoch: 200 | Train BCE: 0.6949211359024048
Epoch: 400 | Train BCE: 0.46762844920158386
Epoch: 600 | Train BCE: 0.5078461170196533
Epoch: 800 | Train BCE: 0.46960216760635376
Epoch: 1000 | Train BCE: 0.28918030858039856
Epoch: 1200 | Train BCE: 0.2831963896751404
------------------------------------------




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
avg_inter_edge_gamma_1,▁
avg_inter_edge_gamma_2,▁
avg_intra_edge_gamma_1,▁
avg_intra_edge_gamma_2,▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
mu,▁
std_inter_edge_gamma_1,▁
std_inter_edge_gamma_2,▁
std_intra_edge_gamma_1,▁
std_intra_edge_gamma_2,▁

0,1
avg_inter_edge_gamma_1,0.0025
avg_inter_edge_gamma_2,0.0
avg_intra_edge_gamma_1,0.0025
avg_intra_edge_gamma_2,0.0
epoch,1249.0
mu,0.00049
std_inter_edge_gamma_1,0.04994
std_inter_edge_gamma_2,0.0
std_intra_edge_gamma_1,0.04994
std_intra_edge_gamma_2,0.0


Training with Mu: 0.0006629824803044508
Epoch: 0 | Train BCE: 0.7291097044944763


KeyboardInterrupt: 

$$
\begin{align}
    e &= \operatorname{EigenDecomp}(I - D^{-1\frac{1}{2}}AD^{-1\frac{1}{2}}) \\\\
    \gamma_{ij} &= \operatorname{Softmax}(\operatorname{LeakyReLU}(\alpha \cdot \phi (Wx_i || Wx_j) + \beta \cdot \psi (We_i || We_j))) \\\\
    x^{ l+1}_{i} &= \sigma(Wx_i + \Sigma_{j \in N_j} W \gamma_{ij} x_j)
\end{align}
$$