In [3]:
import torch
import torch.nn as nn
from torch_geometric.nn import aggr

In [11]:
class ScatterAdd(nn.Module):
    def __init__(self, indices):
        
        super().__init__()
        self.indices = indices
        # self.aggr_add = aggr.SumAggregation()
        self.aggr_func = aggr.MultiAggregation(['sum', 'mean', 'max', 'std'], mode='cat')
        
    def forward(self, source):
        x_out = self.aggr_func(source, self.indices)
        return x_out


In [12]:
index = torch.tensor([0, 0, 1, 2])
model = ScatterAdd(index)
model.eval()

ScatterAdd(
  (aggr_func): MultiAggregation([
    SumAggregation(),
    MeanAggregation(),
    MaxAggregation(),
    StdAggregation(),
  ], mode=cat)
)

In [13]:
# Feature matrix holding 10 elements with 64 features each:
torch.manual_seed(12345)
# x = torch.randn(4, 2)
x = torch.ones(4, 2)

In [14]:
model(x)

tensor([[2., 2., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.]])

In [15]:
new_src = torch.randn(4, 2)
torch_script = torch.jit.script(model)
script_out = torch_script(new_src)
script_out

tensor([[ 0.2309, -3.9140,  0.1155, -1.9570,  1.4271, -1.8701,  1.3117,  0.0869],
        [-0.4560, -1.4295, -0.4560, -1.4295, -0.4560, -1.4295,  0.0000,  0.0000],
        [-0.7175,  1.3922, -0.7175,  1.3922, -0.7175,  1.3922,  0.0000,  0.0000]])

The aggregation functions cannot be imported to ONNX. For example, the following code will throw an error:

```python
torch.onnx.export(torch_script, new_src, 
                  "scatter_add.onnx", 
                  input_names=["source"], 
                  output_names=["x_out"],
                  export_params=True, 
                  opset_version=16)
```

However, they can be saved as jit script and perform the inference as seen in the following code.

In [16]:
torch.jit.save(torch_script, "scatter_add.pt")

In [17]:
new_model = torch.jit.load("scatter_add.pt")

In [18]:
new_output = new_model(new_src)

In [19]:
assert new_output.equal(script_out)