In [11]:
import sys
sys.path.append("../")

import torch
import torch_geometric as pyg
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F

from models.GAT_Block import GAT_Block
from models.Transformer_Block import Transformer_Block
from torch_geometric.utils import to_dense_adj

In [2]:
node_feats = torch.arange(6, dtype=torch.float32).view(3, 2)+1

edge_index = torch.tensor([[ 0, 1, 1, 2, 2],
                          [1, 0, 2, 0, 1]])

adj_matrix = pyg.utils.to_dense_adj(edge_index).squeeze()

print("Node features:\n", node_feats)
print("\nEdge index:\n", edge_index)
print("\nAdjacency matrix:\n", adj_matrix)


Node features:
 tensor([[1., 2.],
        [3., 4.],
        [5., 6.]])

Edge index:
 tensor([[0, 1, 1, 2, 2],
        [1, 0, 2, 0, 1]])

Adjacency matrix:
 tensor([[0., 1., 0.],
        [1., 0., 1.],
        [1., 1., 0.]])


# GAT Block

In [3]:
layer=GAT_Block(2,2)

layer.lin.weight.data = torch.Tensor([[1.0, 0.0], [0.0, 1.0]])
layer.a.data = torch.ones([1,4])

out, att_mtx= layer(node_feats, edge_index, return_attn_matrix=True)


print(out)
print(att_mtx)


tensor([[100., 124.],
        [100., 128.],
        [ 54.,  72.]], grad_fn=<ViewBackward0>)
tensor([[ 0., 10.,  0.],
        [10.,  0., 18.],
        [14., 18.,  0.]], grad_fn=<SqueezeBackward1>)


In [4]:
layer=GAT_Block(2,2)

layer.lin.weight.data = torch.Tensor([[1.0, 2.0], [3.0, 4.0]])
layer.a.data = torch.Tensor([[1., 2., 3., 4.]])

out, att_mtx= layer(node_feats, edge_index, use_softmax=True, return_attn_matrix=True)


print(out)
print(att_mtx)

tensor([[17., 39.],
        [17., 39.],
        [11., 25.]], grad_fn=<ViewBackward0>)
tensor([[0.0000e+00, 2.9375e-30, 0.0000e+00],
        [1.7139e-15, 0.0000e+00, 1.0000e+00],
        [1.0000e+00, 1.0000e+00, 0.0000e+00]], grad_fn=<SqueezeBackward1>)


In [5]:
from torch_geometric.nn import GATConv

layer = GATConv(2, 2, heads=1, add_self_loops=False)

layer.lin_src.weight.data=Tensor([[1.0, 2.0], [3.0, 4.0]])
layer.att_src.data = Tensor([1, 2])
layer.att_dst.data = Tensor([3, 4])

with torch.no_grad():
   out_feats = layer(node_feats, edge_index, return_attention_weights=True)

out_feats


(tensor([[17., 39.],
         [17., 39.],
         [11., 25.]]),
 (tensor([[0, 1, 1, 2, 2],
          [1, 0, 2, 0, 1]]),
  tensor([[2.9375e-30],
          [1.7139e-15],
          [1.0000e+00],
          [1.0000e+00],
          [1.0000e+00]])))

# Transformer Block

In [6]:
layer2=Transformer_Block(2,2)

layer2.lin1.weight.data = torch.Tensor([[0.0, 0.0], [0.0, 0.0]])
layer2.lin2.weight.data = torch.Tensor([[1.0, 0.0], [0.0, 1.0]])
layer2.lin3.weight.data = torch.Tensor([[1.0, 0.0], [0.0, 1.0]])
layer2.lin4.weight.data = torch.Tensor([[1.0, 0.0], [0.0, 1.0]])

out2, att_mtx2= layer2(node_feats, edge_index, return_attn_matrix=True)


print(out2)
print(att_mtx2)


tensor([[118., 146.],
        [206., 256.],
        [117., 156.]], grad_fn=<AddBackward0>)
tensor([[ 0., 11.,  0.],
        [11.,  0., 39.],
        [17., 39.,  0.]], grad_fn=<SqueezeBackward1>)


In [7]:
layer2=Transformer_Block(2,2)

layer2.lin1.weight.data = torch.Tensor([[1.0, 2.0], [3.0, 4.0]])
layer2.lin2.weight.data = torch.Tensor([[1.0, 2.0], [3.0, 4.0]])
layer2.lin3.weight.data = torch.Tensor([[1.0, 2.0], [3.0, 4.0]])
layer2.lin4.weight.data = torch.Tensor([[1.0, 2.0], [3.0, 4.0]])

out2, att_mtx2= layer2(node_feats, edge_index, use_softmax=True, return_attn_matrix=True)

print(out2)
print(att_mtx2)

tensor([[22., 50.],
        [28., 64.],
        [28., 64.]], grad_fn=<AddBackward0>)
tensor([[0., 0., 0.],
        [0., 0., 1.],
        [1., 1., 0.]], grad_fn=<SqueezeBackward1>)


In [8]:
from torch_geometric.nn import TransformerConv

layer2=TransformerConv(2, 2, heads=1, bias = False)

layer2.lin_key.weight.data = Tensor([[1.0, 2.0], [3.0, 4.0]]) ## x3 weight
layer2.lin_query.weight.data = Tensor([[1.0, 2.0], [3.0, 4.0]]) ## x4 weight
layer2.lin_value.weight.data = Tensor([[1.0, 2.0], [3.0, 4.0]]) ## x2 weight
layer2.lin_skip.weight.data = Tensor([[1.0, 2.0], [3.0, 4.0]]) ## x1 weight

layer2.lin_key.bias.data = torch.Tensor([ 0.0, 0.0])
layer2.lin_query.bias.data = torch.Tensor([ 0.0, 0.0])
layer2.lin_value.bias.data = torch.Tensor([ 0.0, 0.0])

with torch.no_grad():
   out_feats = layer2(node_feats, edge_index, return_attention_weights=True)

out_feats


(tensor([[22., 50.],
         [28., 64.],
         [28., 64.]]),
 (tensor([[0, 1, 1, 2, 2],
          [1, 0, 2, 0, 1]]),
  tensor([[0.],
          [0.],
          [1.],
          [1.],
          [1.]])))

## Generate $(XX^TX)$ term

In [9]:
X_Xt=node_feats@node_feats.T
print(X_Xt)

tensor([[ 5., 11., 17.],
        [11., 25., 39.],
        [17., 39., 61.]])


In [10]:
X_Xt_X=X_Xt@node_feats
print(X_Xt_X)

tensor([[123., 156.],
        [281., 356.],
        [439., 556.]])


In [12]:
## Change to a fully connected graph

edge_index_2 = torch.tensor([[ 0, 0, 0, 1, 1, 1, 2, 2, 2],
                          [0, 1, 2, 0, 1, 2, 0, 1, 2]])

adj_matrix_2 = to_dense_adj(edge_index_2).squeeze()

print("Node features:\n", node_feats)
print("\nEdge index:\n", edge_index_2)
print("\nAdjacency matrix:\n", adj_matrix_2)

Node features:
 tensor([[1., 2.],
        [3., 4.],
        [5., 6.]])

Edge index:
 tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2],
        [0, 1, 2, 0, 1, 2, 0, 1, 2]])

Adjacency matrix:
 tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])


In [13]:
layer2=Transformer_Block(2,2)

layer2.lin1.weight.data = torch.Tensor([[0.0, 0.0], [0.0, 0.0]])
layer2.lin2.weight.data = torch.Tensor([[1.0, 0.0], [0.0, 1.0]])
layer2.lin3.weight.data = torch.Tensor([[1.0, 0.0], [0.0, 1.0]])
layer2.lin4.weight.data = torch.Tensor([[1.0, 0.0], [0.0, 1.0]])

out2, att_mtx2= layer2(node_feats, edge_index_2, return_attn_matrix=True)

print(out2)
print(att_mtx2)

tensor([[123., 156.],
        [281., 356.],
        [439., 556.]], grad_fn=<AddBackward0>)
tensor([[ 5., 11., 17.],
        [11., 25., 39.],
        [17., 39., 61.]], grad_fn=<SqueezeBackward1>)


$X^{k+1}\rightarrow W_0X^k + W_3XX^T W_2X^k$