# Main performance experiment for the CNN model using the new ClimateBench metrics
See Sections 4 and 5 of the manuscript

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.append('/content/drive/My Drive/Colab Notebooks/USC Random NN')

Mounted at /content/drive


In [None]:
!pip install eofs --quiet
import numpy as np
import xarray as xr
import pandas as pd
from utils import *
from utils_cnn import *
results_path = results_path + 'cnn/new_metrics_experiment/'

import keras
from keras import Sequential
from keras.layers import *
from keras.callbacks import EarlyStopping
import tensorflow as tf
from tensorflow.keras.utils import plot_model
tf.get_logger().setLevel('ERROR')
from IPython.display import display

tf.keras.utils.set_random_seed(21)

[?25l[K     |▎                               | 10 kB 34.4 MB/s eta 0:00:01[K     |▋                               | 20 kB 6.8 MB/s eta 0:00:01[K     |█                               | 30 kB 9.7 MB/s eta 0:00:01[K     |█▎                              | 40 kB 4.1 MB/s eta 0:00:01[K     |█▋                              | 51 kB 4.3 MB/s eta 0:00:01[K     |██                              | 61 kB 5.2 MB/s eta 0:00:01[K     |██▎                             | 71 kB 5.3 MB/s eta 0:00:01[K     |██▌                             | 81 kB 5.9 MB/s eta 0:00:01[K     |██▉                             | 92 kB 4.7 MB/s eta 0:00:01[K     |███▏                            | 102 kB 5.1 MB/s eta 0:00:01[K     |███▌                            | 112 kB 5.1 MB/s eta 0:00:01[K     |███▉                            | 122 kB 5.1 MB/s eta 0:00:01[K     |████▏                           | 133 kB 5.1 MB/s eta 0:00:01[K     |████▌                           | 143 kB 5.1 MB/s eta 0:00:01[K    

In [None]:
vars_to_predict = ['tas', 'diurnal_temperature_range', 'pr', 'pr90']
simus = ['historical', 'hist-GHG', 'hist-aer', 'ssp126', 'ssp370', 'ssp585',]
# Selects first two years of every decade from 1850 onward as validation
val_idx = np.concatenate((np.arange(0,161,10), np.arange(1,162,10),
                          np.arange(165,326,10), np.arange(166,327,10), 
                          np.arange(330,491,10), np.arange(331,492,10),
                          np.arange(500,571,10), np.arange(501,572,10),
                          np.arange(586,657,10), np.arange(587,658,10),
                          np.arange(672,743,10), np.arange(673,744,10)))

X_train_dict = {}
Y_train_dict = {}
X_val_dict = {}
Y_val_dict = {}

# Create training data
for var in vars_to_predict:
  X, Y, meanstd_inputs = create_training_data(simus, var_to_predict=var)
    
  X_val = np.take(X, val_idx, axis=0)
  X_train = np.delete(X, val_idx, axis=0)
  Y_val = np.take(Y, val_idx, axis=0)
  Y_train = np.delete(Y, val_idx, axis=0)
    
  X_train_dict[var] = X_train
  X_val_dict[var] = X_val
  Y_train_dict[var] = Y_train
  Y_val_dict[var] = Y_val

# Open, reformat, and normalize test data
X_test = xr.open_mfdataset([data_path + 'inputs_historical.nc',
                            data_path + 'inputs_ssp245.nc']).compute()
Y_test = create_predictdand_data(['ssp245'])

for input_var in ['CO2', 'CH4', 'SO2', 'BC']: 
  var_dims = X_test[input_var].dims
  X_test = X_test.assign({input_var: (var_dims, normalize(X_test[input_var].data, input_var, meanstd_inputs))}) 
    
X_test_np = input_for_training(X_test) 

X_train_dict['tas'].shape, Y_train_dict['tas'].shape, X_val_dict['tas'].shape, Y_val_dict['tas'].shape

((603, 96, 144, 4), (603, 96, 144), (150, 96, 144, 4), (150, 96, 144))

In [None]:
param_lims = {1000000: '1M',
              10000000: '10M'
             }
layer_range = (2,11)
num_models = 50
raw_rmse_data_spatial = []
raw_rmse_data_global = []
raw_rmse_data_total = []

In [None]:
for param_lim in param_lims.keys():
    
  # param_lim +/- 10%
  param_range = (int(param_lim-0.1*param_lim), int(param_lim+0.1*param_lim))
  rmse_data_spatial = []
  rmse_data_global = []
  rmse_data_total = []

  for num_layers in range(*layer_range):

    for i in range(num_models):
            
      param_count = float('inf')

      while param_count not in range(*param_range):
        # much of the possible range of layer sizes will be too large,
        # so we sample from a smaller space to speed up search
        if param_lim == 1000000:
            units_list = [np.random.randint(20,200)]*num_layers
        elif param_lim == 10000000:
            units_list = [np.random.randint(200,1000)]*num_layers
        units_list = np.insert(units_list,0,20)
        units_list = np.append(units_list,96*144)
        param_count = count_ffnn_params(units_list)

      keras.backend.clear_session()
      untrained_model = None

      untrained_model = Sequential()
      untrained_model.add(Input(shape=(96, 144, 4,)))
      untrained_model.add(Conv2D(20, (3,3), padding='same', activation='relu', kernel_regularizer='l2'))
      untrained_model.add(AveragePooling2D(2))
      untrained_model.add(GlobalAveragePooling2D())
      for units in units_list[1:-1]:
        untrained_model.add(Dense(units, activation='relu'))
      untrained_model.add(Dense(13824))
      untrained_model.add(Reshape((1, 96, 144)))

      #untrained_model.summary()
      #display(plot_model(untrained_model, show_shapes=True))

      np.save(results_path+f'{param_lims[param_lim]}/models/{num_layers}_layer_model_{i}', untrained_model.get_config())
      plot_model(untrained_model, results_path+f'{param_lims[param_lim]}/images/{num_layers}_layer_model_{i}.png', show_shapes=True)
            
      for var in vars_to_predict:
                
        # Get train/val data
        X_train = X_train_dict[var]
        Y_train = Y_train_dict[var]
        X_val = X_val_dict[var]
        Y_val = Y_val_dict[var]

        model = None
        model = untrained_model
        model.compile(optimizer="adam", loss="mse", metrics=["mse"])
        hist = model.fit(X_train,
                         Y_train,
                         batch_size=16,
                         epochs=100,
                         verbose=0,
                         validation_data=(X_val,Y_val),
                         callbacks=EarlyStopping(patience=10, restore_best_weights=True),
                        )
        #plot_loss(hist)
                
        # Make predictions using trained model
        m_pred = model.predict(X_test_np)
        m_pred = m_pred.reshape(m_pred.shape[0], m_pred.shape[2], m_pred.shape[3])
        m_pred = xr.DataArray(m_pred, dims=['time','lat','lon'], coords=[X_test.time.data, X_test.latitude.data, X_test.longitude.data])
        m_pred = m_pred.transpose('lat','lon','time').sel(time=slice(2015,2101)).to_dataset(name=var)

        # Save prediction as .nc
        if var == 'diurnal_temperature_range':
          m_pred.to_netcdf(results_path+f'{param_lims[param_lim]}/predictions/dtr/{num_layers}_layer_model_{i}.nc', 'w')
        else:
          m_pred.to_netcdf(results_path+f'{param_lims[param_lim]}/predictions/{var}/{num_layers}_layer_model_{i}.nc', 'w')
        
        # Calculate RMSE
        var_truth = Y_test[var]
        m_var_pred = m_pred.transpose('time','lat','lon')[var]
        rmse_spatial = get_rmse_spatial(var_truth[65:], m_var_pred[65:])
        raw_rmse_data_spatial.append(rmse_spatial)
        rmse_global = get_rmse_global(var_truth[65:], m_var_pred[65:])
        raw_rmse_data_global.append(rmse_global)
        rmse_total = rmse_spatial + 5*rmse_global
        raw_rmse_data_total.append(rmse_total)
                
    n_layer_rmse_data_spatial = raw_rmse_data_spatial[-(num_models*len(vars_to_predict)):]
    n_layer_rmse_data_spatial = np.reshape(n_layer_rmse_data_spatial, (len(vars_to_predict), num_models), order='F')

    n_layer_rmse_data_global = raw_rmse_data_global[-(num_models*len(vars_to_predict)):]
    n_layer_rmse_data_global = np.reshape(n_layer_rmse_data_global, (len(vars_to_predict), num_models), order='F')

    n_layer_rmse_data_total = raw_rmse_data_total[-(num_models*len(vars_to_predict)):]
    n_layer_rmse_data_total = np.reshape(n_layer_rmse_data_total, (len(vars_to_predict), num_models), order='F')

    print(f'{num_layers} Layer {param_lims[param_lim]} Mean RMSE:')
    for var in vars_to_predict:
        if var == 'diurnal_temperature_range':
            print(f'\tdtr:')
            print(f'\t\tSpatial: {round(np.mean(n_layer_rmse_data_spatial[vars_to_predict.index(var)]),4)}')
            print(f'\t\tGlobal: {round(np.mean(n_layer_rmse_data_global[vars_to_predict.index(var)]),4)}')
            print(f'\t\tTotal: {round(np.mean(n_layer_rmse_data_total[vars_to_predict.index(var)]),4)}')
        else:
            print(f'\t{var}:')
            print(f'\t\tSpatial: {round(np.mean(n_layer_rmse_data_spatial[vars_to_predict.index(var)]),4)}')
            print(f'\t\tGlobal: {round(np.mean(n_layer_rmse_data_global[vars_to_predict.index(var)]),4)}')
            print(f'\t\tTotal: {round(np.mean(n_layer_rmse_data_total[vars_to_predict.index(var)]),4)}')

    print(f'{num_layers} Layer {param_lims[param_lim]} Min RMSE:')
    for var in vars_to_predict:
        if var == 'diurnal_temperature_range':
            print(f'\tdtr:')
            print(f'\t\tSpatial: {round(min(n_layer_rmse_data_spatial[vars_to_predict.index(var)]),4)}')
            print(f'\t\tGlobal: {round(min(n_layer_rmse_data_global[vars_to_predict.index(var)]),4)}')
            print(f'\t\tTotal: {round(min(n_layer_rmse_data_total[vars_to_predict.index(var)]),4)}')
        else:
            print(f'\t{var}:')
            print(f'\t\tSpatial: {round(min(n_layer_rmse_data_spatial[vars_to_predict.index(var)]),4)}')
            print(f'\t\tGlobal: {round(min(n_layer_rmse_data_global[vars_to_predict.index(var)]),4)}')
            print(f'\t\tTotal: {round(min(n_layer_rmse_data_total[vars_to_predict.index(var)]),4)}')
            
    rmse_data_spatial.append(n_layer_rmse_data_spatial)
    rmse_data_global.append(n_layer_rmse_data_global)
    rmse_data_total.append(n_layer_rmse_data_total)
        
  np.save(results_path+f'{param_lims[param_lim]}/rmse_data_spatial', rmse_data_spatial)
  np.save(results_path+f'{param_lims[param_lim]}/rmse_data_global', rmse_data_global)
  np.save(results_path+f'{param_lims[param_lim]}/rmse_data_total', rmse_data_total)

2 Layer 1M Mean RMSE:
	tas:
		Spatial: 0.8237
		Global: 0.7189
		Total: 4.4183
	dtr:
		Spatial: 17.0375
		Global: 1.4195
		Total: 24.1348
	pr:
		Spatial: 5.3176
		Global: 0.8663
		Total: 9.6491
	pr90:
		Spatial: 6.3206
		Global: 0.899
		Total: 10.8157
2 Layer 1M Min RMSE:
	tas:
		Spatial: 0.6983
		Global: 0.6098
		Total: 3.7474
	dtr:
		Spatial: 16.6856
		Global: 1.3314
		Total: 23.7625
	pr:
		Spatial: 5.1317
		Global: 0.8294
		Total: 9.345
	pr90:
		Spatial: 6.1925
		Global: 0.8726
		Total: 10.5923
3 Layer 1M Mean RMSE:
	tas:
		Spatial: 0.8216
		Global: 0.7145
		Total: 4.3941
	dtr:
		Spatial: 17.0502
		Global: 1.4228
		Total: 24.1642
	pr:
		Spatial: 5.3975
		Global: 0.8726
		Total: 9.7607
	pr90:
		Spatial: 6.3302
		Global: 0.9
		Total: 10.83
3 Layer 1M Min RMSE:
	tas:
		Spatial: 0.7912
		Global: 0.6871
		Total: 4.2282
	dtr:
		Spatial: 16.7751
		Global: 1.3097
		Total: 23.6877
	pr:
		Spatial: 5.2207
		Global: 0.8263
		Total: 9.3743
	pr90:
		Spatial: 6.2119
		Global: 0.8758
		Total: 10.59