In [1]:
import os,sys
import numpy as np
import time

import torch
from torch import nn
from torch.utils.data import DataLoader

from dataset import CosDataset, train_tfm, test_tfm
from model import CNN_cosmo

import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## model
mode = "cosmo"
## dataset
rng_seed = 114514
batch_size = 32
training_ratio = 0.7
valid_ratio = 0.2
test_ratio = 0.1
## training
n_epoch = 1000
lr = 0.002

net_cos = CNN_cosmo("cosmo")
net_all = CNN_cosmo("all")

if torch.cuda.is_available():
    net_cos.cuda()
    net_all.cuda()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
net_cos.load_state_dict(torch.load(os.path.join("./CNN_cosmo_params", "CNN_cosmo_epoch-300_model_params.pkl")))
net_all.load_state_dict(torch.load(os.path.join("./CNN_all_params", "CNN_all_epoch-020_model_params.pkl")))

  net_cos.load_state_dict(torch.load(os.path.join("./CNN_cosmo_params", "CNN_cosmo_epoch-300_model_params.pkl")))
  net_all.load_state_dict(torch.load(os.path.join("./CNN_all_params", "CNN_all_epoch-020_model_params.pkl")))


<All keys matched successfully>

In [None]:
img_list = np.load("/home/chenze/data_gpfs02/CAMELS_multifield/dataset/compiled_img_TNG.npy")
img_list = torch.FloatTensor(img_list)
lab_list = np.load("/home/chenze/data_gpfs02/CAMELS_multifield/dataset/compiled_params_TNG.npy")
lab_list = torch.FloatTensor(lab_list)

np.random.seed(rng_seed)
shuffle = np.arange(img_list.shape[0])
np.random.shuffle(shuffle)

img_list = img_list[shuffle]
lab_list = lab_list[shuffle]

len_training = int(len(img_list) * training_ratio)
len_valid = int(len(img_list) * training_ratio) + int(len(img_list) * valid_ratio)

test_set_cos  = CosDataset(img_list[len_valid:], lab_list[len_valid:], tfm=test_tfm, mode="cosmo")
test_loader_cos = DataLoader(test_set_cos, batch_size=1500, shuffle=True, pin_memory=True)

test_set_all  = CosDataset(img_list[len_valid:], lab_list[len_valid:], tfm=test_tfm, mode="all")
test_loader_all = DataLoader(test_set_all, batch_size=1500, shuffle=True, pin_memory=True)

In [None]:
for batch_data, batch_targ in list(test_loader_cos)[0:1]:
    pred_cos = net_cos(batch_data.to(device)).to("cpu").detach().numpy()
    trut_cos = np.array(batch_targ)

In [None]:
for batch_data, batch_targ in list(test_loader_all)[0:1]:
    pred_all = net_all(batch_data.to(device)).to("cpu").detach().numpy()
    trut_all = np.array(batch_targ)

In [None]:
plt.figure(figsize=(10, 5))
plt.subplot(121)
plt.title("$\Omega_m$")
plt.scatter(trut_cos[:,0], pred_cos[:,0], s=1)
plt.xlim(0.05, 0.6)
plt.ylim(0.05, 0.6)
plt.plot([0, 1], [0, 1], c='black')
plt.xlabel("Truth")
plt.ylabel("Pred")

plt.subplot(122)
plt.title("$\sigma_8$")
plt.scatter(trut_cos[:,1], pred_cos[:,1], s=1)
plt.xlim(0.55, 1.1)
plt.ylim(0.55, 1.1)
plt.plot([0, 1], [0, 1], c='black')
plt.xlabel("Truth")
plt.ylabel("Pred")


In [None]:
plt.figure(figsize=(10, 5))
plt.subplot(121)
plt.title("$\Omega_m$")
plt.scatter(trut_all[:,0], pred_all[:,0], s=1)
plt.xlim(0.05, 0.6)
plt.ylim(0.05, 0.6)
plt.plot([0, 1], [0, 1], c='black')
plt.xlabel("Truth")
plt.ylabel("Pred")

plt.subplot(122)
plt.title("$\sigma_8$")
plt.scatter(trut_all[:,1], pred_all[:,1], s=1)
plt.xlim(0.55, 1.1)
plt.ylim(0.55, 1.1)
plt.plot([0, 1], [0, 1], c='black')
plt.xlabel("Truth")
plt.ylabel("Pred")


plt.figure(figsize=(10, 10))
plt.subplot(221)
plt.title("$\Omega_m$")
plt.scatter(trut_all[:,3], pred_all[:,3], s=1)

plt.plot([0, 10], [0, 10], c='black')
plt.xlabel("Truth")
plt.ylabel("Pred")

plt.subplot(222)
plt.title("$\sigma_8$")
plt.scatter(trut_all[:,4], pred_all[:,4], s=1)

plt.plot([0, 10], [0, 10], c='black')
plt.xlabel("Truth")
plt.ylabel("Pred")

In [None]:
trut_all[:,5]