# Convolutional Neural Networks with the Galaxy10 DECaLS dataset

## Introduction to the physics

One long-standing challenge in observational cosmology is cataloging the huge number of sources discovered by space-based and ground-based survey telescopes. In recent years, physicists have turned to machine learning as a way of classifying large numbers of galaxies.

In this notebook, we explore the Galaxy10 DECaLS and train a CNN to classify galaxies based on the observed images.

The images are labeled as belonging to one of ten morphology classes (labeled manually by [Galaxy Zoo](https://www.zooniverse.org/projects/zookeeper/galaxy-zoo) volunteers!).

The ten galaxy classes are:
- 0: Disturbed
- 1: Merging  
- 2: Round Smooth
- 3: In-between Smooth
- 4: Cigar Smooth
- 5: Barred Spiral
- 6: Unbarred Tight Spiral
- 7: Unbarred Loose Spiral
- 8: Edge-on (no bulge)
- 9: Edge-on (with bulge)


In [None]:
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import h5py
import numpy as np
from pathlib import Path
import urllib.request
from PIL import Image

# Use the Galaxy10 DECaLS dataset
def download_galaxy10(data_dir='./data'):
    """Download Galaxy10 DECaLS dataset (2.5GB)"""
    data_dir = Path(data_dir)
    data_dir.mkdir(parents=True, exist_ok=True)

    file_path = data_dir / 'Galaxy10_DECals.h5'

    if file_path.exists():
        print(f"Dataset already exists at {file_path}")
        return file_path

    url = "https://zenodo.org/records/10845026/files/Galaxy10_DECals.h5"

    print("Downloading Galaxy10 DECaLS dataset (2.5GB)...")
    print("This may take a few minutes...")

    urllib.request.urlretrieve(url, file_path)

    print(f"Download complete! Saved to {file_path}")
    return file_path


In [None]:
class Galaxy10Dataset(Dataset):
    """
    Galaxy10 Dataset with 10 classes:
    0: Disturbed
    1: Merging
    2: Round Smooth
    3: In-between Smooth
    4: Cigar Smooth
    5: Barred Spiral
    6: Unbarred Tight Spiral
    7: Unbarred Loose Spiral
    8: Edge-on (no bulge)
    9: Edge-on (with bulge)
    """

    def __init__(self, h5_file, indices=None, transform=None):
      self.h5_file = h5_file
      self.transform = transform

      with h5py.File(h5_file, 'r') as f:
          self.images = np.array(f['images'])
          self.labels = np.array(f['ans'])

      if indices is not None:
          self.images = self.images[indices]
          self.labels = self.labels[indices]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = Image.fromarray(self.images[idx].astype('uint8'))
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
def create_dataloaders(h5_file, batch_size=32, num_workers=2):

    # Define transforms
    train_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(180),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

    test_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

    with h5py.File(h5_file, 'r') as f:
        total_size = len(f['ans'])

    indices = np.arange(total_size)
    np.random.seed(42)
    np.random.shuffle(indices)

    train_size = int(0.7 * total_size)
    val_size = int(0.15 * total_size)

    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]

    train_dataset = Galaxy10Dataset(h5_file, train_indices, train_transform)
    val_dataset = Galaxy10Dataset(h5_file, val_indices, test_transform)
    test_dataset = Galaxy10Dataset(h5_file, test_indices, test_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                             shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size,
                           shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                            shuffle=False, num_workers=num_workers)

    print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

    return train_loader, val_loader, test_loader


In [None]:
import torch.nn as nn
import torch_geometric
from torch_geometric.nn import global_mean_pool, global_max_pool

class GalaxyCNN(nn.Module):
    """
    Simple CNN for galaxy classification
    Input: 3x128x128 images
    Output: 10 classes
    """
    def __init__(self, num_classes=10):
        super(GalaxyCNN, self).__init__()

        # Only 1 convolutional block
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(4, 4)  # 64 -> 16
        )

        # Global pooling
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Very small classifier
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.global_pool(x)
        x = self.fc(x)
        return x

In [None]:
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total

    return epoch_loss, epoch_acc


def validate(model, val_loader, criterion, device):
    """Validate the model"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    val_loss = running_loss / len(val_loader)
    val_acc = 100. * correct / total

    return val_loss, val_acc


In [None]:
def train_model(num_epochs=10, batch_size=32, learning_rate=0.001):
    """Complete training pipeline"""

    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load data
    print("\n" + "="*60)
    print("Loading Dataset")
    print("="*60)
    h5_file = download_galaxy10(data_dir='./data')
    train_loader, val_loader, test_loader = create_dataloaders(
        h5_file,
        batch_size=batch_size,
        num_workers=2
    )

    # Create model
    print("\n" + "="*60)
    print("Creating Model")
    print("="*60)
    model = GalaxyCNN(num_classes=10).to(device)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                     factor=0.5, patience=2)

    # Training history
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []

    # Training loop
    print("\n" + "="*60)
    print("Training")
    print("="*60)

    best_val_acc = 0.0

    for epoch in range(num_epochs):
        # Train
        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device
        )

        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion, device)

        # Learning rate scheduling
        scheduler.step(val_loss)
        current_lr = optimizer.param_groups[0]['lr']

        # Save history
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        # Print progress
        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | "
              f"LR: {current_lr:.6f}")

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_galaxy_model_no_dropout.pth')
            print(f"  -> New best model saved! (Val Acc: {val_acc:.2f}%)")

       # Test final model
    print("\n" + "="*60)
    print("Testing")
    print("="*60)
    test_loss, test_acc = validate(model, test_loader, criterion, device)
    print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")

    # Plot training curves
    plot_training_curves(train_losses, val_losses, train_accs, val_accs)

    return model, (train_losses, val_losses, train_accs, val_accs)


In [None]:
def plot_training_curves(train_losses, val_losses, train_accs, val_accs):
    """Plot training and validation curves"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    # Loss
    ax1.plot(train_losses, label='Train Loss', marker='o')
    ax1.plot(val_losses, label='Val Loss', marker='s')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True)

    # Accuracy
    ax2.plot(train_accs, label='Train Acc', marker='o')
    ax2.plot(val_accs, label='Val Acc', marker='s')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('Training and Validation Accuracy')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.savefig('training_curves_no_dropout.png', dpi=150)
    print("\nTraining curves saved to 'training_curves_no_dropout.png'")
    plt.close()


In [None]:
print("="*60)
print("Galaxy10 CNN Training")
print("- No dropout (using BatchNorm for regularization)")
print("- torch_geometric compatible structure")
print("="*60)

# Choose model type
# 'standard' or 'graph_style'
model, history = train_model(
    num_epochs=15,
    batch_size=32,
    learning_rate=0.001
)

print("\n" + "="*60)
print("Training Complete!")
print("="*60)
print("Model saved to: best_galaxy_model_no_dropout.pth")
print("Training curves saved to: training_curves_no_dropout.png")

# Graph Neural Networks to classify particle decays

# Introduction to the physics

The Z boson typically decays into 2 particles (like electron-positron pairs), while the Higgs can decays into 4 particles (via intermediate particles like 2 Z bosons).
One challenge is to separate the two kinds of decays so that we can measure Higgs boson production.

In particle physics, decay events naturally form graphs:

- Nodes represent particles
- Edges represent relationships (like coming from the same parent particle)

## Generating particle decay data

Although there are particle physics datasets that we could use for general-purpose studies, it's convenient to simply generate the particle decays of the Z boson and Higgs boson here.
This allows us to control the mix of decays and get exactly the right format.

Typical simulations use random numbers to generate random momenta for the parent particles.
We will write the simulation and the training models as functions so that we can mix them up later.

In [None]:
import torch
from torch.utils.data import Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
import numpy as np

def create_simulated_decay_data(num_events=1000, decay_type='Z'):
    """
    Physics background:
    - Z boson (mass ~91 GeV) decays into 2 particles (e.g., e+ e-)
    - Higgs boson (mass ~125 GeV) often decays into 4 particles (e.g., H -> ZZ -> 4 leptons)

    Args:
        num_events: Number of decay events to generate
        decay_type: 'Z' (2 particles) or 'Higgs' (4 particles)

    Returns:
        List of PyTorch Geometric Data objects
    """
    data_list = []

    for _ in range(num_events):
        if decay_type == 'Z':
            # Z decay: 2 particles back-to-back
            num_particles = 2
            y = torch.tensor([0], dtype=torch.long)  # Label: 0 for Z

            # Simulate particle 4-momenta: [px, py, pz, energy]
            # Z decays show balanced momentum
            #  because the particles recoil against each other
            momentum = np.random.randn(4) * 50 + 45  # 45 GeV = mZ/2
            x = torch.tensor([
                [momentum[0], momentum[1], momentum[2], np.abs(momentum[0])],
                [-momentum[0], -momentum[1], -momentum[2], np.abs(momentum[0])]
            ], dtype=torch.float)

            # Fully connected graph: both particles connected
            edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)

        else:  # Higgs
            # Higgs decay: 4 particles (more complex topology)
            num_particles = 4
            y = torch.tensor([1], dtype=torch.long)  # Label: 1 for Higgs

            # Simulate 4 particles with lower individual momenta
            # Each Z boson from Higgs decays in turn to two leptons
            x = torch.tensor([
                np.random.randn(4) * 30 + 30,
                np.random.randn(4) * 30 + 30,
                np.random.randn(4) * 30 + 30,
                np.random.randn(4) * 30 + 30,
            ], dtype=torch.float)

            # Ensure energies are positive
            x[:, 3] = torch.abs(x[:, 3])

            # Fully connected graph: all particles connected to each other
            # This could be changed later to make two particles connect to the Z,
            #  but for now we don't assume we know the correct pairing.
            edges = []
            for i in range(num_particles):
                for j in range(num_particles):
                    if i != j:
                        edges.append([i, j])
            edge_index = torch.tensor(edges, dtype=torch.long).t()

        # Create PyTorch Geometric Data object
        data = Data(x=x, edge_index=edge_index, y=y)
        data_list.append(data)

    return data_list



## Implement the Graph Neural Network

We avoided [the `torch.geometric` library](https://pytorch-geometric.readthedocs.io/en/latest/) for our very simple spring-mass system, but let's dive in and use it here, as an example of how to leverage the pre-built features.

Some key ideas for `torch.geometric`:
- `forward` defines the passage of data through the network
- message passing is simplified in this example
- aggregation is simple sum of edges on the nodes



In [None]:
!pip install torch-geometric

There is one very annoying thing about `torch.nn.functional`: it does not use `ReLU`, instead it uses `relu`.


In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, global_mean_pool

class SimpleParticleGNN(MessagePassing):
    # The MessagePassing base class takes care of everything after you
    # define the message and update functions.

    def __init__(self, node_features, hidden_dim, num_classes):
        super().__init__(aggr='add') # simplest aggregation

        # input particle features are (px, py, pz, energy)
        self.node_encoder = nn.Linear(node_features, hidden_dim)

        # message passing layers - update particle representations
        #  before classification
        self.mp_layer1 = nn.Linear(hidden_dim, hidden_dim)
        self.mp_layer2 = nn.Linear(hidden_dim, hidden_dim)

        # Graph-level classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x, edge_index, batch):
        """
        Build up the forward pass through the model, defining the organization of the layers.
        Args:
            x: Node features [num_nodes, node_features] - particle 4-momenta
            edge_index: Graph connectivity [2, num_edges] - which particles are connected
            batch: Batch assignment [num_nodes] - which graph each particle belongs to
        """
        # Step 1: Encode features
        x = F.relu(self.node_encoder(x))
        # x.shape: [4, 64] - now 64-dim representations

        # Step 2: Message passing round 1
        x = self.propagate(edge_index, x=x)
        # Each particle aggregates info from 3 neighbors (sum aggregation)
        # Particle 0 now knows about particles 1, 2, 3

        x = F.relu(self.mp_layer1(x))
        # x.shape: still [4, 64]

        # Step 3: Message passing round 2
        x = self.propagate(edge_index, x=x)
        # Second-order neighbors: particle 0 now has info from 2-hop away
        x = F.relu(self.mp_layer2(x))
        # x.shape: still [4, 64]

        # Step 4: Global pooling
        x = global_mean_pool(x, batch)
        # x.shape: [1, 64] - one embedding for the entire event

        # Step 5: Classification
        out = self.classifier(x)
        # out.shape: [1, 2] - logits for [Z, Higgs]
        return out

    def message(self, x_j):
        """
        Define how messages are created from neighbor particles.
        x_j contains features of neighbor particles.
        """
        return x_j  # simply pass neighbor features without changes

    def update(self, aggr_out):
        """
        Define how aggregated messages update particle features.
        We want to return aggr_out, which is the sum of all neighbor messages.
        """
        return aggr_out


Now we have to define the training options
- for a single training epoch
- define forward pass and loss function

And testing options
- define output prediction for testing

In [None]:
def train_model(model, train_loader, optimizer, device):
    # This function will train for one epoch only
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()

        # Forward pass
        out = model(data.x, data.edge_index, data.batch)
        loss = F.cross_entropy(out, data.y)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Track metrics
        total_loss += loss.item() * data.num_graphs
        pred = out.argmax(dim=1)
        correct += (pred == data.y).sum().item()
        total += data.num_graphs

    return total_loss / total, correct / total


def test_model(model, test_loader, device):

    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            pred = out.argmax(dim=1)
            correct += (pred == data.y).sum().item()
            total += data.num_graphs

    return correct / total

Finally, with all of the functions defined, we are ready to run them all together.



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}\n")

print("Generating synthetic decay data...")
z_decays = create_simulated_decay_data(num_events=500, decay_type='Z')
higgs_decays = create_simulated_decay_data(num_events=500, decay_type='Higgs')

# Combine and split into train/test
all_data = z_decays + higgs_decays
np.random.shuffle(all_data)

train_data = all_data[:800]
test_data = all_data[800:]

print(f"Training samples: {len(train_data)}")
print(f"Test samples: {len(test_data)}")
print(f"Example Z decay: {z_decays[0].x.shape[0]} particles")
print(f"Example Higgs decay: {higgs_decays[0].x.shape[0]} particles\n")

# Create data loaders (batches multiple graphs together)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

# Initialize model
model = SimpleParticleGNN(
    node_features=4,      # [px, py, pz, energy]
    hidden_dim=64,        # Hidden representation size
    num_classes=2         # Z or Higgs
).to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

print("Training model...")

# Training loop
for epoch in range(50):
    train_loss, train_acc = train_model(model, train_loader, optimizer, device)
    test_acc = test_model(model, test_loader, device)

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:3d} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Train Acc: {train_acc:.4f} | "
              f"Test Acc: {test_acc:.4f}")

In [None]:
# Demonstrate inference on a few examples
print("Example Predictions:")
print("="*60) # trick for formatting
model.eval()
with torch.no_grad():
    for i, data in enumerate(test_data[:5]):
        data = data.to(device)
        out = model(data.x, data.edge_index,
                   torch.zeros(data.num_nodes, dtype=torch.long, device=device))
        pred = out.argmax(dim=1).item()
        true_label = data.y.item()
        confidence = F.softmax(out, dim=1)[0]

        decay_type = "Z" if true_label == 0 else "Higgs"
        pred_type = "Z" if pred == 0 else "Higgs"

        print(f"\nEvent {i+1}:")
        print(f"  Particles: {data.num_nodes}")
        print(f"  True label: {decay_type}")
        print(f"  Predicted: {pred_type}")
        print(f"  Confidence: Z={confidence[0]:.3f}, Higgs={confidence[1]:.3f}")
        print(f"  {'✓ Correct' if pred == true_label else '✗ Incorrect'}")

The performance of this model is extremely good.
It is probably just learning to count the number of particles instead of learning the momentum correlations between particles.

You can try generating a different dataset, perhaps with Z bosons vs. W bosons, to get more of a challenge.