In [1]:
import numpy as np
import healpy as hp
import matplotlib.pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../deepsphere-pytorch/")
sys.path.append("../")

In [3]:
from deepsphere.layers.chebyshev import SphericalChebConv
from deepsphere.utils.laplacian_funcs import get_healpix_laplacians

In [4]:
import torch
from torch import nn
import torch.nn.functional as F

from deepsphere.layers.samplings.equiangular_pool_unpool import Equiangular
from deepsphere.layers.samplings.healpix_pool_unpool import Healpix
from deepsphere.layers.samplings.icosahedron_pool_unpool import Icosahedron
from deepsphere.models.spherical_unet.decoder import Decoder
from deepsphere.models.spherical_unet.encoder import Encoder, EncoderTemporalConv
from deepsphere.utils.laplacian_funcs import get_healpix_laplacians
from deepsphere.utils.utils import healpix_graph, healpix_laplacian, healpix_weightmatrix

In [5]:
class SphericalChebBNPool(nn.Module):
    """Building Block with a pooling/unpooling, a calling the SphericalChebBN block.
    """

    def __init__(self, in_channels, out_channels, lap, pooling, kernel_size):
        """Initialization.

        Args:
            in_channels (int): initial number of channels.
            out_channels (int): output number of channels.
            lap (:obj:`torch.sparse.FloatTensor`): laplacian.
            pooling (:obj:`torch.nn.Module`): pooling/unpooling module.
            kernel_size (int, optional): polynomial degree. Defaults to 3.
        """
        super().__init__()
        self.spherical_cheb = SphericalChebConv(in_channels, out_channels, lap, kernel_size)
        self.pooling = pooling
        self.batchnorm = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        """Forward Pass.

        Args:
            x (:obj:`torch.tensor`): input [batch x vertices x channels/features]

        Returns:
            :obj:`torch.tensor`: output [batch x vertices x channels/features]
        """
        x = self.spherical_cheb(x)
        x = self.batchnorm(x.permute(0, 2, 1))
        x = F.relu(x.permute(0, 2, 1))
        x = self.pooling(x)
        return x


In [145]:
class SphericalGraphCNN(nn.Module):
    """Spherical GCNN Autoencoder.
    """

    def __init__(self, kernel_size=4):
        """Initialization.

        Args:
            pooling_class (obj): One of three classes of pooling methods
            N (int): Number of pixels in the input image
            kernel_size (int): chebychev polynomial degree
        """
        super().__init__()
        self.kernel_size = kernel_size
        self.pooling_class = Healpix()
        
        self.laps = get_healpix_laplacians(nside_list=nside_list, laplacian_type="combinatorial", indexes_list=indexes_list)

        self.spherical_cheb_bn_pool1 = SphericalChebBNPool(1, 32, self.laps[0], self.pooling_class.pooling, self.kernel_size)
        self.spherical_cheb_bn_pool2 = SphericalChebBNPool(32, 64, self.laps[1], self.pooling_class.pooling, self.kernel_size)
        self.spherical_cheb_bn_pool3 = SphericalChebBNPool(64, 128, self.laps[2], self.pooling_class.pooling, self.kernel_size)
        self.spherical_cheb_bn_pool4 = SphericalChebBNPool(128, 256, self.laps[3], self.pooling_class.pooling, self.kernel_size)
        self.spherical_cheb_bn_pool5 = SphericalChebBNPool(256, 256, self.laps[4], self.pooling_class.pooling, self.kernel_size)
        self.spherical_cheb_bn_pool6 = SphericalChebBNPool(256, 256, self.laps[5], self.pooling_class.pooling, self.kernel_size)
        self.spherical_cheb_bn_pool7 = SphericalChebBNPool(256, 256, self.laps[6], self.pooling_class.pooling, self.kernel_size)

    def forward(self, x):
        """Forward Pass.

        Args:
            x (:obj:`torch.Tensor`): input to be forwarded.

        Returns:
            :obj:`torch.Tensor`: output
        """
        x = self.spherical_cheb_bn_pool1(x)
        x = self.spherical_cheb_bn_pool2(x)
        x = self.spherical_cheb_bn_pool3(x)
        x = self.spherical_cheb_bn_pool4(x)
        x = self.spherical_cheb_bn_pool5(x)
        x = self.spherical_cheb_bn_pool6(x)
        x = self.spherical_cheb_bn_pool7(x)
        return x

In [146]:
import healpy as hp
npix = hp.nside2npix(128)
test_map = np.ones(npix)

In [147]:
sg = SphericalGraphCNN()

In [148]:
in_map = torch.ones(len(indexes_list[0]))

In [149]:
in_map = in_map.reshape((1, -1, 1))
in_map.shape

torch.Size([1, 16384, 1])

In [153]:
sg(in_map).shape

torch.Size([1, 1, 256])

In [154]:
from torchvision import models
from torchsummary import summary

In [155]:
summary(sg, input_size=(16384, 1))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
          ChebConv-1            [-1, 16384, 32]             160
 SphericalChebConv-2            [-1, 16384, 32]               0
       BatchNorm1d-3            [-1, 32, 16384]              64
    HealpixAvgPool-4             [-1, 4096, 32]               0
    HealpixAvgPool-5             [-1, 4096, 32]               0
    HealpixAvgPool-6             [-1, 4096, 32]               0
    HealpixAvgPool-7             [-1, 4096, 32]               0
    HealpixAvgPool-8             [-1, 4096, 32]               0
    HealpixAvgPool-9             [-1, 4096, 32]               0
   HealpixAvgPool-10             [-1, 4096, 32]               0
SphericalChebBNPool-11             [-1, 4096, 32]               0
         ChebConv-12             [-1, 4096, 64]           8,256
SphericalChebConv-13             [-1, 4096, 64]               0
      BatchNorm1d-14             [-1,

In [75]:
nside_list = [128, 64, 32, 16, 8, 4, 2]

In [86]:
sys.path.append("../../fermi-gce-gp/")
from utils import create_mask as cm

hp_mask = cm.make_mask_total(nside=1, band_mask = True, band_mask_range = 0,
                              mask_ring = True, inner = 0, outer = 25)

indexes_list = []
for nside in nside_list:
    hp_mask = hp.ud_grade(hp_mask, nside)
    indexes_list.append(np.arange(hp.nside2npix(nside))[~hp_mask])

In [87]:
hp_mask

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True, False,  True,  True,  True,  True,  True,
        True,  True, False,  True,  True,  True,  True,  True,  True,
       False, False,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True])

In [77]:
sys.path.append("../../fermi-gce-gp/")
from utils import create_mask as cm

pscmask=np.array(np.load('../../fermi-gce-gp/data/mask_3fgl_0p8deg.npy'), dtype=bool)
hp_mask = cm.make_mask_total(nside=128, band_mask = True, band_mask_range = 0,
                              mask_ring = True, inner = 0, outer = 30,
                              custom_mask = None)

print(np.sum(~hp_mask))
hp_mask = hp.ud_grade(hp_mask, 16, power=-2)
print(np.sum(~hp_mask) * 64)
hp.mollview(hp_mask)

In [80]:
len(get_healpix_laplacians(nside_list=nside_list, laplacian_type="combinatorial", indexes_list=indexes_list))

7