In [2]:
from typing import Callable, Optional, Union

import torch
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.typing import (
    Adj,
    OptTensor,
    PairOptTensor,
    PairTensor,
    SparseTensor,
    torch_sparse,
)
from torch_geometric.utils import add_self_loops, remove_self_loops

class CustomNet(MessagePassing):
    r"""The PointNet set layer from the `"PointNet: Deep Learning on Point Sets
    for 3D Classification and Segmentation"
    <https://arxiv.org/abs/1612.00593>`_ and `"PointNet++: Deep Hierarchical
    Feature Learning on Point Sets in a Metric Space"
    <https://arxiv.org/abs/1706.02413>`_ papers

    .. math::
        \mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in
        \mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j,
        \mathbf{p}_j - \mathbf{p}_i) \right),

    where :math:`\gamma_{\mathbf{\Theta}}` and :math:`h_{\mathbf{\Theta}}`
    denote neural networks, *i.e.* MLPs, and
    :math:`\mathbf{P} \in \mathbb{R}^{N \times D}` defines the position of
    each point.

    Args:
        local_nn (torch.nn.Module, optional): A neural network
            :math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x` and
            relative spatial coordinates :obj:`pos_j - pos_i` of shape
            :obj:`[-1, in_channels + num_dimensions]` to shape
            :obj:`[-1, out_channels]`, *e.g.*, defined by
            :class:`torch.nn.Sequential`. (default: :obj:`None`)
        global_nn (torch.nn.Module, optional): A neural network
            :math:`\gamma_{\mathbf{\Theta}}` that maps aggregated node features
            of shape :obj:`[-1, out_channels]` to shape :obj:`[-1,
            final_out_channels]`, *e.g.*, defined by
            :class:`torch.nn.Sequential`. (default: :obj:`None`)
        add_self_loops (bool, optional): If set to :obj:`False`, will not add
            self-loops to the input graph. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.

    Shapes:
        - **input:**
          node features :math:`(|\mathcal{V}|, F_{in})` or
          :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))`
          if bipartite,
          positions :math:`(|\mathcal{V}|, 3)` or
          :math:`((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))` if bipartite,
          edge indices :math:`(2, |\mathcal{E}|)`
        - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or
          :math:`(|\mathcal{V}_t|, F_{out})` if bipartite
    """
    def __init__(self, local_nn: Optional[Callable] = None,
                 global_nn: Optional[Callable] = None,
                 add_self_loops: bool = True, **kwargs):
        kwargs.setdefault('aggr', 'max')
        super().__init__(**kwargs)

        self.local_nn = local_nn
        self.global_nn = global_nn
        self.add_self_loops = add_self_loops

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        reset(self.local_nn)
        reset(self.global_nn)

    def forward(self, x: Union[OptTensor, PairOptTensor],
                pos: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor:

        if not isinstance(x, tuple):
            x: PairOptTensor = (x, None)

        if isinstance(pos, Tensor):
            pos: PairTensor = (pos, pos)

        if self.add_self_loops:
            if isinstance(edge_index, Tensor):
                edge_index, _ = remove_self_loops(edge_index)
                edge_index, _ = add_self_loops(
                    edge_index, num_nodes=min(pos[0].size(0), pos[1].size(0)))
            elif isinstance(edge_index, SparseTensor):
                edge_index = torch_sparse.set_diag(edge_index)

        # propagate_type: (x: PairOptTensor, pos: PairTensor)
        out = self.propagate(edge_index, x=x, pos=pos, size=None)

        if self.global_nn is not None:
            out = self.global_nn(out)

        return out


    def message(self, x_j: Optional[Tensor], pos_i: Tensor,
                pos_j: Tensor) -> Tensor:
        print("x_j", x_j)
        print("pos_j", pos_j)
        print("pos_i", pos_i)
        msg = pos_j - pos_i
        if x_j is not None:
            msg = torch.cat([x_j, msg], dim=1)
        print('message', msg)
        if self.local_nn is not None:
            msg = self.local_nn(msg)
        print('output of message', msg)
        return msg

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(local_nn={self.local_nn}, '
                f'global_nn={self.global_nn})')

In [3]:
class SAModule(torch.nn.Module):

    def __init__(self, ratio, r, nn):
        super().__init__()
        self.ratio = ratio
        self.r = r
        self.conv = CustomNet(nn, add_self_loops=False)

    def forward(self, x, pos, batch):
        # ------ fps ---------
        # this generates indices to sample from data
        # first index represents random value from pos
        # all subsequent indices represent values furthest from pos
        # ratio defines how many points to sample
        idx = fps(pos, batch, ratio=self.ratio)
        # ------ radius -------
        # finds for each element in pos[idx] all points in pos
        # within distance self.r
        # row is the pos[idx] indices
        # e.g. [0,0,1,1,2,2] - first, second, third points
        # col is the index of the nearest points to these
        # e.g. [1,0,2,1,3,0]
        # this all means that
        # pos[idx][0] is nearest to pos[1] and pos[0]
        # pos[idx][1] is nearest to pos[2] and pos[1]
        # pos[idx][2] is nearest to pos[3] and pos[0]
        #row, col = radius(pos, pos[idx], self.r, batch, batch[idx],
        #                  max_num_neighbors=64)
        #edge_index = torch.stack([col, row], dim=0)
        row, col = radius(pos,
                            pos[idx],
                            self.r,
                            batch,
                            batch[idx],
                            max_num_neighbors=64)        
        # don't really get this as i think ends up just being same as if 
        # they hadn't split row and col in first place need to check this!
        edge_index = torch.stack([col, row], dim=0)

        x_dst = None if x is None else x[idx]
        
        print("x", x)
        print("x dst", x_dst)
        print("pos", pos)
        print("pos idx ", pos[idx])
        print("edge index", edge_index)
        
        x = self.conv((x, x_dst), (pos, pos[idx]), edge_index)
        pos, batch = pos[idx], batch[idx]
        return x, pos, batch


In [4]:
from torch_geometric.nn import MLP, PointNetConv, fps, global_max_pool, radius, knn_interpolate

In [5]:
sa1_module = SAModule(1.0, 2.0, MLP([4, 64, 64, 3]))

In [6]:
x = torch.tensor([[10,11], [12,13], [14,15], [15,16], [17,18]])
#x = None
pos = torch.tensor([[1,0], [1,1], [.5,0], [1,1.5], [10,20]])
batch = torch.tensor([0, 0, 0, 0, 0])
x, pos, batch = sa1_module(x, pos, batch)
print('output x', x)

x tensor([[10, 11],
        [12, 13],
        [14, 15],
        [15, 16],
        [17, 18]])
x dst tensor([[14, 15],
        [17, 18],
        [15, 16],
        [10, 11],
        [12, 13]])
pos tensor([[ 1.0000,  0.0000],
        [ 1.0000,  1.0000],
        [ 0.5000,  0.0000],
        [ 1.0000,  1.5000],
        [10.0000, 20.0000]])
pos idx  tensor([[ 0.5000,  0.0000],
        [10.0000, 20.0000],
        [ 1.0000,  1.5000],
        [ 1.0000,  0.0000],
        [ 1.0000,  1.0000]])
edge index tensor([[0, 1, 2, 3, 4, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],
        [0, 0, 0, 0, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4]])
x_j tensor([[10, 11],
        [12, 13],
        [14, 15],
        [15, 16],
        [17, 18],
        [10, 11],
        [12, 13],
        [14, 15],
        [15, 16],
        [10, 11],
        [12, 13],
        [14, 15],
        [15, 16],
        [10, 11],
        [12, 13],
        [14, 15],
        [15, 16]])
pos_j tensor([[ 1.0000,  0.0000],
        [ 1.0000,  1.0000],
        

In [66]:
x_skip = torch.tensor([[10,11], [12,13], [14,15], [15,16], [17,18], [19,20]])


In [67]:
torch.cat([x,x_skip], dim=1)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 5 but got size 6 for tensor number 1 in the list.

In [64]:
x

tensor([[ 0.4098,  0.6650,  0.5324],
        [ 1.1502, -0.2567,  1.1518],
        [ 0.6526,  0.1453,  0.5316],
        [ 0.2675,  0.7671,  0.3871],
        [ 0.3934,  0.3409,  0.2737]], grad_fn=<ScatterReduceBackward0>)

In [65]:
x_skip

tensor([[10, 11],
        [12, 13],
        [14, 15],
        [15, 16],
        [17, 18]])

In [16]:

class SAModule(torch.nn.Module):

    def __init__(self, ratio, r, nn):
        super().__init__()
        self.ratio = ratio
        self.r = r
        self.conv = PointNetConv(nn, add_self_loops=False)

    def forward(self, x, pos, batch):
        # ------ fps ---------
        # this generates indices to sample from data
        # first index represents random value from pos
        # all subsequent indices represent values furthest from pos
        # ratio defines how many points to sample
        idx = fps(pos, batch, ratio=self.ratio)
        # ------ radius -------
        # finds for each element in pos[idx] all points in pos
        # within distance self.r
        # row is the pos[idx] indices
        # e.g. [0,0,1,1,2,2] - first, second, third points
        # col is the index of the nearest points to these
        # e.g. [1,0,2,1,3,0]
        # this all means that
        # pos[idx][0] is nearest to pos[1] and pos[0]
        # pos[idx][1] is nearest to pos[2] and pos[1]
        # pos[idx][2] is nearest to pos[3] and pos[0]

        # note they stack the other way round!
        row, col = radius(pos, pos[idx], self.r, batch, batch[idx],
                          max_num_neighbors=64)
        print('row,col', row)
        print(col)
        edge_index = torch.stack([col, row], dim=0)
        print('edge', edge_index)

        x_dst = None if x is None else x[idx]
        x = self.conv((x, x_dst), (pos, pos[idx]), edge_index)
        pos, batch = pos[idx], batch[idx]
        return x, pos, batch


class GlobalSAModule(torch.nn.Module):

    def __init__(self, nn):
        super().__init__()
        self.nn = nn

    def forward(self, x, pos, batch):
        x = self.nn(torch.cat([x, pos], dim=1))
        x = global_max_pool(x, batch)
        pos = pos.new_zeros((x.size(0), 2))
        batch = torch.arange(x.size(0), device=batch.device)
        return x, pos, batch


class FPModule(torch.nn.Module):
    def __init__(self, k, nn):
        super().__init__()
        self.k = k
        self.nn = nn

    def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):
        print('pre inter', x.shape)
        x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
        print('post inter', x.shape)
        if x_skip is not None:
            x = torch.cat([x, x_skip], dim=1)
        x = self.nn(x)
        return x, pos_skip, batch_skip

                                                                                                                                                                               
class PointNetClassification(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # Input channels account for both `pos` and node features.
        self.sa1_module = SAModule(0.5, 0.2, MLP([3, 64, 64, 128]))
        self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))
        self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))

        self.mlp = MLP([1024, 512, 256, 10], dropout=0.5, norm=None)

    def forward(self, data):
        sa0_out = (data.x, data.pos, data.batch)
        sa1_out = self.sa1_module(*sa0_out)
        sa2_out = self.sa2_module(*sa1_out)
        sa3_out = self.sa3_module(*sa2_out)
        x, pos, batch = sa3_out

        return self.mlp(x).log_softmax(dim=-1)


class PointNetSegmentation(torch.nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        # Input channels account for both `pos` and node features.
        self.sa1_module = SAModule(1.0, 2.0, MLP([2, 64, 64, 6]))
        self.sa2_module = SAModule(0.2, 2.0, MLP([6 + 2, 128, 128, 3]))
        self.sa3_module = GlobalSAModule(MLP([3 + 2, 256, 512, 4]))

        self.fp3_module = FPModule(1, MLP([4 + 3, 256, 5]))
        self.fp2_module = FPModule(3, MLP([5 + 6, 256, 7]))
        self.fp1_module = FPModule(3, MLP([7, 128, 128, 3]))

        self.mlp = MLP([3, 128, 128, num_classes], dropout=0.5, norm=None)

        self.lin1 = torch.nn.Linear(128, 128)
        self.lin2 = torch.nn.Linear(128, 128)
        self.lin3 = torch.nn.Linear(128, num_classes)

    def forward(self, data):
        sa0_out = (data.x, data.pos, data.batch)
        sa1_out = self.sa1_module(*sa0_out)
        sa2_out = self.sa2_module(*sa1_out)
        #print('here', sa2_out)
        #print(sa2_out[1].shape)
        sa3_out = self.sa3_module(*sa2_out)
        
        #print(sa3_out[0].shape)
        #print(sa2_out[0].shape)
        
        fp3_out = self.fp3_module(*sa3_out, *sa2_out)
        print('--- prob ----')
        print(fp3_out[0].shape)
        #print(fp3_)
        print(sa1_out[0].shape)
        fp2_out = self.fp2_module(*fp3_out, *sa1_out)
        #print(fp2_out[0].shape)
        #print(sa0_out[0].shape)

        x, _, _ = self.fp1_module(*fp2_out, *sa0_out)

        return self.mlp(x).log_softmax(dim=-1)

In [17]:
import torch
from torch_geometric import transforms


In [18]:
#x = torch.tensor([[10,11], [12,13], [14,15], [15,16], [17,18], [14,14]])
pos = torch.tensor([[2,0], [3,1], [-.5,0], [2,1.5], [0,0], [-0.5,1]])
y = torch.tensor([1, 1, 0, 1, 0, 0])

#x = x.unsqueeze(0)
#pos = pos.unsqueeze(0)
batch = torch.tensor([0, 0, 0, 0, 0, 0])

from torch_geometric.data import Data

model = PointNetSegmentation(2)
data = Data()
#data.x = x
data.pos = pos
data.y = y
data.batch = batch
model(data)

row,col tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5])
tensor([0, 1, 3, 2, 4, 5, 0, 1, 3, 0, 1, 3, 2, 4, 5, 2, 4, 5])
edge tensor([[0, 1, 3, 2, 4, 5, 0, 1, 3, 0, 1, 3, 2, 4, 5, 2, 4, 5],
        [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5]])
row,col tensor([0, 0, 0, 1, 1, 1])
tensor([0, 2, 3, 1, 4, 5])
edge tensor([[0, 2, 3, 1, 4, 5],
        [0, 0, 0, 1, 1, 1]])
pre inter torch.Size([1, 4])
post inter torch.Size([2, 4])
--- prob ----
torch.Size([2, 5])
torch.Size([6, 6])
pre inter torch.Size([2, 5])
post inter torch.Size([6, 5])
pre inter torch.Size([6, 7])
post inter torch.Size([6, 7])


tensor([[-0.6925, -0.6938],
        [-0.7026, -0.6837],
        [-0.7023, -0.6841],
        [-0.7000, -0.6863],
        [-0.6925, -0.6938],
        [-0.6385, -0.7510]], grad_fn=<LogSoftmaxBackward0>)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

for epoch in range(1, 2):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=2),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='log_probs',
    ),
)
node_index = 10
explanation = explainer(data.x, data.edge_index, index=node_index)
print(explanation.edge_mask)
print(explanation.node_mask)
print(f'Generated explanations in {explanation.available_explanations}')

path = 'feature_importance.png'
explanation.visualize_feature_importance(path, top_k=10)
print(f"Feature importance plot has been saved to '{path}'")

path = 'subgraph.pdf'
explanation.visualize_graph(path)
print(f"Subgraph visualization plot has been saved to '{path}'")

In [None]:
from torch_geometric.explain import unfaithfulness

metric = unfaithfulness(explainer, explanation)
print(metric)

In [None]:
explanation.visualize_feature_importance(top_k=10)

explanation.visualize_graph()

In [9]:
transform = transforms.Compose([transforms.NormalizeScale(), transforms.RandomShear(0.05)])
print(data)

Data(pos=[6, 2], y=[6], batch=[6])


In [140]:
x = torch.tensor([[0,1,2],[4,2,3], [5,4,1]])
global_max_pool(x, batch=torch.tensor([0,0,0]))

tensor([[5, 4, 3]])

In [168]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    print(scores)
    p_attn = scores.softmax(dim=-1)
    print(p_attn)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [169]:
import math
query = torch.tensor([[5,3,4,2], [4,2,44,1], [99,2,3,1]], dtype=torch.float64)
key = torch.tensor([[10,11,2, 2], [1,27,7,1], [1,9,3,1]], dtype=torch.float64)
value = torch.tensor([[2,2,4, 4], [1,2,1,1], [1,10,3,1]], dtype=torch.float64)
a = attention(query, key, value)

tensor([[ 47.5000,  58.0000,  23.0000],
        [ 76.0000, 183.5000,  77.5000],
        [510.0000,  87.5000,  63.5000]], dtype=torch.float64)
tensor([[ 2.7536e-05,  9.9997e-01,  6.3049e-16],
        [ 2.0575e-47,  1.0000e+00,  9.2211e-47],
        [ 1.0000e+00, 3.2403e-184, 1.2232e-194]], dtype=torch.float64)


In [170]:
a[0]

tensor([[1.0000, 2.0000, 1.0001, 1.0001],
        [1.0000, 2.0000, 1.0000, 1.0000],
        [2.0000, 2.0000, 4.0000, 4.0000]], dtype=torch.float64)

In [154]:
attention

<function __main__.attention(query, key, value, mask=None, dropout=None)>

In [194]:
print(dir(model.sa1_module.conv.local_nn))

['T_destination', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_forward_hooks', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '_named_members', '_non_persistent_buffers_set', '_parameters', '_register_load_state_dict_pre_hook', '_register_state_dict_hook', '_replicate_for_data_p

In [203]:
print(dir(model.sa1_module.conv.local_nn.lins[0]))

['T_destination', '__annotations__', '__call__', '__class__', '__deepcopy__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_forward_hooks', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_lazy_load_hook', '_load_from_state_dict', '_load_hook', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '_named_members', '_non_persistent_buffers_set', '_parameters', '_register_load_state_dict_pre_hook', '_

In [205]:
model.sa1_module.conv.local_nn.lins[0].weight.dtype

torch.float32

In [206]:
a = torch.tensor([1])

In [207]:
a.dtype

torch.int64

In [214]:
a = a.float()

In [215]:
a

tensor([1.])

In [216]:
a.dtype

torch.float32

In [217]:
print(dir(model))

['T_destination', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_forward_hooks', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '_named_members', '_non_persistent_buffers_set', '_parameters', '_register_load_state_dict_pre_hook', '_register_state_dict_hook', '_replicate_for_data_p

In [238]:
e = torch.empty((1))
a = torch.tensor([1,2,3])
a = torch.cat((e,a))
b= torch.tensor([4,5,6])
c = torch.cat((a,b))

In [239]:
d = torch.tensor([2,3])

In [240]:
e = torch.cat((c,d))

In [241]:
e

tensor([0., 1., 2., 3., 4., 5., 6., 2., 3.])

In [94]:
from typing import Callable, Optional, Tuple, Union

from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import reset
from torch_geometric.typing import (
    Adj,
    OptTensor,
    PairTensor,
    SparseTensor,
    torch_sparse,
)
from torch_geometric.utils import add_self_loops, remove_self_loops, softmax


class PointTransformerConv(MessagePassing):
    r"""The Point Transformer layer from the `"Point Transformer"
    <https://arxiv.org/abs/2012.09164>`_ paper

    .. math::
        \mathbf{x}^{\prime}_i =  \sum_{j \in
        \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j} \left(\mathbf{W}_3
        \mathbf{x}_j + \delta_{ij} \right),

    where the attention coefficients :math:`\alpha_{i,j}` and
    positional embedding :math:`\delta_{ij}` are computed as

    .. math::
        \alpha_{i,j}= \textrm{softmax} \left( \gamma_\mathbf{\Theta}
        (\mathbf{W}_1 \mathbf{x}_i - \mathbf{W}_2 \mathbf{x}_j +
        \delta_{i,j}) \right)

    and

    .. math::
        \delta_{i,j}= h_{\mathbf{\Theta}}(\mathbf{p}_i - \mathbf{p}_j),

    with :math:`\gamma_\mathbf{\Theta}` and :math:`h_\mathbf{\Theta}`
    denoting neural networks, *i.e.* MLPs, and
    :math:`\mathbf{P} \in \mathbb{R}^{N \times D}` defines the position of
    each point.

    Args:
        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
            derive the size from the first input(s) to the forward method.
            A tuple corresponds to the sizes of source and target
            dimensionalities.
        out_channels (int): Size of each output sample.
        pos_nn (torch.nn.Module, optional): A neural network
            :math:`h_\mathbf{\Theta}` which maps relative spatial coordinates
            :obj:`pos_j - pos_i` of shape :obj:`[-1, 3]` to shape
            :obj:`[-1, out_channels]`.
            Will default to a :class:`torch.nn.Linear` transformation if not
            further specified. (default: :obj:`None`)
        attn_nn (torch.nn.Module, optional): A neural network
            :math:`\gamma_\mathbf{\Theta}` which maps transformed
            node features of shape :obj:`[-1, out_channels]`
            to shape :obj:`[-1, out_channels]`. (default: :obj:`None`)
        add_self_loops (bool, optional) : If set to :obj:`False`, will not add
            self-loops to the input graph. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.

    Shapes:
        - **input:**
          node features :math:`(|\mathcal{V}|, F_{in})` or
          :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))`
          if bipartite,
          positions :math:`(|\mathcal{V}|, 3)` or
          :math:`((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))` if bipartite,
          edge indices :math:`(2, |\mathcal{E}|)`
        - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or
          :math:`(|\mathcal{V}_t|, F_{out})` if bipartite
    """
    def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int, pos_nn: Optional[Callable] = None,
                 attn_nn: Optional[Callable] = None,
                 add_self_loops: bool = True, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.add_self_loops = add_self_loops

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.pos_nn = pos_nn
        if self.pos_nn is None:
            self.pos_nn = Linear(3, out_channels)

        self.attn_nn = attn_nn
        self.lin = Linear(in_channels[0], out_channels, bias=False)
        self.lin_src = Linear(in_channels[0], out_channels, bias=False)
        self.lin_dst = Linear(in_channels[1], out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        reset(self.pos_nn)
        if self.attn_nn is not None:
            reset(self.attn_nn)
        self.lin.reset_parameters()
        self.lin_src.reset_parameters()
        self.lin_dst.reset_parameters()


    def forward(self, x: Union[Tensor, PairTensor],
                pos: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor:

        if isinstance(x, Tensor):
            print('x dtype', x.dtype)
            alpha = (self.lin_src(x), self.lin_dst(x))
            x: PairTensor = (self.lin(x), x)
        else:
            alpha = (self.lin_src(x[0]), self.lin_dst(x[1]))
            x = (self.lin(x[0]), x[1])

        if isinstance(pos, Tensor):
            pos: PairTensor = (pos, pos)

        if self.add_self_loops:
            if isinstance(edge_index, Tensor):
                edge_index, _ = remove_self_loops(edge_index)
                edge_index, _ = add_self_loops(
                    edge_index, num_nodes=min(pos[0].size(0), pos[1].size(0)))
            elif isinstance(edge_index, SparseTensor):
                edge_index = torch_sparse.set_diag(edge_index)
        
        print('x shape', x)
        print('alpha', alpha)
        # propagate_type: (x: PairTensor, pos: PairTensor, alpha: PairTensor)
        out = self.propagate(edge_index, x=x, pos=pos, alpha=alpha, size=None)
        return out


    def message(self, x_j: Tensor, pos_i: Tensor, pos_j: Tensor,
                alpha_i: Tensor, alpha_j: Tensor, index: Tensor,
                ptr: OptTensor, size_i: Optional[int]) -> Tensor:
        
        print(alpha_i, 'alpha i')
        print(alpha_j, 'alpha j')
        delta = self.pos_nn(pos_i - pos_j)
        alpha = alpha_i - alpha_j + delta
        if self.attn_nn is not None:
            alpha = self.attn_nn(alpha)
        print('alpha', alpha)
        print('index', index)
        print('ptr', ptr)
        print('size i', size_i)
        alpha = softmax(alpha, index, ptr, size_i)
        print('alpha', alpha)
        print('index', index)
        print('ptr', ptr)
        print('size i', size_i)
        a =  alpha * (x_j + delta)
        print('xj', x_j)
        print('detla', delta)
        print('a', a)
        return a


In [21]:
from torch_geometric.nn import MLP, radius, fps
from torch_geometric.data import Data
import torch

In [22]:
in_channels = 3
out_channels = 2

In [31]:
x = torch.tensor([[10,11,12], [12,13,14], [14,15,15]], dtype=torch.float32)
pos_nn = MLP([3, 64, out_channels], norm=None, plain_last=False)
attn_nn = MLP([out_channels, 64, out_channels], norm=None, plain_last=False)

pos = torch.tensor([[2,0, 3], [3,1, 0], [-.5,0,2], [1,2,3]])
y = torch.tensor([1, 0, 1])
batch = torch.tensor([0, 0, 0])

#model = PointTransformerConv(in_channels, out_channels,pos_nn,attn_nn)
data = Data()
data.x = x
data.pos = pos
data.y = y
data.batch = batch
#idx = fps(pos, batch, ratio=0.2)
#r = 2.0
#row, col = radius(
#        pos, pos[idx], r, batch, batch[idx], max_num_neighbors=64)
#data.edge_index = torch.stack([col, row], dim=0)

data.edge_index = torch.tensor([[0,0,1,1,1,2,2],
                           [0,1,0,1,2,2,1]])

#model(data.x, data.pos, data.edge_index)

In [32]:
import torch
from torch_geometric.nn import knn

x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [-1.0, 0.0]])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.tensor([[-1.0, 0.0], [1.0, 0.0]])
batch_y = torch.tensor([0, 0])
assign_index = knn(x, y, 2, batch_x, batch_y)

In [33]:
assign_index

tensor([[0, 0, 1, 1],
        [3, 0, 2, 3]])

In [34]:
x

tensor([[-1., -1.],
        [-1.,  1.],
        [ 1., -1.],
        [-1.,  0.]])

In [35]:
x[assign_index[0]]

tensor([[-1., -1.],
        [-1., -1.],
        [-1.,  1.],
        [-1.,  1.]])

In [36]:
from torch_geometric.utils import scatter

In [37]:
scatter(x[assign_index[0]], assign_index[1], dim=0)

tensor([[-1., -1.],
        [ 0.,  0.],
        [-1.,  1.],
        [-2.,  0.]])

In [40]:
data.pos.shape[-1]

3

In [49]:
def foo(config, **kwargs):
    print(config)
    print(kwargs['apple'])
    print(kwargs)

In [50]:
foo ({'dog'}, apple=3)

{'dog'}
3
{'apple': 3}


In [1]:
# Test sub sampple

"""This module defines custom transforms to apply to the data"""

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
import torch
import torch_geometric.transforms as T

from torch_geometric.nn import (
    radius,
)

from torch_geometric.utils import subgraph

import numpy as np


# had to change base code as basetransform not implemented yet for me 
@functional_transform('subsample')
class subsample(BaseTransform):
    r"""Samples points and features from a point cloud within a circle
    (functional name: :obj:`subsample`).

    Args:
        radius (float): The size of the circle to sample from in nm
    """
    def __init__(
        self,
        radius: float,
    ):
        self.radius = radius

    def forward(self, data: Data) -> Data:
        
        # sample 1 node id from all the nodes in data
        idx = np.random.choice(data.num_nodes, 1)
        pos = data.pos
        x = data.x
        batch = data.batch
        row, col = radius(
            pos, pos[idx], self.radius, batch, batch[idx]
        )
        data.edge_index, data.edge_attr = subgraph(col, data.edge_index, data.edge_attr)  
        return data

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self.radius})'

In [31]:
n = 5

pos = torch.tensor([[0,0], [0,1], [0,2], [0,3], [0,4]])
y = torch.ones(n)
batch = torch.zeros(n)
x = torch.arange(n*2, dtype=torch.float32)
x = x.reshape(-1,2)

data = Data()
data.x = x
data.pos = pos
data.y = y
data.batch = batch

data.edge_index = torch.tensor([[0,0,1,1,1,2,2,2,3,3,3,4,4],
                                [0,1,0,1,2,2,1,3,3,2,4,4,3]])


#can then remove isolated nodes!!!!!
transform = T.Compose([subsample(2.0), T.RemoveIsolatedNodes()])
data = transform(data)

idx [3]
pos idx tensor([[0, 3]])
row tensor([0, 0, 0])
col tensor([2, 3, 4])
data Data(x=[5, 2], pos=[5, 2], y=[5], batch=[5], edge_index=[2, 13])
edge index tensor([[2, 2, 3, 3, 3, 4, 4],
        [2, 3, 3, 2, 4, 4, 3]])
pos tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
        [0, 4]])
x tensor([[0., 1.],
        [2., 3.],
        [4., 5.],
        [6., 7.],
        [8., 9.]])
