In [1]:
import torch.nn as nn
from torch_geometric.datasets import GeometricShapes

dataset = GeometricShapes(root='data/GeometricShapes')
print(dataset)



GeometricShapes(40)


In [2]:
import torch
from torch_geometric.transforms import SamplePoints

torch.manual_seed(42)

nr_points=1024

dataset.transform = SamplePoints(num=nr_points)

data = dataset[39]
print(data)

Data(pos=[1024, 3], y=[1])


In [3]:
class GetGraph(nn.Module):
    def __init__(self):
        super(GetGraph, self).__init__()

    def forward(self, point_cloud):
        point_cloud_transpose = point_cloud.permute(0, 2, 1)
        point_cloud_inner = torch.matmul(point_cloud, point_cloud_transpose)
        point_cloud_inner = -2 * point_cloud_inner
        point_cloud_square = torch.sum(torch.mul(point_cloud, point_cloud), dim=2, keepdim=True)
        point_cloud_square_tranpose = point_cloud_square.permute(0, 2, 1)
        adj_matrix = point_cloud_square + point_cloud_inner + point_cloud_square_tranpose
        adj_matrix = torch.exp(-adj_matrix)
        return adj_matrix


class GetLaplacian(nn.Module):
    def __init__(self, normalize=True):
        super(GetLaplacian, self).__init__()
        self.normalize = normalize

    def diag(self, mat):
        # input is batch x vertices
        d = []
        for vec in mat:
            d.append(torch.diag(vec))
        return torch.stack(d)

    def forward(self, adj_matrix):
        if self.normalize:
            D = torch.sum(adj_matrix, dim=2)
            eye = torch.ones_like(D)
            eye = self.diag(eye)
            D = 1 / torch.sqrt(D)
            D = self.diag(D)
            L = eye - torch.matmul(torch.matmul(D, adj_matrix), D)
        else:
            D = torch.sum(adj_matrix, dim=1)
            D = torch.matrix_diag(D)
            L = D - adj_matrix
        return L


# class GetFilter(nn.Module):
#     def __init__(self, Fin, K, Fout):
#         super(GetFilter, self).__init__()
#         self.Fin = Fin
#         self.Fout = Fout
#         self.K = K
#         self.W = nn.Parameter(torch.Tensor(self.K * self.Fin, self.Fout))
#         nn.init.normal_(self.W, mean=0, std=0.2)
#         self.B = nn.Parameter(torch.Tensor(self.Fout))
#         nn.init.normal_(self.B, mean=0, std=0.2)
#         self.relu = nn.ReLU()

#     # def reset_parameters(self):

#     def forward(self, x, L):
#         N, M, Fin = list(x.size())
#         K = self.K
#         x0 = x.clone()
#         x = x0.unsqueeze(0)

#         #         x = x.expand(-1,-1,-1,1)
#         def concat(x, x_):
#             x_ = x_.unsqueeze(0)
#             #             x_ = x.expand(1,-1,-1)
#             return torch.cat((x, x_), dim=0)

#         if K > 1:
#             x1 = torch.matmul(L, x0)
#             x = concat(x, x1)
#         for k in range(2, K):
#             x2 = 2 * torch.matmul(L, x1) - x0
#             x = concat(x, x2)
#             x0, x1 = x1, x2
#         x = x.permute(1, 2, 3, 0)
#         x = x.reshape(N * M, Fin * K)
#         x = torch.matmul(x, self.W)
#         x = torch.add(x, self.B)
#         x = self.relu(x)
#         return x.reshape(N, M, self.Fout)

In [4]:
getter=GetGraph()
print(dataset[0].pos.shape)
data=dataset[0].pos
data=data.reshape([1,nr_points,3])
print(data.shape)
x=data
W=getter(data)
W_torch=W
print(W_torch.shape)
W=torch.squeeze(W)
print(W.shape)
print(W)

torch.Size([1024, 3])
torch.Size([1, 1024, 3])
torch.Size([1, 1024, 1024])
torch.Size([1024, 1024])
tensor([[1.0000, 0.5828, 0.9376,  ..., 0.1363, 0.3722, 0.7379],
        [0.5828, 1.0000, 0.6627,  ..., 0.5064, 0.2619, 0.6291],
        [0.9376, 0.6627, 1.0000,  ..., 0.2344, 0.5538, 0.9149],
        ...,
        [0.1363, 0.5064, 0.2344,  ..., 1.0000, 0.2453, 0.3585],
        [0.3722, 0.2619, 0.5538,  ..., 0.2453, 1.0000, 0.7663],
        [0.7379, 0.6291, 0.9149,  ..., 0.3585, 0.7663, 1.0000]])


In [5]:
getter_laplacian=GetLaplacian()
L_torch=getter_laplacian(W_torch)
L=torch.squeeze(L_torch)
print(L_torch.shape)
print(L.shape)
print(L)


torch.Size([1, 1024, 1024])
torch.Size([1024, 1024])
tensor([[ 9.9728e-01, -1.5304e-03, -2.2361e-03,  ..., -3.4304e-04,
         -7.8977e-04, -1.5904e-03],
        [-1.5304e-03,  9.9746e-01, -1.5274e-03,  ..., -1.2321e-03,
         -5.3711e-04, -1.3104e-03],
        [-2.2361e-03, -1.5274e-03,  9.9791e-01,  ..., -5.1798e-04,
         -1.0316e-03, -1.7307e-03],
        ...,
        [-3.4304e-04, -1.2321e-03, -5.1798e-04,  ...,  9.9767e-01,
         -4.8236e-04, -7.1595e-04],
        [-7.8977e-04, -5.3711e-04, -1.0316e-03,  ..., -4.8236e-04,
          9.9834e-01, -1.2901e-03],
        [-1.5904e-03, -1.3104e-03, -1.7307e-03,  ..., -7.1595e-04,
         -1.2901e-03,  9.9829e-01]])


In [6]:
import torch.nn as nn
def find_eigmax(L):
    with torch.no_grad():
        e1, _ = torch.eig(L, eigenvectors=False)
        return torch.max(e1[:, 0]).item()

def chebyshev_Lapl(X, Lapl, thetas, order):
  list_powers = []
  nodes = Lapl.shape[0]

  T0 = X.float()

  eigmax = find_eigmax(Lapl)
  L_rescaled = (2 * Lapl / eigmax) - torch.eye(1024)

  y = T0 * thetas[0]
  list_powers.append(y)
  T1 = torch.matmul(L_rescaled, T0)
  list_powers.append(T1 * thetas[1])

  # Computation of: T_k = 2*L_rescaled*T_k-1  -  T_k-2
  for k in range(2, order):
      T2 = 2 * torch.matmul(L_rescaled, T1) - T0
      list_powers.append((T2 * thetas[k]))
      T0, T1 = T1, T2
  y_out = torch.stack(list_powers, dim=-1)
  # the powers may be summed or concatenated. i use concatenation here
  y_out = y_out.view(nodes, -1) # -1 = order* features_of_signal
  return y_out

In [7]:
features_1 = 3
out_features_1 = 128
K1 = 6 # p-hops
thetas_1 = nn.Parameter(torch.rand(nr_points))
out = chebyshev_Lapl(x,L,thetas_1,K1)
print('cheb approx out powers concatenated:', out.shape)
linear_1 = nn.Linear(K1*features_1, out_features_1) 
layer_out_1 = linear_1(out)
print('Layers output:', layer_out_1.shape)

cheb approx out powers concatenated: torch.Size([1024, 18])
Layers output: torch.Size([1024, 128])


In [8]:
getter1=GetGraph()
print(layer_out_1.shape)


data1=layer_out_1
data1=data1.reshape([1,nr_points,out_features_1])
x1=data1
W1=getter1(data1)
W_torch1=W1
W1=torch.squeeze(W1)
print(W1.shape)
print(W1)

torch.Size([1024, 128])
torch.Size([1024, 1024])
tensor([[1.0000, 0.6884, 0.9465,  ..., 0.2622, 0.4602, 0.8043],
        [0.6884, 1.0000, 0.7309,  ..., 0.6186, 0.3354, 0.6919],
        [0.9465, 0.7309, 1.0000,  ..., 0.3587, 0.5909, 0.9384],
        ...,
        [0.2622, 0.6186, 0.3587,  ..., 1.0000, 0.3521, 0.4753],
        [0.4602, 0.3354, 0.5909,  ..., 0.3521, 1.0000, 0.7632],
        [0.8043, 0.6919, 0.9384,  ..., 0.4753, 0.7632, 1.0000]],
       grad_fn=<SqueezeBackward0>)


In [9]:
getter_laplacian_1=GetLaplacian()
L_torch1=getter_laplacian_1(W_torch1)
L1=torch.squeeze(L_torch1)
print(L_torch1.shape)
print(L1.shape)
print(L1)

torch.Size([1, 1024, 1024])
torch.Size([1024, 1024])
tensor([[ 9.9778e-01, -1.5113e-03, -1.9219e-03,  ..., -5.5015e-04,
         -8.4292e-04, -1.5040e-03],
        [-1.5113e-03,  9.9783e-01, -1.4673e-03,  ..., -1.2835e-03,
         -6.0733e-04, -1.2792e-03],
        [-1.9219e-03, -1.4673e-03,  9.9814e-01,  ..., -6.8836e-04,
         -9.8963e-04, -1.6045e-03],
        ...,
        [-5.5015e-04, -1.2835e-03, -6.8836e-04,  ...,  9.9802e-01,
         -6.0944e-04, -8.3989e-04],
        [-8.4292e-04, -6.0733e-04, -9.8963e-04,  ..., -6.0944e-04,
          9.9849e-01, -1.1772e-03],
        [-1.5040e-03, -1.2792e-03, -1.6045e-03,  ..., -8.3989e-04,
         -1.1772e-03,  9.9843e-01]], grad_fn=<SqueezeBackward0>)


In [10]:
z1=layer_out_1
features_2 = out_features_1
out_features_2 = 512
K2 = 5 # p-hops
thetas_2 = nn.Parameter(torch.rand(nr_points))
out_2 = chebyshev_Lapl(z1,L1,thetas_2,K2)
print('cheb approx out powers concatenated:', out_2.shape)
linear_2 = nn.Linear(K2*features_2, out_features_2) 
layer_out_2 = linear_2(out_2)
print('Layers output:', layer_out_2.shape)

cheb approx out powers concatenated: torch.Size([1024, 640])
Layers output: torch.Size([1024, 512])


In [11]:
getter2=GetGraph()
print(layer_out_2.shape)
data2=layer_out_2
data2=data2.reshape([1,nr_points,out_features_2])
x2=data2
W2=getter2(data2)
W_torch2=W2
W2=torch.squeeze(W2)
print(W2.shape)
print(W2)

torch.Size([1024, 512])
torch.Size([1024, 1024])
tensor([[1.0000, 0.9404, 0.9885,  ..., 0.7846, 0.8643, 0.9590],
        [0.9404, 1.0000, 0.9470,  ..., 0.9173, 0.8257, 0.9376],
        [0.9885, 0.9470, 1.0000,  ..., 0.8207, 0.9010, 0.9884],
        ...,
        [0.7846, 0.9173, 0.8207,  ..., 1.0000, 0.8180, 0.8597],
        [0.8643, 0.8257, 0.9010,  ..., 0.8180, 1.0000, 0.9431],
        [0.9590, 0.9376, 0.9884,  ..., 0.8597, 0.9431, 1.0000]],
       grad_fn=<SqueezeBackward0>)


In [12]:
getter_laplacian_2=GetLaplacian()
L_torch2=getter_laplacian_2(W_torch2)
L2=torch.squeeze(L_torch2)
print(L_torch2.shape)
print(L2.shape)
print(L2)

torch.Size([1, 1024, 1024])
torch.Size([1024, 1024])
tensor([[ 9.9883e-01, -1.0916e-03, -1.1356e-03,  ..., -9.0467e-04,
         -9.6831e-04, -1.0810e-03],
        [-1.0916e-03,  9.9885e-01, -1.0794e-03,  ..., -1.0494e-03,
         -9.1786e-04, -1.0487e-03],
        [-1.1356e-03, -1.0794e-03,  9.9887e-01,  ..., -9.2931e-04,
         -9.9133e-04, -1.0941e-03],
        ...,
        [-9.0467e-04, -1.0494e-03, -9.2931e-04,  ...,  9.9886e-01,
         -9.0332e-04, -9.5511e-04],
        [-9.6831e-04, -9.1786e-04, -9.9133e-04,  ..., -9.0332e-04,
          9.9893e-01, -1.0181e-03],
        [-1.0810e-03, -1.0487e-03, -1.0941e-03,  ..., -9.5511e-04,
         -1.0181e-03,  9.9891e-01]], grad_fn=<SqueezeBackward0>)


In [13]:
z2=layer_out_2
features_3 = out_features_2
out_features_3 = 1024
K3 = 3 
thetas_3 = nn.Parameter(torch.rand(nr_points))
out_3 = chebyshev_Lapl(z2,L2,thetas_3,K3)
print('cheb approx out powers concatenated:', out_3.shape)
linear_3 = nn.Linear(K3*features_3, out_features_3) 
layer_out_3 = linear_3(out_3)
print('Layers output:', layer_out_3.shape)

cheb approx out powers concatenated: torch.Size([1024, 1536])
Layers output: torch.Size([1024, 1024])


In [None]:
#Trebuie adaugat maxpool

Layers second MLP: torch.Size([1024, 192])


torch.Size([1024, 192])
torch.Size([1024, 512])
Concatenated vector features: 704
torch.Size([1024, 704])


Final layer: torch.Size([1024, 50])
