In [None]:
#Comparing Scores for tigge, deterministic, parametric, 
#test-time dropout models

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import xarray as xr
import xskillscore as xs
import matplotlib.pyplot as plt
from src.data_generator import *
from src.train import *
from src.utils import *
from src.networks import *
from src.score import *

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"]=str(0)
#limit_mem()

In [4]:
#if model was trained on mixed precision policy, loading on the same policy necessary? Yes. see verbose results of model.predict
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)

In [5]:
model_save_dir='/home/garg/data/WeatherBench/predictions/saved_models'
datadir='/home/garg/data/WeatherBench/5.625deg'
pred_save_dir='/home/garg/data/WeatherBench/predictions'

In [6]:
!ls $pred_save_dir

01-resnet_baseline.nc		      80.1-resnet_d3_dr_0.05.nc
028_cat.npy			      81-resnet_d3_dr_0.1.nc
028_cat_truth.npy		      81-resnet_d3_dr_0.1_fixed.nc
100-resnet_d3_param.nc		      81.1-resnet_d3_dr_0.1.h5
120-resnet_d3_param_relu.h5	      81.1-resnet_d3_dr_0.1.nc
120-resnet_d3_param_relu.nc	      81.1-resnet_d3_dr_0.1_history.pkl
120-resnet_d3_param_relu_history.pkl  81.1-resnet_d3_dr_0.1_mean.nc
120-resnet_d3_param_relu_mean.nc      81.1-resnet_d3_dr_0.1_std.nc
120-resnet_d3_param_relu_std.nc       81.1-resnet_d3_dr_0.1_weights.h5
120-resnet_d3_param_relu_weights.h5   82-resnet_d3_dr_0.2.nc
138-resnet_prec.nc		      82.1-resnet_d3_dr_0.2.h5
79.1-resnet_d3_dr_0.0.h5	      82.1-resnet_d3_dr_0.2.nc
79.1-resnet_d3_dr_0.0_history.pkl     82.1-resnet_d3_dr_0.2_history.pkl
79.1-resnet_d3_dr_0.0_mean.nc	      82.1-resnet_d3_dr_0.2_mean.nc
79.1-resnet_d3_dr_0.0_std.nc	      82.1-resnet_d3_dr_0.2_std.nc
79.1-resnet_d3_dr_0.0_weights.h5      82.1-resnet_d3_dr_0.2_weights.h5
80-

## Data

In [7]:
!ls {datadir}

2017_2018_subset.zip	  geopotential	     temperature_850
6hr_precipitation	  geopotential_500   toa_incident_solar_radiation
backup_specific_humidity  specific_humidity  u_component_of_wind
constants		  temperature	     v_component_of_wind


In [8]:
z500_valid = load_test_data(f'{datadir}/geopotential_500', 'z').drop('level')
t850_valid = load_test_data(f'{datadir}/temperature_850', 't').drop('level')
#precipitation add.
valid=xr.merge([z500_valid,t850_valid])
valid

Unnamed: 0,Array,Chunk
Bytes,143.52 MB,71.76 MB
Shape,"(17520, 32, 64)","(8760, 32, 64)"
Count,6 Tasks,2 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 143.52 MB 71.76 MB Shape (17520, 32, 64) (8760, 32, 64) Count 6 Tasks 2 Chunks Type float32 numpy.ndarray",64  32  17520,

Unnamed: 0,Array,Chunk
Bytes,143.52 MB,71.76 MB
Shape,"(17520, 32, 64)","(8760, 32, 64)"
Count,6 Tasks,2 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,143.52 MB,71.76 MB
Shape,"(17520, 32, 64)","(8760, 32, 64)"
Count,6 Tasks,2 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 143.52 MB 71.76 MB Shape (17520, 32, 64) (8760, 32, 64) Count 6 Tasks 2 Chunks Type float32 numpy.ndarray",64  32  17520,

Unnamed: 0,Array,Chunk
Bytes,143.52 MB,71.76 MB
Shape,"(17520, 32, 64)","(8760, 32, 64)"
Count,6 Tasks,2 Chunks
Type,float32,numpy.ndarray


In [None]:
# Deterministic

In [9]:
args = load_args('../nn_configs/B/81.1-resnet_d3_dr_0.1.yml')

args['model_save_dir']=model_save_dir
args['datadir']=datadir
args['pred_save_dir']=pred_save_dir

exp_id=args['exp_id']
mean = xr.open_dataarray(f'{model_save_dir}/{exp_id}_mean.nc') 
std = xr.open_dataarray(f'{model_save_dir}/{exp_id}_std.nc')

args['ext_mean']=mean; args['ext_std']=std

In [10]:
dg_test=load_data(**args, only_test=True)

In [11]:
x,y=dg_test[0]
print(x.shape, y.shape)

(32, 32, 64, 114) (32, 32, 64, 2)


In [12]:
saved_model_path=f'{model_save_dir}/{exp_id}.h5'
model=tf.keras.models.load_model(saved_model_path,
                                 custom_objects={'PeriodicConv2D':PeriodicConv2D,'lat_mse': tf.keras.losses.mse})

preds = model.predict(dg_test, verbose=1) #deterministic
preds=preds* dg_test.std.isel(level=dg_test.output_idxs).values+dg_test.mean.isel(level=dg_test.output_idxs).values



In [48]:
preds.shape

(8722, 32, 64, 2)

In [20]:
preds_d = xr.Dataset()
for i,var in enumerate(args['output_vars']):
    da= xr.DataArray(preds[...,i], 
                     coords={
                             'time': dg_test.valid_time,
                             'lat': dg_test.data.lat, 'lon': dg_test.data.lon,}, 
                     dims=['time','lat', 'lon'])
    preds_d[var]=da
preds_d=xr.Dataset.expand_dims(preds_d,'member')#for easy calculation of scores.

In [49]:
preds_d

In [None]:
# test-time dropouts

In [51]:
exp_id='80.1-resnet_d3_dr_0.05'
preds_1=xr.open_dataset(f'{pred_save_dir}/{exp_id}.nc')

exp_id='81.1-resnet_d3_dr_0.1'
preds_2=xr.open_dataset(f'{pred_save_dir}/{exp_id}.nc')

exp_id='82.1-resnet_d3_dr_0.2'
preds_3=xr.open_dataset(f'{pred_save_dir}/{exp_id}.nc')

In [52]:
preds_1 #smaller time frame

In [None]:
# Scores

In [None]:
mean_spread = []
mean_error = []

for ds in [preds_d, preds_1, preds_2, preds_3]:
    ds=ds.chunk({'time': 500})
    mean_spread.append(compute_weighted_meanspread(ds).load())
    ens_mean = ds.mean('member')
    mean_error.append(compute_weighted_rmse(ens_mean, valid).load())

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 5))
ax1.bar(['deter','dr_0.05', 'dr_0.1', 'dr_0.2'], [ds.z_500_rmse for ds in mean_error])
ax2.bar(['deter','dr_0.05', 'dr_0.1', 'dr_0.2'], [ds.z_500_mean_spread for ds in mean_spread])
ax3.bar(['deter','dr_0.05', 'dr_0.1', 'dr_0.2'], [ds1.z_500_mean_spread / ds2.z_500_rmse 
                                   for ds1, ds2 in zip(mean_spread, mean_error)])
ax1.set_title('Error')
ax2.set_title('Spread')
ax3.set_title('Spread/skill');
#increase in dropout will increase uncertainty but how do we know if it is the right amount of uncertainty. i think at a fixed dropout value, uncertainty remains same regardless of data?- if thats so, then it just simply tells uncertainty of model, and learns nothing of the inherent uncertainty in data.
#no it learns from data. maintaining the same error, increasing its uncertainty spread.