# Merge results for fitted params in a single pandas file

Gather all results on fitted params 

- Author Sylvie Dagoret-Campagne
- Afflilation : IJCLab/IN2P3/CNRS
- Organisation : LSST-DESC
- creation date : 2023-11-26

- conda environment on my laptop : ``conda_jaxcpu_dsps_py310``
  
- last update : 2023-11-27 : fit done by november 2023 (LSST France @ CCIN2P3)
- last update : 2024-01-06 : fit done by Janyary 6th 2024 to increase parameter range
- last update : 2024-01-28 : fit done by Janyary 28th 2024 include metallicity in fit parameters

In [None]:
import pandas as pd
import h5py
import pickle
import os
import glob
import re
import numpy as np
import jax.numpy as jnp
import itertools

In [None]:
import jax
jax.config.update("jax_enable_x64", True)

## Pickup all available files and sort them

In [None]:
#date = "2024-01-06"
date = "2024-01-28"
filenames_fitparams = f"fitssp_results_{date}/*.pickle"
#regexp_pattern_numbers = "^fitssp_results_2024-01-06/fitparams_SPEC(.*)[.]pickle$"
regexp_pattern_numbers = "^fitssp_results_2024-01-28/fitparams_SPEC(.*)[.]pickle$"
filename_output_fittedresults_csv = f"fitssp_results_{date}.csv"
filename_output_fittedresults_hdf = f"fitssp_results_{date}.h5"

In [None]:
all_files = []
for file in glob.glob(filenames_fitparams):
    all_files.append(file)
all_files = np.array(all_files)    
N = len(all_files)

In [None]:
fors2_nums = np.array([int(re.findall(regexp_pattern_numbers , filename)[0]) for filename in all_files])

In [None]:
sorted_indexes = np.argsort(fors2_nums)
all_sorted_files = all_files[sorted_indexes]
fors2_nums_sorted = fors2_nums[sorted_indexes] 

# Read all parameters

In [None]:
all_params_dicts = []
for filename_params in all_sorted_files:
    with open(filename_params, 'rb') as f:
        loaded_dict = pickle.load(f)
        all_params_dicts.append(loaded_dict)

In [None]:
#all_params_dicts[0].keys()

In [None]:
#list(all_params_dicts[0].values())

## Write in pandas dataframe

In [None]:
df = pd.DataFrame(columns=list(all_params_dicts[0].keys()))

In [None]:
NC = len(df.columns)

In [None]:
for idx in range(N):
    row = list(all_params_dicts[idx].values())
    len_row = len(row)
    row_flatten_part1  = [row[i] for i in range(5) ]
    # convert the number from a jax Array into a float value
    row_flatten_part2 = [row[i].item() for i in range(5,len_row) ]
    row_flatten =  row_flatten_part1 +  row_flatten_part2
    if len(row_flatten) != NC:
        print("bad size row skipped for ",row_flatten)
        continue
    df.loc[idx] =  row_flatten

## Drop NAN

In [None]:
df = df.dropna(axis=0, ignore_index=True)

## Write in pandas dataframe

In [None]:
df.to_csv(filename_output_fittedresults_csv)
df.to_hdf(filename_output_fittedresults_hdf,key='fitssp_results', mode='w',format='table',
          data_columns=True,complevel=9)