# CosmoMap SIMBA

In [5]:
# PyTorch imports,
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import ipywidgets as widgets
from IPython.display import display
import numpy as np
import os

## Load Data

In [6]:
training_data = torch.load("train.pt")
test_data = torch.load("test.pt")
validation_data = torch.load("validation.pt")

# Architecture

In [7]:
class Net(nn.Module):
    """Class for the neural network architecture using upsampling instead of deconvolution."""

    def __init__(self):
        """Constructor method which builds the layers of the neural network."""
        super(Net, self).__init__()

        # Replication padding,
        self.pad = nn.ReplicationPad3d(1)

        # Feature extraction layers
        self.conv1 = nn.Conv3d(1, 8, kernel_size=3, padding=0) 
        self.conv2 = nn.Conv3d(8, 16, kernel_size=3, padding=0) 
        self.conv3 = nn.Conv3d(16, 32, kernel_size=3, padding=0) 
        self.conv4 = nn.Conv3d(32, 16, kernel_size=3, padding=0) 
        self.conv5 = nn.Conv3d(16, 8, kernel_size=3, padding=0) 
        self.conv6 = nn.Conv3d(8, 1, kernel_size=1)

    def forward(self, x):
        """Forward pass of the neural network."""

        x = self.pad(x)
        x = F.relu(self.conv1(x))
        x = self.pad(x)
        x = F.relu(self.conv2(x))
        x = self.pad(x)
        x = F.relu(self.conv3(x))
        x = self.pad(x)
        x = F.relu(self.conv4(x))
        x = self.pad(x)
        x = F.relu(self.conv5(x))
        x = self.conv6(x)

        return F.sigmoid(x)

def initialise_weights(model):
    """Applies custom weight initialization to Conv3D layers."""
    if isinstance(model, nn.Conv3d):
        torch.nn.init.kaiming_uniform_(model.weight, nonlinearity='relu')
        if model.bias is not None:
            torch.nn.init.constant_(model.bias, 0)  # Set biases to zero

class WeightedMSELoss(nn.Module):
    """Custom loss function to deal with data sparsity issue."""
    def __init__(self, alpha=1.0):
        super(WeightedMSELoss, self).__init__()
        self.alpha = alpha

    def forward(self, preds, targets):
        weights = torch.abs(targets) ** self.alpha
        loss = weights * (preds - targets) ** 2
        return loss.mean()

def compute_validation_loss(validation_data):

    # Switching model to eval mode,
    model.eval()

    # Loading validation data set,
    input_grids = validation_data[0]
    target_grids = validation_data[1]
    
    with torch.no_grad():
        input_grids = input_grids.to(device)
        output_grids = model(input_grids)

    validation_loss = loss_function(output_grids, target_grids)

    return validation_loss

def train_model(training_data, epochs, lr, current_epoch):

    """TRAINING LOOP"""

    # Single training iteration,
    for batch_idx, (input_data, target) in enumerate(training_data):
        input_data, target = input_data.to(device), target.to(device)

        # Forward pass,
        output = model(input_data)
    
        # Calculating loss,
        loss = loss_function(output, target)

        # Backpropagation,
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

    # Compute validation loss,
    val_loss = compute_validation_loss(validation_data)

    # Printing,
    print(f"Epoch: {current_epoch}, Training Loss: {loss.item()}, Validation Loss: {val_loss.item()}")                                                                                                                                                                                                                                                                                    

In [12]:
# HYPERPARAMETERS
LEARNING_RATE = 1e-3
BATCH_SIZE = 4
EPOCHS = 25
if __name__ == "__main__":

    # Load model and optimiser,
    model = Net()
    optimiser = optim.Adam(model.parameters(), lr = LEARNING_RATE)

    # Weight initialisation,
    model.apply(initialise_weights)

    # Checking for CUDA,
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("")
    print(f"Device: {device}")
    model.to(device)

    # Creating loss function,
    loss_function = WeightedMSELoss(alpha=2.5)

    # Load data,
    training_data_loaded = DataLoader(training_data,
                                      batch_size = BATCH_SIZE,
                                      shuffle = True,
                                      num_workers = 1)

    """TRAINING LOOP"""

    # Switching model into training mode,
    model.train()

    # Begin training,
    for epoch in range(1, (EPOCHS + 1)):
        train_model(training_data = training_data_loaded, epochs = EPOCHS, lr = LEARNING_RATE, current_epoch = epoch)

        # Creating checkpoint,
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimiser.state_dict(),
            'loss': loss_function
        }

        # Saving checkpoint,
        torch.save(checkpoint, f"checkpoint_epoch_{epoch}.pth")


Device: cpu
Epoch: 1, Training Loss: 0.001539055840112269, Validation Loss: 5.362833803701506e-07
Epoch: 2, Training Loss: 0.0010723627638071775, Validation Loss: 5.33860259110952e-07
Epoch: 3, Training Loss: 0.001084979623556137, Validation Loss: 5.283640689413005e-07
Epoch: 4, Training Loss: 0.0010407083900645375, Validation Loss: 5.234020932221028e-07
Epoch: 5, Training Loss: 0.0010175611823797226, Validation Loss: 5.194180516809865e-07
Epoch: 6, Training Loss: 0.0007241687853820622, Validation Loss: 5.169706014385156e-07
Epoch: 7, Training Loss: 0.0009893898386508226, Validation Loss: 5.139526706443576e-07
Epoch: 8, Training Loss: 0.001012943685054779, Validation Loss: 5.105713398734224e-07
Epoch: 9, Training Loss: 0.0005665794597007334, Validation Loss: 5.076565798844968e-07
Epoch: 10, Training Loss: 0.0010219445684924722, Validation Loss: 5.052699521002069e-07
Epoch: 11, Training Loss: 0.0009881663136184216, Validation Loss: 5.033318188907288e-07
Epoch: 12, Training Loss: 0.0005

In [13]:
# Put model into evaluation mode
model.eval()

# Loading grid,
i = 6

# We must unsqueeze (batch_size = 1),
input_grid = test_data[0][i].unsqueeze(dim = 0)

with torch.no_grad():
    input_grid = input_grid.to(device)
    output_grid = model(input_grid)

# Squeezing the data and converting to NumPy arrays,
data_input = input_grid.squeeze(dim = 0).squeeze(dim = 0).detach().numpy()
prediction = output_grid.squeeze(dim = 0).squeeze(dim = 0).detach().numpy()
target = test_data[1][i].squeeze(dim = 0).numpy()

In [14]:
def plot_slice(i, grid0, grid1, grid2):
    fig, axes = plt.subplots(1, 3, figsize=(12, 6))  # Two side-by-side plots

    # Plot zeroth grid,
    img0 = axes[0].imshow(grid0[i], cmap="rainbow")
    axes[0].set_title("Input")
    axes[0].set_xlabel("(Mpc/h)")
    axes[0].set_ylabel("(Mpc/h)")

    # Plot first grid,
    img1 = axes[1].imshow(grid1[i], cmap="rainbow")
    axes[1].set_title("Prediction")
    axes[1].set_xlabel("(Mpc/h)")
    axes[1].set_ylabel("(Mpc/h)")

    # Plot second grid,
    img2 = axes[2].imshow(grid2[i], cmap="rainbow")
    axes[2].set_title("Target")
    axes[2].set_xlabel("(Mpc/h)")
    axes[2].set_ylabel("(Mpc/h)")

    # Custom tick formatter function
    def custom_formatter(x, pos):
        return f'{x*25}'

    formatter = FuncFormatter(custom_formatter)
    for ax in axes:
        ax.xaxis.set_major_formatter(formatter)
        ax.yaxis.set_major_formatter(formatter)

    # Add color bars
    fig.colorbar(img0, ax=axes[0], orientation="vertical")
    fig.colorbar(img1, ax=axes[1], orientation="vertical")
    fig.colorbar(img2, ax=axes[2], orientation="vertical")

    plt.tight_layout()
    plt.show()

# Creating slider widget,
grid0 = data_input
grid1 = prediction
grid2 = target  

grid_zlength = (grid1.shape[2] - 1)
slider = widgets.IntSlider(min=0, max=grid_zlength, value=0, description="Slice")

# Link slider to function,
output = widgets.interactive_output(lambda i: plot_slice(i, grid0, grid1, grid2), {'i': slider})

# Display slider and plot,
display(slider, output)

IntSlider(value=0, description='Slice', max=31)

Output()