In [1]:

import os
import numpy as np
import pandas as pd
from astropy.io import fits
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [18]:
# Hyperparameters
BATCH_SIZE = 32
EPOCHS = 5
LR = 1e-3
WEIGHT_DECAY = 1e-5
K_FOLDS = 5

IMG_SIZE = (301, 301)         # height, width of images
IN_CHANNELS = 2               # two bands
N_LAYERS = 3                  # number of conv layers
CONV_CHANNELS = 32             # channels in conv layers
KERNEL_SIZE = 3
DROPOUT = 0.2
BATCH_NORM = True

DATA_DIR = "practice_data"
LABELS_CSV = "channel_prob_data.csv"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [19]:
class FITSDataset(Dataset):
    def __init__(self, csv_file, data_dir, target_size=(301,301)):
        self.df = pd.read_csv(csv_file)
        self.data_dir = data_dir
        self.target_size = target_size

        # Optional: filter rows to only files that exist
        def files_exist(row):
            file1 = os.path.join(data_dir, row['f475_image_string'])
            file2 = os.path.join(data_dir, row['f814_image_string'])
            return os.path.isfile(file1) and os.path.isfile(file2)
        self.df = self.df[self.df.apply(files_exist, axis=1)].reset_index(drop=True)
        print(f"Filtered dataset length: {len(self.df)}")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        band1 = self.load_fits(os.path.join(self.data_dir, row['f475_image_string']))
        band2 = self.load_fits(os.path.join(self.data_dir, row['f814_image_string']))

        stacked = np.stack([band1, band2], axis=0)
        label = np.float32(row['prob'])

        return torch.tensor(stacked, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)

    def load_fits(self, filepath):
        with fits.open(filepath) as hdul:
            data = hdul[0].data.astype(np.float32)

        data = np.squeeze(data)

        # Replace NaNs with median
        median_val = np.nanmedian(data)
        data = np.nan_to_num(data, nan=median_val)

        # Center crop or pad to target_size
        H, W = data.shape
        target_H, target_W = self.target_size

        # Crop if bigger than target
        start_H = max((H - target_H) // 2, 0)
        start_W = max((W - target_W) // 2, 0)
        cropped = data[start_H:start_H+target_H, start_W:start_W+target_W]

        # Pad if smaller than target
        pad_H = max(target_H - cropped.shape[0], 0)
        pad_W = max(target_W - cropped.shape[1], 0)
        if pad_H > 0 or pad_W > 0:
            cropped = np.pad(cropped, ((0, pad_H), (0, pad_W)), mode='constant', constant_values=median_val)

        return cropped


In [20]:
import matplotlib.pyplot as plt
import torch

# Initialize dataset
test_dataset = FITSDataset("channel_prob_data.csv", "practice_data/")

print(f"Number of image pairs: {len(test_dataset)}\n")

# Check each image
for i in range(len(test_dataset)):
    image_tensor, label = test_dataset[i]
    print(f"Image {i}: shape={image_tensor.shape}, label={label}")
    print(f"Min/Max values: {image_tensor.min()}, {image_tensor.max()}\n")

# Optional: visualize a few random images to check cropping
for i in range(min(0, len(test_dataset))):
    image_tensor, label = test_dataset[i]
    band1 = image_tensor[0].numpy()
    band2 = image_tensor[1].numpy()

    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1)
    plt.imshow(band1, cmap='gray')
    plt.title("Band 1")
    plt.subplot(1,2,2)
    plt.imshow(band2, cmap='gray')
    plt.title("Band 2")
    plt.suptitle(f"Label: {label}")
    plt.show()


Filtered dataset length: 20
Number of image pairs: 20

Image 0: shape=torch.Size([2, 301, 301]), label=0.0
Min/Max values: 0.34768664836883545, 21.175325393676758

Image 1: shape=torch.Size([2, 301, 301]), label=0.0
Min/Max values: 0.3239341974258423, 21.22384262084961

Image 2: shape=torch.Size([2, 301, 301]), label=0.0
Min/Max values: 0.2799783945083618, 18.657468795776367

Image 3: shape=torch.Size([2, 301, 301]), label=0.0
Min/Max values: 0.39322853088378906, 18.384241104125977

Image 4: shape=torch.Size([2, 301, 301]), label=0.0
Min/Max values: 0.3829287886619568, 18.384241104125977

Image 5: shape=torch.Size([2, 301, 301]), label=0.0
Min/Max values: 0.362567275762558, 20.14598846435547

Image 6: shape=torch.Size([2, 301, 301]), label=0.0
Min/Max values: 0.34225305914878845, 17.620946884155273

Image 7: shape=torch.Size([2, 301, 301]), label=0.0
Min/Max values: 0.2941800653934479, 17.429018020629883

Image 8: shape=torch.Size([2, 301, 301]), label=0.0
Min/Max values: 0.29418006539

In [26]:
class FlexibleCNN(nn.Module):
    def __init__(self, 
                 in_channels=IN_CHANNELS,      # use hyperparameter
                 n_layers=N_LAYERS,           # use hyperparameter
                 conv_channels=CONV_CHANNELS, # use hyperparameter
                 kernel_size=KERNEL_SIZE,     # use hyperparameter
                 img_size=IMG_SIZE,           # use hyperparameter
                 dropout=DROPOUT,             # use hyperparameter
                 batch_norm=BATCH_NORM):      # use hyperparameter
        super().__init__()

        layers = []

        # First conv layer
        layers.append(nn.Conv2d(in_channels, conv_channels, kernel_size=kernel_size, padding=kernel_size//2))
        layers.append(nn.BatchNorm2d(conv_channels) if batch_norm else nn.Identity())
        layers.append(nn.ReLU())
        layers.append(nn.MaxPool2d(2))
        layers.append(nn.Dropout2d(dropout) if dropout > 0 else nn.Identity())

        # Additional conv layers
        for i in range(1, n_layers):
            layers.append(nn.Conv2d(conv_channels, conv_channels, kernel_size=kernel_size, padding=kernel_size//2))
            layers.append(nn.BatchNorm2d(conv_channels) if batch_norm else nn.Identity())
            layers.append(nn.ReLU())
            layers.append(nn.MaxPool2d(2))
            layers.append(nn.Dropout2d(dropout) if dropout > 0 else nn.Identity())

        self.conv_model = nn.Sequential(*layers)

        # Flattened size
        H, W = img_size
        H //= 2 ** n_layers
        W //= 2 ** n_layers
        flattened_size = conv_channels * H * W

        # Fully connected output
        self.fc_model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flattened_size, 1),
            nn.Sigmoid()
        )

        # Loss function
        self.loss = nn.BCELoss()

    def forward(self, x):
        x = self.conv_model(x)
        x = self.fc_model(x)
        return x.squeeze(1)

    def configure_optimizers(self, learning_rate=LR, weight_decay=WEIGHT_DECAY):  # use hyperparameter
        return optim.Adam(self.parameters(), lr=learning_rate, weight_decay=weight_decay)


In [28]:
from sklearn.model_selection import KFold
from torch.utils.data import Subset, DataLoader
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, brier_score_loss
import torch

dataset = FITSDataset(LABELS_CSV, DATA_DIR)
kf = KFold(n_splits=K_FOLDS, shuffle=True, random_state=42)

best_val_losses = []  # store best validation loss for each fold

for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
    print(f"\nFold {fold+1}/{K_FOLDS}")

    train_subset = Subset(dataset, train_idx)
    val_subset = Subset(dataset, val_idx)

    train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False)

    # Initialize model
    model = FlexibleCNN(in_channels=IN_CHANNELS, n_layers=N_LAYERS, conv_channels=CONV_CHANNELS, 
                        kernel_size=KERNEL_SIZE, img_size=IMG_SIZE, dropout=DROPOUT, batch_norm=BATCH_NORM)
    model = model.to(DEVICE)

    criterion = nn.BCELoss()
    optimizer = model.configure_optimizers(learning_rate=LR, weight_decay=WEIGHT_DECAY)

    best_val_loss = float('inf')

    for epoch in range(EPOCHS):
        # Training
        model.train()
        running_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)
        train_loss = running_loss / len(train_subset)

        # Validation
        model.eval()
        val_loss_total = 0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss_total += loss.item() * images.size(0)

                all_preds.extend(outputs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        val_loss = val_loss_total / len(val_subset)
        print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        # Save best model for this fold
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), f"best_model_fold{fold+1}.pt")
            print(f"Saved best model for fold {fold+1}")

    best_val_losses.append(best_val_loss)

    # Compute continuous metrics for this fold
    mse = mean_squared_error(all_labels, all_preds)
    mae = mean_absolute_error(all_labels, all_preds)
    r2 = r2_score(all_labels, all_preds)
    # brier = brier_score_loss(all_labels, all_preds)

    print(f"Fold {fold+1} Metrics:")
    print(f"MSE: {mse:.4f}, MAE: {mae:.4f}, R²: {r2:.4f}")

# Average validation loss across all folds
avg_val_loss = sum(best_val_losses) / len(best_val_losses)
print(f"\nAverage best validation loss across all folds: {avg_val_loss:.4f}")


Filtered dataset length: 20

Fold 1/5
Epoch 1, Train Loss: 0.7383, Val Loss: 0.6381
Saved best model for fold 1
Epoch 2, Train Loss: 2.2347, Val Loss: 1.2712
Epoch 3, Train Loss: 2.9418, Val Loss: 1.1307
Epoch 4, Train Loss: 1.9222, Val Loss: 0.8076
Epoch 5, Train Loss: 1.0049, Val Loss: 0.7877
Fold 1 Metrics:
MSE: 0.1817, MAE: 0.3769, R²: -0.7372

Fold 2/5
Epoch 1, Train Loss: 0.6701, Val Loss: 0.6129
Saved best model for fold 2
Epoch 2, Train Loss: 5.3028, Val Loss: 2.3291
Epoch 3, Train Loss: 1.6007, Val Loss: 4.0271
Epoch 4, Train Loss: 7.2496, Val Loss: 4.7108
Epoch 5, Train Loss: 1.6497, Val Loss: 4.6772
Fold 2 Metrics:
MSE: 0.7563, MAE: 0.7995, R²: -5.9829

Fold 3/5
Epoch 1, Train Loss: 0.7773, Val Loss: 1.1757
Saved best model for fold 3
Epoch 2, Train Loss: 2.8234, Val Loss: 0.7033
Saved best model for fold 3
Epoch 3, Train Loss: 0.5359, Val Loss: 0.9169
Epoch 4, Train Loss: 2.3977, Val Loss: 1.1196
Epoch 5, Train Loss: 2.8926, Val Loss: 1.0576
Fold 3 Metrics:
MSE: 0.2154, MAE

In [None]:
# # After cross-validation, train final model on all data
# full_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# final_model = FlexibleCNN(in_channels=IN_CHANNELS, n_layers=N_LAYERS, conv_channels=CONV_CHANNELS,
#                           kernel_size=KERNEL_SIZE, img_size=IMG_SIZE, dropout=DROPOUT, batch_norm=BATCH_NORM)
# final_model = final_model.to(DEVICE)

# criterion = nn.BCELoss()
# optimizer = final_model.configure_optimizers(learning_rate=LR, weight_decay=WEIGHT_DECAY)

# for epoch in range(EPOCHS):
#     final_model.train()
#     running_loss = 0
#     for images, labels in full_loader:
#         images, labels = images.to(DEVICE), labels.to(DEVICE)
#         optimizer.zero_grad()
#         outputs = final_model(images)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()
#         running_loss += loss.item() * images.size(0)
#     print(f"Final Model Epoch {epoch+1}, Loss: {running_loss/len(dataset):.4f}")

# # Save final model
# torch.save(final_model.state_dict(), "final_model.pt")
# print("Final model saved to final_model.pt")
