In [4]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")

import numpy as np
import torch
import torch.nn as nn

np.random.seed(42)
torch.manual_seed(42)    

softmax_0d = nn.Softmax(dim = 0)
softmax_1d = nn.Softmax(dim = 1)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
gamma = [[4,5,6,7]]
# gamma = np.array(gamma).astype('float32')
gamma = torch.tensor(gamma).to(torch.float32)

v = [[0.8, 0.2, 0.5, 0.1]]
v = torch.tensor(v).to(torch.float32)

# in all cases, except VA, you want your V to be in dim 2x2. This assumed only one data point goes in
v = v.reshape(2,2)

## Grouped Vector Attention

As written in the paper:

![GVA](images/GVA_paper.png)


## How it works out

Won't touch on Vector Attention, since that was so [2020](https://arxiv.org/abs/2012.09164). Below is how I think the rest of the algorithms work. 

Note the Linear Layer in GVA-Linear weights are given b:

```
[[4.0, 4.0, 2.0, 1.0],
 [5.0, 6.0, 1.0, 2.0]]
```

Also do note that the the numbers in the middle layer were not softmax'd. Don't worry though. The numbers do check out in the workings below. 

![GVA writings](images/GVA_written_out.png)

### GVA - Linear 

Easy implementation. Don't need to tinker with the weights to reshape. Just need to reshape out before hadamard with v. 

In [6]:
weights = nn.Linear(4,2, bias = False).to(torch.float32)
weights.weight = torch.nn.Parameter(torch.tensor([[4.0,4.0,2.0,1.0],[5.0, 6.0, 1.0, 2.0]])).to(torch.float32)

In [7]:
out = weights(gamma)
out = out.reshape(2,1)
out
# softmax here before do hadamard with the v
out = softmax_0d(out)

In [8]:
# hadamard product
out*v

tensor([[2.4472e-07, 6.1180e-08],
        [5.0000e-01, 1.0000e-01]], grad_fn=<MulBackward0>)

# GVA - MSA

The "easiest" because all you do is reshape gamma and sum them 

In [9]:
softmax_0d = nn.Softmax(dim = 0)
out_msa = gamma.reshape(2,2)
out_msa = torch.sum(out_msa, dim = 1)
out_msa = out_msa.reshape(2,1)
out_msa *= 1 / np.sqrt(2) # we have 4 channel features, 2 groups. Hence 2
out_msa = softmax_0d(out_msa)
out_msa

tensor([[0.0558],
        [0.9442]])

In [10]:
out_msa * v

tensor([[0.0446, 0.0112],
        [0.4721, 0.0944]])

### GVA - Grouped Linear

Solution - use 1D convolution, groups = 2 or how many you want to specify. They don't appear to share weights so it should work out right. 

In [11]:
conv_weights = nn.Conv1d(4,2, 1, groups = 2, bias = False)
conv_weights.weight = torch.nn.Parameter(torch.tensor([[[4.0],[4.0]],[[5.0],[6.0]]])).to(torch.float32)

In [12]:
out_gl = conv_weights(gamma.unsqueeze(0).permute(0,2,1))
out_gl = out_gl.squeeze(-1).transpose(1,0)
out_gl

tensor([[36.],
        [72.]], grad_fn=<TransposeBackward0>)

In [13]:
out_gl = softmax_0d(out_gl)
out_gl * v

tensor([[1.8556e-16, 4.6390e-17],
        [5.0000e-01, 1.0000e-01]], grad_fn=<MulBackward0>)

In [50]:
# simulate pointclouds
point_clouds = torch.randn(1, 1, 4)
out_gl = conv_weights(point_clouds.permute(0,2,1))
out_gl = softmax_1d(out_gl)

In [15]:
out_gl = out_gl.permute(0,2,1)
out_gl * v

tensor([[[0.7315, 0.0171],
         [0.4572, 0.0086]]], grad_fn=<MulBackward0>)

In [33]:
from src.model.PTv2.ptv2_utils import GroupVectorAttention, PositionalEncoding

In [68]:
gva = GroupVectorAttention(4, 2, 2)
points = torch.randn(1, 16, 7)
neighbours = torch.randn(1, 16, 14, 7)
out = gva(points[..., :3], points[...,3:], neighbours[..., :3], neighbours[...,3:])
print(out.shape)

torch.Size([1, 16, 14, 4])
torch.Size([1, 16, 14, 4])
