## Torch scatter

* [Documentation](https://pytorch-scatter.readthedocs.io/en/latest/)
* [GitHub](https://github.com/rusty1s/pytorch_scatter)

Scatter and segment operations are roughly described as reduce operations based on given "group-index" tensor. The `torch_scatter` package builts upon these operations

In [1]:
import torch 
import torch.nn as nn
import numpy as np 

random_seed = 42
np.random.seed(42)
torch.manual_seed(random_seed);

Looking at basic `scatter_` function in PyTorch:
You can supply a tensor to new tensor which is either of same dimension or bigger based on the index tensor defined for the source tensor. 

* Source tensor 
* Index tensor to indicate what would be the value of the element in the new tensor as borrowed from the source 
* dimension along which this is iterated -- 0 - along row, 1 along column

In [3]:
x = torch.rand(2,2)
print(x)

tensor([[0.3904, 0.6009],
        [0.2566, 0.7936]])


In [4]:
torch.zeros(3, 3).scatter_(0, torch.tensor([[1, 0], [2, 1]]), x)

tensor([[0.0000, 0.6009, 0.0000],
        [0.3904, 0.7936, 0.0000],
        [0.2566, 0.0000, 0.0000]])

scatter_(dim, index, src) → Tensor

index: tells the index which will be used to take the value from source. 

In [5]:
from torch_scatter import scatter_max 

src =  torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out, argmax = scatter_max(src, index, dim=-1)

In [6]:
print(out)

tensor([[0, 0, 4, 3, 2, 0],
        [2, 4, 3, 0, 0, 0]])


In [7]:
print(argmax)

tensor([[5, 5, 3, 4, 0, 1],
        [1, 4, 3, 5, 5, 5]])
