In [31]:
import zarr
import xarray as xr
import torch
from SphericConv import RegionalSphericalConv
import numpy as np


def get_bands(data_tree, res):

    res_key = f"r{res}"
    bands = data_tree.measurements.reflectance[res_key]
    return list(bands.keys())

## **SphericalConv.**init**()**

This is the constructor that sets up the spherical convolution layer for your regional HEALPix data. Unlike the original implementation that assumes a full sphere, this version works with only the HEALPix cells present in the dataset. It takes `available_cell_ids` (the cells that actually contain data in your geographic region) and the HEALPix resolution level. The constructor builds the neighbor relationships, creates efficient lookup tables, and initializes the underlying 1D convolution layer that will process the 3×3 patches. The key innovation here is that it adapts to your specific data coverage rather than assuming global coverage.

## **_build_neighbor_index()**

This function implements a strategy for creating 3×3 patches on the sphere. For each available cell in your dataset (optionally subsampled by stride), it uses HEALPix's `get_all_neighbours()` to find the 8 surrounding cells. However, since the data only covers a specific region, some of these neighbors might not exist in your dataset. The function handles this by replacing any missing neighbors with the center cell ID itself - effectively implementing "same padding" at the boundaries. This ensures every patch has exactly 9 elements (center + 8 neighbors), making it compatible with standard CNN operations while preserving the spherical topology.

## **_convert_to_data_indices()**

This function performs a crucial optimization for computational efficiency. While `_build_neighbor_index()` works with HEALPix cell IDs (which can be large, sparse numbers), your actual data array uses sequential indices from 0 to N-1. This function creates a mapping that converts the HEALPix cell IDs in each patch to the corresponding positions in your data array. This pre-computation means that during the forward pass, you can directly index into your data tensor without expensive lookups, making the convolution much faster.

## **forward()**

This is where the actual spherical convolution happens. The function takes the input tensor of shape [batch, channels, cells] and extracts all the 3×3 patches simultaneously using advanced indexing. It reshapes these patches into a format suitable for PyTorch's 1D convolution (which treats each 9-element patch as a "sequence"), applies the learned convolutional weights, and returns the feature maps. The beauty of this approach is that it maintains the spherical neighborhood relationships while leveraging standard CNN operations, allowing you to use existing deep learning frameworks efficiently.

In [29]:
ds_healpix = xr.open_dataset("/home/ubuntu/project/sentinel-2-dggs-ai-processor/src/notebook/healpix.zarr")

  engine = plugins.guess_engine(filename_or_obj)
  engine = plugins.guess_engine(filename_or_obj)
  engine = plugins.guess_engine(filename_or_obj)


In [32]:
print("=== Setting up Spherical Convolution ===")
# 1. Extract spectral data for all bands
band_list = ds_healpix.Sentinel2.bands.values
print(f"Available bands: {band_list}")

# 2. Get available cell IDs from your dataset
available_cell_ids = ds_healpix.cell_ids.values
print(f"Number of available HEALPix cells: {len(available_cell_ids)}")

# 3. Create input tensor with all spectral bands
# Shape: [n_bands, n_cells]
spectral_data = []
for band in band_list:
    band_data = ds_healpix.Sentinel2.sel(bands=band).compute().values
    spectral_data.append(band_data)

# Stack all bands: [n_bands, n_cells]
x_multi_band = np.stack(spectral_data, axis=0)
print(f"Multi-band data shape: {x_multi_band.shape}")

# 4. Create spherical conv layer
conv_layer = RegionalSphericalConv(
    available_cell_ids=available_cell_ids,
    level=19,
    in_channels=len(band_list),
    out_channels=64,
    stride=2
)

print(f"Created conv layer with {len(band_list)} input channels, 64 output channels")
print(f"Number of patches that will be generated: {conv_layer.n_patches}")

# 5. Convert to PyTorch tensor and add batch dimension
x_tensor = torch.tensor(x_multi_band, dtype=torch.float32).unsqueeze(0)
print(f"Input tensor shape: {x_tensor.shape}")  # [1, n_bands, n_cells]

# 6. Forward pass
with torch.no_grad():
    output = conv_layer(x_tensor)
    print(f"Output tensor shape: {output.shape}")  # [1, 64, n_patches]

# 7. Optional: Convert back to numpy for further processing
output_np = output.squeeze(0).numpy()  # Remove batch dimension
print(f"Output as numpy array: {output_np.shape}")  # [64, n_patches]

=== Setting up Spherical Convolution ===
Available bands: ['b02' 'b03' 'b04' 'b08']
Number of available HEALPix cells: 2159668
Multi-band data shape: (4, 2159668)
Created conv layer with 4 input channels, 64 output channels
Number of patches that will be generated: 1079834
Input tensor shape: torch.Size([1, 4, 2159668])
Output tensor shape: torch.Size([1, 64, 1079834])
Output as numpy array: (64, 1079834)


In [37]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import healpy as hp

class SphericalConv(nn.Module):
    def __init__(self, available_cell_ids, level, in_channels, out_channels, bias=True, nest=True, stride=1):
        """
        Regional Spherical Convolutional layer for HEALPix data covering a specific area.
        """
        super(SphericalConv, self).__init__()

        self.level = level
        self.NSIDE = 2 ** level
        self.nest = nest
        self.stride = stride
        self.available_cell_ids = np.array(available_cell_ids)
        self.available_cell_set = set(available_cell_ids)

        # Build neighbor index using your strategy
        self.neighbor_indices = self._build_neighbor_index()
        self.n_patches = self.neighbor_indices.shape[0]

        # Create cell_id to data_index mapping
        self.cell_to_data_idx = {cell_id: i for i, cell_id in enumerate(self.available_cell_ids)}

        # Convert neighbor indices to data indices for efficient lookup
        self.data_neighbor_indices = self._convert_to_data_indices()

        # 1D convolution with kernel size 9 (3x3 patch flattened)
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=9, stride=9, bias=bias)

        # Initialize weights
        nn.init.kaiming_normal_(self.conv.weight)
        if bias:
            nn.init.constant_(self.conv.bias, 0.0)

    def _build_neighbor_index(self):
        """Build 9-cell neighborhood index"""
        available_cell_ids = set(self.available_cell_ids)
        neighbor_indices = []

        # Apply stride to center cell list
        center_cells = self.available_cell_ids[::self.stride]

        for cell_id in center_cells:
            neighbors = hp.get_all_neighbours(self.NSIDE, cell_id, nest=self.nest)

            # Validate each neighbor; replace invalid or missing with center
            valid_neighbors = [
                n if (n != -1 and n in available_cell_ids) else cell_id
                for n in neighbors
            ]

            patch = [cell_id] + valid_neighbors  # Center + 8 neighbors
            neighbor_indices.append(patch)

        return np.array(neighbor_indices)

    def _convert_to_data_indices(self):
        """Convert HEALPix cell IDs to data array indices"""
        data_indices = np.zeros_like(self.neighbor_indices)

        for i, patch in enumerate(self.neighbor_indices):
            for j, cell_id in enumerate(patch):
                data_indices[i, j] = self.cell_to_data_idx[cell_id]

        return torch.tensor(data_indices, dtype=torch.long)

    def forward(self, x):
        """Forward pass"""
        batch_size, n_channels, n_cells = x.shape

        # Ensure we have the right number of cells
        assert n_cells == len(self.available_cell_ids), \
            f"Expected {len(self.available_cell_ids)} cells, got {n_cells}"

        # Extract patches using the neighbor indices
        # Shape: [B, C_in, N_patches, 9]
        patches = x[:, :, self.data_neighbor_indices]

        # Reshape to [B, C_in, N_patches * 9] for Conv1d
        patches_flat = patches.view(batch_size, n_channels, -1)

        # Apply convolution
        output = self.conv(patches_flat)

        return output


class SphericalConvBlock(nn.Module):
    """Basic convolutional block with batch norm and ReLU"""

    def __init__(self, available_cell_ids, level, in_channels, out_channels, stride=1):
        super(SphericalConvBlock, self).__init__()

        self.conv = SphericalConv(
            available_cell_ids=available_cell_ids,
            level=level,
            in_channels=in_channels,
            out_channels=out_channels,
            stride=stride
        )
        self.bn = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class SphericalDoubleConvBlock(nn.Module):
    """Double convolution block (Conv -> BN -> ReLU -> Conv -> BN -> ReLU)"""

    def __init__(self, available_cell_ids, level, in_channels, out_channels, stride=1):
        super(SphericalDoubleConvBlock, self).__init__()

        # First conv with specified stride
        self.conv1 = SphericalConvBlock(
            available_cell_ids=available_cell_ids,
            level=level,
            in_channels=in_channels,
            out_channels=out_channels,
            stride=stride
        )

        # Second conv with stride=1 (operating on the output of first conv)
        # We need to determine the output size after first conv
        n_patches_after_first = len(available_cell_ids) // stride

        self.conv2 = SphericalConvBlock(
            available_cell_ids=available_cell_ids,  # This is approximate
            level=level,
            in_channels=out_channels,
            out_channels=out_channels,
            stride=1
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class Model(nn.Module):
    """Simple model that just applies SphericalDoubleConv
    TBD for a UNET later
    """

    def __init__(self, available_cell_ids, level, in_channels, out_channels, stride=1):
        super(Model, self).__init__()

        self.double_conv = SphericalDoubleConvBlock(
            available_cell_ids=available_cell_ids,
            level=level,
            in_channels=in_channels,
            out_channels=out_channels,
            stride=stride
        )

    def forward(self, x):
        return self.double_conv(x)


In [38]:
# Extract available cell IDs from your dataset
available_cell_ids = ds_healpix.cell_ids.values
level = 19  # Your HEALPix level
in_channels = len(band_list)
stride = 1
print(f"Creating Spherical U-Net:")
print(f"  - Input cells: {len(available_cell_ids)}")
print(f"  - Input channels: {in_channels} (bands: {band_list})")
print(f"  - Output classes: {in_channels}")
print(f"  - HEALPix level: {level}")

Creating Spherical U-Net:
  - Input cells: 2159668
  - Input channels: 4 (bands: ['b02' 'b03' 'b04' 'b08'])
  - Output classes: 4
  - HEALPix level: 19


In [40]:
import torch
import numpy as np
import xarray as xr

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Your code with minimal fixes
ds_healpix = xr.open_dataset("/home/ubuntu/project/sentinel-2-dggs-ai-processor/src/notebook/healpix.zarr")

# 2. Get available cell IDs from your dataset
available_cell_ids = ds_healpix.cell_ids.values
print(f"Number of available HEALPix cells: {len(available_cell_ids)}")

model = Model(
        available_cell_ids=available_cell_ids,
        level=level,
        in_channels=in_channels,
        out_channels=in_channels,
        stride=stride
    )

# Move model to GPU
model = model.to(device)

# Calculate total parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"  - Total parameters: {total_params:,}")

# 3. Create input tensor with all spectral bands
# Shape: [n_bands, n_cells]
spectral_data = []
for band in band_list:
    band_data = ds_healpix.Sentinel2.sel(bands=band).compute().values
    spectral_data.append(band_data)

# Stack all bands: [n_bands, n_cells]
x_multi_band = np.stack(spectral_data, axis=0)
print(f"Multi-band data shape: {x_multi_band.shape}")

# 5. Convert to PyTorch tensor and add batch dimension
x_tensor = torch.tensor(x_multi_band, dtype=torch.float32).unsqueeze(0)
print(f"Input tensor shape: {x_tensor.shape}")  # [1, n_bands, n_cells]

# Move input to GPU
x_tensor = x_tensor.to(device)

# 6. Forward pass
with torch.no_grad():
    output = model(x_tensor)  # Fixed: was SphericalUNet(x_tensor)
    print(f"Output tensor shape: {output.shape}")  # Fixed: was output.shape

Using device: cuda
Number of available HEALPix cells: 2159668


  engine = plugins.guess_engine(filename_or_obj)
  engine = plugins.guess_engine(filename_or_obj)
  engine = plugins.guess_engine(filename_or_obj)


  - Total parameters: 312
Multi-band data shape: (4, 2159668)
Input tensor shape: torch.Size([1, 4, 2159668])
Output tensor shape: torch.Size([1, 4, 2159668])


In [41]:
total_params = sum(p.numel() for p in model.parameters())

In [43]:
import timm

In [50]:
import timm
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

import segmentation_models_pytorch as smp


In [55]:
import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet18",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)

config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/46.8M [00:00<?, ?B/s]

In [56]:
model

Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track