# Demo: Mapping SIMP 0136 with NIRSpec

In this tutorial, we will demonstrate how to use `spectralmap` to analyze time-series observations. We will use a sequence of NIRSpec observations of the brown dwarf SIMP 0136 to recover a spectroscopic map of its surface.

**We will cover:**
1. Loading and preprocessing the lightcurve data.
2. Inverting the lightcurves to obtain surface maps.
3. Clustering the map pixels to identify distinct spectral regions.
4. Extracting and analyzing the spectra of these regions.

Let's get started!

In [None]:
from __future__ import annotations

from dataclasses import dataclass
import jax
import numpy as np
import jax.numpy as jnp
from jaxoplanet.core.limb_dark import light_curve as _limb_dark_light_curve
from jaxoplanet.starry.core.basis import A1, A2_inv, U
from jaxoplanet.starry.core.polynomials import Pijk
from jaxoplanet.starry.core.rotation import left_project, sky_projection_axis_angle, dot_rotation_matrix
from jaxoplanet.starry.core.solution import rT, solution_vector
from jaxoplanet.starry.surface import Surface
from jaxoplanet.starry.system_observable import system_observable


from spectralmap.bayesian_linalg import optimize_alpha_fixed_beta, solve_posterior


In [None]:
import numpy as np
import jax
import jax.numpy as jnp

def build_design_matrix_operator(ydeg, inc=90):
    """
    Factory returning a JIT-compiled function to compute the design matrix.
    Precomputes static basis transforms and rotations (A1, rT, Rinc, Rx).
    """
    # 1. Static Basis Transforms (Poly -> SH) & Integration Vector
    # rT @ A1 gives the map integration operator in the SH basis
    rT_arr = jnp.array(rT(ydeg))
    A1_val = A1(ydeg)
    A1_arr = jnp.array(A1_val.toarray() if hasattr(A1_val, "toarray") else A1_val)
    
    # 2. Static Rotations
    n_sph = A1_arr.shape[1]
    I = np.eye(n_sph)
    
    # Rx: Standard orientation alignment (-pi/2 about x-axis)
    Rx = jnp.array(dot_rotation_matrix(ydeg, 1.0, 0.0, 0.0, -0.5 * np.pi)(I))
    
    # Rinc: Inclination rotation
    inc_axis_angle = sky_projection_axis_angle(inc, 0.0) # (x, y, z, angle)
    Rinc = jnp.array(dot_rotation_matrix(ydeg, *inc_axis_angle)(I))

    # 3. Combine Static Parts: Projector = (rT @ A1) @ Rinc
    # The full operation is: rT @ A1 @ Rinc @ Ry(theta) @ Rx
    # We pre-multiply everything to the left of Ry
    Projector = rT_arr @ A1_arr @ Rinc
    I_jax = jnp.eye(n_sph)

    @jax.jit
    def design_matrix(theta_batch):
        """Computes design matrix (N_phases, N_SH) given phases theta."""
        theta_batch = jnp.atleast_1d(theta_batch)

        def compute_row(th):
            # Ry: Dynamic phase rotation (-th about z-axis)
            Ry = dot_rotation_matrix(ydeg, 0.0, 0.0, 1.0, -th)(I_jax)
            
            # Result = (Projector) @ Ry @ Rx
            return Projector @ Ry @ Rx

        return jax.vmap(compute_row)(theta_batch)

    return design_matrix

In [None]:
ydeg = 8
theta = np.linspace(0, 360, 100)
"""Compute design matrix for given observation angles theta."""
inc = 90.0  # edge-on
rT_deg = rT(ydeg)
design_matrix_p = rT_deg # n_pol
# surface map, no limb darkening
A1_val = A1(ydeg) # n_pol x n_sph
n_sph = A1_val.shape[1]
A = np.zeros((len(theta), n_sph))


axis_x, axis_y, axis_z, angle = sky_projection_axis_angle(inc, 0)
Rx = dot_rotation_matrix(ydeg, 1.0, None, None, -0.5 * jnp.pi)
Rinc = dot_rotation_matrix(ydeg, axis_x, axis_y, axis_z, angle)
I = jnp.identity(n_sph)

In [None]:
# Build the operator (runs once to setup constants)
design_matrix_op = build_design_matrix_operator(ydeg=12, inc=80)

# Run the JIT-compiled operator
theta_grid = jnp.linspace(0, 2*jnp.pi, 300)
A = design_matrix_op(theta_grid).block_until_ready()

print(f"Design Matrix Shape: {A.shape}")
print("First row sample:", A[0, :5])

In [None]:
import starry

In [None]:
import numpy as np
np.bool = bool
import pymc3
import starry

starry.config.lazy = False
map = starry.Map(ydeg=4)

In [None]:
map.design_matrix(theta=np.linspace(0, 360, 10))

In [None]:
import matplotlib.pyplot as plt
for i in range(10):
    plt.plot(theta_grid, A[:, i], label=f'SH Coefficient {i}')
plt.legend()
plt.xlabel('Theta (radians)')
plt.ylabel('Coefficient Value')
plt.title('Spherical Harmonic Coefficients vs Phase')
plt.show()

In [None]:
import os
import matplotlib.pyplot as plt

import spectralmap
print(f"spectralmap version: {spectralmap.__version__}")

## 1. Load Data
We load the NIRSpec lightcurve files. Each file corresponds to a specific rotation phase (angle). We will read them in, sort them by phase, and normalize the flux.

In [None]:
import pandas as pd
from glob import glob
import numpy as np

# 1. Load the data files
# The filenames contain the rotation angle (phase) of the observation
path = "NIRSPEC/SIMP0136_NIRSpec_Flambda_*degrees_2025.dat"
files = sorted(glob(path))

theta = []
for f in files:
    # Extract angle from filename (e.g., "..._45degrees_...")
    angle = int(f.split("_")[-2].replace("degrees", ""))
    theta.append(angle)

# Sort files by angle
theta = np.array(theta)
order = np.argsort(theta)
theta = theta[order]

# Read CSVs into a numpy array
dfs = np.array([pd.read_csv(f, delimiter=" ", comment="#").to_numpy() for f in files])
dfs = dfs[order]

# Re-order to match sorted theta
# Structure: [n_files, n_wavelengths, columns]
# Columns: 0=wavelength, 1=flux, 2=noise

# 2. Extract relevant arrays
wl_B = dfs[0, :, 0]
LC_B = dfs[:, :, 1].T # shape: (n_wavelengths, n_times)
noise_B = dfs[:, :, 2].T

# 3. Normalize the lightcurves
# We normalize by the mean amplitude over time for each wavelength
amplitudes = np.nanmean(LC_B, axis=1)
LC_B_norm = (LC_B.T / amplitudes).T
flux_err = (noise_B.T / amplitudes).T
flux = LC_B_norm - 1 # Center around 0 for spectralmap

print(f"Loaded data with shape: {flux.shape} (wavelengths, time points)")

## 2. Inversion with SpectralMap

First, we organize the data into a `LightCurveData` object. We must specify the inclination of the object (here assumed to be 80 degrees).

In [None]:
np.bool = bool
from spectralmap.mapping import LightCurveData
data = LightCurveData(theta=theta, flux=flux[:, :], flux_err=flux_err[:, :], inc=80)

### Find Optimal Map Complexity
We now solve for the maps. Since we don't know the complexity of the surface beforehand, we test spherical harmonic degrees ranging from `ydeg=2` to `ydeg=10`. The algorithm will automatically select the degree that best fits the data without overfitting (using the Bayesian Information Criterion).

In [None]:
from spectralmap.mapping import solve_posterior, best_ydeg_maps, Map

# Solve for the best spherical harmonic degree (ydeg) for each wavelength
# This optimizes the model complexity (BIC) to avoid overfitting
ydeg_best, I_all_wl, I_cov_all_wl = best_ydeg_maps(data, ydeg_min=2, ydeg_max=3)

print("Best Spherical Harmonic Degree per wavelength bin:")
print(ydeg_best)

## 3. Visualize the Maps
Let's take a look at the recovered maps for a few wavelengths to see how the surface features change.

In [None]:
for i, ydeg in enumerate(ydeg_best):
    if ydeg <= 5:
        plt.figure()
        plt.imshow(I_all_wl[i].reshape((30, 30)), origin='lower')
        plt.title(f"Wavelength {wl_B[i]: .2f}: best ydeg = {ydeg}")


## 4. Spectral Extraction
We convert the retrieved intensity maps ($I$) back into physical flux units ($F$) and calculate the associated variances. This gives us the spatially-resolved spectra.

In [None]:
F_all_wl = I_all_wl * amplitudes[:, None] *np.pi
F_cov_all_wl = I_cov_all_wl * (np.pi * amplitudes[:, None, None])**2
F_var_all_wl = np.diagonal(F_cov_all_wl, axis1=1, axis2=2)

In [None]:
import matplotlib.pyplot as plt
i_grid = 300
plt.plot(wl_B[:], F_all_wl[:, i_grid])
plt.fill_between(wl_B[:], F_all_wl[:, i_grid] - np.sqrt(F_var_all_wl[:, i_grid]), F_all_wl[:, i_grid] + np.sqrt(F_var_all_wl[:, i_grid]), alpha=0.5)

## 5. Clustering and Regional Identification
To make sense of the map, we group pixels that show similar spectral behavior using a clustering algorithm. This helps us identify distinct "regions" or features on the object's surface.

In [None]:
from spectralmap.cluster import find_clusters
F_regionals, F_regional_errs, labels = find_clusters(F_all_wl, F_cov_all_wl, n_neighbors=50)
N = len(F_regionals)

## 6. Results
Finally, we visualize the identified clusters on the map and plot the mean spectrum for each region.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

# --- 1. SETUP & MOCK DATA ---
# (Assuming spectra.shape and lat are defined as per previous context)
map_res = I_cov_all_wl.shape[-1]

# Create Mock "Assigned" Labels (for the contour overlay)
# We shift it slightly so the dotted lines are visible against the background
labels_assigned = labels.reshape(map_res, map_res).T.astype(np.float64)
# --- 2. STYLE DEFINITIONS ---
# Hot, Cold, Unique, Neutral/Background
colors_list = ['#E8E8E8', '#D62728', '#1F77B4', '#9467BD', "#20AB1E"] 
cluster_names = ["Background", "Region 1", "Region 2", "Region 3", "Region 4"]

cmap = mcolors.ListedColormap(colors_list)

# --- 3. PLOTTING ---
# Single plot
fig, ax = plt.subplots(figsize=(3.5, 3.5 * (180/360)), dpi=300, constrained_layout=True)

# A. Plot Ground Truth (Input) as filled image
im = ax.imshow(
    labels_assigned.T,
    origin='lower',
    cmap=cmap,
    vmin=-1.5, vmax=N-1.5,  # Centers colors on -1, 0, 1, 2, 3
    extent=[-180, 180, -90, 90], 
    alpha=0.8,
    aspect='auto'
)


# Add grid
ax.grid(True, linestyle='--', color='gray', alpha=0.6) 

# Titles and Labels
ax.set_ylabel("Latitude (deg)", fontsize=8)
ax.set_xlabel("Longitude (deg)", fontsize=8)
# ax.set_title("Recovered", fontsize=9)
ax.tick_params(axis='both', which='major', labelsize=7)

# --- COLORBAR ---
cbar = fig.colorbar(im, ax=ax, orientation='vertical', pad=0.03, aspect=30)
cbar.set_ticks(np.arange(N)-1)  # Center ticks on each color
cbar.ax.set_yticklabels(cluster_names, fontsize=7)
cbar.ax.yaxis.set_tick_params(length=0) 
cbar.outline.set_edgecolor('black')

output_dir = "paper_plots"
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "S0136_recovered_regions.pdf")
plt.savefig(output_path, dpi=300)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import xarray as xr
import os

fig, ax = plt.subplots(figsize=(7, 2.5))

# Plot the recovered spectra for each region
for i in range(N):
    mean_flux = F_regionals[i]
    error_flux = F_regional_errs[i] # Fixed variable name
    color = colors_list[i]
    label = cluster_names[i]
    
    ax.plot(wl_B, mean_flux, label=f"{label}", color=color, linewidth=1.5)
    ax.fill_between(wl_B, 
                    mean_flux - error_flux, 
                    mean_flux + error_flux, 
                    alpha=0.3, color=color)
    
# Overlay the range of the observed time-series variability
time_series = np.sort(LC_B.T, axis=0)
ax.fill_between(wl_B, time_series[0, :],
            time_series[-1, :], color='black', alpha=0.1, zorder=0, label="Observed Variability")

# Formatting
ax.set_xlabel(r"Wavelength ($\mu$m)")
ax.set_ylabel("Flux (normalized units)")
ax.set_title("Recovered Regional Spectra", fontsize=10)
ax.legend(fontsize=8, loc='upper right', ncol=2)

plt.tight_layout()

output_dir = "paper_plots"
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "S0136_recovered_spectra.pdf")
plt.savefig(output_path, dpi=300)
print(f"Plot saved to {output_path}")

plt.show()