In [28]:
import functools
import operator
from typing import Any, Optional, Sequence, Union
from flax import linen as nn
import jax.numpy as jnp
import jax
import numpy as np
from spherical_cnn import layers
from spherical_cnn import sphere_utils
from spherical_cnn import spin_spherical_harmonics

In [29]:
Array = Union[np.ndarray, jnp.ndarray]

In [30]:
class SpinSphericalClassifier(nn.Module):
    """Construct a spin-weighted spherical CNN for classification.

    Attributes:
      num_classes: Number of nodes in the final layer.
      resolutions: (n_layers,) list of resolutions at each layer. For consecutive
        resolutions a, b, we must have either a == b or a == 2*b. The latter
        triggers inclusion of a pooling layer.
      spins: A (n_layers,) list of (n_spins,) lists of spin weights per layer.
      widths: (n_layers,) list of width per layer (number of channels).
      spectral_pooling: When True, use spectral instead of spatial pooling.
      axis_name: Identifier for the mapped axis in parallel training.
      num_filter_params: (n_layers,) the number of filter parameters per layer.
      input_transformer: None, or SpinSphericalFourierTransformer
        instance. Will be computed automatically if None.
    """

    num_classes: int
    resolutions: Sequence[int]
    spins: Sequence[Sequence[int]]
    widths: Sequence[int]
    spectral_pooling: bool
    axis_name: Any
    num_filter_params: Optional[Sequence[int]] = None
    input_transformer: Optional[
        spin_spherical_harmonics.SpinSphericalFourierTransformer
    ] = None

    def setup(self):
        if self.input_transformer is None:
            # Flatten spins.
            all_spins = functools.reduce(operator.concat, self.spins)
            self.transformer = spin_spherical_harmonics.SpinSphericalFourierTransformer(
                resolutions=np.unique(self.resolutions), spins=np.unique(all_spins)
            )
        else:
            self.transformer = self.input_transformer

        num_layers = len(self.resolutions)
        if len(self.spins) != num_layers or len(self.widths) != num_layers:
            raise ValueError("resolutions, spins, and widths must be the same size!")
        model_layers = []
        for layer_id in range(num_layers - 1):
            resolution_in = self.resolutions[layer_id]
            resolution_out = self.resolutions[layer_id + 1]
            spins_in = self.spins[layer_id]
            spins_out = self.spins[layer_id + 1]
            if self.num_filter_params is None:
                num_filter_params = None
            else:
                num_filter_params = self.num_filter_params[layer_id + 1]

            num_channels = self.widths[layer_id + 1]

            # We pool before conv to avoid expensive increase of number of channels at
            # higher resolution.
            if resolution_out == resolution_in // 2:
                downsampling_factor = 2
            elif resolution_out != resolution_in:
                raise ValueError("Consecutive resolutions must be equal or halved.")
            else:
                downsampling_factor = 1

            model_layers.append(
                layers.SpinSphericalBlock(
                    num_channels=num_channels,
                    spins_in=spins_in,
                    spins_out=spins_out,
                    downsampling_factor=downsampling_factor,
                    spectral_pooling=self.spectral_pooling,
                    num_filter_params=num_filter_params,
                    axis_name=self.axis_name,
                    transformer=self.transformer,
                    name=f"spin_block_{layer_id}",
                )
            )

        self.layers = model_layers

        self.final_dense = nn.Dense(self.num_classes, name="final_dense")

    def __call__(self, inputs: Array, train: bool) -> jnp.ndarray:
        """Apply the network to `inputs`.

        Args:
          inputs: (batch_size, resolution, resolution, n_spins, n_channels) array of
            spin-weighted spherical functions (SWSF) with equiangular sampling.
          train: whether to run in training or inference mode.
        Returns:
          A (batch_size, num_classes) float32 array with per-class scores (logits).
        """
        resolution, num_spins, num_channels = inputs.shape[2:]
        if (
            resolution != self.resolutions[0]
            or num_spins != len(self.spins[0])
            or num_channels != self.widths[0]
        ):
            raise ValueError("Incorrect input dimensions!")

        feature_maps = inputs
        for layer in self.layers:
            feature_maps = layer(feature_maps, train=train)

        # Current feature maps are still spin spherical. Do final processing.
        # Global pooling is not equivariant for spin != 0, so me must take the
        # absolute values before.
        mean_abs = sphere_utils.spin_spherical_mean(jnp.abs(feature_maps))
        mean = sphere_utils.spin_spherical_mean(feature_maps).real
        spins = jnp.expand_dims(jnp.array(self.spins[-1]), [0, 2])
        feature_maps = jnp.where(spins == 0, mean, mean_abs)
        # Shape is now (batch, spins, channel).
        feature_maps = feature_maps.reshape((feature_maps.shape[0], -1))

        return self.final_dense(feature_maps)

In [31]:
_SMALL_CLASSIFIER_RESOLUTIONS = (64, 64, 64, 32, 32, 16, 16)
num_layers = len(_SMALL_CLASSIFIER_RESOLUTIONS)
widths = (1, 16, 16, 32, 32, 58, 58)
num_filter_params_per_layer = tuple([8] * num_layers)
# The difference between spherical and spin-weighted models is that spins are
# zero in every layer for the spherical.
spins = tuple([(0,)] * num_layers)
model = SpinSphericalClassifier(
    10,
    resolutions=_SMALL_CLASSIFIER_RESOLUTIONS,
    spins=spins,
    widths=widths,
    spectral_pooling=False,
    num_filter_params=num_filter_params_per_layer,
    axis_name='axis_1',
)

In [32]:
model

SpinSphericalClassifier(
    # attributes
    num_classes = 10
    resolutions = (64, 64, 64, 32, 32, 16, 16)
    spins = ((0,), (0,), (0,), (0,), (0,), (0,), (0,))
    widths = (1, 16, 16, 32, 32, 58, 58)
    spectral_pooling = False
    axis_name = 'axis_1'
    num_filter_params = (8, 8, 8, 8, 8, 8, 8)
    input_transformer = None
)

In [33]:
class SpinSphericalAutoencoder(nn.Module):
    """Spin-weighted spherical CNN with an encoder-decoder structure.

    Attributes:
      resolutions: (n_layers,) list of resolutions at each layer.
      spins: A (n_layers,) list of (n_spins,) lists of spin weights per layer.
      widths: (n_layers,) list of width per layer (number of channels).
      spectral_pooling: When True, use spectral instead of spatial pooling.
      axis_name: Identifier for the mapped axis in parallel training.
      num_filter_params: (n_layers,) the number of filter parameters per layer.
      input_transformer: None, or SpinSphericalFourierTransformer instance.
    """

    resolutions: Sequence[int]
    spins: Sequence[Sequence[int]]
    widths: Sequence[int]
    spectral_pooling: bool
    axis_name: Any
    num_filter_params: Optional[Sequence[int]] = None
    input_transformer: Optional[
        spin_spherical_harmonics.SpinSphericalFourierTransformer
    ] = None

    def setup(self):
        if self.input_transformer is None:
            all_spins = functools.reduce(operator.concat, self.spins)
            self.transformer = spin_spherical_harmonics.SpinSphericalFourierTransformer(
                resolutions=np.unique(self.resolutions), spins=np.unique(all_spins)
            )
        else:
            self.transformer = self.input_transformer

        num_layers = len(self.resolutions)
        if len(self.spins) != num_layers or len(self.widths) != num_layers:
            raise ValueError("resolutions, spins, and widths must be the same size!")

        self.encoder_layers = list()
        self.decoder_layers = list()

        # Encoder (downsampling)
        for layer_id in range(num_layers - 1):
            resolution_in = self.resolutions[layer_id]
            resolution_out = self.resolutions[layer_id + 1]
            spins_in = self.spins[layer_id]
            spins_out = self.spins[layer_id + 1]
            num_channels = self.widths[layer_id + 1]

            if self.num_filter_params is None:
                num_filter_params = None
            else:
                num_filter_params = self.num_filter_params[layer_id + 1]

            downsampling_factor = 2 if resolution_out == resolution_in // 2 else 1
            if resolution_out != resolution_in and resolution_out != resolution_in // 2:
                raise ValueError("Resolutions must be equal or halved.")

            print(self.encoder_layers)
            self.encoder_layers.append(
                layers.SpinSphericalBlock(
                    num_channels=num_channels,
                    spins_in=spins_in,
                    spins_out=spins_out,
                    downsampling_factor=downsampling_factor,
                    spectral_pooling=self.spectral_pooling,
                    num_filter_params=num_filter_params,
                    axis_name=self.axis_name,
                    transformer=self.transformer,
                    name=f"encoder_block_{layer_id}",
                )
            )

        # Decoder (upsampling, symmetric to encoder)
        for layer_id in reversed(range(num_layers - 1)):
            resolution_in = self.resolutions[layer_id + 1]
            resolution_out = self.resolutions[layer_id]
            spins_in = self.spins[layer_id + 1]
            spins_out = self.spins[layer_id]
            num_channels = self.widths[layer_id]

            upsampling_factor = 2 if resolution_out == resolution_in * 2 else 1
            if resolution_out != resolution_in and resolution_out != resolution_in * 2:
                raise ValueError("Resolutions must be equal or doubled.")

            self.decoder_layers.append(
                layers.SpinSphericalBlock(
                    num_channels=num_channels,
                    spins_in=spins_in,
                    spins_out=spins_out,
                    downsampling_factor=1 / upsampling_factor,
                    spectral_pooling=self.spectral_pooling,
                    num_filter_params=num_filter_params,
                    axis_name=self.axis_name,
                    transformer=self.transformer,
                    name=f"decoder_block_{layer_id}",
                )
            )

        self.final_conv = nn.Conv(
            features=self.widths[0],
            kernel_size=(1, 1),
            name="final_conv"
        )

    def __call__(self, inputs: jnp.ndarray, train: bool) -> jnp.ndarray:
        """Apply the autoencoder network to `inputs`.

        Args:
          inputs: (batch_size, resolution, resolution, n_spins, n_channels) array.
          train: whether to run in training or inference mode.

        Returns:
          Reconstructed tensor (same shape as input).
        """
        feature_maps = inputs
        skips = []

        # Encoding Path
        for layer in self.encoder_layers:
            feature_maps = layer(feature_maps, train=train)
            skips.append(feature_maps)  # Save for skip connections

        # Decoding Path
        for layer, skip in zip(self.decoder_layers, reversed(skips)):
            feature_maps = layer(feature_maps, train=train) + skip  # Skip connection

        return self.final_conv(feature_maps)

In [34]:
model = SpinSphericalAutoencoder(
    resolutions=_SMALL_CLASSIFIER_RESOLUTIONS,
    spins=spins,
    widths=widths,
    spectral_pooling=False,
    axis_name='axis_1'
)

In [35]:
model

SpinSphericalAutoencoder(
    # attributes
    resolutions = (64, 64, 64, 32, 32, 16, 16)
    spins = ((0,), (0,), (0,), (0,), (0,), (0,), (0,))
    widths = (1, 16, 16, 32, 32, 58, 58)
    spectral_pooling = False
    axis_name = 'axis_1'
    num_filter_params = None
    input_transformer = None
)

In [36]:
batch_size = 1  # For testing on a single slice, set batch size to 1
n_spins = 0  # Adjust based on your data (e.g., spin-weighted spherical harmonics)
n_channels = 3  # Adjust based on the number of channels

key = jax.random.PRNGKey(0)
# Example: Generating random input tensor (replace with your actual slice)
input_slice = jax.random.normal(key, (batch_size, 128, 110, n_spins, n_channels))

In [37]:
yhat = model.apply(input_slice, train=True)
# params = model.init(key, input_slice, train=False)

# # Get the model prediction for the 128x110 slice (which is input_slice here)
# predicted_output = model.apply(params, input_slice, train=False)

# # The predicted_output will now be the predicted slice(s) for all 141 shells
# print(predicted_output.shape)

AttributeError: 'jaxlib.xla_extension.ArrayImpl' object has no attribute 'items'