# Test-time dropout script

The goal is to develop a function with a command line interface that takes a trained model with dropout and returns an ensemble prediction, so I imagine something like:

```
python create_dropout_ensemble.py --exp_id 44-resnet_deeper2 --members 100 ...
```

The script should return and save a xarray dataset just like `create_prediction` but with an added dimension `ens_member`.

You basically already did the work in the starter exercise I gave you. You can also check out my solution. Now it's just a matter of creating a convenient script. For examples of command line scripts I wrote, check out `src/extract_level.py` using `argparse` or `scripts/download_tigge.py` using Google's `fire`. Also, see whether your or my method of implementing the test-time dropout is more convenient. Whatever requires fewer changes to the rest of the code (probably yours).

As mentioned in the WeatherBench paper, testing is done using the years 2017 and 2018. This means the ensemble predictions also have to be created for these two years. The data can be downloaded here: https://mediatum.ub.tum.de/1524895. However, the files, which contain all years, are quite large, so you probably don't want to download it to your laptop. I uploaded just the last two years for each variable here: To come...

Next, you need a trained model. I number my experiments (see Dropbox document). You can find two different models in the link above. 

As mentioned in the Dropbox document, I would suggest developing the main function in the notebook. Once that works, you can create a CLI around it and save the script. 

Also, let's use `tensorflow>=2.0`.

#This notebook is just for testing. Script saved as create_dropout_ensemble.py

ToDo:
- make it work for all networks. #(Differences: custom_objects, -can be done with an if conditon on load_model(), #output_vars, test_years, lead_time?, anything else?
- load full data instead of batches. output for full size of X.
- pass optional arguments. like is_normalized, start_date, end_date, test_years
- what to do if output vars are different?? numpy-->xarray wont work
- solve eager_execution problem

In [70]:
# Here is a useful tip: Using autoreload allows you to make changes to an imported module
# which are then automatically updated in this notebook. This is how I start all my notebooks.
%load_ext autoreload
%autoreload 2 # Every two seconds

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [71]:
import fire
from fire import Fire
import xarray as xr
import numpy as np
from src.data_generator import *
from src.train import *
from src.networks import *
from src.utils import *
from tensorflow.keras import backend as K

In [72]:
tf.compat.v1.disable_eager_execution() #needed?
tf.__version__

'2.1.0'

In [73]:
exp_id_path='../nn_configs/B/13.1-resnet_bn_dr_0.1.yml'

In [74]:
!ls {exp_id_path}

../nn_configs/B/13.1-resnet_bn_dr_0.1.yml


In [75]:
#Getting INput
args = load_args(exp_id_path)
var_dict=args['var_dict']
lead_time=args['lead_time']
batch_size=args['batch_size']
output_vars=args['output_vars'] #ToDo: need to change numpy-->xarray function if this changes.
# #Question: Should we input data_subsample, norm_subsample, nt_in, dt_in, test_year?
# # data_subsample=args['data_subsample']
# # norm_subsample=args['norm_subsample']
# # nt_in=args['nt_in']
# # dt_in=args['dt_in']
# # test_years=args['test_years']

In [76]:
ds = xr.merge([xr.open_mfdataset(f'../../data/WeatherBench/5.625deg/{var}/*.nc', combine='by_coords') for var in var_dict.keys()])

In [77]:
#Question: How will this be available?
mean = xr.open_dataarray('../../data/WeatherBench/5.625deg/13-mean.nc') #for year 2018??
std = xr.open_dataarray('../../data/WeatherBench/5.625deg/13-std.nc')

In [78]:
#ds_test

In [79]:
start_time='2017-01-01';end_time='2018-12-31' #want to use as optional arguments
ds_test= ds.sel(time=slice('2017-01-01', '2018-12-31'))

#Question: Should we input data_subsample, norm_subsample, nt_in, dt_in?
# dg_test = DataGenerator(ds_test, var_dict, lead_time, batch_size=32, shuffle=True, load=True,
#                  mean=None, std=None, output_vars=None, data_subsample=1, norm_subsample=1,
#                  nt_in=1, dt_in=1 )
dg_test = DataGenerator(
    ds_test, var_dict, lead_time, batch_size=batch_size, mean=mean, std=std,
    shuffle=False, output_vars=output_vars
)

DG start 17:22:06.279497
DG normalize 17:22:06.292763
DG load 17:22:06.299067
Loading data into RAM
DG done 17:22:09.993878


In [84]:
X,y=dg_test[0]
X.shape, y.shape

((64, 32, 64, 15), (64, 32, 64, 2))

In [80]:
!pwd

/home/garg/WeatherBench/nbs_probabilistic


In [81]:
exp_id=args['exp_id']
saved_model_path=f'../../data/WeatherBench/predictions/saved_models/{exp_id}.h5'

In [82]:
!ls {saved_model_path}

../../data/WeatherBench/predictions/saved_models/13.1-resnet_bn_dr_0.1.h5


In [83]:
mymodel=tf.keras.models.load_model(saved_model_path
                           ,custom_objects={'PeriodicConv2D':PeriodicConv2D,'lat_mse': tf.keras.losses.mse})

In [69]:
#provide details of model- dropout rate, optimiser, loss fn.
#mymodel.summary()
#mymodel.optimizer

In [None]:
#ds

In [85]:
from tensorflow.keras import backend as K
func = K.function(mymodel.inputs + [K.learning_phase()], mymodel.outputs)

In [86]:
number_of_forecasts=5
time=2 ## keep it lower for testing code. takes time.
#number of inputs. different input times each for which an ensemble of predictions is made. 
pred_ensemble = np.array([np.asarray(func([X[:time]] + [1.]), dtype=np.float32).squeeze() for _ in range(number_of_forecasts)])



In [87]:
pred_ensemble.shape

(5, 2, 32, 64, 2)

In [88]:
#unnormalize
pred_ensemble=pred_ensemble* dg_test.std.isel(level=dg_test.output_idxs).values+dg_test.mean.isel(level=dg_test.output_idxs).values



In [89]:
    #convert from numpy to xarray
    preds = xr.Dataset({
    'z500': xr.DataArray(pred_ensemble[...,0],
        dims=['forecast_number', 'time','lat', 'lon'],
        coords={'forecast_number': np.arange(number_of_forecasts),'time': np.arange(time), 'lat': dg_test.data.lat, 'lon': dg_test.data.lon,},)
    ,
    't850': xr.DataArray(pred_ensemble[...,1],
        dims=['forecast_number', 'time','lat', 'lon'],
        coords={'forecast_number': np.arange(number_of_forecasts),'time': np.arange(time), 'lat': dg_test.data.lat, 'lon': dg_test.data.lon,},)
})

observation= xr.Dataset({
    'z500': xr.DataArray(y[:time,:,:,0],
                         dims=['time','lat','lon'],
                         coords={'time':np.arange(time),'lat':dg_test.data.lat,'lon':dg_test.data.lon},)
    ,
    't850': xr.DataArray(y[:time,:,:,1],dims=['time','lat','lon'],coords={'time':np.arange(time),'lat':dg_test.data.lat,'lon':dg_test.data.lon},)          
})

In [90]:
preds.t850.isel(time=0,forecast_number=0,lat=0,lon=0).values

array(254.80960183)

In [97]:
preds.to_netcdf(f'../../data/WeatherBench/predictions/{exp_id}.nc')

In [95]:
preds