In [None]:
import numpy as np
import datetime
import pytz
import glob
import os

from sklearn.preprocessing import RobustScaler, MinMaxScaler, StandardScaler
from dataset_creation_tools import generate_dataset_from_config, split_safes

paris_timezone = pytz.timezone('Europe/Paris') # Set timezone, as it can be the wrong one on gpus

# **1 - SAFEs splitting**

In [None]:
# load safes 
listing = 'dirname/listing_name.txt'
safes = np.loadtxt(listing, dtype=str)

In [None]:
# gather the test safes (not used for training nor validation)
test_safes = np.concatenate([np.loadtxt(f, dtype=str) for f in glob.glob('case_studies/safes/*.txt')])

In [None]:
# split the safes between training and validation (test_safes will be excluded)
train_safes, val_safes = split_safes(safes, 0.10, test_safes=test_safes)

In [None]:
# set save directory
date = datetime.datetime.now(paris_timezone)
root_path = f'savedir/{date.strftime("%Y-%m-%d_%Hh%M")}'

In [None]:
# save the splitting used 
os.makedirs(root_path, exist_ok=True)
np.savetxt(os.path.join(root_path, 'train_safes.txt'), train_safes, fmt='%s')
np.savetxt(os.path.join(root_path, 'val_safes.txt'), val_safes, fmt='%s')
np.savetxt(os.path.join(root_path, 'test_safes.txt'), test_safes, fmt='%s')

# **2 - Create dataset**

In [None]:
# scaler that can be used (only sklearn scalers for now)
scaler_types = {
    'RobustScaler': RobustScaler,
    'MinMaxScaler': MinMaxScaler,
    'StandardScaler': StandardScaler
}

In [None]:
# set up dataset generation configuration
date = datetime.datetime.now(paris_timezone)

config = { 
    'raw_csv': 'raw.csv',
    'filter': 'normalized_variance_filt < 2 and 0 <= azimuth_cutoff <= 500',
    'safes_listing': 'dirname/listing_name.txt',
    'train_safes': f'{root_path}/train_safes.txt',
    'val_safes': f'{root_path}/val_safes.txt',
    'test_safes': f'{root_path}/test_safes.txt',
    'bin_width': 0.1,
    'scaler': {'name': 'RobustScaler', 'kwargs': {'quantile_range': (10, 90)}},
    'kept_columns': ['hs', 'phs0', 't0m1', 'sigma0_filt', 'normalized_variance_filt', 'incidence', 'azimuth_cutoff', 'cwave_params_k_gp=1_and_phi_hf=1','cwave_params_k_gp=2_and_phi_hf=1','cwave_params_k_gp=3_and_phi_hf=1','cwave_params_k_gp=4_and_phi_hf=1', 'cwave_params_k_gp=1_and_phi_hf=2','cwave_params_k_gp=2_and_phi_hf=2','cwave_params_k_gp=3_and_phi_hf=2','cwave_params_k_gp=4_and_phi_hf=2', 'cwave_params_k_gp=1_and_phi_hf=3','cwave_params_k_gp=2_and_phi_hf=3','cwave_params_k_gp=3_and_phi_hf=3','cwave_params_k_gp=4_and_phi_hf=3', 'cwave_params_k_gp=1_and_phi_hf=4','cwave_params_k_gp=2_and_phi_hf=4','cwave_params_k_gp=3_and_phi_hf=4','cwave_params_k_gp=4_and_phi_hf=4', 'cwave_params_k_gp=1_and_phi_hf=5','cwave_params_k_gp=2_and_phi_hf=5','cwave_params_k_gp=3_and_phi_hf=5','cwave_params_k_gp=4_and_phi_hf=5', 'file_path', 'safe'],
    'target_columns': ['hs', 'phs0', 't0m1'],
    'save_directory': f'savedir/{date.strftime("%Y-%m-%d_%Hh%M")}', # different save directory than the one for the splitting
    'date': date.strftime('%d/%m/%Y %H:%M'),
    'additional_informations': 'The raw input data was filtered such as the input normalized variance (filt) is inferior to 2 and the azimuth_cutoff is between 0 and 500m'
    }

In [None]:
generate_dataset_from_config(config)