In [1]:
%load_ext autoreload
%autoreload 2

In [22]:
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import *
import tensorflow.keras.backend as K
import seaborn as sns
import pickle
import re
from src.score import *
from collections import OrderedDict

In [8]:
from src.networks import *
from src.utils import *

In [9]:
os.environ["CUDA_VISIBLE_DEVICES"]=str(1)

In [10]:
datadir = '/data/stephan/WeatherBench/5.625deg/'

In [11]:
!ls $datadir

10m_u_component_of_wind  potential_vorticity	       total_cloud_cover
10m_v_component_of_wind  relative_humidity	       total_precipitation
2m_temperature		 specific_humidity	       u_component_of_wind
constants		 temperature		       v_component_of_wind
geopotential		 temperature_850	       vorticity
geopotential_500	 toa_incident_solar_radiation


In [13]:
var_dict = {
    'geopotential': ('z', [500, 850]),
    'toa_incident_solar_radiation': ('tisr', None),
    'constants': ['lat2d', 'orography']
}

In [14]:
ds = [xr.open_mfdataset(f'{datadir}/{var}/*.nc', combine='by_coords') for var in var_dict.keys()]

In [15]:
ds = xr.merge(ds)

In [16]:
var_dict

{'geopotential': ('z', [500, 850]),
 'toa_incident_solar_radiation': ('tisr', None),
 'constants': ['lat2d', 'orography']}

In [17]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self, ds, var_dict, lead_time, batch_size=32, shuffle=True, load=True, 
                 mean=None, std=None, output_vars=None):
        """
        Data generator for WeatherBench data.
        Template from https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
        Args:
            ds: Dataset containing all variables
            var_dict: Dictionary of the form {'var': level}. Use None for level if data is of single level
            lead_time: Lead time in hours
            batch_size: Batch size
            shuffle: bool. If True, data is shuffled.
            load: bool. If True, datadet is loaded into RAM.
            mean: If None, compute mean from data.
            std: If None, compute standard deviation from data.
        """

        self.ds = ds
        self.var_dict = var_dict
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.lead_time = lead_time

        data = []
        level_names = []
        generic_level = xr.DataArray([1], coords={'level': [1]}, dims=['level'])
        for long_var, params in var_dict.items():
            if long_var == 'constants': 
                for var in params:
                    data.append(ds[var].expand_dims(
                        {'level': generic_level, 'time': ds.time}, (1, 0)
                    ))
                    level_names.append(var)
            else:
                var, levels = params
                try:
                    data.append(ds[var].sel(level=levels))
                    level_names += [f'{var}_{level}' for level in levels]
                except ValueError:
                    data.append(ds[var].expand_dims({'level': generic_level}, 1))
                    level_names.append(var)

        self.data = xr.concat(data, 'level').transpose('time', 'lat', 'lon', 'level')
        self.data['level_names'] = xr.DataArray(
            level_names, dims=['level'], coords={'level': self.data.level})
        if output_vars is None:
            self.output_idxs = range(len(dg_valid.data.level))
        else:
            self.output_idxs = [i for i, l in enumerate(self.data.level_names.values) 
                                if any([bool(re.match(o, l)) for o in output_vars])]
        
        # Normalize
        self.mean = self.data.mean(('time', 'lat', 'lon')).compute() if mean is None else mean
#         self.std = self.data.std('time').mean(('lat', 'lon')).compute() if std is None else std
        self.std = self.data.std(('time', 'lat', 'lon')).compute() if std is None else std
        self.data = (self.data - self.mean) / self.std
        
        self.n_samples = self.data.isel(time=slice(0, -lead_time)).shape[0]
        self.init_time = self.data.isel(time=slice(None, -lead_time)).time
        self.valid_time = self.data.isel(time=slice(lead_time, None)).time

        self.on_epoch_end()

        # For some weird reason calling .load() earlier messes up the mean and std computations
        if load: print('Loading data into RAM'); self.data.load()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(self.n_samples / self.batch_size))

    def __getitem__(self, i):
        'Generate one batch of data'
        idxs = self.idxs[i * self.batch_size:(i + 1) * self.batch_size]
        X = self.data.isel(time=idxs).values
        y = self.data.isel(time=idxs + self.lead_time, level=self.output_idxs).values
        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.idxs = np.arange(self.n_samples)
        if self.shuffle == True:
            np.random.shuffle(self.idxs)

In [25]:
from src.data_generator import *

In [26]:
bs=32
lead_time=3*24

In [27]:
ds_train = ds.sel(time=slice('2015', '2015'))
ds_valid = ds.sel(time=slice('2016', '2016'))
ds_test = ds.sel(time=slice('2017', '2018'))

In [98]:
output_vars = ['z_500', 'z_850', 'tisr']

In [99]:
DataGenerator

src.data_generator.DataGenerator

In [100]:
%%time
dg_train = DataGenerator(ds_train, var_dict, lead_time, batch_size=bs, load=True, 
                         output_vars=output_vars)
dg_valid = DataGenerator(ds_train, var_dict, lead_time, batch_size=bs, mean=dg_train.mean, std=dg_train.std, 
                         shuffle=False, output_vars=output_vars)

DG start 10:52:54.778011
DG normalize 10:52:54.805262
DG load 10:52:56.258504
Loading data into RAM
DG done 10:52:57.190268
DG start 10:52:57.191848
DG normalize 10:52:57.204134
DG load 10:52:57.210734
Loading data into RAM
DG done 10:52:58.120746
CPU times: user 5.76 s, sys: 9.33 s, total: 15.1 s
Wall time: 3.34 s


In [101]:
dg_test = DataGenerator(ds_test, var_dict, lead_time, batch_size=bs, mean=dg_train.mean, std=dg_train.std, 
                         shuffle=False, output_vars=output_vars)

DG start 10:52:58.140861
DG normalize 10:52:58.153552
DG load 10:52:58.159698
Loading data into RAM
DG done 10:52:59.862208


In [32]:
dg_train.data.level_names

<xarray.DataArray 'level_names' (level: 5)>
array(['z_500', 'z_850', 'tisr', 'lat2d', 'orography'], dtype='<U9')
Coordinates:
  * level        (level) int64 500 850 1 1 1
    level_names  (level) <U9 'z_500' 'z_850' 'tisr' 'lat2d' 'orography'

In [33]:
X, y = dg_train[0]; X.shape, y.shape

((32, 32, 64, 5), (32, 32, 64, 1))

In [34]:
dg_train.output_idxs

[0]

In [35]:
dg_train.std

<xarray.DataArray 'z' (level: 5)>
array([3.48306094e+03, 1.55964335e+03, 1.44036266e+06, 5.19361462e+01,
       8.59872249e+02])
Coordinates:
  * level        (level) int64 500 850 1 1 1
    level_names  (level) <U9 'z_500' 'z_850' 'tisr' 'lat2d' 'orography'

In [36]:
dg_train.mean

<xarray.DataArray 'z' (level: 5)>
array([5.41249140e+04, 1.37205813e+04, 1.07486679e+06, 0.00000000e+00,
       3.79497583e+02])
Coordinates:
  * level        (level) int64 500 850 1 1 1
    level_names  (level) <U9 'z_500' 'z_850' 'tisr' 'lat2d' 'orography'

In [102]:
cnn = build_cnn([64, 64, 64, 64, 3], [5, 5, 5, 5, 5], (32, 64, 5))

In [103]:
cnn.compile(keras.optimizers.Adam(1e-4), 'mse')

In [104]:
cnn.fit_generator(dg_train, epochs=1)

Epoch 1/1


<tensorflow.python.keras.callbacks.History at 0x7f40dd68b780>

In [253]:
preds = cnn.predict_generator(dg_train)

In [258]:
dg = dg_train

In [256]:
preds.shape

(8688, 32, 64, 2)

In [259]:
dg.data.isel(level=dg.output_idxs).level

<xarray.DataArray 'level' (level: 2)>
array([500,   1])
Coordinates:
  * level        (level) int64 500 1
    level_names  (level) <U9 'z_500' 'u10'

In [260]:
preds = xr.DataArray(
    preds,
    dims=['time', 'lat', 'lon', 'level'],
    coords={'time': dg.valid_time, 'lat': dg.data.lat, 'lon': dg.data.lon, 
            'level': dg.data.isel(level=dg.output_idxs).level,
            'level_names': dg.data.isel(level=dg.output_idxs).level_names
           },
)

In [262]:
preds.shape, preds.level, preds.level_names

((8688, 32, 64, 2), <xarray.DataArray 'level' (level: 2)>
 array([500,   1])
 Coordinates:
   * level        (level) int64 500 1
     level_names  (level) <U9 'z_500' 'u10', <xarray.DataArray 'level_names' (level: 2)>
 array(['z_500', 'u10'], dtype='<U9')
 Coordinates:
   * level        (level) int64 500 1
     level_names  (level) <U9 'z_500' 'u10')

In [266]:
dg.var_dict

{'geopotential': ('z', [500]),
 '10m_u_component_of_wind': ('u10', None),
 'constants': ['lat2d', 'orography']}

In [308]:
das = []
for long_var, params in dg.var_dict.items():
    if not long_var == 'constants':
        var, levels = params
        var_names = [var] if levels is None else [f'{var}_{level}' for level in levels]
        print(var_names, levels)
        var_idxs = [i for i, v in enumerate(preds.level_names) if v in var_names]
        print(var_idxs)
        da = preds.isel(level=var_idxs)
        if levels is not None: da = da.sel(level=levels).drop('level_names')
        else: da = da.squeeze().drop('level').drop('level_names')
        das.append({var: da})

['z_500'] [500]
[0]
['u10'] None
[1]


In [310]:
das = xr.merge(das)

In [311]:
das

<xarray.Dataset>
Dimensions:  (lat: 32, level: 1, lon: 64, time: 8688)
Coordinates:
  * time     (time) datetime64[ns] 2015-01-04 ... 2015-12-31T23:00:00
  * lat      (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19
  * lon      (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4
  * level    (level) int64 500
Data variables:
    z        (time, lat, lon, level) float32 -1.6765449 ... -0.17071085
    u10      (time, lat, lon) float32 -0.71567136 -0.5935094 ... 0.08219392

In [67]:
'tisr'.split('_')

['tisr']

In [146]:
def create_predictions(model, dg):
    """Create non-iterative predictions"""
    preds = xr.DataArray(
        model.predict_generator(dg),
        dims=['time', 'lat', 'lon', 'level'],
        coords={'time': dg.valid_time, 'lat': dg.data.lat, 'lon': dg.data.lon, 
                'level': dg.data.isel(level=dg.output_idxs).level,
                'level_names': dg.data.isel(level=dg.output_idxs).level_names
               },
    )
    # Unnormalize
    preds = (preds * dg.std.isel(level=dg.output_idxs).values + 
             dg.mean.isel(level=dg.output_idxs).values)
    unique_vars = list(set([l.split('_')[0] for l in preds.level_names.values])); unique_vars
    
    das = []
    for v in unique_vars:
        idxs = [i for i, vv in enumerate(preds.level_names.values) if vv.split('_')[0] in v]
        print(v, idxs)
        da = preds.isel(level=idxs).squeeze().drop('level_names')
        if not 'level' in da.dims: da.drop('level')
        das.append({v: da})
    return xr.merge(das)

In [147]:
create_predictions

<function __main__.create_predictions(model, dg)>

In [148]:
preds = create_predictions(cnn, dg_test)

tisr [2]
z [0, 1]


In [128]:
unique_vars = list(set([l.split('_')[0] for l in preds.level_names.values])); unique_vars

['tisr', 'z']

In [144]:
das = []
for v in unique_vars:
    idxs = [i for i, vv in enumerate(preds.level_names.values) if vv.split('_')[0] in v]
    print(v, idxs)
    da = preds.isel(level=idxs).squeeze().drop('level_names')
    if not 'level' in da.dims: da.drop('level')
    das.append({v: da})

tisr [2]
z [0, 1]


In [145]:
xr.merge(das)

<xarray.Dataset>
Dimensions:  (lat: 32, level: 2, lon: 64, time: 17448)
Coordinates:
  * time     (time) datetime64[ns] 2017-01-04 ... 2018-12-31T23:00:00
  * lat      (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19
  * lon      (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4
  * level    (level) int64 500 850
Data variables:
    tisr     (time, lat, lon) float64 1.089e+06 1.072e+06 ... 9.867e+03
    z        (time, lat, lon, level) float64 5.012e+04 1.236e+04 ... 1.224e+04

In [141]:
a = das[0]['tisr']

In [142]:
a.dims

('time', 'lat', 'lon')

In [108]:
%debug

> [0;32m/home/stephan/miniconda3/lib/python3.6/site-packages/xarray/core/merge.py[0m(135)[0;36munique_variable[0;34m()[0m
[0;32m    133 [0;31m        raise MergeError(
[0m[0;32m    134 [0;31m            [0;34m"conflicting values for variable {!r} on objects to be combined. "[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 135 [0;31m            [0;34m"You can skip this check by specifying compat='override'."[0m[0;34m.[0m[0mformat[0m[0;34m([0m[0mname[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    136 [0;31m        )
[0m[0;32m    137 [0;31m[0;34m[0m[0m
[0m


ipdb>  u


> [0;32m/home/stephan/miniconda3/lib/python3.6/site-packages/xarray/core/merge.py[0m(217)[0;36mmerge_collected[0;34m()[0m
[0;32m    215 [0;31m                [0mvariables[0m [0;34m=[0m [0;34m[[0m[0mvariable[0m [0;32mfor[0m [0mvariable[0m[0;34m,[0m [0m_[0m [0;32min[0m [0melements_list[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    216 [0;31m                [0;32mtry[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 217 [0;31m                    [0mmerged_vars[0m[0;34m[[0m[0mname[0m[0;34m][0m [0;34m=[0m [0munique_variable[0m[0;34m([0m[0mname[0m[0;34m,[0m [0mvariables[0m[0;34m,[0m [0mcompat[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    218 [0;31m                [0;32mexcept[0m [0mMergeError[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    219 [0;31m                    [0;32mif[0m [0mcompat[0m [0;34m!=[0m [0;34m"minimal"[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  u


> [0;32m/home/stephan/miniconda3/lib/python3.6/site-packages/xarray/core/merge.py[0m(544)[0;36mmerge_core[0;34m()[0m
[0;32m    542 [0;31m[0;34m[0m[0m
[0m[0;32m    543 [0;31m    [0mprioritized[0m [0;34m=[0m [0m_get_priority_vars_and_indexes[0m[0;34m([0m[0maligned[0m[0;34m,[0m [0mpriority_arg[0m[0;34m,[0m [0mcompat[0m[0;34m=[0m[0mcompat[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 544 [0;31m    [0mvariables[0m[0;34m,[0m [0mout_indexes[0m [0;34m=[0m [0mmerge_collected[0m[0;34m([0m[0mcollected[0m[0;34m,[0m [0mprioritized[0m[0;34m,[0m [0mcompat[0m[0;34m=[0m[0mcompat[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    545 [0;31m    [0massert_unique_multiindex_level_names[0m[0;34m([0m[0mvariables[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    546 [0;31m[0;34m[0m[0m
[0m


ipdb>  u


> [0;32m/home/stephan/miniconda3/lib/python3.6/site-packages/xarray/core/merge.py[0m(782)[0;36mmerge[0;34m()[0m
[0;32m    780 [0;31m        [0mdict_like_objects[0m[0;34m.[0m[0mappend[0m[0;34m([0m[0mobj[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    781 [0;31m[0;34m[0m[0m
[0m[0;32m--> 782 [0;31m    [0mmerge_result[0m [0;34m=[0m [0mmerge_core[0m[0;34m([0m[0mdict_like_objects[0m[0;34m,[0m [0mcompat[0m[0;34m,[0m [0mjoin[0m[0;34m,[0m [0mfill_value[0m[0;34m=[0m[0mfill_value[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    783 [0;31m    [0mmerged[0m [0;34m=[0m [0mDataset[0m[0;34m.[0m[0m_construct_direct[0m[0;34m([0m[0;34m**[0m[0mmerge_result[0m[0;34m.[0m[0m_asdict[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    784 [0;31m    [0;32mreturn[0m [0mmerged[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  u


> [0;32m<ipython-input-105-6ff341516a97>[0m(28)[0;36mcreate_predictions[0;34m()[0m
[0;32m     24 [0;31m[0;31m#             var_idxs = [i for i, v in enumerate(preds.level_names) if v == var][0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     25 [0;31m[0;31m#             das.append({var: preds.isel(level=var_idxs).squeeze().drop('level_names')})[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m[0;31m# #             das.append({var: preds.isel(level=var_idxs)})[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     27 [0;31m        [0mprint[0m[0;34m([0m[0mvar[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 28 [0;31m    [0;32mreturn[0m [0mxr[0m[0;34m.[0m[0mmerge[0m[0;34m([0m[0mdas[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  len(das)


3


ipdb>  das.keys()


*** AttributeError: 'list' object has no attribute 'keys'


ipdb>  xr.merge(das[:2])


*** xarray.core.merge.MergeError: conflicting values for variable 'level' on objects to be combined. You can skip this check by specifying compat='override'.


ipdb>  xr.merge(das[:2], compat='override')


<xarray.Dataset>
Dimensions:      (lat: 32, lon: 64, time: 17448)
Coordinates:
  * time         (time) datetime64[ns] 2017-01-04 ... 2018-12-31T23:00:00
  * lat          (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19
  * lon          (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4
    level        int64 500
    level_names  <U9 'z_500'
Data variables:
    z            (time, lat, lon) float64 5.012e+04 5.028e+04 ... 4.942e+04


ipdb>  q


In [61]:
preds

<xarray.Dataset>
Dimensions:      (lat: 32, level: 0, lon: 64, time: 17448)
Coordinates:
  * time         (time) datetime64[ns] 2017-01-04 ... 2018-12-31T23:00:00
  * lat          (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19
  * lon          (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4
  * level        (level) int64 
    level_names  (level) <U9 
Data variables:
    z            (time, lat, lon, level) float64 
    tisr         (time, lat, lon, level) float64 

In [49]:
z500_valid = load_test_data(f'{datadir}geopotential_500', 'z').drop('level')
t850_valid = load_test_data(f'{datadir}temperature_850', 't')
valid = xr.merge([z500_valid, t850_valid])

In [50]:
valid

<xarray.Dataset>
Dimensions:  (lat: 32, lon: 64, time: 17520)
Coordinates:
  * lon      (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4
  * lat      (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19
  * time     (time) datetime64[ns] 2017-01-01 ... 2018-12-31T23:00:00
    level    int32 850
Data variables:
    z        (time, lat, lon) float32 dask.array<chunksize=(8760, 32, 64), meta=np.ndarray>
    t        (time, lat, lon) float32 dask.array<chunksize=(8760, 32, 64), meta=np.ndarray>

In [150]:
compute_weighted_rmse(preds.z.sel(level=500), ds_test.z.sel(level=500)).load()

<xarray.DataArray 'z_rmse' ()>
array(760.22060395)
Coordinates:
    level    int64 500

In [6]:
ds = xr.open_dataset('/home/stephan/data/myWeatherBench/predictions/01-default.nc')

In [7]:
ds.squeeze()

<xarray.Dataset>
Dimensions:  (lat: 32, level: 2, lon: 64, time: 17448)
Coordinates:
  * level    (level) int64 500 850
  * time     (time) datetime64[ns] 2017-01-04 ... 2018-12-31T23:00:00
  * lat      (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19
  * lon      (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4
Data variables:
    z        (time, lat, lon, level) float32 ...
    t        (time, lat, lon, level) float32 ...

In [11]:
ds.z.sel(level=850).isel(time=0)

<xarray.DataArray 'z' (lat: 32, lon: 64)>
array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       ...,
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan]], dtype=float32)
Coordinates:
    level    int64 850
    time     datetime64[ns] 2017-01-04
  * lat      (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19
  * lon      (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4

In [12]:
ds.close()

In [13]:
a

NameError: name 'a' is not defined