[Pytorch Geometric tutorial: Graph attention networks (GAT) implementation](https://www.youtube.com/watch?v=CwsPoa7z2c8&ab_channel=AntonioLonga)

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Structure

```
class GATLayer(nn.Module):
    """
    Simple PyTorch Implementation of the Graph Attention layer.
    """
    def __init__(self):
        super(GATLayer, self).__init__()
      
    def forward(self, input, adj):
        print("")
```


Let's start from the forward method
Linear Transformation
$$
\bar{h'}_i = \textbf{W}\cdot \bar{h}_i
$$
with $\textbf{W}\in\mathbb R^{F'\times F}$ and $\bar{h}_i\in\mathbb R^{F}$.




torch.Size([3, 2])


Linear Transformation 을 통해서  $\bar{h'}_i$ 만들기



In [6]:


in_features = 5
out_features = 2
nb_nodes = 3

W = nn.Parameter(torch.zeros(size=(in_features, out_features))) #xavier paramiter inizializator
nn.init.xavier_uniform_(W.data, gain=1.414)

input = torch.rand(nb_nodes,in_features) 


# linear transformation
h = torch.mm(input, W)
N = h.size()[0]

print(h.shape)

torch.Size([3, 2])


In [4]:
a = nn.Parameter(torch.zeros(size=(2*out_features, 1))) #xavier paramiter inizializator
nn.init.xavier_uniform_(a.data, gain=1.414)
print(a.shape)

leakyrelu = nn.LeakyReLU(0.2)  # LeakyReLU

torch.Size([4, 1])


In [5]:
a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * out_features)

In [7]:
h.repeat(1, N).view(N * N, -1)

tensor([[ 0.6070, -2.3404],
        [ 0.6070, -2.3404],
        [ 0.6070, -2.3404],
        [ 0.5616, -1.8638],
        [ 0.5616, -1.8638],
        [ 0.5616, -1.8638],
        [-0.2128, -1.6883],
        [-0.2128, -1.6883],
        [-0.2128, -1.6883]], grad_fn=<ViewBackward>)

In [9]:
h.repeat(N, 1)

tensor([[ 0.6070, -2.3404],
        [ 0.5616, -1.8638],
        [-0.2128, -1.6883],
        [ 0.6070, -2.3404],
        [ 0.5616, -1.8638],
        [-0.2128, -1.6883],
        [ 0.6070, -2.3404],
        [ 0.5616, -1.8638],
        [-0.2128, -1.6883]], grad_fn=<RepeatBackward>)

In [11]:
a_input.shape

torch.Size([3, 3, 4])