In [2]:
import zarr
import xarray as xr
import torch
from SphericConv import RegionalSphericalConv
import numpy as np


def get_bands(data_tree, res):

    res_key = f"r{res}"
    bands = data_tree.measurements.reflectance[res_key]
    return list(bands.keys())

In [3]:
ds_healpix = xr.open_dataset("/home/ubuntu/project/sentinel-2-dggs-ai-processor/src/notebook/healpix.zarr")

  engine = plugins.guess_engine(filename_or_obj)
  engine = plugins.guess_engine(filename_or_obj)
  engine = plugins.guess_engine(filename_or_obj)


In [4]:
print("=== Setting up Spherical Convolution ===")
# 1. Extract spectral data for all bands
band_list = ds_healpix.Sentinel2.bands.values
print(f"Available bands: {band_list}")

# 2. Get available cell IDs from your dataset
available_cell_ids = ds_healpix.cell_ids.values
print(f"Number of available HEALPix cells: {len(available_cell_ids)}")

# 3. Create input tensor with all spectral bands
# Shape: [n_bands, n_cells]
spectral_data = []
for band in band_list:
    band_data = ds_healpix.Sentinel2.sel(bands=band).compute().values
    spectral_data.append(band_data)

# Stack all bands: [n_bands, n_cells]
x_multi_band = np.stack(spectral_data, axis=0)
print(f"Multi-band data shape: {x_multi_band.shape}")

# 4. Create spherical conv layer
conv_layer = RegionalSphericalConv(
    available_cell_ids=available_cell_ids,
    level=19,
    in_channels=len(band_list),
    out_channels=64,
    stride=2
)

print(f"Created conv layer with {len(band_list)} input channels, 64 output channels")
print(f"Number of patches that will be generated: {conv_layer.n_patches}")

# 5. Convert to PyTorch tensor and add batch dimension
x_tensor = torch.tensor(x_multi_band, dtype=torch.float32).unsqueeze(0)
print(f"Input tensor shape: {x_tensor.shape}")  # [1, n_bands, n_cells]

# 6. Forward pass
with torch.no_grad():
    output = conv_layer(x_tensor)
    print(f"Output tensor shape: {output.shape}")  # [1, 64, n_patches]

# 7. Optional: Convert back to numpy for further processing
output_np = output.squeeze(0).numpy()  # Remove batch dimension
print(f"Output as numpy array: {output_np.shape}")  # [64, n_patches]

=== Setting up Spherical Convolution ===
Available bands: ['b02' 'b03' 'b04' 'b08']
Number of available HEALPix cells: 2159668
Multi-band data shape: (4, 2159668)
Created conv layer with 4 input channels, 64 output channels
Number of patches that will be generated: 1079834
Input tensor shape: torch.Size([1, 4, 2159668])
Output tensor shape: torch.Size([1, 64, 1079834])
Output as numpy array: (64, 1079834)


In [None]:
from torch import nn


class Model(nn.Module):
    """
    U-Net architecture for semantic segmentation.

    The U-Net consists of an encoder, a bottleneck, and a decoder.

    Attributes:
    - conv1 to conv9: Double convolutional blocks for encoding and decoding.
    - pool1 to pool4: Max-pooling layers for downsampling in the encoder.
    - tconv1 to tconv4: Transpose convolutional blocks for upsampling in the decoder.
    - out: Single convolutional block for the final output.

    Methods:
    - forward(x): Forward pass through the U-Net.

    """
    def __init__(self):
        super().__init__()


        self.conv1 = RegionalSphericalConv(
            available_cell_ids=available_cell_ids,
            level=19,
            in_channels=len(band_list),
            out_channels=64,
            stride=2
        )
        self.conv2 = RegionalSphericalConv(
            available_cell_ids=available_cell_ids,
            level=19,
            in_channels=len(band_list),
            out_channels=64,
            stride=2
        )

    def forward(self, x):
            """
            Forward pass through the U-Net.

            Parameters:
            - x: Input tensor.

            Returns:
            - out: Output tensor after passing through the U-Net.
            """
            # Encoder
            x = self.conv1(x)
            print(x.shape)

In [11]:
model = Model()

In [12]:
x_tensor = torch.tensor(x_multi_band, dtype=torch.float32).unsqueeze(0)
x_tensor.shape

torch.Size([1, 4, 2159668])

In [13]:
out = model(x_tensor)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import healpy as hp

# Your RegionalSphericalConv class (from the document)
class RegionalSphericalConv(nn.Module):
    def __init__(self, available_cell_ids, level, in_channels, out_channels, bias=True, nest=True, stride=1):
        """
        Regional Spherical Convolutional layer for HEALPix data covering a specific area.
        """
        super(RegionalSphericalConv, self).__init__()

        self.level = level
        self.NSIDE = 2 ** level
        self.nest = nest
        self.stride = stride
        self.available_cell_ids = np.array(available_cell_ids)
        self.available_cell_set = set(available_cell_ids)

        # Build neighbor index using your strategy
        self.neighbor_indices = self._build_neighbor_index()
        self.n_patches = self.neighbor_indices.shape[0]

        # Create cell_id to data_index mapping
        self.cell_to_data_idx = {cell_id: i for i, cell_id in enumerate(self.available_cell_ids)}

        # Convert neighbor indices to data indices for efficient lookup
        self.data_neighbor_indices = self._convert_to_data_indices()

        # 1D convolution with kernel size 9 (3x3 patch flattened)
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=9, stride=9, bias=bias)

        # Initialize weights
        nn.init.kaiming_normal_(self.conv.weight)
        if bias:
            nn.init.constant_(self.conv.bias, 0.0)

    def _build_neighbor_index(self):
        """Build 9-cell neighborhood index"""
        available_cell_ids = set(self.available_cell_ids)
        neighbor_indices = []

        # Apply stride to center cell list
        center_cells = self.available_cell_ids[::self.stride]

        for cell_id in center_cells:
            neighbors = hp.get_all_neighbours(self.NSIDE, cell_id, nest=self.nest)

            # Validate each neighbor; replace invalid or missing with center
            valid_neighbors = [
                n if (n != -1 and n in available_cell_ids) else cell_id
                for n in neighbors
            ]

            patch = [cell_id] + valid_neighbors  # Center + 8 neighbors
            neighbor_indices.append(patch)

        return np.array(neighbor_indices)

    def _convert_to_data_indices(self):
        """Convert HEALPix cell IDs to data array indices"""
        data_indices = np.zeros_like(self.neighbor_indices)

        for i, patch in enumerate(self.neighbor_indices):
            for j, cell_id in enumerate(patch):
                data_indices[i, j] = self.cell_to_data_idx[cell_id]

        return torch.tensor(data_indices, dtype=torch.long)

    def forward(self, x):
        """Forward pass"""
        batch_size, n_channels, n_cells = x.shape

        # Ensure we have the right number of cells
        assert n_cells == len(self.available_cell_ids), \
            f"Expected {len(self.available_cell_ids)} cells, got {n_cells}"

        # Extract patches using the neighbor indices
        # Shape: [B, C_in, N_patches, 9]
        patches = x[:, :, self.data_neighbor_indices]

        # Reshape to [B, C_in, N_patches * 9] for Conv1d
        patches_flat = patches.view(batch_size, n_channels, -1)

        # Apply convolution
        output = self.conv(patches_flat)

        return output


class SphericalConvBlock(nn.Module):
    """Basic convolutional block with batch norm and ReLU"""

    def __init__(self, available_cell_ids, level, in_channels, out_channels, stride=1):
        super(SphericalConvBlock, self).__init__()

        self.conv = RegionalSphericalConv(
            available_cell_ids=available_cell_ids,
            level=level,
            in_channels=in_channels,
            out_channels=out_channels,
            stride=stride
        )
        self.bn = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class SphericalDoubleConv(nn.Module):
    """Double convolution block (Conv -> BN -> ReLU -> Conv -> BN -> ReLU)"""

    def __init__(self, available_cell_ids, level, in_channels, out_channels, stride=1):
        super(SphericalDoubleConv, self).__init__()

        # First conv with specified stride
        self.conv1 = SphericalConvBlock(
            available_cell_ids=available_cell_ids,
            level=level,
            in_channels=in_channels,
            out_channels=out_channels,
            stride=stride
        )

        # Second conv with stride=1 (operating on the output of first conv)
        # We need to determine the output size after first conv
        n_patches_after_first = len(available_cell_ids) // stride

        self.conv2 = SphericalConvBlock(
            available_cell_ids=available_cell_ids,  # This is approximate
            level=level,
            in_channels=out_channels,
            out_channels=out_channels,
            stride=1
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x


# Simple model using just SphericalDoubleConv
class SimpleSphericalModel(nn.Module):
    """Simple model that just applies SphericalDoubleConv"""

    def __init__(self, available_cell_ids, level, in_channels, out_channels, stride=1):
        super(SimpleSphericalModel, self).__init__()

        self.double_conv = SphericalDoubleConv(
            available_cell_ids=available_cell_ids,
            level=level,
            in_channels=in_channels,
            out_channels=out_channels,
            stride=stride
        )

    def forward(self, x):
        return self.double_conv(x)



In [None]:
# Extract available cell IDs from your dataset
available_cell_ids = ds_healpix.cell_ids.values
level = 19  # Your HEALPix level
in_channels = len(band_list)
stride = 1
print(f"Creating Spherical U-Net:")
print(f"  - Input cells: {len(available_cell_ids)}")
print(f"  - Input channels: {in_channels} (bands: {band_list})")
print(f"  - Output classes: {in_channels}")
print(f"  - HEALPix level: {level}")

Creating Spherical U-Net:
  - Input cells: 2159668
  - Input channels: 4 (bands: ['b02' 'b03' 'b04' 'b08'])
  - Output classes: 4
  - HEALPix level: 19


In [27]:
import torch
import numpy as np
import xarray as xr

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Your code with minimal fixes
ds_healpix = xr.open_dataset("/home/ubuntu/project/sentinel-2-dggs-ai-processor/src/notebook/healpix.zarr")

# 2. Get available cell IDs from your dataset
available_cell_ids = ds_healpix.cell_ids.values
print(f"Number of available HEALPix cells: {len(available_cell_ids)}")

model = SimpleSphericalModel(
        available_cell_ids=available_cell_ids,
        level=level,
        in_channels=in_channels,
        out_channels=in_channels,
        stride=stride
    )

# Calculate total parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"  - Total parameters: {total_params:,}")

# 3. Create input tensor with all spectral bands
# Shape: [n_bands, n_cells]
spectral_data = []
for band in band_list:
    band_data = ds_healpix.Sentinel2.sel(bands=band).compute().values
    spectral_data.append(band_data)

# Stack all bands: [n_bands, n_cells]
x_multi_band = np.stack(spectral_data, axis=0)
print(f"Multi-band data shape: {x_multi_band.shape}")

# 5. Convert to PyTorch tensor and add batch dimension
x_tensor = torch.tensor(x_multi_band, dtype=torch.float32).unsqueeze(0)
print(f"Input tensor shape: {x_tensor.shape}")  # [1, n_bands, n_cells]

# Move input to GPU
x_tensor = x_tensor.to(device)

# 6. Forward pass
with torch.no_grad():
    output = model(x_tensor)  # Fixed: was SphericalUNet(x_tensor)
    print(f"Output tensor shape: {output.shape}")  # Fixed: was output.shape

Using device: cuda
Number of available HEALPix cells: 2159668


  engine = plugins.guess_engine(filename_or_obj)
  engine = plugins.guess_engine(filename_or_obj)
  engine = plugins.guess_engine(filename_or_obj)


  - Total parameters: 312
Multi-band data shape: (4, 2159668)
Input tensor shape: torch.Size([1, 4, 2159668])


RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same