In [13]:
import torch, copy, wandb
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric as tg
import torch_geometric.nn as tgnn
from torch_geometric.utils import get_laplacian, to_dense_adj
from torch_geometric.typing import Adj, OptTensor, PairTensor
from torch_geometric.datasets import Planetoid

import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Tuple, Union
import math, wandb

In [7]:
class Eigen(nn.Module):
    def __init__(self, k):
        super().__init__()
        self.k = k
    
    def forward(self, edge_idx):
        lap_idx, lap_wt = get_laplacian(edge_idx, normalization="sym")
        lap_adj = to_dense_adj(lap_idx)
        eigenvals, eigenvecs = torch.linalg.eig(lap_adj)
        top_eig = eigenvecs.squeeze(0)[:, 1:self.k+1]
        top_eig = torch.real(top_eig)
        new_edge_features = torch.Tensor(edge_idx.size(1), 2 * self.k)
        new_edge_idx = edge_idx.T

        for idx, pair in enumerate(new_edge_idx):
            i, j = pair
            x_i_prime = top_eig[i]
            x_j_prime = top_eig[j]
            new_feat = torch.cat([x_i_prime, x_j_prime], dim=0)
            new_edge_features[idx] = new_feat

        return new_edge_features

class GATv3Layer(tgnn.MessagePassing):
    def __init__(self, indim, eigendim, outdim):
        super().__init__(aggr="add")
        self.original_mlp = nn.Sequential(
                nn.Linear(2 * indim, outdim), # account for extra Wx_i || Wx_j from GATv1
                nn.Linear(outdim, outdim),
                nn.LeakyReLU(0.02),
                nn.Linear(outdim, outdim)
            )
        
        self.eigen_mlp = nn.Sequential(
                nn.Linear(eigendim, outdim), # account for the fact that edge attributes are already concatenated
                nn.Linear(outdim, outdim),
                nn.LeakyReLU(0.02),
                nn.Linear(outdim, outdim)
            )
        self.W = nn.Linear(indim, indim)
        self.project = nn.Linear(outdim, 1)
        self.out = nn.Linear(indim, outdim)

        self.alpha = nn.Parameter(torch.rand(1, 1))
        self.glorot(self.alpha)

        self.beta = nn.Parameter(torch.rand(1, 1))
        self.glorot(self.beta)
        
        self.all_gammas = None
        
    def glorot(self, value):
        if isinstance(value, torch.Tensor):
            stdv = math.sqrt(6.0 / (value.size(-2) + value.size(-1)))
            value.data.uniform_(-stdv, stdv)
        else:
            for v in value.parameters() if hasattr(value, 'parameters') else []:
                glorot(v)
            for v in value.buffers() if hasattr(value, 'buffers') else []:
                glorot(v)
        
    def forward(self, x, edge_idx, edge_attr):
        num_nodes = x.size(0)
        edge_idx, edge_attr = tg.utils.remove_self_loops(edge_idx, edge_attr)
        edge_idx, edge_attr = tg.utils.add_self_loops(edge_idx, edge_attr, num_nodes=num_nodes)
        
        return self.propagate(edge_idx, x=x, edge_attr=edge_attr)

    def message(self, x_j: torch.Tensor, x_i: torch.Tensor,
                edge_attr: torch.Tensor,
                index: torch.Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> torch.Tensor:
    
        cat = torch.cat([x_i, x_j], dim=1)
        
        node_attr = self.alpha * self.original_mlp(cat) # [E, d]
        edge_attr = self.beta * self.eigen_mlp(edge_attr) # [E, d]
        
        temp = F.leaky_relu(node_attr + edge_attr) # [E, d]
        project = self.project(temp)
        gamma = tg.utils.softmax(project, index, ptr, size_i) # [E, d]
        msg = gamma * self.out(x_j) # [E, d]
        
        self.all_gammas = gamma
        
        return msg

In [8]:
class GATv3(nn.Module):
    def __init__(self, indim, eigendim, hidden, outdim, k):
        super().__init__()

        self.eigen = Eigen(k)
        self.gat1 = GATv3Layer(indim, eigendim, hidden)
        self.gat2 = GATv3Layer(hidden, eigendim, outdim)

    def forward(self, x, edge_idx):
        with torch.no_grad():
            eigen_x = self.eigen(edge_idx)
        x = torch.relu(self.gat1(x, edge_idx, eigen_x))
        out = self.gat2(x, edge_idx, eigen_x)

        return out

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [10]:
datasets = [Planetoid(root='data/CiteSeer/', name='CiteSeer'), Planetoid(root='data/Cora/', name='Cora'),Planetoid(root='data/PubMed/', name='PubMed')]

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.test.index
Processing...
Done!
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/dat

In [None]:
class Model_GAT(nn.Module):
    def __init__(self, d, out_d, K):
        super(Model_GAT, self).__init__()
        self.conv1 = GATv3(d, out_d)

    def forward(self, data):
        x = data.x
        x, attn_weights = self.conv1(x, data.edge_index)
        
        return x.squeeze(-1), attn_weights, pair_pred

In [None]:
def create_ansatz_gatv3(model, mean, R):
    model_mlp_gat_ansatz = copy.deepcopy(model)
    w = (R / torch.norm(mean)) * mean
    
    with torch.no_grad():
        model_mlp_gat_ansatz.gat1.
        
    model_mlp_gat_ansatz