### Import Libraries

In [1]:
import h5py
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.preprocessing import LabelEncoder
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import mean_absolute_error

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
from pytorch_tabnet.tab_model import TabNetRegressor

### Load Data

In [2]:
CUBE_PATH = "../sdss_cube_sfr/SDSS_cube_lean.h5"

with h5py.File(CUBE_PATH, "r") as cube:
    spectra = cube["dense_cube/0/ml_spectrum/spectral_1d_cube_zoom_0"][:]
    images = cube["dense_cube/0/ml_image/cutout_3d_cube_zoom_0"][:]
    metadata = cube["dense_cube/0/ml_spectrum/spec_metadata_refs"][:]

In [3]:
sfrs = pd.read_hdf(CUBE_PATH, key="fits_spectra_metadata_star_formation_rates")

valid_sfrs = sfrs[(sfrs["FLAG"] == 0.0) & (sfrs["Z"].notna())].copy()
metadata_indices = metadata[:, 0]["ds_slice_idx"]
mask = np.in1d(metadata_indices, valid_sfrs.index)

In [4]:
# print(", ".join(sfrs.columns))

### Process Data

In [5]:
spectra_filtered = spectra[mask]
images_filtered = images[mask]
metadata_indices_filtered = metadata_indices[mask]

In [6]:
labels = valid_sfrs.loc[metadata_indices_filtered, "TARGETTYPE"]
redshifts = valid_sfrs.loc[metadata_indices_filtered, "Z"].values

In [7]:
def simplify_targettype(val):
    if isinstance(val, bytes):
        val = val.decode("utf-8").strip()
    else:
        val = str(val).strip()
    return "GALAXY" if "GALAXY" in val else "OTHER"

In [8]:
labels_simplified = pd.Series(labels).apply(simplify_targettype).values

In [9]:
galaxy_mask = labels_simplified == "GALAXY"
spectra_filtered = spectra_filtered[galaxy_mask]
images_filtered = images_filtered[galaxy_mask]
redshifts = redshifts[galaxy_mask]

nan_mask = ~np.isnan(spectra_filtered).any(axis=1)
spectra_filtered = spectra_filtered[nan_mask]
images_filtered = images_filtered[nan_mask]
redshifts = redshifts[nan_mask]

In [10]:
def normalize_spectra(spec):
    min_val = np.min(spec, axis=1, keepdims=True)
    max_val = np.max(spec, axis=1, keepdims=True)
    return 2 * (spec - min_val) / (max_val - min_val) - 1

spectra_normalized = normalize_spectra(spectra_filtered)

In [11]:
filters = ["u", "g", "r", "i", "z"]

def plot_images_and_spectrum(images, spectra, labels, redshifts, class_name, obj_num):
    class_index = np.where(labels == class_name)[0]
    if len(class_index) > 0:
        idx = class_index[obj_num]
        img_channels = images[idx]
        spectrum = spectra[idx]
        redshift = redshifts[idx] if not np.isnan(redshifts[idx]) else "N/A"
        
        fig = plt.figure(figsize=(15, 8))
        gs = gridspec.GridSpec(2, 1, height_ratios=[1, 0.5])
        gs_images = gs[0].subgridspec(1, 5, wspace=0.1)
        
        for i in range(5):
            ax = fig.add_subplot(gs_images[0, i])
            ax.imshow(img_channels[i], cmap="viridis")
            ax.set_title(f"{class_name} - {filters[i]}", fontsize=9)
            ax.axis("off")
        
        ax_spec = fig.add_subplot(gs[1])
        wavelengths = np.linspace(3800, 9200, len(spectrum))
        ax_spec.plot(wavelengths, spectrum, color="blue")
        ax_spec.set_title(f"{class_name} - Spectrum (z={redshift:.3f})", fontsize=9)
        ax_spec.set_xlabel("Wavelength (Å)", fontsize=9)
        ax_spec.set_ylabel("Flux", fontsize=9)
        
        plt.tight_layout()
        plt.show()
    else:
        print(f"No images found for class: {class_name}")

# Machine Learning for SFR

In [12]:
y = valid_sfrs.loc[metadata_indices_filtered, ["AVG", "ENTROPY", "MEDIAN", "MODE", "P16", "P2P5", "P84", "P97P5"]].values
y = y[galaxy_mask]
y = y[nan_mask]

In [13]:
indices = np.arange(y.shape[0])
train_idx, test_idx = train_test_split(indices, test_size=0.2, random_state=42)

In [14]:
y_train, y_test = y[train_idx], y[test_idx]
X_spec_train, X_spec_test = spectra_normalized[train_idx], spectra_normalized[test_idx]
X_photo = images_filtered.reshape(images_filtered.shape[0], -1)
X_photo_train, X_photo_test = X_photo[train_idx], X_photo[test_idx]
X_comb_train = np.concatenate([X_spec_train, X_photo_train], axis=1)
X_comb_test = np.concatenate([X_spec_test, X_photo_test], axis=1)

In [16]:
scaler = StandardScaler()
X_spec_train_scaled = scaler.fit_transform(X_spec_train)
X_spec_test_scaled = scaler.transform(X_spec_test)
X_photo_train_scaled = scaler.fit_transform(X_photo_train)
X_photo_test_scaled = scaler.transform(X_photo_test)
X_comb_train_scaled = scaler.fit_transform(X_comb_train)
X_comb_test_scaled = scaler.transform(X_comb_test)

In [17]:
def train_tabnet(X_train, y_train, X_test):
    model = TabNetRegressor(
        n_d=64, n_a=64, n_steps=7,
        gamma=1.5, lambda_sparse=0.0001,
        momentum=0.7, mask_type="entmax"
    )
    model.fit(X_train, y_train, max_epochs=200, patience=30, batch_size=256, virtual_batch_size=128)
    return model.predict(X_test)

y_pred_spec_tab = train_tabnet(X_spec_train_scaled, y_train, X_spec_test_scaled)
y_pred_photo_tab = train_tabnet(X_photo_train_scaled, y_train, X_photo_test_scaled)
y_pred_comb_tab = train_tabnet(X_comb_train_scaled, y_train, X_comb_test_scaled)



epoch 0  | loss: 5.5399  |  0:00:05s
epoch 1  | loss: 1.08571 |  0:00:06s
epoch 2  | loss: 0.82952 |  0:00:08s
epoch 3  | loss: 0.6221  |  0:00:09s
epoch 4  | loss: 0.591   |  0:00:11s
epoch 5  | loss: 0.58049 |  0:00:13s
epoch 6  | loss: 0.57796 |  0:00:14s
epoch 7  | loss: 0.57637 |  0:00:16s
epoch 8  | loss: 0.57949 |  0:00:17s
epoch 9  | loss: 0.57569 |  0:00:19s
epoch 10 | loss: 0.58217 |  0:00:20s
epoch 11 | loss: 0.57658 |  0:00:22s
epoch 12 | loss: 0.57717 |  0:00:23s
epoch 13 | loss: 0.57057 |  0:00:25s
epoch 14 | loss: 0.56798 |  0:00:26s
epoch 15 | loss: 0.57084 |  0:00:28s
epoch 16 | loss: 0.57182 |  0:00:30s
epoch 17 | loss: 0.56829 |  0:00:31s
epoch 18 | loss: 0.56961 |  0:00:33s
epoch 19 | loss: 0.56944 |  0:00:34s
epoch 20 | loss: 0.57231 |  0:00:36s
epoch 21 | loss: 0.56995 |  0:00:37s
epoch 22 | loss: 0.57526 |  0:00:39s
epoch 23 | loss: 0.57216 |  0:00:40s
epoch 24 | loss: 0.56962 |  0:00:42s
epoch 25 | loss: 0.57112 |  0:00:43s
epoch 26 | loss: 0.56943 |  0:00:45s
e



epoch 0  | loss: 10.51246|  0:00:08s
epoch 1  | loss: 4.46478 |  0:00:17s
epoch 2  | loss: 2.67589 |  0:00:25s
epoch 3  | loss: 1.21878 |  0:00:34s
epoch 4  | loss: 1.02759 |  0:00:42s
epoch 5  | loss: 0.69071 |  0:00:51s
epoch 6  | loss: 0.64338 |  0:00:59s
epoch 7  | loss: 0.60085 |  0:01:08s
epoch 8  | loss: 0.58933 |  0:01:16s
epoch 9  | loss: 0.58554 |  0:01:25s
epoch 10 | loss: 0.58508 |  0:01:33s
epoch 11 | loss: 0.58192 |  0:01:42s
epoch 12 | loss: 0.5841  |  0:01:50s
epoch 13 | loss: 0.58322 |  0:01:59s
epoch 14 | loss: 0.57872 |  0:02:07s
epoch 15 | loss: 0.58732 |  0:02:16s
epoch 16 | loss: 0.57988 |  0:02:25s
epoch 17 | loss: 0.58118 |  0:02:33s
epoch 18 | loss: 0.58505 |  0:02:42s
epoch 19 | loss: 0.58196 |  0:02:50s
epoch 20 | loss: 0.58288 |  0:02:59s
epoch 21 | loss: 0.57974 |  0:03:07s
epoch 22 | loss: 0.58444 |  0:03:16s
epoch 23 | loss: 0.57854 |  0:03:24s
epoch 24 | loss: 0.57728 |  0:03:33s
epoch 25 | loss: 0.57837 |  0:03:41s
epoch 26 | loss: 0.57797 |  0:03:50s
e



epoch 0  | loss: 7.99162 |  0:00:12s
epoch 1  | loss: 2.61452 |  0:00:24s
epoch 2  | loss: 1.41326 |  0:00:36s
epoch 3  | loss: 0.91298 |  0:00:48s
epoch 4  | loss: 0.7389  |  0:01:00s
epoch 5  | loss: 0.6656  |  0:01:12s
epoch 6  | loss: 0.65535 |  0:01:24s
epoch 7  | loss: 0.62572 |  0:01:37s
epoch 8  | loss: 0.5991  |  0:01:49s
epoch 9  | loss: 0.58203 |  0:02:01s
epoch 10 | loss: 0.58761 |  0:02:13s
epoch 11 | loss: 0.58224 |  0:02:25s
epoch 12 | loss: 0.5729  |  0:02:37s
epoch 13 | loss: 0.57386 |  0:02:50s
epoch 14 | loss: 0.58062 |  0:03:02s
epoch 15 | loss: 0.57927 |  0:03:14s
epoch 16 | loss: 0.57831 |  0:03:26s
epoch 17 | loss: 0.58124 |  0:03:39s
epoch 18 | loss: 0.57593 |  0:03:51s
epoch 19 | loss: 0.57421 |  0:04:03s
epoch 20 | loss: 0.57038 |  0:04:15s
epoch 21 | loss: 0.56691 |  0:04:27s
epoch 22 | loss: 0.56625 |  0:04:40s
epoch 23 | loss: 0.56664 |  0:04:52s
epoch 24 | loss: 0.56686 |  0:05:04s
epoch 25 | loss: 0.56985 |  0:05:16s
epoch 26 | loss: 0.56356 |  0:05:28s
e

In [20]:
y_mean = np.full_like(y_test, np.mean(y_train, axis=0))
mae_mean = mean_absolute_error(y_test, y_mean)

mae_spec_tab = mean_absolute_error(y_test, y_pred_spec_tab)
mae_photo_tab = mean_absolute_error(y_test, y_pred_photo_tab)
mae_comb_tab = mean_absolute_error(y_test, y_pred_comb_tab)

print("Mean model MAE:", mae_mean)
print("Spectrum model MAE (TabNet):", mae_spec_tab)
print("Photo model MAE (TabNet):", mae_photo_tab)
print("Combination model MAE (TabNet):", mae_comb_tab)

Mean model MAE: 0.6163717097520727
Spectrum model MAE (TabNet): 0.6071666175166874
Photo model MAE (TabNet): 0.5036490425747494
Combination model MAE (TabNet): 0.5891615369827254


In [19]:
df_results = pd.DataFrame({
    "y_test": list(y_test),
    "y_pred_spec_tab": list(y_pred_spec_tab),
    "y_pred_photo_tab": list(y_pred_photo_tab),
    "y_pred_comb_tab": list(y_pred_comb_tab)
})
print(df_results.head(10))

                                              y_test  \
0  [-0.4525999128818512, -5.7508357269963755, -0....   
1  [-0.2147999107837677, -5.758284017535481, -0.1...   
2  [0.37796008586883545, -4.434794977749806, 0.34...   
3  [-1.3138599395751953, -5.47998471566957, -1.26...   
4  [0.4476400911808014, -4.724946891551867, 0.424...   
5  [0.13043010234832764, -4.487824326584547, 0.10...   
6  [0.7521800994873047, -3.3737913659710284, 0.70...   
7  [-1.4386299848556519, -5.80033062012948, -1.34...   
8  [0.07596009224653244, -4.853448751780046, 0.04...   
9  [0.2064300924539566, -4.926453766562226, 0.186...   

                                     y_pred_spec_tab  \
0  [-0.26803556, -5.15075, -0.2480962, -0.1462267...   
1  [-0.2959844, -5.3948207, -0.26265803, -0.14058...   
2  [-0.33333322, -5.44071, -0.30643123, -0.200913...   
3  [-0.12509161, -5.0243163, -0.11539796, -0.0381...   
4  [-0.28625694, -5.30781, -0.26523438, -0.165015...   
5  [-0.37707242, -5.36875, -0.34114948, -0.2134