In [11]:
import os
import argparse
import yaml
import time
import numpy as np
import pandas as pd
from multiprocessing import Pool

import torch 
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn

from ili.dataloaders import StaticNumpyLoader, TorchLoader, NumpyLoader
from ili.dataloaders import TorchLoader
from ili.inference import InferenceRunner
from ili.validation import ValidationRunner


from sbi.analysis import pairplot
from sbi.inference import SNPE, simulate_for_sbi
from sbi.utils import BoxUniform
from sbi.utils.user_input_checks import (
    check_sbi_inputs,
    process_prior,
    process_simulator,
)

from sbi.utils import posterior_nn

from CASBI.utils.create_dataframe import rescale


In [2]:

# N_subhalos = 2
data = pd.read_parquet('../../../../data/dataframe/dataframe.parquet')
data = rescale(data, mean_and_std_path='../../../../data/preprocess/mean_and_std.parquet', scale_observations=True, scale_parameter=True, inverse=True) 
data =  data.drop(['gas_log10mass', 'a','redshift', 'mean_metallicity', 'std_metallicity','mean_FeMassFrac', 'std_FeMassFrac', 'mean_OMassFrac', 'std_OMassFrac'], axis=1)
min_feh, max_feh = min(data['feh']), max(data['feh'])
min_ofe, max_ofe = min(data['ofe']), max(data['ofe'])


N_test = 1_000
def preprocess_testset(i):
    N_subhalos = np.random.randint(2, 100)
    galaxies = set(data['Galaxy_name'].drop_duplicates().sample(N_subhalos, random_state=i))
    # parameters =  data[data['Galaxy_name'].isin(galaxies)].drop(['feh', 'ofe', 'Galaxy_name'], axis=1).drop_duplicates().values.T
    # sorted_index = np.argsort(parameters[0], )[::-1]
    # parameters = (parameters[:,sorted_index]).reshape(-1)
    galaxy_data = data[data['Galaxy_name'].isin(galaxies)].values
    histogram_galaxy, _, _ = np.histogram2d(galaxy_data[:, 0], galaxy_data[:, 1], bins=64, range=[[min_feh, max_feh], [min_ofe, max_ofe]])
    sim_data =  np.expand_dims(np.log10(histogram_galaxy + 1e-6 +1), axis=0)
    return N_subhalos, sim_data, galaxies

# Create a pool of workers
with Pool() as pool:
    # Map the function to the data
    results = pool.map(preprocess_testset, range(N_test))
    
# Unpack the results
N_subhalos_test, x_test, galaxies_test = zip(*results)

#take the first test set element as x_0 and theta_0    
galaxies_0 = galaxies_test[0]
data_to_plot_halos = data[data['Galaxy_name'].isin(galaxies_0)].to_parquet('./halos_0.parquet')
N_subhalos_0 = N_subhalos_test[0]
x_0 =  x_test[0]

N = 10_000
def process_sample(i):
    N_subhalos = np.random.randint(2, 100)
    galaxies = data['Galaxy_name'].drop_duplicates().sample(N_subhalos, random_state=i+int(time.time()))
    while (any(set(galaxies) == galaxy_in_testset for galaxy_in_testset in galaxies_test)):
        print('matched galaxies, try again')
        print('galaxies', set(galaxies))
        print('test galaxies', galaxies_test)
        galaxies = data['Galaxy_name'].drop_duplicates().sample(N_subhalos, random_state=i)
    # parameters =  data[data['Galaxy_name'].isin(galaxies)].drop(['feh', 'ofe', 'Galaxy_name'], axis=1).drop_duplicates().values.T
    # sorted_index = np.argsort(parameters[0], )[::-1]
    # parameters = (parameters[:,sorted_index]).reshape(-1)
    galaxy_data = data[data['Galaxy_name'].isin(galaxies)].values
    histogram_galaxy, _, _ = np.histogram2d(galaxy_data[:, 0], galaxy_data[:, 1], bins=64, range=[[min_feh, max_feh], [min_ofe, max_ofe]])
    sim_data =  np.expand_dims(np.log10(histogram_galaxy + 1e-6 +1), axis=0)
    return N_subhalos, sim_data

# Create a pool of workers
with Pool() as pool:
    # Map the function to the data
    results = pool.map(process_sample, range(N))

# Unpack the results
N_subhalos, x = zip(*results)

#save in .npy files, we remove the first element of the test set since it will be stored as x_0 and theta_0')
path = '../../../../../../data/vgiusepp/'
np.save(path+'x_test.npy', x_test[1:])
np.save(path+'N_subhalos_test.npy', N_subhalos_test[1:])
np.save(path+'x_0.npy', x_0)
np.save(path+'N_subhalos_0.npy', N_subhalos_0)
np.save(path+'x.npy', x)
np.save(path+'N_subhalos.npy', N_subhalos)
print('finish prepare the data')

finish prepare the data


In [24]:
N_subhalos = torch.from_numpy(np.array(N_subhalos)).reshape((len(N_subhalos), 1)).float()
x =  torch.from_numpy(np.array(x)).float()

N_subhalos_0 =  torch.from_numpy(np.array(N_subhalos_0)).float()
x_0 =  torch.from_numpy(np.array(x_0)).float()

In [25]:
N_subhalos.shape

torch.Size([10000, 1])

In [26]:
from CNN import ConvNet
prior = BoxUniform(low=torch.tensor([2]), high=torch.tensor([100]))
embedding_net = ConvNet(input_channel=1, output_dim=10)

neural_posterior = posterior_nn(model="nsf", embedding_net=embedding_net)


# setup the inference procedure with the SNPE-C procedure
inferer = SNPE(prior=prior, density_estimator=neural_posterior)

# train the density estimator
density_estimator = inferer.append_simulations(N_subhalos, x).train(training_batch_size=256)
posterior = inferer.build_posterior(density_estimator)

 Training neural network. Epochs trained: 2

KeyboardInterrupt: 