# Used for making predictions on the all data with UiT model

In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import trange
device = torch.device('cuda')

In [2]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, padding_mode='circular'),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, padding_mode='circular'),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, param_dim=3, base_filters=32):
        super(UNet3D, self).__init__()
        
        # Encoder
        self.enc1 = DoubleConv(in_channels, base_filters)
        self.enc2 = DoubleConv(base_filters, base_filters * 1)
        self.enc3 = DoubleConv(base_filters * 1, base_filters * 2)
        self.enc4 = DoubleConv(base_filters * 2, base_filters * 4)

        self.pool = nn.MaxPool3d(2)

        # Bottleneck (where we will integrate parameters)
        self.bottleneck = DoubleConv(131, base_filters * 16)

        # Decoder
        self.upconv4 = nn.ConvTranspose3d(base_filters * 16, base_filters * 8, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(base_filters * 10, base_filters * 8)
        
        self.upconv3 = nn.ConvTranspose3d(base_filters * 8, base_filters * 4, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(base_filters * 5, base_filters * 4)
        
        self.upconv2 = nn.ConvTranspose3d(base_filters * 4, base_filters * 2, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(base_filters * 3, base_filters * 2)
        
        self.upconv1 = nn.ConvTranspose3d(base_filters * 2, base_filters, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(base_filters * 1, base_filters)
        
        self.final_conv = nn.Conv3d(base_filters, out_channels, kernel_size=3, stride=2, padding=1)
        
    def forward(self, x, params):
        
        # Encoder
        enc1 = self.enc1(x)   # shape: (n, 32, 48, 48, 48)
        enc2 = self.enc2(self.pool(enc1))  # shape: (n, 64, 24, 24, 24)
        enc3 = self.enc3(self.pool(enc2))  # shape: (n, 128, 12, 12, 12)
        enc4 = self.enc4(self.pool(enc3))  # shape: (n, 256, 6, 6, 6)

        # Integrating input parameters in bottleneck
        n, c, d, h, w = enc4.shape
        params = params.view(n, -1, 1, 1, 1).repeat(1, 1, d, h, w)
        bottleneck_input = torch.cat([enc4, params], dim=1)  # Concatenate params to feature maps
        bottleneck = self.bottleneck(bottleneck_input)  # shape: (n, 512, 6, 6, 6)
        
        # Decoder
        dec4 = self.upconv4(bottleneck)  # shape: (n, 256, 12, 12, 12)
        dec4 = self.dec4(torch.cat([dec4, enc3], dim=1))
        
        dec3 = self.upconv3(dec4)  # shape: (n, 128, 24, 24, 24)
        dec3 = self.dec3(torch.cat([dec3, enc2], dim=1))
        
        dec2 = self.upconv2(dec3)  # shape: (n, 64, 48, 48, 48)
        dec2 = self.dec2(torch.cat([dec2, enc1], dim=1))
        
        dec1 = self.upconv1(dec2)  # shape: (n, 32, 48, 48, 48)
        dec1 = self.dec1(dec1)
        
        output = self.final_conv(dec1)  # shape: (n, 1, 48, 48, 48)
        return output

model = UNet3D().to(device)
# output = model(input_halo, input_params)
# print(output.shape)

model.load_state_dict(torch.load('/media/disk1/prasad/codes/RefinedModelOutputs/CosmoUNet100/MidRun55.pth', weights_only=True))
model.eval() 

UNet3D(
  (enc1): DoubleConv(
    (double_conv): Sequential(
      (0): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), padding_mode=circular)
      (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), padding_mode=circular)
      (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (enc2): DoubleConv(
    (double_conv): Sequential(
      (0): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), padding_mode=circular)
      (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), padding_mode=circular)
      (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running

In [3]:
# input halo: normalized already
input_halo = torch.load('/media/disk2/prasad/TensorData/InputFields/Halo.pt', weights_only=False)[0].unsqueeze(0).to(device)
print("input halo shape: ",input_halo.shape)

# commented out since we are only using the halo at first
# # input dm: normalized already
# input_dm= torch.load('/media/disk2/prasad/TensorData/InputFields/DM.pt', weights_only=False).to(device)
# print("input dm shape: ", input_dm.shape)

input halo shape:  torch.Size([1, 1, 48, 48, 48])


In [4]:
import numpy as np
import torch
# import torch.nn.functional as F
from tqdm import trange
import gc
from sklearn.metrics import r2_score

j = 0
MSEs = []
R2s = []
indices = []

for index1 in trange(7204):
    # Load true output data
    true_output1 = np.load(f'/media/disk2/prasad/ReducedData48/xHI{index1}.npy')

    # Expand dims of true outputs to match the predictions' dimensions
    true_output1 = np.expand_dims(np.expand_dims(true_output1, axis=0), axis=1)

    # Convert true outputs to tensors (consider using float16 to save memory)
    true_output1 = torch.tensor(true_output1, dtype=torch.float32).to(device)

    # Load input parameters
    input_params1 = torch.tensor(np.expand_dims(np.load('/media/disk2/prasad/Params.npy')[index1], axis=0), dtype=torch.float32).to(device)

    # Predict outputs using the model
    prediction1 = model(input_halo, input_params1).detach()

    # Extract parameter values
    # Mh1, Nion1, Rmfp1 = input_params1[0].numpy()

    # Calculate MSE loss values
    mse1 = F.mse_loss(prediction1, true_output1).item()

    # Calculate R^2 score
    r2_1 = r2_score(true_output1.cpu().flatten().numpy(), prediction1.cpu().flatten().numpy())
    # Storing the index of files with R2 less than 0.642
    # if r2_1<0.642:
    #     # print(index1)
    #     indices.append(index1)
    # Append MSE and R^2 values to respective lists
    MSEs.append(mse1)
    R2s.append(r2_1)
    
    
    

    # Delete tensors and collect garbage
    del true_output1
    del prediction1
    gc.collect()

    # print(mse1)
    # print(r2_1)
    # j+=1
    # if j==100:
    # break

# Save the final arrays
np.save('MSE55.npy', arr=np.array(MSEs))
np.save('R255.npy', arr=np.array(R2s))
# np.save('TroublingIndices.npy', arr=np.array(indices))

  0%|          | 0/7203 [00:00<?, ?it/s]
