# Test-time dropout script

The goal is to develop a function with a command line interface that takes a trained model with dropout and returns an ensemble prediction, so I imagine something like:

```
python create_dropout_ensemble.py --exp_id 44-resnet_deeper2 --members 100 ...
```

The script should return and save a xarray dataset just like `create_prediction` but with an added dimension `ens_member`.

You basically already did the work in the starter exercise I gave you. You can also check out my solution. Now it's just a matter of creating a convenient script. For examples of command line scripts I wrote, check out `src/extract_level.py` using `argparse` or `scripts/download_tigge.py` using Google's `fire`. Also, see whether your or my method of implementing the test-time dropout is more convenient. Whatever requires fewer changes to the rest of the code (probably yours).

As mentioned in the WeatherBench paper, testing is done using the years 2017 and 2018. This means the ensemble predictions also have to be created for these two years. The data can be downloaded here: https://mediatum.ub.tum.de/1524895. However, the files, which contain all years, are quite large, so you probably don't want to download it to your laptop. I uploaded just the last two years for each variable here: To come...

Next, you need a trained model. I number my experiments (see Dropbox document). You can find two different models in the link above. 

As mentioned in the Dropbox document, I would suggest developing the main function in the notebook. Once that works, you can create a CLI around it and save the script. 

Also, let's use `tensorflow>=2.0`.

#NOTE: This notebook is just for testing. Script saved as create_dropout_ensemble.py

ToDo:
- make it work for all networks. #(Differences: custom_objects, -can be done with an if conditon on load_model(), #output_vars, test_years, lead_time?, anything else?
- load full data instead of batches. output for full size of X.
- pass optional arguments. like is_normalized, start_date, end_date, test_years
- solve eager_execution problem

In [1]:
# Here is a useful tip: Using autoreload allows you to make changes to an imported module
# which are then automatically updated in this notebook. This is how I start all my notebooks.
%load_ext autoreload
%autoreload 2 # Every two seconds

In [2]:
import fire
from fire import Fire
import xarray as xr
import numpy as np
from src.data_generator import *
from src.train import *
from src.networks import *
from src.utils import *
from tensorflow.keras import backend as K

In [3]:
# You only need this if you are using a GPU
os.environ["CUDA_VISIBLE_DEVICES"]=str(0)
limit_mem()

In [None]:
#Final Working Script
# exp_id_path='/home/garg/WeatherBench/nn_configs/B/63-resnet_d3_best.yml'
# model_save_dir='/home/garg/data/WeatherBench/predictions/saved_models'
# datadir='/home/garg/data/WeatherBench/5.625deg'
# pred_save_dir='/home/garg/data/WeatherBench/predictions'

# !python create_dropout_ensemble.py 5 {exp_id_path} {datadir} {model_save_dir} {pred_save_dir}

#Everything from below is just for practice. CAN IGNORE!

In [None]:
#use conda-forge
#!conda uninstall tensorflow --y
#!conda install -c conda-forge tensorflow-gpu=2.0.0
#check CUDA compatibility: https://www.tensorflow.org/install/source#tested_build_configurations

In [None]:
# from tensorflow.python.client import device_lib
# print(device_lib.list_local_devices())

In [4]:
import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

Num GPUs Available:  1


In [5]:
tf.compat.v1.disable_eager_execution() #needed 
tf.__version__
# tf.debugging.set_log_device_placement(True)

'2.0.0'

In [6]:
exp_id_path='/home/garg/WeatherBench/nn_configs/B/80.1-resnet_d3_dr_0.05.yml'
!ls {exp_id_path}

/home/garg/WeatherBench/nn_configs/B/80.1-resnet_d3_dr_0.05.yml


In [7]:
    args=load_args(exp_id_path)
    exp_id=args['exp_id']
    var_dict=args['var_dict']
    batch_size=args['batch_size']
    output_vars=args['output_vars']
    
    #Question: how to optionally  input data_subsample, norm_subsample, nt_in, dt_in, test_years?
    data_subsample=args['data_subsample']
    norm_subsample=args['norm_subsample']
    nt_in=args['nt_in']
    #nt_in=args['nt']
    dt_in=args['dt_in']
    test_years=args['test_years']
    lead_time=args['lead_time']
    #changing paths
    model_save_dir='/home/garg/data/WeatherBench/predictions/saved_models'
    datadir='/home/garg/data/WeatherBench/5.625deg'

In [8]:
var_dict

{'geopotential': ('z', [50, 250, 500, 600, 700, 850, 925]),
 'temperature': ('t', [50, 250, 500, 600, 700, 850, 925]),
 'u_component_of_wind': ('u', [50, 250, 500, 600, 700, 850, 925]),
 'v_component_of_wind': ('v', [50, 250, 500, 600, 700, 850, 925]),
 'specific_humidity': ('q', [50, 250, 500, 600, 700, 850, 925]),
 'constants': ['lsm', 'orography', 'lat2d']}

In [9]:
nt_in#Ques: difference b/w nt and nt_in 
# --> A: nt = number of time steps corresponding to forecast leat time
# nt_in = number of time steps in the input

3

In [10]:
ds = xr.merge([xr.open_mfdataset(f'{datadir}/{var}/*.nc', combine='by_coords') for var in var_dict.keys()])
mean = xr.open_dataarray(f'{model_save_dir}/{exp_id}_mean.nc') 
std = xr.open_dataarray(f'{model_save_dir}/{exp_id}_std.nc')

In [11]:
data_subsample

2

In [12]:
start_date='2017-01-01';end_date='2017-12-31'
#start_date=None;end_date=None

In [13]:
#Ques:  shuffle should be false? since its testing  --> Correct
#Question: Should we input data_subsample, norm_subsample, nt_in, dt_in? 
#for instance, dt_in not always provided in config file.
# nt_in, data_subsample is needed. 
# predictions for every time step. norm_subsample doesn't matter since we pass an external mean/std file
if (start_date and end_date)!=None:
    ds_test=ds.sel(time=slice(start_date,end_date))
else:
    ds_test= ds.sel(time=slice(test_years[0],test_years[-1]))
dg_test = DataGenerator(ds_test, var_dict, lead_time, batch_size=batch_size, shuffle=False, load=True,
                 mean=mean, std=std, output_vars=output_vars, nt_in=nt_in, dt_in=dt_in, data_subsample = data_subsample) 
# dg_test = DataGenerator(
#     ds_test, var_dict, lead_time, batch_size=batch_size, mean=mean, std=std,
#     shuffle=False, output_vars=output_vars)

In [14]:
#NOT a good idea to load whole data at once. rather load a batch, make prediction, and so on make a loop.

# X,y=dg_test[0]
# for i in range(len(dg_test)):
#     X_batch,y_batch=dg_test[i+1]
#     X=np.append(X,X_batch,axis=0)
#     y=np.append(y,y_batch,axis=0)

In [15]:
# Number of time steps in the data set
dg_test.data.time.shape

(4380,)

In [16]:
# Number of time steps to forecast
dg_test.nt

36

In [17]:
# Number of samples (because we need a y for every x)
# But this isn't the actual number of samples (yeah, legacy code...)
dg_test.n_samples, dg_test.data.time.shape[0] - dg_test.nt

(4344, 4344)

In [18]:
# For the actual number of sample you also have to subtract the number of input time steps (-1) = nt_offset
# Yeah, this could probably be cleaned up.
len(dg_test.idxs), dg_test.data.time.shape[0] - dg_test.nt - dg_test.nt_offset

(4342, 4342)

In [19]:
# dg_test.data.time.isel(time=slice(None,X.shape[0])) #would work for any size of x

In [20]:
PeriodicConv2D, tf.keras.losses.mse

(src.networks.PeriodicConv2D,
 <function tensorflow.python.keras.losses.mean_squared_error(y_true, y_pred)>)

In [None]:
# policy = mixed_precision.Policy('mixed_float16')
# mixed_precision.set_policy(policy)

In [None]:
# #ToDo: add other loss functions to custom_objects. doesn't matter if it is not used in the model itself, only so that load_model() doesn't break)
# #Since we dont build again, we dont need to pass model params like kernel, filters, activation, dropout,loss and other details to the network?
# saved_model_path=f'{model_save_dir}/{exp_id}.h5'
# substr=['resnet','unet_google','unet']
# assert any(x in exp_id for x in substr)
# # model=tf.keras.models.load_model(saved_model_path, custom_objects={'PeriodicConv2D':PeriodicConv2D})
# model=tf.keras.models.load_model(saved_model_path,
#                                 custom_objects={'PeriodicConv2D':PeriodicConv2D,'lat_mse': tf.keras.losses.mse})

In [None]:
# %debug

In [21]:
dg_test.shape

(32, 64, 114)

In [22]:
def convblock(inputs, filters, kernel=3, stride=1, bn_position=None, l2=0,
              use_bias=True, dropout=0, activation='relu', test_dropout=False):
    x = inputs
    if bn_position == 'pre': x = BatchNormalization()(x)
    x = PeriodicConv2D(
        filters, kernel, conv_kwargs={
            'kernel_regularizer': regularizers.l2(l2),
            'use_bias': use_bias
        }
    )(x)
    if bn_position == 'mid': x = BatchNormalization()(x)
    x = LeakyReLU()(x) if activation == 'leakyrelu' else Activation(activation)(x) 
    if bn_position == 'post': x = BatchNormalization()(x)
    if dropout > 0: x = Dropout(dropout)(x, training=test_dropout)
    return x

def resblock(inputs, filters, kernel, bn_position=None, l2=0, use_bias=True,
             dropout=0, skip=True, activation='relu', down=False, up=False, test_dropout=False):
    x = inputs
    if down:
        x = MaxPooling2D()(x)
    for i in range(2):
        x = convblock(
            x, filters, kernel, bn_position=bn_position, l2=l2, use_bias=use_bias,
            dropout=dropout, activation=activation, test_dropout=test_dropout
        )
    if down or up:
        inputs = PeriodicConv2D(
            filters, kernel, conv_kwargs={
                'kernel_regularizer': regularizers.l2(l2),
                'use_bias': use_bias,
                'strides': 2 if down else 1
            }
        )(inputs)
    if skip: x = Add()([inputs, x])
    return x

def build_resnet(filters, kernels, input_shape, bn_position=None, use_bias=True, l2=0,
                 skip=True, dropout=0, activation='relu', test_dropout=False, **kwargs):
    x = input = Input(shape=input_shape)

    # First conv block to get up to shape
    x = convblock(
        x, filters[0], kernels[0], bn_position=bn_position, l2=l2, use_bias=use_bias,
        dropout=dropout, activation=activation, test_dropout=test_dropout
    )

    # Resblocks
    for f, k in zip(filters[1:-1], kernels[1:-1]):
        x = resblock(x, f, k, bn_position=bn_position, l2=l2, use_bias=use_bias,
                dropout=dropout, skip=skip, activation=activation, test_dropout=test_dropout)

    # Final convolution
    output = PeriodicConv2D(
        filters[-1], kernels[-1],
        conv_kwargs={'kernel_regularizer': regularizers.l2(l2)},
    )(x)
    output = Activation('linear', dtype='float32')(output)
    return keras.models.Model(input, output)

In [23]:
model=build_resnet(args['filters'],
    args['kernels'],
    dg_test.shape,
    bn_position=args['bn_position'],
    use_bias=args['use_bias'],
    l2=args['l2'],
    skip=args['skip'],
    dropout=args['dropout'],
    activation=args['activation'],
    test_dropout=True)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [24]:
model.load_weights(f'{model_save_dir}/{exp_id}_weights.h5')

In [25]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 32, 64, 114) 0                                            
__________________________________________________________________________________________________
periodic_conv2d (PeriodicConv2D (None, 32, 64, 128)  715136      input_1[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 32, 64, 128)  0           periodic_conv2d[0][0]            
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 32, 64, 128)  512         leaky_re_lu[0][0]                
______________________________________________________________________________________________

In [26]:
exp_id

'80.1-resnet_d3_dr_0.05'

In [None]:
# exp_id='82-resnet_d3_dr_0.2'
# saved_model_path=f'{model_save_dir}/{exp_id}.h5'
# substr=['resnet','unet_google','unet']
# assert any(x in exp_id for x in substr)

# model1=tf.keras.models.load_model(saved_model_path,
#                                  custom_objects={'PeriodicConv2D':PeriodicConv2D,'lat_mse': tf.keras.losses.mse})

# model1.summary()

# pred1=model1.predict(dg_test, verbose=1)

In [27]:
ensemble_size = 100 # 50
preds = []
for _ in tqdm(range(ensemble_size)):
    preds.append(model.predict(dg_test))
    
pred_ensemble = np.array(preds)
pred_ensemble=(pred_ensemble * dg_test.std.isel(level=dg_test.output_idxs).values +
                   dg_test.mean.isel(level=dg_test.output_idxs).values)

preds = xr.Dataset()
for i,var in enumerate(output_vars):
    da= xr.DataArray(pred_ensemble[...,i], 
                         coords={'member': np.arange(ensemble_size),
                                 'time': dg_test.valid_time,
                                 'lat': dg_test.data.lat, 'lon': dg_test.data.lon,}, 
                         dims=['member', 'time','lat', 'lon'])
    preds[var]=da

HBox(children=(FloatProgress(value=0.0), HTML(value='')))




In [28]:
exp_id

'80.1-resnet_d3_dr_0.05'

In [None]:
preds.to_netcdf(f'../../data/WeatherBench/predictions/{exp_id}.nc')

In [None]:
preds

In [None]:
#OLD

In [None]:
preds = model.predict(dg_test, verbose=1) #deterministic
preds=preds* dg_test.std.isel(level=dg_test.output_idxs).values+dg_test.mean.isel(level=dg_test.output_idxs).values

In [None]:
#numpy -->xarray
preds_d = xr.Dataset()
for i,var in enumerate(output_vars):
    da= xr.DataArray(preds[...,i], 
                     coords={
                             'time': dg_test.valid_time,
                             'lat': dg_test.data.lat, 'lon': dg_test.data.lon,}, 
                     dims=['time','lat', 'lon'])
    preds_d[var]=da

In [None]:
preds_d=xr.Dataset.expand_dims(preds_d,'member')

In [None]:
preds_d.t_850

In [None]:
print(len(dg_test))
X,y=dg_test[len(dg_test)-1]
X.shape, y.shape

In [None]:
X, y = dg_test[0]

In [None]:
import tqdm

In [None]:
%%time
p = model.predict(dg_test, verbose=1) #deterministic

In [None]:
p

In [None]:
# func = K.function(model.inputs + [K.learning_phase()], model.outputs) #slow method

# preds = []
# for X, y in tqdm.tqdm(dg_test): 
#     preds.append(np.asarray(func([X] + [1.]), dtype=np.float32).squeeze())

In [None]:
# So unfortunately this is much slower. Not entirely sure why but I think this means that we do not want to use K.function after all. 
# Below is a workaround that allows us to load the model and then change the training attribute afterwards. 
# Super ugly but I really can't think of a better way.

In [None]:
# model=tf.keras.models.load_model(saved_model_path,
#                                  custom_objects={'PeriodicConv2D':PeriodicConv2D,'lat_mse': tf.keras.losses.mse})

In [None]:
[model.predict(X[:1])[0, 0, 0,0] for _ in range(3)]   # Always the same output, no test time dropout

In [None]:
# c = model.get_config()

In [None]:
# c

In [None]:
# for l in c['layers']:
#     if l['class_name'] == 'Dropout':
#         l['inbound_nodes'][0][0][-1] = {'training': True}

In [None]:
# model=keras.models.Model.from_config(c, custom_objects={'PeriodicConv2D':PeriodicConv2D,'lat_mse': tf.keras.losses.mse})

In [None]:
# model2 = keras.models.Model.from_config(c, custom_objects={'PeriodicConv2D':PeriodicConv2D,'lat_mse': tf.keras.losses.mse})

# model2.set_weights(model.get_weights())

# model=model2

In [None]:
# [model.predict(X[:1])[0, 0, 0,1] for _ in range(3)]   # Different output everytime = dropout on :)

In [None]:
%%time
p = model.predict(dg_test, verbose=1)   # Maybe slightly slower because of dropout

In [None]:
p.shape

In [None]:
from tqdm import tqdm

In [None]:
# Here is a new version without K.function. Much easier thankfully
ensemble_size = 4 # 50
preds = []
for _ in tqdm(range(ensemble_size)):
    preds.append(model.predict(dg_test))

In [None]:
pred_ensemble = np.array(preds)
pred_ensemble.shape   # No transposing necessary

In [None]:
#observation y
z500_valid = load_test_data(f'{datadir}/geopotential_500', 'z')
t850_valid=load_test_data(f'{datadir}/temperature_850','t')

observation=xr.Dataset()
observation['z_500']=z500_valid
observation['t_850']=t850_valid
observation=observation.astype('float64')
observation=observation.drop('level')

observation

In [None]:
dg_test.valid_time

In [None]:
z500_valid.sel(time=dg_test.valid_time) #dg_test.valid_time is what you need.

In [None]:
pred_ensemble.shape

In [None]:
pred_ensemble_reserve=pred_ensemble
#observation_reserve=y
#observation=y

In [None]:
#unnormalize
pred_ensemble=pred_ensemble* dg_test.std.isel(level=dg_test.output_idxs).values+dg_test.mean.isel(level=dg_test.output_idxs).values
#observation=observation* dg_test.std.isel(level=dg_test.output_idxs).values+dg_test.mean.isel(level=dg_test.output_idxs).values

In [None]:
pred_ensemble.shape[1]

In [None]:
number_of_forecasts=4

In [None]:
dg_test.data.time.sel(time=dg_test.valid_time)

In [None]:
preds = xr.Dataset()
for i,var in enumerate(output_vars):
    da= xr.DataArray(pred_ensemble[...,i], 
                     coords={'member': np.arange(number_of_forecasts),
                             'time': dg_test.data.time.sel(time=dg_test.valid_time),
                             'lat': dg_test.data.lat, 'lon': dg_test.data.lon,}, 
                     dims=['member', 'time','lat', 'lon'])
    preds[var]=da
    print(i)
    print(var)

In [None]:
preds

In [None]:
#pred_dataset

In [None]:
#xr.Dataset.equals(pred_dataset,preds)

In [None]:
#observation

In [None]:
#preds.t850.isel(time=0,forecast_number=0,lat=0,lon=0).values

In [None]:
exp_id='82.1-resnet_d3_dr_0.2'

In [None]:
preds.to_netcdf(f'../../data/WeatherBench/predictions/{exp_id}.nc')

In [None]:
preds.t_850

In [None]:
def compute_weighted_crps(da_fc, da_true, mean_dims=xr.ALL_DIMS):
    weights_lat = np.cos(np.deg2rad(da_fc.lat))
    weights_lat /= weights_lat.mean()
    crps = xs.crps_ensemble(da_true, da_fc)
    crps = (crps * weights_lat).mean(mean_dims)
    return crps

In [None]:
crps=compute_weighted_crps(preds, observation.sel(time=preds.time))

In [None]:
crps.z_500.values, crps.t_850.values

In [None]:
from ranky import rankz

obs = np.asarray(observation.to_array(), dtype=np.float32).squeeze();
obs_z500=obs[0,...].squeeze()
obs_t850=obs[1,...].squeeze()

pred=np.asarray(preds.to_array(), dtype=np.float32).squeeze();
pred_z500=pred[0,...].squeeze() 
pred_t850=pred[1,...].squeeze() 

mask=np.ones(obs_z500.shape) #useless
# feed into rankz function
result = rankz(obs_z500, pred_z500, mask)
# plot histogram
plt.bar(range(1,pred_z500.shape[0]+2), result[0])
# view histogram
plt.show() ##overconfident (underdispersive)