I use scatter_max with 1800w+ points occur illegal CUDA memory.
It is Ok with 1500w+ points or with cpu operator.
torch 2.3.1+cu118
torch-scatter 2.1.2
usage:
new_feat, argmax = torch_scatter.scatter_max(feat, unq_inv, dim=0)
Illegal CUDA memory would happen in forward or backward stage.