In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import *
import tensorflow.keras.backend as K
import seaborn as sns
import pickle
from src.score import *
from collections import OrderedDict

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [3]:
from src.train_nn import *

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"]=str(1)

In [5]:
datadir = '/data/stephan/WeatherBench/5.625deg/'

In [146]:
!ls $datadir

10m_u_component_of_wind  potential_vorticity	       total_cloud_cover
10m_v_component_of_wind  relative_humidity	       total_precipitation
2m_temperature		 specific_humidity	       u_component_of_wind
constants		 temperature		       v_component_of_wind
geopotential		 temperature_850	       vorticity
geopotential_500	 toa_incident_solar_radiation


In [215]:
var_dict = {
    'geopotential': ('z', [500]),
    '10m_u_component_of_wind': ('u10', None),
    'constants': ['lat2d', 'orography']
}

In [220]:
ds = [xr.open_mfdataset(f'{datadir}/{var}/*.nc', combine='by_coords') for var in var_dict.keys()]

In [221]:
ds = xr.merge(ds)

In [222]:
var_dict

{'geopotential': ('z', [500]),
 '10m_u_component_of_wind': ('u10', None),
 'constants': ['lat2d', 'orography']}

In [239]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self, ds, var_dict, lead_time, batch_size=32, shuffle=True, load=True, 
                 mean=None, std=None, output_vars=None):
        """
        Data generator for WeatherBench data.
        Template from https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
        Args:
            ds: Dataset containing all variables
            var_dict: Dictionary of the form {'var': level}. Use None for level if data is of single level
            lead_time: Lead time in hours
            batch_size: Batch size
            shuffle: bool. If True, data is shuffled.
            load: bool. If True, datadet is loaded into RAM.
            mean: If None, compute mean from data.
            std: If None, compute standard deviation from data.
        """

        self.ds = ds
        self.var_dict = var_dict
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.lead_time = lead_time

        data = []
        level_names = []
        generic_level = xr.DataArray([1], coords={'level': [1]}, dims=['level'])
        for long_var, params in var_dict.items():
            if long_var == 'constants': 
                for var in params:
                    data.append(ds[var].expand_dims(
                        {'level': generic_level, 'time': ds.time}, (1, 0)
                    ))
                    level_names.append(var)
            else:
                var, levels = params
                try:
                    data.append(ds[var].sel(level=levels))
                    level_names += [f'{var}_{level}' for level in levels]
                except ValueError:
                    data.append(ds[var].expand_dims({'level': generic_level}, 1))
                    level_names.append(var)

        self.data = xr.concat(data, 'level').transpose('time', 'lat', 'lon', 'level')
        self.data['level_names'] = xr.DataArray(
            level_names, dims=['level'], coords={'level': self.data.level})
        if output_vars is None:
            self.output_idxs = range(len(dg_valid.data.level))
        else:
            self.output_idxs = [i for i, l in enumerate(self.data.level_names.values) 
                                if any([bool(re.match(o, l)) for o in output_vars])]
        
        # Normalize
        self.mean = self.data.mean(('time', 'lat', 'lon')).compute() if mean is None else mean
#         self.std = self.data.std('time').mean(('lat', 'lon')).compute() if std is None else std
        self.std = self.data.std(('time', 'lat', 'lon')).compute() if std is None else std
        self.data = (self.data - self.mean) / self.std
        
        self.n_samples = self.data.isel(time=slice(0, -lead_time)).shape[0]
        self.init_time = self.data.isel(time=slice(None, -lead_time)).time
        self.valid_time = self.data.isel(time=slice(lead_time, None)).time

        self.on_epoch_end()

        # For some weird reason calling .load() earlier messes up the mean and std computations
        if load: print('Loading data into RAM'); self.data.load()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(self.n_samples / self.batch_size))

    def __getitem__(self, i):
        'Generate one batch of data'
        idxs = self.idxs[i * self.batch_size:(i + 1) * self.batch_size]
        X = self.data.isel(time=idxs).values
        y = self.data.isel(time=idxs + self.lead_time, level=self.output_idxs).values
        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.idxs = np.arange(self.n_samples)
        if self.shuffle == True:
            np.random.shuffle(self.idxs)

In [240]:
bs=32
lead_time=3*24

In [241]:
ds_train = ds.sel(time=slice('2015', '2015'))
ds_valid = ds.sel(time=slice('2016', '2016'))
ds_test = ds.sel(time=slice('2017', '2018'))

In [242]:
output_vars = ['z_*', 'u10']

In [243]:
%%time
dg_train = DataGenerator(ds_train, var_dict, lead_time, batch_size=bs, load=True, 
                         output_vars=output_vars)
dg_valid = DataGenerator(ds_train, var_dict, lead_time, batch_size=bs, mean=dg_train.mean, std=dg_train.std, 
                         shuffle=False, output_vars=output_vars)

Loading data into RAM
Loading data into RAM
CPU times: user 6.48 s, sys: 59.3 s, total: 1min 5s
Wall time: 29.1 s


In [244]:
dg_test = DataGenerator(ds_test, var_dict, lead_time, batch_size=bs, mean=dg_train.mean, std=dg_train.std, 
                         shuffle=False, output_vars=output_vars)

Loading data into RAM


In [245]:
dg_train.data.level_names

<xarray.DataArray 'level_names' (level: 4)>
array(['z_500', 'u10', 'lat2d', 'orography'], dtype='<U9')
Coordinates:
  * level        (level) int64 500 1 1 1
    level_names  (level) <U9 'z_500' 'u10' 'lat2d' 'orography'

In [246]:
X, y = dg_train[0]; X.shape, y.shape

((32, 32, 64, 4), (32, 32, 64, 2))

In [247]:
dg_train.output_idxs

[0, 1]

In [248]:
dg_train.std

<xarray.DataArray 'z' (level: 4)>
array([3483.06093502,    5.62151846,   51.93614619,  859.8722486 ])
Coordinates:
  * level        (level) int64 500 1 1 1
    level_names  (level) <U9 'z_500' 'u10' 'lat2d' 'orography'

In [249]:
dg_train.mean

<xarray.DataArray 'z' (level: 4)>
array([ 5.41249140e+04, -3.31912328e-04,  0.00000000e+00,  3.79497583e+02])
Coordinates:
  * level        (level) int64 500 1 1 1
    level_names  (level) <U9 'z_500' 'u10' 'lat2d' 'orography'

In [250]:
cnn = build_cnn([64, 64, 64, 64, 2], [5, 5, 5, 5, 5], (32, 64, 4))

In [251]:
cnn.compile(keras.optimizers.Adam(1e-4), 'mse')

In [252]:
cnn.fit_generator(dg_train, epochs=1)

Epoch 1/1


<tensorflow.python.keras.callbacks.History at 0x7f8440197128>

In [253]:
preds = cnn.predict_generator(dg_train)

In [258]:
dg = dg_train

In [256]:
preds.shape

(8688, 32, 64, 2)

In [259]:
dg.data.isel(level=dg.output_idxs).level

<xarray.DataArray 'level' (level: 2)>
array([500,   1])
Coordinates:
  * level        (level) int64 500 1
    level_names  (level) <U9 'z_500' 'u10'

In [260]:
preds = xr.DataArray(
    preds,
    dims=['time', 'lat', 'lon', 'level'],
    coords={'time': dg.valid_time, 'lat': dg.data.lat, 'lon': dg.data.lon, 
            'level': dg.data.isel(level=dg.output_idxs).level,
            'level_names': dg.data.isel(level=dg.output_idxs).level_names
           },
)

In [262]:
preds.shape, preds.level, preds.level_names

((8688, 32, 64, 2), <xarray.DataArray 'level' (level: 2)>
 array([500,   1])
 Coordinates:
   * level        (level) int64 500 1
     level_names  (level) <U9 'z_500' 'u10', <xarray.DataArray 'level_names' (level: 2)>
 array(['z_500', 'u10'], dtype='<U9')
 Coordinates:
   * level        (level) int64 500 1
     level_names  (level) <U9 'z_500' 'u10')

In [266]:
dg.var_dict

{'geopotential': ('z', [500]),
 '10m_u_component_of_wind': ('u10', None),
 'constants': ['lat2d', 'orography']}

In [308]:
das = []
for long_var, params in dg.var_dict.items():
    if not long_var == 'constants':
        var, levels = params
        var_names = [var] if levels is None else [f'{var}_{level}' for level in levels]
        print(var_names, levels)
        var_idxs = [i for i, v in enumerate(preds.level_names) if v in var_names]
        print(var_idxs)
        da = preds.isel(level=var_idxs)
        if levels is not None: da = da.sel(level=levels).drop('level_names')
        else: da = da.squeeze().drop('level').drop('level_names')
        das.append({var: da})

['z_500'] [500]
[0]
['u10'] None
[1]


In [310]:
das = xr.merge(das)

In [311]:
das

<xarray.Dataset>
Dimensions:  (lat: 32, level: 1, lon: 64, time: 8688)
Coordinates:
  * time     (time) datetime64[ns] 2015-01-04 ... 2015-12-31T23:00:00
  * lat      (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19
  * lon      (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4
  * level    (level) int64 500
Data variables:
    z        (time, lat, lon, level) float32 -1.6765449 ... -0.17071085
    u10      (time, lat, lon) float32 -0.71567136 -0.5935094 ... 0.08219392

In [114]:
def create_predictions(model, dg):
    """Create non-iterative predictions"""
    preds = xr.DataArray(
        model.predict_generator(dg),
        dims=['time', 'lat', 'lon', 'level'],
        coords={'time': dg.valid_time, 'lat': dg.data.lat, 'lon': dg.data.lon, 
                'level': dg.data.isel(level=dg.output_idxs).level,
                'level_names': dg.data.isel(level=dg.output_idxs).level_names
               },
    )
    # Unnormalize
    preds = (preds * dg.std.isel(level=dg.output_idxs).values + 
             dg.mean.isel(level=dg.output_idxs).values)
    
    das = []
    for long_var, params in dg.var_dict.items():
        if not long_var == 'constants':
            var, levels = params
            var_idxs = [i for i, v in enumerate(preds.level_names) if v == var]
            das.append({var: preds.isel(level=var_idxs).squeeze().drop('level_names')})
    return xr.merge(das)

In [125]:
preds = create_predictions(cnn, dg_test)

In [205]:
preds

<xarray.Dataset>
Dimensions:  (lat: 32, level: 2, lon: 64, time: 8688)
Coordinates:
  * time     (time) datetime64[ns] 2015-01-04 ... 2015-12-31T23:00:00
  * lat      (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19
  * lon      (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4
  * level    (level) int64 500 1000
Data variables:
    z        (time, lat, lon, level) float32 -1.5707413 ... 0.6681948
    u10      (time, lat, lon) float32 -0.611318 -0.54330826 ... 0.1051542

In [133]:
z500_valid = load_test_data(f'{datadir}geopotential_500', 'z').drop('level')
t850_valid = load_test_data(f'{datadir}temperature_850', 't')
valid = xr.merge([z500_valid, t850_valid])

In [134]:
valid

<xarray.Dataset>
Dimensions:  (lat: 32, lon: 64, time: 17520)
Coordinates:
  * lon      (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4
  * lat      (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19
  * time     (time) datetime64[ns] 2017-01-01 ... 2018-12-31T23:00:00
    level    int32 850
Data variables:
    z        (time, lat, lon) float32 dask.array<chunksize=(8760, 32, 64), meta=np.ndarray>
    t        (time, lat, lon) float32 dask.array<chunksize=(8760, 32, 64), meta=np.ndarray>

In [128]:
!ls $datadir/geopotential_500

geopotential_500hPa_1979_5.625deg.nc  geopotential_500hPa_1999_5.625deg.nc
geopotential_500hPa_1980_5.625deg.nc  geopotential_500hPa_2000_5.625deg.nc
geopotential_500hPa_1981_5.625deg.nc  geopotential_500hPa_2001_5.625deg.nc
geopotential_500hPa_1982_5.625deg.nc  geopotential_500hPa_2002_5.625deg.nc
geopotential_500hPa_1983_5.625deg.nc  geopotential_500hPa_2003_5.625deg.nc
geopotential_500hPa_1984_5.625deg.nc  geopotential_500hPa_2004_5.625deg.nc
geopotential_500hPa_1985_5.625deg.nc  geopotential_500hPa_2005_5.625deg.nc
geopotential_500hPa_1986_5.625deg.nc  geopotential_500hPa_2006_5.625deg.nc
geopotential_500hPa_1987_5.625deg.nc  geopotential_500hPa_2007_5.625deg.nc
geopotential_500hPa_1988_5.625deg.nc  geopotential_500hPa_2008_5.625deg.nc
geopotential_500hPa_1989_5.625deg.nc  geopotential_500hPa_2009_5.625deg.nc
geopotential_500hPa_1990_5.625deg.nc  geopotential_500hPa_2010_5.625deg.nc
geopotential_500hPa_1991_5.625deg.nc  geopotential_500hPa_2011_5.625deg.nc
geopotential_500hPa_1992_

In [123]:
compute_weighted_rmse(preds.z.sel(level=500), ds_train.z.sel(level=500)).load()

<xarray.DataArray 'z_rmse' ()>
array(55647.16552956)
Coordinates:
    level    int64 500

In [6]:
ds = xr.open_dataset('/home/stephan/data/myWeatherBench/predictions/01-default.nc')

In [7]:
ds.squeeze()

<xarray.Dataset>
Dimensions:  (lat: 32, level: 2, lon: 64, time: 17448)
Coordinates:
  * level    (level) int64 500 850
  * time     (time) datetime64[ns] 2017-01-04 ... 2018-12-31T23:00:00
  * lat      (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19
  * lon      (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4
Data variables:
    z        (time, lat, lon, level) float32 ...
    t        (time, lat, lon, level) float32 ...

In [11]:
ds.z.sel(level=850).isel(time=0)

<xarray.DataArray 'z' (lat: 32, lon: 64)>
array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       ...,
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan]], dtype=float32)
Coordinates:
    level    int64 850
    time     datetime64[ns] 2017-01-04
  * lat      (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19
  * lon      (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4

In [12]:
ds.close()

In [13]:
a

NameError: name 'a' is not defined