# Galaxy Classification, Anomaly Detection, Autoencoder & VAE (Keras)
This notebook uses the Galaxy10 DECaLS dataset (astronomical galaxy images) for:

1. Data loading & splitting, defining an anomaly class  
2. CNN classifier + ROC curves & confusion matrices  
3. Convolutional autoencoder for anomaly detection  
4. Variational autoencoder (VAE) + galaxy generation

Dataset: Galaxy10 DECaLS (HDF5).

In [4]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf

print("Physical devices:", tf.config.list_physical_devices())
print("Logical GPU devices:", tf.config.list_logical_devices('GPU'))
print("GPU device name:", tf.test.gpu_device_name())

from tensorflow import keras # (changed) sticking to tf.keras
from tensorflow.keras import layers, ops, Model, random, regularizers, optimizers, models

import numpy as np
import h5py
import matplotlib.pyplot as plt
import random as py_random

SEED = 42  # for reproducibility
np.random.seed(SEED)
tf.random.set_seed(SEED)
py_random.seed(SEED)

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc, confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import roc_curve, auc

Physical devices: [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]
Logical GPU devices: []
GPU device name: 


## 1. Download the dataset, inspect classes, create anomaly split

Steps:
- Download `Galaxy10_DECals_64.h5`
- Load images and labels with `h5py`
- Inspect shape and class distribution
- Remove class 0 and store it as the anomaly dataset
- Split remaining (standard) data into train (50%), val (25%), test (25%)

In [8]:
# Download Galaxy10 DECaLS (HDF5)
! rm ./Galaxy10_DECals_64.h5
! wget https://cernbox.cern.ch/remote.php/dav/public-files/RyWK8CBk2yqKR0b/Galaxy10_DECals_64.h5
data_path = "./Galaxy10_DECals_64.h5"

print("Downloaded to:", data_path)

--2025-12-09 16:41:19--  https://cernbox.cern.ch/remote.php/dav/public-files/RyWK8CBk2yqKR0b/Galaxy10_DECals_64.h5
Resolving cernbox.cern.ch (cernbox.cern.ch)... 128.142.53.35, 128.142.53.28, 137.138.120.151, ...
Connecting to cernbox.cern.ch (cernbox.cern.ch)|128.142.53.35|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 159140118 (152M) [application/octet-stream]
Saving to: ‘Galaxy10_DECals_64.h5’

Galaxy10_DECals_64.   5%[>                   ]   8.27M  95.0KB/s    eta 3m 38s ^C
Downloaded to: ./Galaxy10_DECals_64.h5


In [7]:
# Load images and labels using h5py
with h5py.File(data_path, "r") as f:
    images = np.array(f["images"])   # shape (17736, 64, 64, 3) (changed from 256 to 64)
    labels = np.array(f["ans"])      # shape (17736,)

#  normalize pixels in [0,1]
images = images.astype("float32") / 255.0 # (Keras works more efficiently and uses less memory with float32)


print("Images shape:", images.shape)
print("Labels shape:", labels.shape, "dtype:", labels.dtype)


OSError: Unable to synchronously open file (truncated file: eof = 1801408, sblock->base_addr = 0, stored_eof = 159140118)

In [None]:
# Class names from the Galaxy10 DECaLS documentation
class_names = {
    0: "Disturbed Galaxies",
    1: "Merging Galaxies",
    2: "Round Smooth Galaxies",
    3: "In-between Round Smooth Galaxies",
    4: "Cigar Shaped Smooth Galaxies",  # will be 'anomaly'
    5: "Barred Spiral Galaxies",
    6: "Unbarred Tight Spiral Galaxies",
    7: "Unbarred Loose Spiral Galaxies",
    8: "Edge-on Galaxies without Bulge",
    9: "Edge-on Galaxies with Bulge"
}

unique, counts = np.unique(labels, return_counts=True)
print("Class distribution:")
for u, c in zip(unique, counts):
    print(f"Class {u} ({class_names[u]}): {c}")

In [None]:
# Show a grid of sample images from different classes
def show_examples(images, labels, class_names, n_rows=2, n_cols=5):
    plt.figure(figsize=(3*n_cols, 3*n_rows))
    indices = np.random.choice(len(images), size=n_rows*n_cols, replace=False)
    for i, idx in enumerate(indices):
        plt.subplot(n_rows, n_cols, i+1)
        plt.imshow(images[idx])
        plt.title(f"Class {labels[idx]}:\n{class_names[int(labels[idx])]}", fontsize=8)
        plt.axis("off")
    plt.tight_layout()
    plt.show()

show_examples(images, labels, class_names)

In [None]:
# Create anomaly dataset: all class 0 images (changed to 4 for now)
ANOMALY_CLASS = 4

anomaly_mask = (labels == ANOMALY_CLASS)
standard_mask = ~anomaly_mask

anom_images = images[anomaly_mask]
anom_labels = labels[anomaly_mask]  # all 0, but we keep them for bookkeeping (changed to 4)

std_images = images[standard_mask]
std_labels_original = labels[standard_mask]

print("Standard images:", std_images.shape)
print("Anomaly images:", anom_images.shape)

In [None]:
# For training, remap standard labels from {0,1,2,3,5,6,7,8,9} -> {0,...,8}
unique_std_classes = sorted(np.unique(std_labels_original))
print("Standard classes (original indices):", unique_std_classes)

# Create mapping dict
label_map = {original: new for new, original in enumerate(unique_std_classes)}
inv_label_map = {v: k for k, v in label_map.items()}

print("Label map (original -> new):", label_map)

std_labels = np.vectorize(label_map.get)(std_labels_original)
print("Remapped standard labels min/max:", std_labels.min(), std_labels.max())
n_classes = len(unique_std_classes)

In [None]:
# Split standard data: 50% train, 25% val, 25% test

# First split: train (50%) and temp (50%)
X_train_std, X_temp_std, y_train_std, y_temp_std = train_test_split(
    std_images,
    std_labels,
    test_size=0.5,
    stratify=std_labels,
    random_state=42
)

# Second split: temp into val (25%) and test (25%) of full standard data
X_val_std, X_test_std, y_val_std, y_test_std = train_test_split(
    X_temp_std,
    y_temp_std,
    test_size=0.5,
    stratify=y_temp_std,
    random_state=42
)

print("Train standard:", X_train_std.shape, y_train_std.shape)
print("Val standard:  ", X_val_std.shape, y_val_std.shape)
print("Test standard: ", X_test_std.shape, y_test_std.shape)
del std_images, std_labels, X_temp_std, y_temp_std

**EDA**

In [None]:
with h5py.File(data_path, "r") as f:
    images = f["images"][()]      # (N, 64, 64, 3)
    labels = f["ans"][()]         # (N,)
    ra = f["ra"][()]
    dec = f["dec"][()]
    redshift = f["redshift"][()]
    pxscale = f["pxscale"][()]

Dataset structure & basic stats

In [None]:
# Basic dataset overview
N, H, W, C = images.shape
print(f"Number of images: {N}")
print(f"Image size: {H}x{W}, channels: {C}")
print(f"Images dtype: {images.dtype}, range: [{images.min():.3f}, {images.max():.3f}]")
print(f"Labels shape: {labels.shape}, dtype: {labels.dtype}")

# Per-channel statistics
channel_means = images.reshape(-1, C).mean(axis=0)
channel_stds = images.reshape(-1, C).std(axis=0)
print("\nPer-channel mean (R,G,B):", channel_means)
print("Per-channel std  (R,G,B):", channel_stds)

# Redshift NaN info
nan_z = np.isnan(redshift).sum()
print(f"\nRedshift NaNs: {nan_z} / {N}")


Class distribution

In [None]:
# Class names
class_names = {
    0: "Disturbed Galaxies",
    1: "Merging Galaxies",
    2: "Round Smooth Galaxies",
    3: "In-between Round Smooth Galaxies",
    4: "Cigar Shaped Smooth Galaxies",  # anomaly for me
    5: "Barred Spiral Galaxies",
    6: "Unbarred Tight Spiral Galaxies",
    7: "Unbarred Loose Spiral Galaxies",
    8: "Edge-on Galaxies without Bulge",
    9: "Edge-on Galaxies with Bulge"
}

unique, counts = np.unique(labels, return_counts=True)

print("=== Class distribution ===")
for u, c in zip(unique, counts):
    frac = c / N * 100
    print(f"Class {u} ({class_names[u]}): {c} samples ({frac:.2f}%)")

# Bar plot
plt.figure(figsize=(6,4))
plt.bar(unique, counts)
plt.xticks(unique)
plt.xlabel("Class label")
plt.ylabel("Count")
plt.title("Class distribution")
plt.tight_layout()
plt.show()


Metadata distributions

In [None]:
# Redshift distribution (valid values only)
valid_z = redshift[~np.isnan(redshift)]

plt.figure(figsize=(6,4))
plt.hist(valid_z, bins=40)
plt.xlabel("Redshift z")
plt.ylabel("Count")
plt.title("Redshift distribution (valid entries)")
plt.tight_layout()
plt.show()

# Pixel scale distribution
plt.figure(figsize=(6,4))
plt.hist(pxscale, bins=30)
plt.xlabel("Pixel scale [arcsec/pixel]")
plt.ylabel("Count")
plt.title("Pixel scale distribution")
plt.tight_layout()
plt.show()

# Sky coverage
plt.figure(figsize=(6,4))
plt.scatter(ra, dec, s=2)
plt.xlabel("RA [deg]")
plt.ylabel("Dec [deg]")
plt.title("Sky coverage (RA vs Dec)")
plt.tight_layout()
plt.show()


## 2. CNN classifier + ROC curves + confusion matrices

Steps:
- Build a CNN classifier on the **standard** dataset (9 classes)
- Train on train set, validate on val set
- Plot training history (loss & accuracy)
- Compute ROC curves (one-vs-rest) on the standard test set
- Compute confusion matrix for standard test

Build a CNN classifier on the standard dataset (9 classes)

In [None]:
def build_simple_cnn(input_shape, n_classes):
    inputs = keras.Input(shape=input_shape)

    # --- Conv block 1 ---
    x = layers.Conv2D(32, (3, 3), padding="same", activation="relu")(inputs)
    x = layers.MaxPooling2D((2, 2))(x)   # 32x32

    # --- Conv block 2 ---
    x = layers.Conv2D(64, (3, 3), padding="same", activation="relu")(x)
    x = layers.MaxPooling2D((2, 2))(x)   # 16x16

    # --- Conv block 3 ---
    x = layers.Conv2D(128, (3, 3), padding="same", activation="relu")(x)
    x = layers.MaxPooling2D((2, 2))(x)   # 8x8

    # --- Top ---
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.3)(x)

    outputs = layers.Dense(n_classes, activation="softmax")(x)

    model = keras.Model(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(1e-3),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )
    return model


input_shape = X_train_std.shape[1:]  # (64, 64, 3)
model = build_simple_cnn(input_shape, n_classes)
model.summary()


Train on train set, validate on val set

In [None]:
# Train the CNN
batch_size = 64
epochs = 50

callbacks = [
    keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=5,
        restore_best_weights=True,
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,      # halve LR
        patience=3,      # after 3 epochs without improvement
        min_lr=1e-5,
        verbose=1,
    ),
]


history = model.fit(
    X_train_std, y_train_std,
    validation_data=(X_val_std, y_val_std),
    batch_size=batch_size,
    epochs=epochs,
    callbacks=callbacks,
    verbose=1,
)


Plot training history (loss & accuracy)

In [None]:
# Plot training history: loss & accuracy
history_dict = history.history

plt.figure(figsize=(10,4))

plt.subplot(1,2,1)
plt.plot(history_dict["loss"], label="train")
plt.plot(history_dict["val_loss"], label="val")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training & validation loss")
plt.legend()

plt.subplot(1,2,2)
plt.plot(history_dict["accuracy"], label="train")
plt.plot(history_dict["val_accuracy"], label="val")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Training & validation accuracy")
plt.legend()

plt.tight_layout()
plt.show()

Predictions on the standard test set

In [None]:
y_proba_test = model.predict(X_test_std)
y_pred_test = np.argmax(y_proba_test, axis=1)

print("Test accuracy (overall):", np.mean(y_pred_test == y_test_std))

Compute ROC curves (one-vs-rest) on the standard test set

In [None]:
classes = np.unique(y_test_std)
plt.figure(figsize=(7,6))

for c in classes:
    # binary labels: 1 if this class, 0 otherwise
    y_true_c = (y_test_std == c).astype(int)
    y_score_c = y_proba_test[:, c]

    fpr, tpr, _ = roc_curve(y_true_c, y_score_c)
    roc_auc = auc(fpr, tpr)

    original_label = inv_label_map[int(c)]
    label_name = class_names[original_label]
    plt.plot(fpr, tpr, lw=1.5, label=f"class {original_label} ({label_name}) AUC={roc_auc:.3f}")

plt.plot([0, 1], [0, 1], "k--", lw=1)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate (Sensitivity)")
plt.title("One-vs-rest ROC curves (standard test set)")
plt.legend(fontsize=7)
plt.tight_layout()
plt.show()


Compute confusion matrix for standard test

In [None]:
cm = confusion_matrix(y_test_std, y_pred_test)

fig, ax = plt.subplots(figsize=(6,6))
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(ax=ax, cmap="Blues", colorbar=False)
plt.title("Confusion matrix (standard test set)")
plt.show()


## 3. Convolutional Autoencoder for anomaly detection

Steps:
- Train a convolutional autoencoder on **standard** train set only
- Plot training history
- Compute reconstruction loss (MSE) per image for:
  - Standard test set
  - Anomaly dataset (class 4)
- Plot histograms of reconstruction losses
- Use reconstruction loss as anomaly score and build ROC curve & confusion matrix (adapt provided code)

Train a convolutional autoencoder on standard train set only

In [None]:
def build_conv_autoencoder(input_shape=(64, 64, 3)):
    # ----- Encoder (all conv, no Dense) -----
    inputs = keras.Input(shape=input_shape)

    x = layers.Conv2D(16, (3, 3), padding="same", activation="relu")(inputs)
    x = layers.MaxPooling2D((2, 2), padding="same")(x)   # 64 -> 32

    x = layers.Conv2D(32, (3, 3), padding="same", activation="relu")(x)
    x = layers.MaxPooling2D((2, 2), padding="same")(x)   # 32 -> 16

    x = layers.Conv2D(64, (3, 3), padding="same", activation="relu")(x)
    encoded = layers.MaxPooling2D((2, 2), padding="same", name="latent")(x)  # 16 -> 8

    # ----- Decoder (mirror of encoder) -----
    x = layers.Conv2D(64, (3, 3), padding="same", activation="relu")(encoded)
    x = layers.UpSampling2D((2, 2))(x)                   # 8 -> 16

    x = layers.Conv2D(32, (3, 3), padding="same", activation="relu")(x)
    x = layers.UpSampling2D((2, 2))(x)                   # 16 -> 32

    x = layers.Conv2D(16, (3, 3), padding="same", activation="relu")(x)
    x = layers.UpSampling2D((2, 2))(x)                   # 32 -> 64

    outputs = layers.Conv2D(
        3, (3, 3),
        padding="same",
        activation="sigmoid",   # images in [0,1]
        name="reconstruction",
    )(x)

    autoencoder = keras.Model(inputs, outputs, name="conv_autoencoder_simple")
    autoencoder.compile(
        optimizer=keras.optimizers.Adam(1e-3),
        loss="mse",
    )
    return autoencoder

# Build model
input_shape = X_train_std.shape[1:]  # (64, 64, 3)
autoencoder = build_conv_autoencoder(input_shape)
autoencoder.summary()


Plot training history


In [None]:
# Train autoencoder on *standard* train set only
batch_size = 64
epochs = 25

ae_callbacks = [
    keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=3,
        restore_best_weights=True,
    )
]

history_ae = autoencoder.fit(
    X_train_std, X_train_std,               # input == target
    validation_data=(X_val_std, X_val_std),
    batch_size=batch_size,
    epochs=epochs,
    callbacks=ae_callbacks,
    verbose=1,
)

# Plot training history (loss only, since autoencoder has no accuracy)
hist = history_ae.history

plt.figure(figsize=(6,4))
plt.plot(hist["loss"], label="train")
plt.plot(hist["val_loss"], label="val")
plt.xlabel("Epoch")
plt.ylabel("MSE loss")
plt.title("Autoencoder training & validation loss")
plt.legend()
plt.tight_layout()
plt.show()

Compute reconstruction loss (MSE) per image for:
- Standard test set
- Anomaly dataset (class 4)

In [None]:
# Reconstruction on standard test set
X_test_std_recon = autoencoder.predict(X_test_std, batch_size=64, verbose=1)
recon_loss_std = np.mean(
    np.square(X_test_std_recon - X_test_std),
    axis=(1, 2, 3)
)

# Reconstruction on anomaly set (class 4) – note: never seen during training
anom_recon = autoencoder.predict(anom_images, batch_size=64, verbose=1)
recon_loss_anom = np.mean(
    np.square(anom_recon - anom_images),
    axis=(1, 2, 3)
)

print("Reconstruction loss (standard test):")
print("  mean =", recon_loss_std.mean(), "std =", recon_loss_std.std())
print("Reconstruction loss (anomaly):")
print("  mean =", recon_loss_anom.mean(), "std =", recon_loss_anom.std())


Plot histograms of reconstruction losses

In [None]:
plt.figure(figsize=(7,5))
plt.hist(recon_loss_std,  bins=50, alpha=0.6, label="Standard test")
plt.hist(recon_loss_anom, bins=50, alpha=0.6, label="Anomaly (class 4)")
plt.xlabel("Reconstruction MSE")
plt.ylabel("Count")
plt.title("Reconstruction loss distributions")
plt.legend()
plt.tight_layout()
plt.show()

Use reconstruction loss as anomaly score and build ROC curve & confusion matrix (adapt provided code)

In [None]:
# Build labels and scores
y_true_ae = np.concatenate([
    np.zeros_like(recon_loss_std, dtype=int),   # 0 = standard
    np.ones_like(recon_loss_anom, dtype=int),   # 1 = anomaly
])
scores_ae = np.concatenate([recon_loss_std, recon_loss_anom])

# ROC curve
fpr, tpr, thresholds = roc_curve(y_true_ae, scores_ae)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(6,5))
plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.3f}")
plt.plot([0,1], [0,1], "k--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate (Recall for anomaly)")
plt.title("ROC curve (anomaly detection via AE reconstruction error)")
plt.legend()
plt.tight_layout()
plt.show()




In [None]:
# Choose threshold using Youden's J statistic
j_scores = tpr - fpr
best_idx = np.argmax(j_scores)
best_thresh = thresholds[best_idx]
print("Best threshold (Youden's J):", best_thresh)

# Binary predictions: 1 = anomaly if recon loss >= threshold
y_pred_ae = (scores_ae >= best_thresh).astype(int)

cm_ae = confusion_matrix(y_true_ae, y_pred_ae)
print("Confusion matrix (AE anomaly detection):")
print(cm_ae)

fig, ax = plt.subplots(figsize=(4,4))
disp = ConfusionMatrixDisplay(
    confusion_matrix=cm_ae,
    display_labels=["Standard", "Anomaly"],
)
disp.plot(ax=ax, cmap="Blues", colorbar=False)
plt.title("Confusion matrix (AE anomaly detection)")
plt.tight_layout()
plt.show()


## 4. Variational Autoencoder (VAE) on standard data

Steps:
- Build a VAE with convolutional encoder/decoder (trained on standard train set)
- Plot training history
- Generate new galaxy images from the VAE
- Visualize some generated images