In [None]:
%load_ext autoreload
%autoreload 2

import tensorflow as tf

import json
import os

from sea_parameters import SeaParametersModel, custom_loss, cumulative_constraint_generalized, custom_categorical_crossentropy
from validation_dataset import ValidationXarrayDataset, read_json

# **1 - Load data**

In [None]:
data_folder = 'data_folder'

In [None]:
model_wrapper = SeaParametersModel(data_folder=data_folder, batch_size=8092, epochs=50, patience=5)
loss = custom_loss(model_wrapper.nb_classes, model_wrapper.batch_size, h=0.01)
model_wrapper.build_model(loss)

# **2 - Training**

In [None]:
model_wrapper.train_model()

In [None]:
save_directory = model_wrapper.save_model(save_directory='trained_models')

# **3 - Validation data**

In [None]:
model = tf.keras.models.load_model(os.path.join(save_directory, 'model.keras'),
                                   custom_objects={'loss': custom_loss(model_wrapper.nb_classes, model_wrapper.batch_size, h=0.01)})

In [None]:
data_folder = read_json(os.path.join(save_directory, 'config.json'))['training_data']

In [None]:
val = ValidationXarrayDataset(model, data_folder)
ds = val.generate_validation_dataset()
ds = val.add_variables(ds, ['sigma0_filt', 'normalized_variance_filt', 'incidence', 'azimuth_cutoff'])

In [None]:
ds.to_netcdf(os.path.join(save_directory, 'validation_predictions.nc'))