# Imports

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import pickle
import os
import numpy as np

from time import time

from utils.dataset_utils import shuffle_six_arrays
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical

# Data Structure

### data

* X_vec_train
* X_vec_test
* y_class_train
* y_class_test
* y_reg_train
* y_reg_test
* y_reg_train_norm
* y_reg_test_norm
* resc_factor_train
* resc_factor_test
* div_scenario_train
* div_scenario_test

# Load regression data from pickle file

In [4]:
pickle_path_base = "/workspace/coniferas(1)/data_inference/pickles/simulations_no_fossil/"
pickle_files = ["raw_87_10k.pkl", "raw_489_10k.pkl", "raw_674_10k.pkl"]
test_perc = 0.1
# pickle_path_base = "/workspace/coniferas(1)/data_inference/pickles/treepar_dataset/"
# pickle_files = ["raw_87.pkl", "raw_489.pkl", "raw_674.pkl"]
# test_perc = 1

for file in pickle_files:
    with open(pickle_path_base + file, 'rb') as f:
        dataset = pickle.load(f)
    
    # Process diversification scenario information
    div_info = ["/".join(dataset.label_names[int(elem)].split('/')[-2:]) for elem in dataset.label]
    
    # Process y classification
    y_class = to_categorical(dataset.label, num_classes= int(np.max(dataset.label) + 1))
    
    # Process y regression
    y_reg = []
    y_reg_norm = []
    for i, label in enumerate(dataset.label):
        div_scenario = os.path.basename(dataset.label_names[int(label)]).split('_')[0]
        
        # Predict different labels for different simulations
        if div_scenario=='BD' or div_scenario=='HE':
            y_reg.append([dataset.r0[i], dataset.a0[i]])
            y_reg_norm.append([dataset.norm_r0[i], dataset.norm_a0[i]])

        elif div_scenario=='ME':
            y_reg.append([dataset.r0[i], dataset.a0[i],
                          dataset.time[i], dataset.frac1[i]])
            y_reg_norm.append([dataset.norm_r0[i], dataset.a0[i],
                               dataset.norm_time[i], dataset.frac1[i]])

        elif div_scenario=='SR' or div_scenario=='WW':
            y_reg.append([dataset.r0[i], dataset.r1[i], dataset.a0[i], dataset.a1[i],
                          dataset.time[i]])
            y_reg_norm.append([dataset.norm_r0[i], dataset.norm_r1[i], dataset.a0[i], dataset.a1[i],
                          dataset.norm_time[i]])

        elif div_scenario=='SAT':
            y_reg.append([dataset.r0[i]])
            y_reg_norm.append([dataset.norm_r0[i]])
    
    X_vec, y_class, y_reg, y_reg_norm, resc_factor, div_info = shuffle_six_arrays(dataset.X_vec,
                                                                                  y_class,
                                                                                  y_reg,
                                                                                  y_reg_norm,
                                                                                  dataset.resc_factor,
                                                                                  div_info)
    
    X_train = X_vec[int(test_perc*len(X_vec)):]
    X_test = X_vec[:int(test_perc*len(X_vec))]
    y_class_train = y_class[int(test_perc*len(y_class)):]
    y_class_test = y_class[:int(test_perc*len(y_class))]
    y_reg_train = y_reg[int(test_perc*len(y_reg)):]
    y_reg_test = y_reg[:int(test_perc*len(y_reg))]
    y_reg_norm_train = y_reg_norm[int(test_perc*len(y_reg_norm)):]
    y_reg_norm_test = y_reg_norm[:int(test_perc*len(y_reg_norm))]
    resc_factor_train = resc_factor[int(test_perc*len(resc_factor)):]
    resc_factor_test = resc_factor[:int(test_perc*len(resc_factor))]
    div_info_train = div_info[int(test_perc*len(div_info)):]
    div_info_test = div_info[:int(test_perc*len(div_info))]
    
    print('\n', file[4:])
    data = dict()
    data['X_train'] = X_train
    data['X_test'] = X_test
    data['y_class_train'] = y_class_train
    data['y_class_test'] = y_class_test
    data['y_reg_train'] = y_reg_train
    data['y_reg_test'] = y_reg_test
    data['y_reg_norm_train'] = y_reg_norm_train
    data['y_reg_norm_test'] = y_reg_norm_test
    data['resc_factor_train'] = resc_factor_train
    data['resc_factor_test'] = resc_factor_test
    data['div_info_train'] = div_info_train
    data['div_info_test'] = div_info_test
    
    print('X_train:', np.shape(data['X_train']))
    print('X_test:', np.shape(data['X_test']))
    print('y_class_train:', np.shape(data['y_class_train']))
    print('y_class_test:', np.shape(data['y_class_test']))
    print('y_reg_train:', np.shape(data['y_reg_train']))
    print('y_reg_test:', np.shape(data['y_reg_test']))
    print('y_reg_norm_train:', np.shape(data['y_reg_norm_train']))
    print('y_reg_norm_test:', np.shape(data['y_reg_norm_test']))
    print('resc_factor_train:', np.shape(data['resc_factor_train']))
    print('resc_factor_test:', np.shape(data['resc_factor_test']))
    print('div_info_train:', np.shape(data['div_info_train']))
    print('div_info_test:', np.shape(data['div_info_test']))

    with open(pickle_path_base + "dataset_" + file[4:], 'wb') as f:
              pickle.dump(data, f)


 87_10k.pkl
X_train: (54000, 87)
X_test: (6000, 87)
y_class_train: (54000, 6)
y_class_test: (6000, 6)
y_reg_train: (54000,)
y_reg_test: (6000,)
y_reg_norm_train: (54000,)
y_reg_norm_test: (6000,)
resc_factor_train: (54000,)
resc_factor_test: (6000,)
div_info_train: (54000,)
div_info_test: (6000,)

 489_10k.pkl
X_train: (54000, 489)
X_test: (6000, 489)
y_class_train: (54000, 6)
y_class_test: (6000, 6)
y_reg_train: (54000,)
y_reg_test: (6000,)
y_reg_norm_train: (54000,)
y_reg_norm_test: (6000,)
resc_factor_train: (54000,)
resc_factor_test: (6000,)
div_info_train: (54000,)
div_info_test: (6000,)

 674_10k.pkl
X_train: (54000, 674)
X_test: (6000, 674)
y_class_train: (54000, 6)
y_class_test: (6000, 6)
y_reg_train: (54000,)
y_reg_test: (6000,)
y_reg_norm_train: (54000,)
y_reg_norm_test: (6000,)
resc_factor_train: (54000,)
resc_factor_test: (6000,)
div_info_train: (54000,)
div_info_test: (6000,)
