#  Graph attention networks (GAT) implementation
- https://www.youtube.com/watch?v=A-yKQamf2Fc [Understanding Graph Attention Networks]
- https://www.youtube.com/watch?v=CwsPoa7z2c8 [Pytorch Geometric tutorial: Graph attention networks (GAT) implementation]
- https://github.com/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial3/Tutorial3.ipynb
- https://github.com/rish-16/pytorch-graphdl/blob/main/gat/layers.py
- https://github.com/rish-16/gin-attn-conv-pytorch/blob/main/gin_attn_pytorch/gin_attn_conv.py
- https://arxiv.org/abs/1710.10903

## How to reshape the embeds to pass the attention $a$

![Image](./Graph_Attention_Network/2022-03-02_23-05-53_screenshot.png)

In [105]:
import numpy as np
import torch

In [109]:
h = torch.randint(low=0, high=20, size=(N, embed_size))
print(20*"-")
print(h)
print(h.shape)

print(20*"-")
print(h.repeat(1, N))
print(h.repeat(1, N).shape)

print(20*"-")
print(h.repeat(1, N).view(N * N, -1))
print(h.repeat(1, N).view(N * N, -1).shape)

print(20*"-")
print(h.repeat(N, 1))
print(h.repeat(N, 1).shape)

print(20*"-")
print(torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1))
print(torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).shape)

print(20*"-")
a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * embed_size)
print(a_input.shape)
print(a_input)

--------------------
tensor([[ 0, 11, 11, 12],
        [12, 11,  2,  2],
        [ 0,  5, 17,  7]])
torch.Size([3, 4])
--------------------
tensor([[ 0, 11, 11, 12,  0, 11, 11, 12,  0, 11, 11, 12],
        [12, 11,  2,  2, 12, 11,  2,  2, 12, 11,  2,  2],
        [ 0,  5, 17,  7,  0,  5, 17,  7,  0,  5, 17,  7]])
torch.Size([3, 12])
--------------------
tensor([[ 0, 11, 11, 12],
        [ 0, 11, 11, 12],
        [ 0, 11, 11, 12],
        [12, 11,  2,  2],
        [12, 11,  2,  2],
        [12, 11,  2,  2],
        [ 0,  5, 17,  7],
        [ 0,  5, 17,  7],
        [ 0,  5, 17,  7]])
torch.Size([9, 4])
--------------------
tensor([[ 0, 11, 11, 12],
        [12, 11,  2,  2],
        [ 0,  5, 17,  7],
        [ 0, 11, 11, 12],
        [12, 11,  2,  2],
        [ 0,  5, 17,  7],
        [ 0, 11, 11, 12],
        [12, 11,  2,  2],
        [ 0,  5, 17,  7]])
torch.Size([9, 4])
--------------------
tensor([[ 0, 11, 11, 12,  0, 11, 11, 12],
        [ 0, 11, 11, 12, 12, 11,  2,  2],
        [ 

In [180]:
N = 3
embed_size = 4
h = np.random.randint(low=0, high=20, size=(N, embed_size))
print(h.shape)
h

(3, 4)


array([[ 1,  5,  7, 13],
       [19, 18,  5,  2],
       [ 7, 16,  0, 16]])

In [183]:
a =  np.hstack(N*[h]).reshape(N*N, -1)
print(a.shape)
a

(9, 4)


array([[ 1,  5,  7, 13],
       [ 1,  5,  7, 13],
       [ 1,  5,  7, 13],
       [19, 18,  5,  2],
       [19, 18,  5,  2],
       [19, 18,  5,  2],
       [ 7, 16,  0, 16],
       [ 7, 16,  0, 16],
       [ 7, 16,  0, 16]])

In [184]:
b =  np.vstack(N*[h])
print(b.shape)
b

(9, 4)


array([[ 1,  5,  7, 13],
       [19, 18,  5,  2],
       [ 7, 16,  0, 16],
       [ 1,  5,  7, 13],
       [19, 18,  5,  2],
       [ 7, 16,  0, 16],
       [ 1,  5,  7, 13],
       [19, 18,  5,  2],
       [ 7, 16,  0, 16]])

In [188]:
concat_whi_whj  = np.concatenate([a, b], axis=1)
concat_whi_whj  

array([[ 1,  5,  7, 13,  1,  5,  7, 13],
       [ 1,  5,  7, 13, 19, 18,  5,  2],
       [ 1,  5,  7, 13,  7, 16,  0, 16],
       [19, 18,  5,  2,  1,  5,  7, 13],
       [19, 18,  5,  2, 19, 18,  5,  2],
       [19, 18,  5,  2,  7, 16,  0, 16],
       [ 7, 16,  0, 16,  1,  5,  7, 13],
       [ 7, 16,  0, 16, 19, 18,  5,  2],
       [ 7, 16,  0, 16,  7, 16,  0, 16]])

In [196]:
concat_whi_whj.reshape((N, -1, 2 * embed_size))

array([[[ 1,  5,  7, 13,  1,  5,  7, 13],
        [ 1,  5,  7, 13, 19, 18,  5,  2],
        [ 1,  5,  7, 13,  7, 16,  0, 16]],

       [[19, 18,  5,  2,  1,  5,  7, 13],
        [19, 18,  5,  2, 19, 18,  5,  2],
        [19, 18,  5,  2,  7, 16,  0, 16]],

       [[ 7, 16,  0, 16,  1,  5,  7, 13],
        [ 7, 16,  0, 16, 19, 18,  5,  2],
        [ 7, 16,  0, 16,  7, 16,  0, 16]]])