# xarray use case: Neural network training


**tl;dr**

1. This notebook is an example of reading from a climate model netCDF file to train a neural network. Neural networks (for use in parameterization research) require random columns of several stacked variables at a time. 

2. Experiments in this notebook show:
    1. Reading from raw climate model output files is super slow (1s per batch... need speeds on the order of ms)
    2. open_mfdataset is half as fast as opening the same dataset with open_dataset
    3. Pure h5py is much faster than reading the same dataset using xarray (even using the h5 backend)

3. Currently, I revert to preformatting the dataset (flatten time, lat, lon). This gets the reading speed down to milliseconds per batch.

**Conclusions**

Reading straight from the raw netCDF files (with all dimensions intact) is handy and might be necessary for later applications (using continuous time slices or lat-lon regions for RNNs or CNNs).

However, at the moment this is many orders of magnitude too slow. Preprocessing seems required.

What would be a good way of speeding this up without too extensive post processing?


In [1]:
import xarray as xr
import numpy as np

In [2]:
xr.__version__

'0.11.2'

## Load an example dataset

I uploaded a sample dataset here: http://doi.org/10.5281/zenodo.2559313

The files are around 1GB large. Let's download it.

NOTE: I have all my data on an SSD

In [47]:
# Modify this path!
DATADIR = '/local/S.Rasp/tmp/'

In [48]:
!wget -P $DATADIR https://zenodo.org/record/2559313/files/sample_SPCAM_1.nc
!wget -P $DATADIR https://zenodo.org/record/2559313/files/sample_SPCAM_2.nc
!wget -P $DATADIR https://zenodo.org/record/2559313/files/sample_SPCAM_concat.nc

--2019-02-07 13:08:52--  https://zenodo.org/record/2559183/files/sample_SPCAM_1.nc
Resolving zenodo.org (zenodo.org)... 137.138.76.77
Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 923498891 (881M) [application/octet-stream]
Saving to: ‘/local/S.Rasp/tmp/sample_SPCAM_1.nc’


2019-02-07 13:10:42 (8.09 MB/s) - ‘/local/S.Rasp/tmp/sample_SPCAM_1.nc’ saved [923498891/923498891]

--2019-02-07 13:10:42--  https://zenodo.org/record/2559183/files/sample_SPCAM_2.nc
Resolving zenodo.org (zenodo.org)... 137.138.76.77
Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 923498891 (881M) [application/octet-stream]
Saving to: ‘/local/S.Rasp/tmp/sample_SPCAM_2.nc’


2019-02-07 13:12:09 (10.1 MB/s) - ‘/local/S.Rasp/tmp/sample_SPCAM_2.nc’ saved [923498891/923498891]

--2019-02-07 13:12:09--  https://zenodo.org/record/2559183/files/sample_SPCAM_concat.nc


In [49]:
!ls -lh $DATADIR/sample_SPCAM*

-rw-r--r-- 1 S.Rasp ls-craig 881M Feb  7 13:00 /local/S.Rasp/tmp//sample_SPCAM_1.nc
-rw-r--r-- 1 S.Rasp ls-craig 881M Feb  7 13:00 /local/S.Rasp/tmp//sample_SPCAM_2.nc
-rw-r--r-- 1 S.Rasp ls-craig 1.8G Feb  7 13:00 /local/S.Rasp/tmp//sample_SPCAM_concat.nc


The files are typical climate model output files. `sample_SPCAM_1.nc` and `sample_SPCAM_2.nc` are two contiguous output files. `sample_SPCAM_concat.nc` is the concatenated version of the two files.

In [53]:
%%time
ds = xr.open_mfdataset(DATADIR + 'sample_SPCAM_1.nc')

CPU times: user 56 ms, sys: 0 ns, total: 56 ms
Wall time: 54.7 ms


In [54]:
ds

<xarray.Dataset>
Dimensions:       (crm_x: 32, crm_y: 1, crm_z: 28, ilev: 31, isccp_prs: 7, isccp_prstau: 49, isccp_tau: 7, lat: 64, lev: 30, lon: 128, tbnd: 2, time: 48)
Coordinates:
  * lat           (lat) float64 -87.86 -85.1 -82.31 -79.53 ... 82.31 85.1 87.86
  * lon           (lon) float64 0.0 2.812 5.625 8.438 ... 351.6 354.4 357.2
  * crm_x         (crm_x) float64 0.0 4.0 8.0 12.0 ... 112.0 116.0 120.0 124.0
  * crm_y         (crm_y) float64 0.0
  * crm_z         (crm_z) float64 992.6 976.3 957.5 936.2 ... 38.27 24.61 14.36
  * lev           (lev) float64 3.643 7.595 14.36 24.61 ... 957.5 976.3 992.6
  * ilev          (ilev) float64 2.255 5.032 10.16 18.56 ... 967.5 985.1 1e+03
  * isccp_prs     (isccp_prs) float64 90.0 245.0 375.0 500.0 620.0 740.0 900.0
  * isccp_tau     (isccp_tau) float64 0.15 0.8 2.45 6.5 16.2 41.5 219.5
  * isccp_prstau  (isccp_prstau) float64 90.0 90.0 90.0 ... 900.0 900.0 900.2
  * time          (time) object 0000-01-01 00:00:00 ... 0000-01-01 23:29:59
D

## Random columns for machine learning parameterizations

For the work on ML parameterizations that a few of us are doing now, we would like to work one column at a time. One simple example would be predicting the temperature and humidity tendencies (TPHYSTND and PHQ) from the temperature and humidity profiles (TAP and QAP). 

This means we would like to give the neural network a stacked vector containing the inputs (2 x 30 levels) and ask it to predict the outputs (also 2 x 30 levels).

In NN training, we usually train on a batch of data at a time. Batches typically have a few hundred samples (columns in our case). It is really important that the samples in a batch are not correlated but rather represent a random sample of the entire dataset.

To achieve this we will write a data generator that loads the batches by randomly selecting along the time, lat and lon dimensions.

In [57]:
class DataGenerator(object):
    """
    Data generator that randomly (if shuffle = True) picks columns from the dataset and returns them in 
    batches. For each column the input variables and output variables will be stacked.
    """
    def __init__(self, fn_or_ds, batch_size=128, input_vars=['TAP', 'QAP'], output_vars=['TPHYSTND', 'PHQ'], 
                 shuffle=True, engine='netcdf4'):
        self.ds = xr.open_mfdataset(fn_or_ds, engine=engine) if type(fn_or_ds) is str else fn_or_ds
        self.batch_size = batch_size
        self.input_vars = input_vars
        self.output_vars = output_vars
        self.ntime, self.nlat, self.nlon = self.ds.time.size, self.ds.lat.size, self.ds.lon.size
        self.ntot = self.ntime * self.nlat * self.ntime
        self.n_batches = self.ntot // batch_size
        self.indices = np.arange(self.ntot)
        if shuffle:
            self.indices = np.random.permutation(self.indices)
    def __getitem__(self, index):
        time_indices, lat_indices, lon_indices = np.unravel_index(
            self.indices[index*self.batch_size:(index+1)*self.batch_size], (self.ntime, self.nlat, self.nlon)
        )
        
        X, Y = [], []
        for itime, ilat, ilon in zip(time_indices, lat_indices, lon_indices):
            X.append(
                np.concatenate(
                    [self.ds[v].isel(time=itime, lat=ilat, lon=ilon).values for v in self.input_vars]
                )
            )
            Y.append(
                np.concatenate(
                    [self.ds[v].isel(time=itime, lat=ilat, lon=ilon).values for v in self.output_vars]
                )
            )

        return np.array(X), np.array(Y)

### Multi-file dataset

Let's start by using the split dataset `sample_SPCAM_1.nc` and `sample_SPCAM_2.nc`.

In [58]:
gen = DataGenerator(DATADIR + 'sample_SPCAM_[1-2].nc')

In [59]:
# This is how we get one batch of inputs and corresponding outputs
x, y = gen[0]

In [60]:
x.shape, y.shape

((128, 60), (128, 60))

In [61]:
# A little test function to check the timing.
def test(g, n):
    for i in range(n):
        x, y = g[i]

In [64]:
%%time
test(gen, 10)

CPU times: user 13.3 s, sys: 1.34 s, total: 14.6 s
Wall time: 14.3 s


In [65]:
# does shuffling make a big difference
gen = DataGenerator(DATADIR + 'sample_SPCAM_[1-2].nc', shuffle=True)
%time test(gen, 10)

CPU times: user 12.5 s, sys: 1.28 s, total: 13.8 s
Wall time: 13.5 s


So it takes more than one second to read one batch. This is way too slow to train a neural network in a reasonable amount of time. Shuffling doesn't seem to be a huge problem, but even without shuffling I am probably accessing the data in a different order than saved on disc. 

Let's check what actually takes that long.

In [66]:
%load_ext line_profiler

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [67]:
%lprun -f gen.__getitem__ test(gen, 10)

Output:

```
Timer unit: 1e-06 s

Total time: 24.5229 s
File: <ipython-input-57-78b9d254df3b>
Function: __getitem__ at line 18

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    18                                               def __getitem__(self, index):
    19        10         17.0      1.7      0.0          time_indices, lat_indices, lon_indices = np.unravel_index(
    20        10        267.0     26.7      0.0              self.indices[index*self.batch_size:(index+1)*self.batch_size], (self.ntime, self.nlat, self.nlon)
    21                                                   )
    22                                                   
    23        10         10.0      1.0      0.0          X, Y = [], []
    24      1290       4642.0      3.6      0.0          for itime, ilat, ilon in zip(time_indices, lat_indices, lon_indices):
    25      1280       1399.0      1.1      0.0              X.append(
    26      1280       1721.0      1.3      0.0                  np.concatenate(
    27      1280   12256070.0   9575.1     50.0                      [self.ds[v].isel(time=itime, lat=ilat, lon=ilon).values for v in self.input_vars]
    28                                                           )
    29                                                       )
    30      1280       2393.0      1.9      0.0              Y.append(
    31      1280       1750.0      1.4      0.0                  np.concatenate(
    32      1280   12253415.0   9573.0     50.0                      [self.ds[v].isel(time=itime, lat=ilat, lon=ilon).values for v in self.output_vars]
    33                                                           )
    34                                                       )
    35                                           
    36        10       1218.0    121.8      0.0          return np.array(X), np.array(Y)
```

### Using the concatenated dataset

Let's see whether it makes a difference to use the pre-concatenated dataset.

In [74]:
ds = xr.open_dataset(f'{DATADIR}sample_SPCAM_concat.nc')
gen = DataGenerator(ds, shuffle=True)
%time test(gen, 10)

CPU times: user 5.93 s, sys: 984 ms, total: 6.91 s
Wall time: 6.91 s


In [76]:
ds = xr.open_mfdataset(f'{DATADIR}sample_SPCAM_concat.nc')
gen = DataGenerator(ds, shuffle=True)
%time test(gen, 10)

CPU times: user 11.5 s, sys: 1.25 s, total: 12.8 s
Wall time: 12.5 s


So yes, it approximately halves the time but only if the single dataset is NOT opened with `open_mfdataset`.

### With h5py engine

Let's see whether using the h5py backend makes a difference

In [77]:
import h5netcdf

In [79]:
ds.close()

In [80]:
ds = xr.open_dataset(f'{DATADIR}sample_SPCAM_concat.nc', engine='h5netcdf')
gen = DataGenerator(ds)

In [81]:
%%time
test(gen, 10)

CPU times: user 6.97 s, sys: 972 ms, total: 7.94 s
Wall time: 7.8 s


Doesn't seem to speed it up

In [83]:
ds.close()

### Using plain h5py

Let's write a version of the data generator that uses plain h5py for data loading.

In [82]:
class DataGeneratorH5(object):
    def __init__(self, fn, batch_size=128, input_vars=['TAP', 'QAP'], output_vars=['TPHYSTND', 'PHQ'], shuffle=True):
        self.ds = xr.open_dataset(fn)
        self.batch_size = batch_size
        self.input_vars = input_vars
        self.output_vars = output_vars
        self.ntime, self.nlat, self.nlon = self.ds.time.size, self.ds.lat.size, self.ds.lon.size
        self.ntot = self.ntime * self.nlat * self.ntime
        self.n_batches = self.ntot // batch_size
        self.indices = np.arange(self.ntot)
        if shuffle:
            self.indices = np.random.permutation(self.indices)
            
        # Close xarray dataset and open h5py object
        self.ds.close()
        self.ds = h5py.File(fn, 'r')
        
    def __getitem__(self, index):
        time_indices, lat_indices, lon_indices = np.unravel_index(
            self.indices[index*self.batch_size:(index+1)*self.batch_size], (self.ntime, self.nlat, self.nlon)
        )
        
        X, Y = [], []
        for itime, ilat, ilon in zip(time_indices, lat_indices, lon_indices):
            X.append(
                np.concatenate(
                    [self.ds[v][itime, :, ilat, ilon] for v in self.input_vars]
                )
            )
            Y.append(
                np.concatenate(
                    [self.ds[v][itime, :, ilat, ilon] for v in self.output_vars]
                )
            )

        return np.array(X), np.array(Y)

In [84]:
gen = DataGeneratorH5(f'{DATADIR}sample_SPCAM_concat.nc')

In [85]:
%%time
test(gen, 10)

CPU times: user 1.78 s, sys: 860 ms, total: 2.64 s
Wall time: 2.61 s


In [96]:
gen.ds.close()

So this is significantly faster than xarray.

## Use in a simple neural network

How would we actually use this data generator for network training...

Note that this neural network will not actually learn much because we didn't normalize the input data. But we only care about computational performance here, right?

In [87]:
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Sequential

In [88]:
tf.keras.__version__

'2.1.6-tf'

In [89]:
model = Sequential([
    Dense(128, input_shape=(60,), activation='relu'),
    Dense(60),
])

In [90]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 128)               7808      
_________________________________________________________________
dense_1 (Dense)              (None, 60)                7740      
Total params: 15,548
Trainable params: 15,548
Non-trainable params: 0
_________________________________________________________________


In [91]:
model.compile('adam', 'mse')

In [99]:
# Load the xarray version using the concatenated dataset
ds = xr.open_dataset(f'{DATADIR}sample_SPCAM_concat.nc')
gen = DataGenerator(ds, shuffle=True)

In [101]:
model.fit_generator(iter(gen), steps_per_epoch=gen.n_batches)

Epoch 1/1
  37/4608 [..............................] - ETA: 1:04:11 - loss: 1733.6299

KeyboardInterrupt: 

So as you can see, it would take around 1 hour to go through one epoch (i.e. the entire dataset once). This is crazy slow since we only used 2 days of data. The full dataset contains a year of data...

## Pre-processing the dataset

What I have resorted to to solve this issue is to prestack the data, preshuffle the data and save it to disc conveniently.

These files contain the exactly same information for the input (features) and output (targets) variables required.

The files only have two dimensions: sample, which is the shuffled, flattened time, lat and lon dimensions and lev which is the stacked vertical coordinate.

The preprocessing for these two files only takes a few seconds but for an entire year of data, the preprocessing alone can take around an hour.


In [163]:
!wget -P $DATADIR https://zenodo.org/record/2559313/files/preproc_features.nc
!wget -P $DATADIR https://zenodo.org/record/2559313/files/preproc_targets.nc

--2019-02-07 15:42:32--  https://zenodo.org/record/2559313/files/preproc_features.nc
Resolving zenodo.org (zenodo.org)... 137.138.76.77
Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 205465847 (196M) [application/octet-stream]
Saving to: ‘/local/S.Rasp/tmp/preproc_features.nc.2’


2019-02-07 15:42:48 (13.0 MB/s) - ‘/local/S.Rasp/tmp/preproc_features.nc.2’ saved [205465847/205465847]

--2019-02-07 15:42:48--  https://zenodo.org/record/2559313/files/preproc_targets.nc
Resolving zenodo.org (zenodo.org)... 137.138.76.77
Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 205465846 (196M) [application/octet-stream]
Saving to: ‘/local/S.Rasp/tmp/preproc_targets.nc.1’


2019-02-07 15:42:58 (20.6 MB/s) - ‘/local/S.Rasp/tmp/preproc_targets.nc.1’ saved [205465846/205465846]



In [104]:
!ls -lh $DATADIR/preproc*

-rw-r--r-- 1 S.Rasp ls-craig 196M Feb  7 13:57 /local/S.Rasp/tmp//preproc_features.nc
-rw-r--r-- 1 S.Rasp ls-craig 196M Feb  7 13:57 /local/S.Rasp/tmp//preproc_targets.nc


In [105]:
ds = xr.open_dataset(f'{DATADIR}preproc_features.nc')

In [106]:
ds

<xarray.Dataset>
Dimensions:        (feature_lev: 60, sample: 778240)
Coordinates:
  * feature_lev    (feature_lev) int64 0 1 2 3 4 5 6 7 ... 53 54 55 56 57 58 59
    time           (sample) int64 ...
    lat            (sample) float64 ...
    lon            (sample) float64 ...
    feature_names  (feature_lev) object ...
Dimensions without coordinates: sample
Data variables:
    features       (sample, feature_lev) float32 ...
Attributes:
    log:      \n    Time: 2019-02-07T13:57:24\n\n    Executed command:\n\n   ...

In [129]:
# Write a new data generator
class DataGeneratorPreproc(object):
    """
    Data generator that randomly (if shuffle = True) picks columns from the dataset and returns them in 
    batches. For each column the input variables and output variables will be stacked.
    """
    def __init__(self, feature_fn, target_fn, batch_size=128, shuffle=True, engine='netcdf4'):
        self.feature_ds = xr.open_dataset(feature_fn, engine=engine)
        self.target_ds = xr.open_dataset(target_fn, engine=engine)
        self.batch_size = batch_size
        self.ntot = self.feature_ds.sample.size
        self.n_batches = self.ntot // batch_size
        self.indices = np.arange(self.ntot)
        if shuffle:
            self.indices = np.random.permutation(self.indices)
    def __getitem__(self, index):
        batch_indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]
        
        X = self.feature_ds.features.isel(sample=batch_indices)
        Y = self.target_ds.targets.isel(sample=batch_indices)

        return X, Y

In [130]:
gen = DataGeneratorPreproc(f'{DATADIR}preproc_features.nc', f'{DATADIR}preproc_targets.nc')

In [131]:
x, y = gen[0]

In [132]:
x.shape, y.shape

((128, 60), (128, 60))

In [133]:
%%time
test(gen, 10)

CPU times: user 84 ms, sys: 0 ns, total: 84 ms
Wall time: 81.6 ms


In [134]:
gen = DataGeneratorPreproc(f'{DATADIR}preproc_features.nc', f'{DATADIR}preproc_targets.nc', shuffle=False)

In [135]:
%%time
test(gen, 10)

CPU times: user 84 ms, sys: 0 ns, total: 84 ms
Wall time: 83.9 ms


In [152]:
gen.feature_ds.close(); gen.target_ds.close()

In [139]:
gen = DataGeneratorPreproc(f'{DATADIR}preproc_features.nc', f'{DATADIR}preproc_targets.nc', engine='h5netcdf')

In [140]:
%%time
test(gen, 10)

CPU times: user 84 ms, sys: 0 ns, total: 84 ms
Wall time: 80.6 ms


So these are the sort of times that are required for training a neural network.

### Pure h5py version

In [158]:
class DataGeneratorPreprocH5(object):
    """
    Data generator that randomly (if shuffle = True) picks columns from the dataset and returns them in 
    batches. For each column the input variables and output variables will be stacked.
    """
    def __init__(self, feature_fn, target_fn, batch_size=128):
        self.feature_ds = xr.open_dataset(feature_fn)
        self.target_ds = xr.open_dataset(target_fn)
        self.batch_size = batch_size
        self.ntot = self.feature_ds.sample.size
        self.n_batches = self.ntot // batch_size
        
        # Close xarray dataset and open h5py object
        self.feature_ds.close()
        self.feature_ds = h5py.File(feature_fn, 'r')
        self.target_ds.close()
        self.target_ds = h5py.File(target_fn, 'r')
        
    def __getitem__(self, index):
        
        X = self.feature_ds['features'][index*self.batch_size:(index+1)*self.batch_size, :]
        Y = self.target_ds['targets'][index*self.batch_size:(index+1)*self.batch_size, :]

        return X, Y

In [159]:
gen.feature_ds.close(); gen.target_ds.close()

In [160]:
gen = DataGeneratorPreprocH5(f'{DATADIR}preproc_features.nc', f'{DATADIR}preproc_targets.nc')

In [161]:
%%time
test(gen, 10)

CPU times: user 8 ms, sys: 0 ns, total: 8 ms
Wall time: 6.61 ms


So again, the pure h5py version is an order of magnitude faster than the xarray version.

## End