In [None]:
library(transformeR)
library(climate4R.datasets)
library(downscaleR.keras)

In [None]:
data("VALUE_Iberia_tas") # illustrative datasets included in transformeR
y <- VALUE_Iberia_tas 
data("NCEP_Iberia_hus850", "NCEP_Iberia_psl", "NCEP_Iberia_ta850")
x <- makeMultiGrid(NCEP_Iberia_hus850, NCEP_Iberia_psl, NCEP_Iberia_ta850)

In [None]:
# We standardize the predictors using transformeR function scaleGrid
x <- scaleGrid(x,type = "standardize") 

In [None]:
# calculating predictors
data <- prepareData.keras(x = x, y = y, 
                          first.connection = "conv",
                          last.connection = "dense",
                          channels = "last")

In [None]:
# Defining the keras model.... 
# We define 3 hidden layers that consists on 
# 2 convolutional steps followed by a dense connection.
input_shape  <- dim(data$x.global)[-1]
output_shape  <- dim(data$y$Data)[2]
inputs <- layer_input(shape = input_shape)
hidden <- inputs %>% 
  layer_conv_2d(filters = 25, kernel_size = c(3,3), activation = 'relu') %>%  
  layer_conv_2d(filters = 10, kernel_size = c(3,3), activation = 'relu') %>% 
  layer_flatten() %>% 
  layer_dense(units = 20, activation = "relu")
outputs <- layer_dense(hidden,units = output_shape)
model <- keras_model(inputs = inputs, outputs = outputs)
# We can print model in console to observe its configuration
summary(model)

In [None]:
# Training a deep learning model 
# (saving the model using callbacks according to an early-stopping criteria)
downscaleTrain.keras(data,
                     model = model,
                     compile.args = list("loss" = "mse", 
                                         "optimizer" = optimizer_adam(lr = 0.01)
                                        ),
                     fit.args = list("epochs" = 100, 
                                     "batch_size" = 100, 
                                     "validation_split" = 0.1, 
                                     "verbose" = 0,
                                     "callbacks" = list(callback_early_stopping(patience = 30),
                                                        callback_model_checkpoint(filepath=paste0(getwd(),"/model.h5"),
                                                                                  monitor='val_loss', 
                                                                                  save_best_only=TRUE
                                                                                 )
                                                       )
                                    ),
                     clear.session = TRUE
                    )