In [2]:
import os
import glob
import natsort
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
import xdggs

In [3]:
version = "V4"
TRAIN_DIR = f"/mnt/disk/dataset/sentinel-ai-processor/{version}/train/input"

In [4]:
# List zarr fil in a given directory
zarr_path = natsort.natsorted(glob.glob(os.path.join(TRAIN_DIR, "*.zarr"), recursive=False))
# Open a zarr product
dt = xr.open_datatree(zarr_path[11], engine="zarr", mask_and_scale=False, chunks={})

In [5]:
def get_bands(data_tree, res):

    res_key = f"r{res}"
    bands = data_tree.measurements.reflectance[res_key]
    return list(bands.keys())

def get_chunk(data_tree, res, chunk_y_idx, chunk_x_idx, chunk_size_y, chunk_size_x):
    """
    Extract a specific chunk from a given band at a given spatial resolution in a DataTree.

    Parameters:
    - data_tree: xarray.DataTree
        The root DataTree object loaded from a Zarr store (e.g., xr.open_datatree(...)).
    - band: str
        The band name to extract (e.g., "b03").
    - res: str
        The spatial resolution as a string (e.g., "10m", "20m", "60m").
    - chunk_y_idx: int
        Index of the chunk along the vertical (y) axis.
    - chunk_x_idx: int
        Index of the chunk along the horizontal (x) axis.

    Returns:
    - xarray.DataArray
        A DataArray corresponding to the specified chunk.
    """
    res_key = f"r{res}"
    y_res = f"y_{res}"
    x_res = f"x_{res}"
    data = data_tree.measurements.reflectance[res_key]

    y_start = chunk_y_idx * chunk_size_y
    x_start = chunk_x_idx * chunk_size_x
    return data.isel(
        {y_res: slice(y_start, y_start + chunk_size_y),
         x_res: slice(x_start, x_start + chunk_size_x)}
    )

def get_chunk_info(data_tree, band, res):
    """
    Extract chunk size and number of chunks from a dataset.

    Parameters:
    - data_tree: xarray.DataTree
    - band: str, e.g. "b03"
    - resolution: str, y-dimension name (e.g. "y_10m")
    - x_res: str, x-dimension name (e.g. "x_10m")

    Returns:
    - chunk_size_y: int
    - chunk_size_x: int
    - nb_chunks_y: int
    - nb_chunks_x: int
    """
    res_key = f"r{res}"
    y_res = f"y_{res}"
    x_res = f"x_{res}"
    data_tree = data_tree.measurements.reflectance[res_key]

    chunk_size_y = data_tree[band].chunksizes[y_res][0]
    chunk_size_x = data_tree[band].chunksizes[x_res][0]
    nb_chunks_y = len(data_tree[band].chunksizes[y_res])
    nb_chunks_x = len(data_tree[band].chunksizes[x_res])

    print(f"Chunk size: y={chunk_size_y}, x={chunk_size_x}")
    print(f"Number of chunks: y={nb_chunks_y}, x={nb_chunks_x}")

    return chunk_size_y, chunk_size_x, nb_chunks_y, nb_chunks_x

In [6]:
res = "10m"
band_list = get_bands(data_tree=dt, res=res)
chunk_size_y, chunk_size_x, nb_chunks_y, nb_chunks_x = get_chunk_info(data_tree=dt, band=band_list[0], res=res)
chunk = get_chunk(data_tree=dt, res=res, chunk_size_y=chunk_size_y,chunk_size_x=chunk_size_x, chunk_y_idx=3, chunk_x_idx=2).load()

Chunk size: y=1830, x=1830
Number of chunks: y=6, x=6


In [7]:
from numcodecs import Zstd

class proj_odysea:
    """
    HEALPix projection class for spatial data aggregation compatible with xdggs
    """

    def __init__(
        self,
        level,
        heal_idx,
        inv_idx,
        nscale=2,
        nest=False,
        chunk_size=4096,
        cell_id_name="cell_ids",
    ):
        self.level = level
        self.nside = 2**(level)
        self.nscale = nscale
        self.nest = nest
        self.chunk_size = chunk_size
        self.cell_id_name = cell_id_name

        # HEALPix cell setup with ONLY xdggs-compatible attributes
        self.cell_ids = heal_idx.flatten()
        self.var_cell_ids = xr.DataArray(
            self.cell_ids,
            dims="cells",
            attrs={
                "grid_name": "healpix",
                "indexing_scheme": "nested" if self.nest else "ring",
                "resolution": self.level,
                # Remove ALL legacy attributes that cause conflicts
            }
        )
        self.inv_idx = inv_idx.flatten()
        self.him = np.bincount(self.inv_idx)

    def eval(self, ds):
        """
        Convert dataset to HEALPix projection without time dimension
        """
        var_name = list(ds.data_vars)
        print(f"Processing {len(var_name)} variables: {var_name}")

        # Initialize 2D data array (bands, cells)
        all_data = np.zeros([len(var_name), self.cell_ids.shape[0]])

        # Process each variable
        for i in range(len(var_name)):
            ivar = var_name[i]
            print(f"Processing {ivar} ({i+1}/{len(var_name)})")

            # Flatten spatial data
            b_data = ds[ivar].values.flatten()

            # Find valid data (non-zero and non-NaN)
            idx = np.where((b_data != 0) & (~np.isnan(b_data)))

            # Aggregate to HEALPix cells
            data = np.bincount(
                self.inv_idx[idx],
                weights=b_data[idx],
                minlength=self.cell_ids.shape[0]
            )

            # Count pixels per cell
            hdata = np.bincount(
                self.inv_idx[idx],
                minlength=self.cell_ids.shape[0]
            )

            # Calculate mean (handle division by zero)
            data = data.astype(float)
            data[hdata == 0] = np.nan
            valid_mask = hdata > 0
            data[valid_mask] = data[valid_mask] / hdata[valid_mask]

            # Store in 2D array
            all_data[i] = data

        # Create DataArray with correct dimensions
        data_array = xr.DataArray(
            all_data,
            dims=("bands", "cells"),
            coords={
                "bands": var_name,
                self.cell_id_name: self.var_cell_ids
            },
            name='Sentinel2',
            attrs={
                "description": "Sentinel-2 reflectance aggregated to HEALPix cells"
            }
        )

        # Convert to Dataset
        ds_total = data_array.to_dataset()

        # Set ONLY xdggs-compatible attributes (no extra attributes)
        ds_total[self.cell_id_name].attrs = {
            "grid_name": "healpix",
            "indexing_scheme": "nested" if self.nest else "ring",
            "resolution": self.level,
        }

        # Apply chunking
        chunk_size_data = max(1, int((12 * (4**self.level)) / self.chunk_size))
        ds_total = ds_total.chunk({"cells": chunk_size_data})

        print(f"HEALPix conversion complete - Level {self.level}, {len(self.cell_ids):,} cells")

        return ds_total

In [8]:
def healpix_projection(dt, res="10m", chunk=True, level=19, chunk_size=4096):
    import numpy as np
    import healpy as hp
    from pyproj import Transformer

    # 1. Extract data and grid
    if chunk==True:

        band_list = get_bands(data_tree=dt, res=res)
        chunk_size_y, chunk_size_x, nb_chunks_y, nb_chunks_x = get_chunk_info(data_tree=dt, band=band_list[0], res=res)
        ds = get_chunk(data_tree=dt, res=res, chunk_size_y=chunk_size_y,chunk_size_x=chunk_size_x, chunk_y_idx=3, chunk_x_idx=2).load()


    else:
        ds = dt.measurements.reflectance[f"r{res}"]

    x = ds["x"].values
    y = ds["y"].values
    xx, yy = np.meshgrid(x, y)

    print(f"Coordinate grid shape: {xx.shape}")
    print(f"X range: {x.min():.0f} to {x.max():.0f}")
    print(f"Y range: {y.min():.0f} to {y.max():.0f}")

    # 2. Transform UTM to lat/lon
    utm_crs = dt.other_metadata["horizontal_CRS_code"]
    transformer = Transformer.from_crs(utm_crs, "EPSG:4326", always_xy=True)
    lon, lat = transformer.transform(xx, yy)

    # 3. Generate HEALPix indices
    nside = 2 ** level
    idx = hp.ang2pix(nside, lon, lat, lonlat=True, nest=True)
    lidx, ilidx = np.unique(idx, return_inverse=True)

    print(f"HEALPix Level {level} → {len(lidx):,} unique cells")

    # 4. Project to HEALPix using your custom class
    proj = proj_odysea(level, lidx, ilidx, nest=True, chunk_size=chunk_size)
    ds_healpix = proj.eval(ds.to_dataset())

    # # 5. Decode to add lat/lon if available via xdggs
    # if "xdggs" in globals():
    ds_healpix = ds_healpix.pipe(xdggs.decode)

    return ds_healpix

In [9]:
ds_healpix = healpix_projection(dt=dt, res="10m", chunk=True, level=19, chunk_size=4096)

Chunk size: y=1830, x=1830
Number of chunks: y=6, x=6
Coordinate grid shape: (1830, 1830)
X range: 736565 to 754855
Y range: 5126825 to 5145115
HEALPix Level 19 → 2,159,668 unique cells
Processing 4 variables: ['b02', 'b03', 'b04', 'b08']
Processing b02 (1/4)
Processing b03 (2/4)
Processing b04 (3/4)
Processing b08 (4/4)
HEALPix conversion complete - Level 19, 2,159,668 cells


In [10]:
import lonboard

def create_arrow_table(polygons, arr, coords=None):
    from arro3.core import Array, ChunkedArray, Schema, Table

    if coords is None:
        coords = ["latitude", "longitude"]

    array = Array.from_arrow(polygons)
    name = arr.name or "data"
    arrow_arrays = {
        "geometry": array,
        "cell_ids": ChunkedArray([Array.from_numpy(arr.coords["cell_ids"])]),
        name: ChunkedArray([Array.from_numpy(arr.data)]),
    } | {
        coord: ChunkedArray([Array.from_numpy(arr.coords[coord].data)])
        for coord in coords
        if coord in arr.coords
    }

    fields = [array.field.with_name(name) for name, array in arrow_arrays.items()]
    schema = Schema(fields)

    return Table.from_arrays(list(arrow_arrays.values()), schema=schema)


def exploire_layer(
    arr,
    cell_dim="cells",
    cmap="viridis",
    center=None,
    alpha=None,
):
    from lonboard import SolidPolygonLayer
    from lonboard.colormap import apply_continuous_cmap
    from matplotlib import colormaps

    if len(arr.dims) != 1 or cell_dim not in arr.dims:
        raise ValueError(
            f"exploration only works with a single dimension ('{cell_dim}')"
        )

    cell_ids = arr.dggs.coord.data
    grid_info = arr.dggs.grid_info

    polygons = grid_info.cell_boundaries(cell_ids, backend="geoarrow")

    # normalized_data = normalize(arr.variable, center=center)

    colormap = colormaps[cmap]
    colors = apply_continuous_cmap(arr.variable, colormap, alpha=alpha)

    table = create_arrow_table(polygons, arr)
    layer = SolidPolygonLayer(table=table, filled=True, get_fill_color=colors)

    return layer

In [11]:
# # Use of tanh to concentrate the scale variation for the lower values
# lonboard.Map(
#     [
#         exploire_layer(
#             ds_healpix.Sentinel2.sel(bands=band_list[-1]).compute(),
#             alpha=0.80,
#             cmap='viridis'
#         )
#     ]
# )

In [12]:
ds_healpix.cell_ids.values

array([185483525589, 185483525590, 185483525591, ..., 186919324224,
       186919324288, 186919324289], shape=(2159668,))

In [13]:
import healpy as hp
def get_closest_neighbor(cell_id, level):
    nside = 2 ** level
    neighbors = hp.get_all_neighbours(nside, cell_id, nest=True)
    return neighbors

In [14]:
cell_id = ds_healpix.cell_ids.values[0]

In [15]:
level = 19

neighbors = get_closest_neighbor(cell_id, level)

# Convert ds_healpix.cell_ids to a set for fast lookup
available_cells = set(ds_healpix.cell_ids.values)

valid_neighbors = [cell_id] + [n for n in neighbors if n in available_cells]
selected_ds_healpix = ds_healpix.sel(cell_ids=valid_neighbors)

selected_ds_healpix

Unnamed: 0,Array,Chunk
Bytes,48 B,48 B
Shape,"(6,)","(6,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 48 B 48 B Shape (6,) (6,) Dask graph 1 chunks in 2 graph layers Data type int64 numpy.ndarray",6  1,

Unnamed: 0,Array,Chunk
Bytes,48 B,48 B
Shape,"(6,)","(6,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,192 B,192 B
Shape,"(4, 6)","(4, 6)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 192 B 192 B Shape (4, 6) (4, 6) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",6  4,

Unnamed: 0,Array,Chunk
Bytes,192 B,192 B
Shape,"(4, 6)","(4, 6)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [16]:
neighbors

array([185483525588, 185483525590, 185483525591, 185484224642,
       185484224640, 185484224554, 185483525503, 185483525502])

- The healpix product here generated at level 19
- The main goal first is to generate a conv layer that goes througt the product cell_ids index.
- Compute the function get_closest_neighbor to get the 8 nearest neighbors in a list and add the central cell_id to it. In total the list shall have 9 cell_ids that represent the index of the product.
- With the index we then select the data to apply a 2D conv
- The goal is to do this with a stride of 1 and also mitigate the border effect as normal pytorch conv 2D is doing


## build neighbour index 

The build_neighbor_index function constructs a consistent 9-cell neighbourhood for each HEALPix cell in a dataset at a given resolution level. For each input cell_id, it uses the HEALPix library (Healpy) to retrieve the 8 adjacent neighboring cells using nested indexing. These 8 neighbors, together with the central cell, form a 3×3 patch structure—mimicking the receptive field of a 2D convolutional kernel.

While the HEALPix sphere is globally continuous, a real-world dataset may only cover a limited subset of it. As a result, some of the 8 neighbors returned by Healpy may not be present in the dataset. To ensure consistent patch size across all cells, the function performs an intersection between the neighbor list and the available cells in the dataset. Any missing or invalid neighbors are replaced with the central cell ID, effectively simulating 'same' padding at the borders.

This approach guarantees that each cell is associated with a 9-element patch (center + 8 neighbors), making the data suitable for spatial operations such as spherical convolutions. The output is a 2D NumPy array of shape (N_cells, 9), where each row corresponds to one patch of cell indices ready to be used in downstream convolutional models.


### Information about padding system

In convolutional operations (like in a CNN), we often apply a filter (or "kernel") over a neighborhood of pixels (e.g., a 3×3 grid). But when the filter reaches the edges of the data, it lacks some neighboring values. To avoid losing spatial coverage or changing the output size, padding is applied — typically by adding zeros or repeating values around the edges.

In our HEALPix-based spherical data:

- Each cell has up to 8 neighboring cells.
- cells near the "edges" (e.g., at poles or borders in the pixelization) might not have all 8 valid neighbors.
- HEALPix marks these missing neighbors with -1.


In [19]:
ds_healpix = healpix_projection(dt=dt, res="10m", chunk=True, level=19, chunk_size=4096)

Chunk size: y=1830, x=1830
Number of chunks: y=6, x=6
Coordinate grid shape: (1830, 1830)
X range: 736565 to 754855
Y range: 5126825 to 5145115
HEALPix Level 19 → 2,159,668 unique cells
Processing 4 variables: ['b02', 'b03', 'b04', 'b08']
Processing b02 (1/4)
Processing b03 (2/4)
Processing b04 (3/4)
Processing b08 (4/4)
HEALPix conversion complete - Level 19, 2,159,668 cells


In [20]:
def build_neighbor_index(cell_ids, level, available_cell_ids, stride=1):
    """
    Construct a 9-cell neighborhood (center + 8 neighbors) for a subset of HEALPix cell IDs
    using a configurable stride, including only neighbors that are actually present.

    Parameters:
    -----------
    cell_ids : array-like
        Full list of HEALPix cell IDs.

    level : int
        HEALPix resolution level. Used to compute NSIDE (NSIDE = 2^level).

    available_cell_ids : set or list
        Set of valid HEALPix cell IDs from the dataset.

    stride : int, optional
        Step size to sample center cells (default: 1, use all cell_ids).
        Higher values skip more cells, reducing the output size.

    Returns:
    --------
    np.ndarray
        Array of shape (N_patches, 9) where each row is a 3x3 patch centered on a valid cell.
    """
    nside = 2 ** level
    available_cell_ids = set(available_cell_ids)
    neighbor_indices = []

    # --- Apply stride to center cell list ---
    center_cells = cell_ids[::stride]

    for cell_id in center_cells:
        neighbors = hp.get_all_neighbours(nside, cell_id, nest=True)

        # Validate each neighbor; replace invalid or missing with center
        valid_neighbors = [
            n if (n != -1 and n in available_cell_ids) else cell_id
            for n in neighbors
        ]

        patch = [cell_id] + valid_neighbors  # Center + 8 neighbors
        neighbor_indices.append(patch)

    return np.array(neighbor_indices)

In [21]:
cell_ids = ds_healpix.cell_ids.values

In [22]:
cell_ids = ds_healpix.cell_ids.values
available_cells = set(cell_ids)
neighbor_index = build_neighbor_index(cell_ids, level=19, available_cell_ids=available_cells, stride=1)

In [204]:
neighbor_index.shape

(2159668, 9)

In [None]:
# Convert ds_healpix.cell_ids to a set for fast lookup
available_cells = set(ds_healpix.cell_ids.values)
valid_neighbors = [n for n in neighbor_index[0] if n in available_cells]
selected_ds_healpix = ds_healpix.sel(cell_ids=valid_neighbors)
selected_ds_healpix.Sentinel2.sel(bands=band_list[-1]).compute().values

array([0.3097    , 0.3097    , 0.367     , 0.3532    , 0.34216667,
       0.3431    , 0.22      , 0.3097    , 0.3097    ])

In [225]:
x = selected_ds_healpix.Sentinel2.sel(bands=band_list[-1]).compute().values
x

array([0.3097    , 0.3097    , 0.367     , 0.3532    , 0.34216667,
       0.3431    , 0.22      , 0.3097    , 0.3097    ])

In [31]:
selected_ds_healpix.Sentinel2.sel(bands=band_list[-1]).cell_ids.values

array([185483525589, 185483525590, 185483525591, 185484224642,
       185484224640, 185484224554])

In [None]:
import torch
from notebook.SphericConv import RegionalSphericalConv

In [None]:
def example_with_your_data():
    """
    Complete example showing how to use the spherical conv with your ds_healpix data
    """
    # Assuming you have your ds_healpix from the healpix_projection function
    # and your band_list from get_bands function

    print("=== Setting up Spherical Convolution ===")

    # 1. Extract spectral data for all bands
    band_list = get_bands(data_tree=dt, res="10m")  # Your band list
    print(f"Available bands: {band_list}")

    # 2. Get available cell IDs from your dataset
    available_cell_ids = ds_healpix.cell_ids.values
    print(f"Number of available HEALPix cells: {len(available_cell_ids)}")

    # 3. Create input tensor with all spectral bands
    # Shape: [n_bands, n_cells]
    spectral_data = []
    for band in band_list:
        band_data = ds_healpix.Sentinel2.sel(bands=band).compute().values
        spectral_data.append(band_data)

    # Stack all bands: [n_bands, n_cells]
    x_multi_band = np.stack(spectral_data, axis=0)
    print(f"Multi-band data shape: {x_multi_band.shape}")

    # 4. Create spherical conv layer
    conv_layer = RegionalSphericalConv(
        available_cell_ids=available_cell_ids,
        level=19,
        in_channels=len(band_list),
        out_channels=64,
        stride=1
    )

    print(f"Created conv layer with {len(band_list)} input channels, 64 output channels")
    print(f"Number of patches that will be generated: {conv_layer.n_patches}")

    # 5. Convert to PyTorch tensor and add batch dimension
    x_tensor = torch.tensor(x_multi_band, dtype=torch.float32).unsqueeze(0)
    print(f"Input tensor shape: {x_tensor.shape}")  # [1, n_bands, n_cells]

    # 6. Forward pass
    with torch.no_grad():  # For demonstration, no gradients needed
        output = conv_layer(x_tensor)
        print(f"Output tensor shape: {output.shape}")  # [1, 64, n_patches]

    # 7. Optional: Convert back to numpy for further processing
    output_np = output.squeeze(0).numpy()  # Remove batch dimension
    print(f"Output as numpy array: {output_np.shape}")  # [64, n_patches]

    return conv_layer, x_tensor, output


# Example with different stride values
def example_with_stride():
    """
    Example showing how stride affects the number of output patches
    """
    available_cell_ids = ds_healpix.cell_ids.values
    band_list = get_bands(data_tree=dt, res="10m")

    print("\n=== Stride Comparison ===")

    for stride in [1, 2, 4, 8]:
        conv_layer = RegionalSphericalConv(
            available_cell_ids=available_cell_ids,
            level=19,
            in_channels=len(band_list),
            out_channels=32,
            stride=stride
        )

        reduction_factor = len(available_cell_ids) / conv_layer.n_patches
        print(f"Stride {stride}: {conv_layer.n_patches} patches "
              f"({reduction_factor:.1f}x reduction from {len(available_cell_ids)} cells)")

In [38]:
import torch

conv_layer, input_tensor, output = example_with_your_data()

=== Setting up Spherical Convolution ===
Available bands: ['b02', 'b03', 'b04', 'b08']
Number of available HEALPix cells: 2159668
Multi-band data shape: (4, 2159668)
Created conv layer with 4 input channels, 64 output channels
Number of patches that will be generated: 2159668
Input tensor shape: torch.Size([1, 4, 2159668])
Output tensor shape: torch.Size([1, 64, 2159668])
Output as numpy array: (64, 2159668)


In [39]:
example_with_stride()


=== Stride Comparison ===
Stride 1: 2159668 patches (1.0x reduction from 2159668 cells)
Stride 2: 1079834 patches (2.0x reduction from 2159668 cells)
Stride 4: 539917 patches (4.0x reduction from 2159668 cells)
Stride 8: 269959 patches (8.0x reduction from 2159668 cells)


In [None]:
available_cell_ids = ds_healpix.cell_ids.values

conv_layer = RegionalSphericalConv(
    available_cell_ids=available_cell_ids,
    level=level,
    in_channels=4,
    out_channels=64,
    stride=1
)

## **RegionalSphericalConv.__init__()** 
This is the constructor that sets up the spherical convolution layer for your regional HEALPix data. Unlike the original implementation that assumes a full sphere, this version works with only the HEALPix cells present in your dataset. It takes your `available_cell_ids` (the cells that actually contain data in your geographic region) and the HEALPix resolution level. The constructor builds the neighbor relationships, creates efficient lookup tables, and initializes the underlying 1D convolution layer that will process the 3×3 patches. The key innovation here is that it adapts to your specific data coverage rather than assuming global coverage.

## **_build_neighbor_index()** 
This function implements your exact strategy for creating 3×3 patches on the sphere. For each available cell in your dataset (optionally subsampled by stride), it uses HEALPix's `get_all_neighbours()` to find the 8 surrounding cells. However, since your data only covers a specific region, some of these neighbors might not exist in your dataset. The function handles this by replacing any missing neighbors with the center cell ID itself - effectively implementing "same padding" at the boundaries. This ensures every patch has exactly 9 elements (center + 8 neighbors), making it compatible with standard CNN operations while preserving the spherical topology.

## **_convert_to_data_indices()** 
This function performs a crucial optimization for computational efficiency. While `_build_neighbor_index()` works with HEALPix cell IDs (which can be large, sparse numbers), your actual data array uses sequential indices from 0 to N-1. This function creates a mapping that converts the HEALPix cell IDs in each patch to the corresponding positions in your data array. This pre-computation means that during the forward pass, you can directly index into your data tensor without expensive lookups, making the convolution much faster.

## **forward()** 
This is where the actual spherical convolution happens. The function takes your input tensor of shape [batch, channels, cells] and extracts all the 3×3 patches simultaneously using advanced indexing. It reshapes these patches into a format suitable for PyTorch's 1D convolution (which treats each 9-element patch as a "sequence"), applies the learned convolutional weights, and returns the feature maps. The beauty of this approach is that it maintains the spherical neighborhood relationships while leveraging standard CNN operations, allowing you to use existing deep learning frameworks efficiently.

## **example_with_your_data()** 
This comprehensive example demonstrates the complete workflow with your actual HEALPix data. It starts by extracting all spectral bands from your `ds_healpix` dataset, stacks them into a multi-channel tensor (just like RGB channels in regular images), creates the spherical convolution layer, and performs a forward pass. The example shows how your satellite data with multiple spectral bands gets transformed into feature maps that respect the spherical geometry. It also includes practical details like tensor shapes and data type conversions needed for PyTorch.

## **example_with_stride()** 
This function illustrates how the stride parameter affects computational efficiency and spatial resolution. By using different stride values (1, 2, 4, 8), you can control how many patches are generated - stride=1 creates a patch for every available cell, while stride=4 creates patches for every 4th cell, reducing computation by 4×. This is particularly useful for multi-scale processing or when you need to balance computational resources with spatial detail. The function shows the trade-off between spatial resolution and computational efficiency.

The key advantage of this approach is that it seamlessly integrates your regional HEALPix strategy with standard deep learning, allowing you to build sophisticated spherical CNNs that work efficiently on your satellite data while respecting the underlying spherical geometry of Earth's surface.

In [44]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import healpy as hp
import matplotlib.pyplot as plt
from collections import defaultdict

# First, let's include the RegionalSphericalConv class from the previous artifact
class RegionalSphericalConv(nn.Module):
    def __init__(self, available_cell_ids, level, in_channels, out_channels, bias=True, nest=True, stride=1):
        """
        Regional Spherical Convolutional layer for HEALPix data covering a specific area.
        """
        super(RegionalSphericalConv, self).__init__()

        self.level = level
        self.NSIDE = 2 ** level
        self.nest = nest
        self.stride = stride
        self.available_cell_ids = np.array(available_cell_ids)
        self.available_cell_set = set(available_cell_ids)

        # Build neighbor index using your strategy
        self.neighbor_indices = self._build_neighbor_index()
        self.n_patches = self.neighbor_indices.shape[0]

        # Create cell_id to data_index mapping
        self.cell_to_data_idx = {cell_id: i for i, cell_id in enumerate(self.available_cell_ids)}

        # Convert neighbor indices to data indices for efficient lookup
        self.data_neighbor_indices = self._convert_to_data_indices()

        # 1D convolution with kernel size 9 (3x3 patch flattened)
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=9, stride=9, bias=bias)

        # Initialize weights
        nn.init.kaiming_normal_(self.conv.weight)
        if bias:
            nn.init.constant_(self.conv.bias, 0.0)

    def _build_neighbor_index(self):
        """Build 9-cell neighborhood index using your strategy"""
        available_cell_ids = set(self.available_cell_ids)
        neighbor_indices = []

        # Apply stride to center cell list
        center_cells = self.available_cell_ids[::self.stride]

        for cell_id in center_cells:
            neighbors = hp.get_all_neighbours(self.NSIDE, cell_id, nest=self.nest)

            # Validate each neighbor; replace invalid or missing with center
            valid_neighbors = [
                n if (n != -1 and n in available_cell_ids) else cell_id
                for n in neighbors
            ]

            patch = [cell_id] + valid_neighbors  # Center + 8 neighbors
            neighbor_indices.append(patch)

        return np.array(neighbor_indices)

    def _convert_to_data_indices(self):
        """Convert HEALPix cell IDs to data array indices"""
        data_indices = np.zeros_like(self.neighbor_indices)

        for i, patch in enumerate(self.neighbor_indices):
            for j, cell_id in enumerate(patch):
                data_indices[i, j] = self.cell_to_data_idx[cell_id]

        return torch.tensor(data_indices, dtype=torch.long)

    def forward(self, x):
        """Forward pass"""
        batch_size, n_channels, n_cells = x.shape

        # Ensure we have the right number of cells
        assert n_cells == len(self.available_cell_ids), \
            f"Expected {len(self.available_cell_ids)} cells, got {n_cells}"

        # Extract patches using the neighbor indices
        # Shape: [B, C_in, N_patches, 9]
        patches = x[:, :, self.data_neighbor_indices]

        # Reshape to [B, C_in, N_patches * 9] for Conv1d
        patches_flat = patches.view(batch_size, n_channels, -1)

        # Apply convolution
        output = self.conv(patches_flat)

        return output

# Test Suite
class SphericalConvTestSuite:
    def __init__(self):
        self.test_results = {}
        self.verbose = True

    def log(self, message):
        if self.verbose:
            print(message)

    def test_neighbor_index_logic(self):
        """Test 1: Verify neighbor index construction logic"""
        self.log("\n=== TEST 1: Neighbor Index Logic ===")

        # Create a small test case with known HEALPix cells
        level = 3  # Small for testing
        nside = 2 ** level

        # Create a small region of cells
        center_cell = 100
        all_neighbors = hp.get_all_neighbours(nside, center_cell, nest=True)
        available_cells = [center_cell] + [n for n in all_neighbors if n != -1]

        self.log(f"Test region: center={center_cell}, neighbors={all_neighbors}")
        self.log(f"Available cells: {available_cells}")

        # Create conv layer
        conv_layer = RegionalSphericalConv(
            available_cell_ids=available_cells,
            level=level,
            in_channels=3,
            out_channels=16,
            stride=1
        )

        # Check neighbor indices
        neighbor_indices = conv_layer.neighbor_indices
        self.log(f"Generated {len(neighbor_indices)} patches")

        # Verify each patch has 9 elements
        for i, patch in enumerate(neighbor_indices):
            assert len(patch) == 9, f"Patch {i} has {len(patch)} elements, expected 9"
            # Center should be first element
            center = patch[0]
            assert center in available_cells, f"Center {center} not in available cells"

        self.log("✅ Neighbor index logic test passed!")
        self.test_results['neighbor_index'] = True

    def test_boundary_padding(self):
        """Test 2: Verify boundary padding with missing neighbors"""
        self.log("\n=== TEST 2: Boundary Padding ===")

        level = 4
        nside = 2 ** level

        # Create a sparse region where many neighbors will be missing
        sparse_cells = [1000, 1001, 1050]  # Scattered cells

        conv_layer = RegionalSphericalConv(
            available_cell_ids=sparse_cells,
            level=level,
            in_channels=1,
            out_channels=8,
            stride=1
        )

        # Check that missing neighbors are replaced with center
        neighbor_indices = conv_layer.neighbor_indices

        for i, patch in enumerate(neighbor_indices):
            center = patch[0]
            neighbors = patch[1:]

            # Count how many neighbors are the center (indicating padding)
            padding_count = sum(1 for n in neighbors if n == center)

            # For sparse data, we expect some padding
            self.log(f"Patch {i}: center={center}, padding={padding_count}/8 neighbors")

            # Verify all elements are valid
            for cell_id in patch:
                assert cell_id in sparse_cells, f"Invalid cell {cell_id} in patch"

        self.log("✅ Boundary padding test passed!")
        self.test_results['boundary_padding'] = True

    def test_stride_functionality(self):
        """Test 3: Verify stride reduces number of patches correctly"""
        self.log("\n=== TEST 3: Stride Functionality ===")

        level = 3
        # Create a larger region
        available_cells = list(range(100, 200))  # 100 cells

        stride_results = {}

        for stride in [1, 2, 4, 5]:
            conv_layer = RegionalSphericalConv(
                available_cell_ids=available_cells,
                level=level,
                in_channels=2,
                out_channels=10,
                stride=stride
            )

            expected_patches = len(available_cells) // stride
            actual_patches = conv_layer.n_patches

            stride_results[stride] = {
                'expected': expected_patches,
                'actual': actual_patches,
                'cells': len(available_cells)
            }

            self.log(f"Stride {stride}: {actual_patches} patches (expected ~{expected_patches})")

            # Allow for small differences due to integer division
            assert abs(actual_patches - expected_patches) <= 1, \
                f"Stride {stride}: got {actual_patches}, expected ~{expected_patches}"

        self.log("✅ Stride functionality test passed!")
        self.test_results['stride'] = True

    def test_forward_pass_shapes(self):
        """Test 4: Verify forward pass produces correct output shapes"""
        self.log("\n=== TEST 4: Forward Pass Shapes ===")

        level = 4
        available_cells = list(range(500, 600))  # 100 cells
        batch_size = 2
        in_channels = 5
        out_channels = 16

        conv_layer = RegionalSphericalConv(
            available_cell_ids=available_cells,
            level=level,
            in_channels=in_channels,
            out_channels=out_channels,
            stride=1
        )

        # Create input tensor
        n_cells = len(available_cells)
        input_tensor = torch.randn(batch_size, in_channels, n_cells)

        self.log(f"Input shape: {input_tensor.shape}")

        # Forward pass
        output = conv_layer(input_tensor)

        expected_output_shape = (batch_size, out_channels, conv_layer.n_patches)
        actual_output_shape = output.shape

        self.log(f"Expected output shape: {expected_output_shape}")
        self.log(f"Actual output shape: {actual_output_shape}")

        assert actual_output_shape == expected_output_shape, \
            f"Shape mismatch: got {actual_output_shape}, expected {expected_output_shape}"

        self.log("✅ Forward pass shapes test passed!")
        self.test_results['forward_shapes'] = True

    def test_patch_extraction_correctness(self):
        """Test 5: Verify patch extraction extracts correct values"""
        self.log("\n=== TEST 5: Patch Extraction Correctness ===")

        level = 3
        available_cells = [100, 101, 102, 103, 104]

        conv_layer = RegionalSphericalConv(
            available_cell_ids=available_cells,
            level=level,
            in_channels=1,
            out_channels=4,
            stride=1
        )

        # Create input with known values
        input_data = torch.arange(len(available_cells), dtype=torch.float32)
        input_tensor = input_data.unsqueeze(0).unsqueeze(0)  # [1, 1, 5]

        self.log(f"Input data: {input_data}")
        self.log(f"Cell to data mapping: {conv_layer.cell_to_data_idx}")

        # Extract patches manually to verify
        patches = input_tensor[:, :, conv_layer.data_neighbor_indices]

        self.log(f"Extracted patches shape: {patches.shape}")

        # Check first patch
        first_patch = patches[0, 0, 0, :].numpy()
        self.log(f"First patch values: {first_patch}")

        # Verify patch corresponds to correct neighbor indices
        first_patch_cell_ids = conv_layer.neighbor_indices[0]
        expected_values = [conv_layer.cell_to_data_idx[cell_id] for cell_id in first_patch_cell_ids]

        self.log(f"First patch cell IDs: {first_patch_cell_ids}")
        self.log(f"Expected data indices: {expected_values}")

        # The patch should contain the data values at those indices
        for i, expected_idx in enumerate(expected_values):
            assert first_patch[i] == expected_idx, \
                f"Patch position {i}: got {first_patch[i]}, expected {expected_idx}"

        self.log("✅ Patch extraction correctness test passed!")
        self.test_results['patch_extraction'] = True

    def test_gradient_flow(self):
        """Test 6: Verify gradients flow correctly through the layer"""
        self.log("\n=== TEST 6: Gradient Flow ===")

        level = 4
        available_cells = list(range(200, 250))

        conv_layer = RegionalSphericalConv(
            available_cell_ids=available_cells,
            level=level,
            in_channels=3,
            out_channels=8,
            stride=1
        )

        # Create input that requires gradients
        input_tensor = torch.randn(1, 3, len(available_cells), requires_grad=True)

        # Forward pass
        output = conv_layer(input_tensor)

        # Backward pass
        loss = output.sum()
        loss.backward()

        # Check that gradients exist
        assert input_tensor.grad is not None, "No gradients computed for input"
        assert conv_layer.conv.weight.grad is not None, "No gradients computed for conv weights"

        self.log(f"Input gradient shape: {input_tensor.grad.shape}")
        self.log(f"Conv weight gradient shape: {conv_layer.conv.weight.grad.shape}")

        # Check gradient magnitudes are reasonable
        input_grad_norm = input_tensor.grad.norm().item()
        weight_grad_norm = conv_layer.conv.weight.grad.norm().item()

        assert input_grad_norm > 0, "Input gradients are zero"
        assert weight_grad_norm > 0, "Weight gradients are zero"

        self.log(f"Input gradient norm: {input_grad_norm:.6f}")
        self.log(f"Weight gradient norm: {weight_grad_norm:.6f}")

        self.log("✅ Gradient flow test passed!")
        self.test_results['gradient_flow'] = True

    def test_memory_efficiency(self):
        """Test 7: Check memory usage for larger datasets"""
        self.log("\n=== TEST 7: Memory Efficiency ===")

        level = 6  # Larger level
        available_cells = list(range(1000, 3000))  # 2000 cells

        conv_layer = RegionalSphericalConv(
            available_cell_ids=available_cells,
            level=level,
            in_channels=10,
            out_channels=32,
            stride=2  # Use stride to reduce memory
        )

        # Create reasonably large input
        input_tensor = torch.randn(4, 10, len(available_cells))  # 4 batches

        self.log(f"Input tensor size: {input_tensor.numel() * 4 / 1024 / 1024:.2f} MB")

        # Forward pass
        output = conv_layer(input_tensor)

        self.log(f"Output tensor size: {output.numel() * 4 / 1024 / 1024:.2f} MB")
        self.log(f"Number of patches: {conv_layer.n_patches}")

        # Basic sanity check
        assert output.shape[0] == 4, "Batch dimension incorrect"
        assert output.shape[1] == 32, "Channel dimension incorrect"

        self.log("✅ Memory efficiency test passed!")
        self.test_results['memory_efficiency'] = True

    def run_all_tests(self):
        """Run all tests and provide summary"""
        self.log("🚀 Starting Spherical Convolution Test Suite")

        test_methods = [
            self.test_neighbor_index_logic,
            self.test_boundary_padding,
            self.test_stride_functionality,
            self.test_forward_pass_shapes,
            self.test_patch_extraction_correctness,
            self.test_gradient_flow,
            self.test_memory_efficiency
        ]

        for test_method in test_methods:
            try:
                test_method()
            except Exception as e:
                test_name = test_method.__name__
                self.log(f"❌ {test_name} FAILED: {str(e)}")
                self.test_results[test_name] = False

        self.print_summary()

    def print_summary(self):
        """Print test summary"""
        self.log("\n" + "="*50)
        self.log("TEST SUMMARY")
        self.log("="*50)

        passed = sum(1 for result in self.test_results.values() if result)
        total = len(self.test_results)

        for test_name, result in self.test_results.items():
            status = "✅ PASSED" if result else "❌ FAILED"
            self.log(f"{test_name}: {status}")

        self.log(f"\nOverall: {passed}/{total} tests passed")

        if passed == total:
            self.log("🎉 ALL TESTS PASSED! The spherical convolution implementation is working correctly.")
        else:
            self.log("⚠️  Some tests failed. Please review the implementation.")


test_suite = SphericalConvTestSuite()
test_suite.run_all_tests()

🚀 Starting Spherical Convolution Test Suite

=== TEST 1: Neighbor Index Logic ===
Test region: center=100, neighbors=[ 97  99 102 103 101  79  78  75]
Available cells: [100, np.int64(97), np.int64(99), np.int64(102), np.int64(103), np.int64(101), np.int64(79), np.int64(78), np.int64(75)]
Generated 9 patches
✅ Neighbor index logic test passed!

=== TEST 2: Boundary Padding ===
Patch 0: center=1000, padding=7/8 neighbors
Patch 1: center=1001, padding=7/8 neighbors
Patch 2: center=1050, padding=8/8 neighbors
✅ Boundary padding test passed!

=== TEST 3: Stride Functionality ===
Stride 1: 100 patches (expected ~100)
Stride 2: 50 patches (expected ~50)
Stride 4: 25 patches (expected ~25)
Stride 5: 20 patches (expected ~20)
✅ Stride functionality test passed!

=== TEST 4: Forward Pass Shapes ===
Input shape: torch.Size([2, 5, 100])
Expected output shape: (2, 16, 100)
Actual output shape: torch.Size([2, 16, 100])
✅ Forward pass shapes test passed!

=== TEST 5: Patch Extraction Correctness ===


[W718 08:11:01.210286836 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W718 08:11:01.224200675 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.


In [46]:
import numpy as np
import healpy as hp
import matplotlib.pyplot as plt
import torch
from typing import List, Tuple

class SphericalConvVisualTest:
    """Visual tests to verify the spherical convolution logic"""

    def __init__(self):
        self.level = 4  # Small level for visualization
        self.nside = 2 ** self.level

    def create_test_region(self, center_lat: float = 45.0, center_lon: float = 10.0,
                          radius_deg: float = 20.0) -> List[int]:
        """Create a test region similar to your geographical area"""

        # Convert center to HEALPix pixel
        center_pix = hp.ang2pix(self.nside, center_lon, center_lat, lonlat=True, nest=True)

        # Get all pixels within radius
        vec = hp.ang2vec(center_lon, center_lat, lonlat=True)
        pixels_in_region = hp.query_disc(self.nside, vec, np.radians(radius_deg), nest=True)

        return sorted(pixels_in_region.tolist())

    def visualize_neighbor_logic(self, test_cell_id: int, available_cells: List[int]):
        """Visualize how neighbor finding works for a specific cell"""

        print(f"\n=== Neighbor Logic Test for Cell {test_cell_id} ===")

        # Get all neighbors using HEALPix
        neighbors = hp.get_all_neighbours(self.nside, test_cell_id, nest=True)
        print(f"Raw neighbors from HEALPix: {neighbors}")

        # Apply your strategy
        available_set = set(available_cells)
        valid_neighbors = [
            n if (n != -1 and n in available_set) else test_cell_id
            for n in neighbors
        ]

        patch = [test_cell_id] + valid_neighbors
        print(f"Final patch (center + 8 neighbors): {patch}")

        # Count padding
        padding_count = sum(1 for n in valid_neighbors if n == test_cell_id)
        print(f"Padding applied: {padding_count}/8 neighbors replaced with center")

        # Analyze neighbor availability
        neighbor_analysis = []
        for i, n in enumerate(neighbors):
            if n == -1:
                status = "INVALID (edge of sphere)"
            elif n in available_set:
                status = "AVAILABLE"
            else:
                status = "MISSING (outside region)"
            neighbor_analysis.append(f"  Neighbor {i}: {n} -> {status}")

        print("Neighbor Analysis:")
        for analysis in neighbor_analysis:
            print(analysis)

        return patch

    def test_patch_extraction_step_by_step(self):
        """Step by step test of patch extraction"""

        print("\n=== Step-by-Step Patch Extraction Test ===")

        # Create small test region
        test_region = self.create_test_region(center_lat=45.0, center_lon=10.0, radius_deg=15.0)
        print(f"Test region has {len(test_region)} cells")

        # Test with a specific cell
        test_cell = test_region[len(test_region)//2]  # Middle cell
        print(f"Testing with cell: {test_cell}")

        # Show neighbor logic
        patch = self.visualize_neighbor_logic(test_cell, test_region)

        # Create mock data
        cell_to_data_idx = {cell_id: i for i, cell_id in enumerate(test_region)}
        mock_data = np.arange(len(test_region)) * 10  # Easy to identify values

        print(f"\nMock data shape: {mock_data.shape}")
        print(f"Mock data values: {mock_data}")

        # Extract patch values
        patch_values = []
        for cell_id in patch:
            data_idx = cell_to_data_idx[cell_id]
            value = mock_data[data_idx]
            patch_values.append(value)
            print(f"  Cell {cell_id} -> data_idx {data_idx} -> value {value}")

        print(f"Extracted patch values: {patch_values}")

        return patch_values

    def test_convolution_mathematics(self):
        """Test the mathematical correctness of convolution"""

        print("\n=== Convolution Mathematics Test ===")

        # Create simple test case
        test_region = list(range(100, 120))  # 20 cells
        n_channels = 2


        conv_layer = RegionalSphericalConv(
            available_cell_ids=test_region,
            level=self.level,
            in_channels=n_channels,
            out_channels=1,  # Single output for easier verification
            stride=1
        )

        # Create simple input where each cell has its index as value
        input_data = torch.zeros(1, n_channels, len(test_region))
        for i in range(len(test_region)):
            input_data[0, 0, i] = i  # Channel 0: cell index
            input_data[0, 1, i] = i * 2  # Channel 1: 2x cell index

        print(f"Input data shape: {input_data.shape}")
        print(f"Channel 0 values: {input_data[0, 0, :].numpy()}")
        print(f"Channel 1 values: {input_data[0, 1, :].numpy()}")

        # Manually set conv weights for predictable output
        with torch.no_grad():
            # Set weight to sum all inputs in a patch
            conv_layer.conv.weight.fill_(1.0)  # All weights = 1
            if conv_layer.conv.bias is not None:
                conv_layer.conv.bias.fill_(0.0)  # No bias

        # Forward pass
        output = conv_layer(input_data)
        print(f"Output shape: {output.shape}")
        print(f"Output values: {output[0, 0, :].detach().numpy()}")

        # Verify first patch manually
        first_patch_indices = conv_layer.data_neighbor_indices[0]
        print(f"First patch data indices: {first_patch_indices}")

        # Calculate expected output for first patch
        expected_sum = 0
        for channel in range(n_channels):
            for idx in first_patch_indices:
                expected_sum += input_data[0, channel, idx].item()

        actual_output = output[0, 0, 0].item()
        print(f"Expected sum for first patch: {expected_sum}")
        print(f"Actual output for first patch: {actual_output}")

        # They should match (within floating point precision)
        assert abs(expected_sum - actual_output) < 1e-5, \
            f"Mathematics error: expected {expected_sum}, got {actual_output}"

        print("✅ Convolution mathematics verified!")

    def test_boundary_cases(self):
        """Test edge cases and boundary conditions"""

        print("\n=== Boundary Cases Test ===")

        # Test 1: Single cell region
        single_cell = [1000]
        try:
            conv_single = RegionalSphericalConv(
                available_cell_ids=single_cell,
                level=self.level,
                in_channels=1,
                out_channels=4,
                stride=1
            )

            # All neighbors should be the center cell
            patch = conv_single.neighbor_indices[0]
            assert all(cell == 1000 for cell in patch), "Single cell test failed"
            print("✅ Single cell region test passed")

        except Exception as e:
            print(f"❌ Single cell test failed: {e}")

        # Test 2: Linear chain of cells
        chain_cells = [2000, 2001, 2002]  # Sparse chain
        try:
            conv_chain = RegionalSphericalConv(
                available_cell_ids=chain_cells,
                level=self.level,
                in_channels=1,
                out_channels=4,
                stride=1
            )

            print(f"Chain test: {conv_chain.n_patches} patches created")

            # Check that patches are valid
            for i, patch in enumerate(conv_chain.neighbor_indices):
                assert all(cell in chain_cells for cell in patch), \
                    f"Invalid cell in chain patch {i}"

            print("✅ Chain cells test passed")

        except Exception as e:
            print(f"❌ Chain cells test failed: {e}")

        # Test 3: Very large stride
        large_region = list(range(3000, 3100))  # 100 cells
        try:
            conv_large_stride = RegionalSphericalConv(
                available_cell_ids=large_region,
                level=self.level,
                in_channels=1,
                out_channels=4,
                stride=10
            )

            expected_patches = len(large_region) // 10
            actual_patches = conv_large_stride.n_patches

            print(f"Large stride test: {actual_patches} patches (expected ~{expected_patches})")
            assert abs(actual_patches - expected_patches) <= 1, "Large stride test failed"
            print("✅ Large stride test passed")

        except Exception as e:
            print(f"❌ Large stride test failed: {e}")

    def compare_with_original_implementation(self):
        """Compare behavior with the original spherical conv when possible"""

        print("\n=== Comparison with Original Implementation ===")

        # For this test, we'll create a full sphere region to compare
        # This simulates what the original implementation expects

        level = 2  # Very small for full sphere test
        nside = 2 ** level
        all_cells = list(range(hp.nside2npix(nside)))

        print(f"Full sphere test with {len(all_cells)} cells")

        # Create our regional implementation with full sphere
        our_conv = RegionalSphericalConv(
            available_cell_ids=all_cells,
            level=level,
            in_channels=3,
            out_channels=8,
            stride=1
        )

        # Create test input
        input_tensor = torch.randn(1, 3, len(all_cells))

        # Forward pass
        our_output = our_conv(input_tensor)

        print(f"Our implementation output shape: {our_output.shape}")
        print(f"Number of patches: {our_conv.n_patches}")

        # Basic sanity checks
        assert our_output.shape[0] == 1, "Batch dimension wrong"
        assert our_output.shape[1] == 8, "Channel dimension wrong"
        assert our_output.shape[2] == len(all_cells), "Should have one patch per cell for full sphere"

        print("✅ Full sphere comparison test passed")

    def run_all_visual_tests(self):
        """Run all visual and logic tests"""

        print("🔍 Starting Visual and Logic Tests")
        print("="*60)

        try:
            self.test_patch_extraction_step_by_step()
            self.test_convolution_mathematics()
            self.test_boundary_cases()
            self.compare_with_original_implementation()

            print("\n🎉 All visual and logic tests completed successfully!")

        except Exception as e:
            print(f"\n❌ Visual test failed: {e}")
            raise


# Run the comprehensive test suite first
print("Running comprehensive test suite...")
test_suite = SphericalConvTestSuite()
test_suite.run_all_tests()

# Then run visual tests
print("\n" + "="*60)
visual_test = SphericalConvVisualTest()
visual_test.run_all_visual_tests()

Running comprehensive test suite...
🚀 Starting Spherical Convolution Test Suite

=== TEST 1: Neighbor Index Logic ===
Test region: center=100, neighbors=[ 97  99 102 103 101  79  78  75]
Available cells: [100, np.int64(97), np.int64(99), np.int64(102), np.int64(103), np.int64(101), np.int64(79), np.int64(78), np.int64(75)]
Generated 9 patches
✅ Neighbor index logic test passed!

=== TEST 2: Boundary Padding ===
Patch 0: center=1000, padding=7/8 neighbors
Patch 1: center=1001, padding=7/8 neighbors
Patch 2: center=1050, padding=8/8 neighbors
✅ Boundary padding test passed!

=== TEST 3: Stride Functionality ===
Stride 1: 100 patches (expected ~100)
Stride 2: 50 patches (expected ~50)
Stride 4: 25 patches (expected ~25)
Stride 5: 20 patches (expected ~20)
✅ Stride functionality test passed!

=== TEST 4: Forward Pass Shapes ===
Input shape: torch.Size([2, 5, 100])
Expected output shape: (2, 16, 100)
Actual output shape: torch.Size([2, 16, 100])
✅ Forward pass shapes test passed!

=== TEST 

[W718 10:11:55.020261638 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W718 10:11:55.020915754 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W718 10:11:55.045348704 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W718 10:11:55.048977796 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
