In [2]:
import torch
import torch.nn as nn

In [3]:
x = torch.arange(20).reshape(5,4)

### Need a matrix/process that can concatenate each row to each other

In [69]:
def method_1():
    cat_a= torch.zeros(25,8)
    for i in range(len(a)):
        for j in range(len(a)):
         cat_a[i*len(a)+j] = torch.cat([a[i], a[j]])    
    return cat_a

In [67]:
def method_2():
    r_i = a.repeat_interleave(5, dim=0)
    r = a.repeat(5,1)
    cat_a = torch.cat([r_i,r], dim=1)
    return cat_a

In [71]:
%timeit method_1()

299 µs ± 2.27 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [72]:
%timeit method_2()

40.2 µs ± 2.89 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### Method 2 is nearly 10 times as fast

### Data flow

##### Inputs

10 nodes, each with 20 features

projected into a 15 feature space with 2 attention mechanisms

In [4]:
N = 10
InD = 20
OutD = 15
K = 2

$N \times InD$

In [5]:
X = torch.Tensor(N,InD)

In [6]:
edges_ = [(0,1),(0,3),(0,5),(1,2),(1,7),(2,9),(3,9),(4,5),(5,8),(6,7),(7,9)]
edges = []
for e in edges_:
            edges.append([e[0],e[1]])
            edges.append([e[1],e[0]])
            bond_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
AM = torch.sparse_coo_tensor(bond_index, torch.ones(bond_index.shape[1]))

In [7]:
E = len(edges)

### Attention mechanism

In [8]:
n = 10
d_in = 5
d_out = 4
num_heads = 2

In [9]:
x = torch.Tensor(n,d_in)

In [10]:
x.shape

torch.Size([10, 5])

In [11]:
w = torch.Tensor(d_in, d_out*num_heads)

In [12]:
w.shape

torch.Size([5, 8])

In [13]:
wh = torch.matmul(x,w)

In [14]:
wh.shape

torch.Size([10, 8])

In [15]:
wh = wh.view(-1, num_heads, d_out)

In [16]:
wh.shape

torch.Size([10, 2, 4])

In [17]:
att_src = torch.Tensor(1,num_heads,d_out)
att_dst = torch.Tensor(1,num_heads,d_out)

In [18]:
print(wh.shape, att_src.shape)

torch.Size([10, 2, 4]) torch.Size([1, 2, 4])


In [19]:
(wh*att_src).shape

torch.Size([10, 2, 4])

In [20]:
(wh*att_src).sum(-1).shape

torch.Size([10, 2])

In [21]:
(wh*att_dst).sum(-1).shape

torch.Size([10, 2])

### Computing only the nodes that are connected 

##### Forward pass

$InD \times OutD$

In [22]:
W = torch.Tensor(InD,OutD)

linear transform w/o attention nor graph structure

$N \times OutD$

In [23]:
XW = (X@W)
XW.shape

torch.Size([10, 15])

#### Attention

split attention matrix into two - source and destination. 

Each $(OutD \times K$)

In [24]:
att_src_w = torch.Tensor(K, OutD)
att_dst_w = torch.Tensor(K, OutD)

In [25]:
print(att_src_w.shape, att_dst_w.shape)

torch.Size([2, 15]) torch.Size([2, 15])


for each node calculate the attention coefficient components for source and destination, for each head


source: $ (N \times OutD ) \bullet (OutD \times K) \implies N \times K $


destination: $ (N \times OutD ) \bullet (OutD \times K)  \implies N \times K $



In [26]:
ac_src = XW@att_src_w.t()
ac_dst = XW@att_dst_w.t()

In [27]:
print(ac_dst.shape, ac_src.shape)

torch.Size([10, 2]) torch.Size([10, 2])


Now we need to find the correct combinations of source and destination coefficients based on the edges

The naive implementation is to find the exhaistive combination set. Which is $O(N^2)$. In our case, it would be $10^2 = 100$

Rather, we use the edge list to select the combinations we are interested in

For source and destination seperately, a new vector is created based on the original coefficient vectors.

Each time a node shows up in the edge list, its coeffient gets added to the selected list. If the node is the source in the edge list, it gets added to selected source vector. The same follows for destination

This leaves us with two tensors of dimensionality:

$ E \times K $ where E is number of edges if you count each edge twice since it is bidirectional

In [28]:
selected_ac_src = ac_src.index_select(0,bond_index[0])
selected_ac_dst = ac_dst.index_select(0,bond_index[1])

In [29]:
print(E, selected_ac_src.shape, selected_ac_dst.shape)

22 torch.Size([22, 2]) torch.Size([22, 2])


Now we can get the unnormalised attention coefficients by adding the two vectors and plugging the sum into a LRelU

$ E \times K $

In [30]:
raw_ac = nn.functional.leaky_relu_(selected_ac_src + selected_ac_dst, 0.2)

In [31]:
raw_ac.shape

torch.Size([22, 2])

###### normalise coefficients

Done via softmax. The denominator - population which we normalise against - is the neighbourhood i.e. connected nodes

Numerator:

In [32]:
exp_ac = raw_ac.exp()

Denominator:

For each coefficient we need to find the right denominator

This is done by adding up all the coefficients that are associated with the same edge

This can be done via scatter add

We will use the first part of the bond index (source) - irrelevant which is chosen 

Workaround Note:
"Index tensor must have the same number of dimensions as src tensor" for `scatter_add`
Therefore we repeat the index to match the source(exponentiated ACs)

exp_acs : $ E \times K $

idx : $ E \implies E \times K $

In [33]:
exp_ac.shape

torch.Size([22, 2])

In [34]:
bond_index[0].shape

torch.Size([22])

In [35]:
idx_stretched = bond_index[0].unsqueeze(-1).repeat((1,K))
idx_stretched.shape

torch.Size([22, 2])

In [36]:
neighbourhood_sum = torch.zeros((N,K))
neighbourhood_sum.shape

torch.Size([10, 2])

Logic: find all coefficients linked to same edge source (Node) and add them together. Repeat for each attention head


result: $ N \times K $

In [37]:
neighbourhood_sum = torch.scatter_add(neighbourhood_sum,0,idx_stretched,exp_ac)

In [38]:
neighbourhood_sum.shape

torch.Size([10, 2])

But, the numerator is of dimensionality $E \times K$

We need to reshape these sums to match up with numerator.

Since a node can be in numerous edges, we need to use the sums and assign them to associated edges. Done for each attention head

Transformation: $ N \times K \implies E \times K$



In [39]:
neighbourhood_sum_edges = neighbourhood_sum.index_select(0,bond_index[0])

In [40]:
neighbourhood_sum_edges.shape

torch.Size([22, 2])

In [41]:
norm_acs = exp_ac/neighbourhood_sum_edges

In [42]:
norm_acs.shape

torch.Size([22, 2])

###### Weighting the features with attention coefficients

The attention coefficients are specified in relation to edges while the features are in terms of nodes. We have to select the features based on the edges

features: $N \times OutD \implies E \times OutD$
    
attention_coefficients: $ E \times K   $



In [43]:
print(XW.shape, norm_acs.shape)

torch.Size([10, 15]) torch.Size([22, 2])


In [50]:
XW_edge = XW.index_select(0,bond_index[0])

In [51]:
XW_edge.shape

torch.Size([22, 15])

We also need a copy of features for each attention head

features: $ E \times OutD \implies E \times K \times OutD $

In [52]:
XWK_edge = XW_edge.unsqueeze_(1).repeat(1,K,1)

In [53]:
XWK_edge.shape

torch.Size([22, 2, 15])

In order to do a Hadamard product, we need the same number of dimensions

coefficients: $ E \times K \implies E \times K \times 1 $

In [54]:
XWK_weighted = XWK_edge * norm_acs.unsqueeze_(-1)

In [55]:
XWK_weighted.shape

torch.Size([22, 2, 15])

###### Finally we update the features through aggregation

Keep in mind that we still have a unique set of features for each head

In [56]:
updated_features = torch.zeros(N,K,OutD)

In [57]:
updated_features.shape

torch.Size([10, 2, 15])

Stretch index to comply with `scatter_add`

Two ways of doing this - second way is way it was done in https://github.com/gordicaleksa/pytorch-GAT

In [61]:
idx_stretched = bond_index[1].unsqueeze(-1).unsqueeze(-1).expand_as(XWK_weighted).shape

In [59]:
idx_stretched = bond_index[1].unsqueeze(-1).unsqueeze(-1).repeat((1,K,OutD))
print(idx_stretched.shape)

torch.Size([22, 2, 15])


In [252]:
idx_streteched_ = bond_index[1].unsqueeze(-1).unsqueeze(-1).expand_as(XWK_weighted)
print(idx_streteched_.shape)

torch.Size([22, 2, 15])


In [249]:
idx_stretched.shape

torch.Size([22, 2, 1])

In [250]:
updated_features = updated_features.scatter_add_(0,idx_stretched,XWK_weighted)

#### aggregating heads

concatenate heads - can be done by leaving a heads dimension

$ N \times K \times OutD $

In [255]:
updated_features.shape

torch.Size([10, 2, 15])

concatenation by collapsing heads

$ N \times (K \times OutD) $

In [256]:
updated_features.reshape(N, K*OutD).shape

torch.Size([10, 30])

average out heads - usually done on the final layer

$ N \times OutD $

In [257]:
updated_features.mean(1).shape

torch.Size([10, 15])