In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import re, os
import numpy as np
import xarray as xr
import tensorflow.keras as keras
import datetime
import pdb
import matplotlib.pyplot as plt
from src.utils import *
from src.score import *
from src.data_generator import *
from src.networks import *
from src.train import *
import cartopy.crs as ccrs
import seaborn as sns

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"]=str(3)
limit_mem()

In [4]:
args = load_args('../nn_configs/B/81-resnet_d3_dr_0.1.yml')
args['exp_id'] = '81.1-resnet_d3_dr_0.1.yml'

In [5]:
ds = xr.merge(
    [xr.open_mfdataset(f'{args["datadir"]}/{var}/*.nc', combine='by_coords')
     for var in args["var_dict"].keys()],
    fill_value=0  # For the 'tisr' NaNs
)

In [6]:
ds_train = ds.sel(time=slice('2015', '2015'))

In [7]:
dg_train = DataGenerator(
    ds_train, args['var_dict'], args['lead_time'], batch_size=args['batch_size'], output_vars=args['output_vars'],
    data_subsample=1, norm_subsample=10, nt_in=args['nt_in'], dt_in=args['dt_in'], load=False
)

## Reference tests

In [7]:
%%time
# Default engine
X, y = dg_train[np.random.randint(len(dg_train))]

CPU times: user 7.73 s, sys: 19.4 s, total: 27.1 s
Wall time: 12.1 s


In [8]:
dg_train.shuffle = False; dg_train.on_epoch_end()  # Maybe doesn't matter because I am just reading in one year

In [9]:
%%time
X, y = dg_train[np.random.randint(len(dg_train))]

CPU times: user 7.62 s, sys: 17.1 s, total: 24.8 s
Wall time: 10 s


In [10]:
%%time
_ = dg_train.data.isel(time=777, level=0).load();

CPU times: user 404 ms, sys: 1.19 s, total: 1.59 s
Wall time: 1.59 s


In [24]:
%%time
_ = ds_train.isel(time=777, level=0).load();

CPU times: user 16.6 ms, sys: 1.21 ms, total: 17.8 ms
Wall time: 14.7 ms


Ok, this doesn't make any sense because it's accessing the same data. Is it because I've concatenated the levels?

In [26]:
var_dict = args['var_dict']
data = []
level_names = []
generic_level = xr.DataArray([1], coords={'level': [1]}, dims=['level'])
for long_var, params in var_dict.items():
    if long_var == 'constants':
        for var in params:
            data.append(ds_train[var].expand_dims(
                {'level': generic_level, 'time': ds_train.time}, (1, 0)
            ))
            level_names.append(var)
    else:
        var, levels = params
        try:
            data.append(ds_train[var].sel(level=levels))
            level_names += [f'{var}_{level}' for level in levels]
        except ValueError:
            data.append(ds_train[var].expand_dims({'level': generic_level}, 1))
            level_names.append(var)

data = xr.concat(data, 'level').transpose('time', 'lat', 'lon', 'level')
data['level_names'] = xr.DataArray(
    level_names, dims=['level'], coords={'level': data.level})

In [27]:
%%time
_ = data.isel(time=777, level=0).load();

CPU times: user 163 ms, sys: 643 ms, total: 806 ms
Wall time: 804 ms


Ok, weird. Somewhere inbetween but unacceptably slow. What about the normalization?

In [28]:
data = data.isel(time=slice(0, None, args['data_subsample']))

In [29]:
%%time
_ = data.isel(time=777, level=0).load();

CPU times: user 175 ms, sys: 551 ms, total: 726 ms
Wall time: 723 ms


In [30]:
data = (data - dg_train.mean) / dg_train.std

In [31]:
%%time
_ = data.isel(time=777, level=0).load();

CPU times: user 262 ms, sys: 832 ms, total: 1.09 s
Wall time: 1.09 s


Ok, so the concatenation makes it unbearably slow. Before we try to fix it, let's investigate accessing random time indices.

In [34]:
rand = np.random.randint(len(ds_train.time), size=10); rand

array([7871, 6466, 4342, 1703, 8173, 4710, 1983, 8113, 1876, 6970])

In [35]:
%%time
_ = dg_train.data.isel(time=rand, level=0).load();

CPU times: user 424 ms, sys: 925 ms, total: 1.35 s
Wall time: 1.34 s


In [36]:
%%time
_ = ds_train.isel(time=rand, level=0).load();

CPU times: user 21.1 ms, sys: 1.54 s, total: 1.56 s
Wall time: 1.55 s


In [37]:
%%time
_ = ds_train.isel(time=list(rand), level=0).load();

CPU times: user 13.2 ms, sys: 1.56 s, total: 1.57 s
Wall time: 1.56 s


In [38]:
%%time
_ = ds.isel(time=np.random.randint(len(ds.time), size=10), level=0).load();

CPU times: user 302 ms, sys: 15 s, total: 15.3 s
Wall time: 49.1 s


Ok, so this still takes forever. Especially across multiple files. So fixing the concatenation is kind of pointless. TFRecords it is...

In [39]:
%%time
# 461G-->463G
dg_train.data.load()

CPU times: user 5.8 s, sys: 6.2 s, total: 12 s
Wall time: 7.11 s


In [40]:
%%time
X, y = dg_train[np.random.randint(len(dg_train))]

CPU times: user 39.1 ms, sys: 18.5 ms, total: 57.6 ms
Wall time: 56.4 ms


## TFRecords

In [259]:
var_dict = args['var_dict']
data = []
level_names = []
generic_level = xr.DataArray([1], coords={'level': [1]}, dims=['level'])
for long_var, params in var_dict.items():
    if long_var == 'constants':
        for var in params:
            data.append(ds_train[var].expand_dims(
                {'level': generic_level, 'time': ds_train.time}, (1, 0)
            ))
            level_names.append(var)
    else:
        var, levels = params
        try:
            data.append(ds_train[var].sel(level=levels))
            level_names += [f'{var}_{level}' for level in levels]
        except ValueError:
            data.append(ds_train[var].expand_dims({'level': generic_level}, 1))
            level_names.append(var)

data = xr.concat(data, 'level').transpose('time', 'lat', 'lon', 'level')
data['level_names'] = xr.DataArray(
    level_names, dims=['level'], coords={'level': data.level})

In [260]:
ds = data

In [264]:
def _tensor_feature(value):
    """Returns a float_list from a float / double."""
    value = tf.io.serialize_tensor(value).numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [265]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [266]:
def serialize_example(time, data):
    feature = {
        'time': _bytes_feature(time),
        'data': _tensor_feature(data)
    }
    feature = tf.train.Features(feature=feature)
    example_proto = tf.train.Example(features=feature)
    return example_proto.SerializeToString()

In [257]:
type(time.encode('utf-8'))

bytes

In [267]:
filename = 'test4.tfrecord'
writer = tf.io.TFRecordWriter(filename)
for i in range(10):
    time = str(ds.isel(time=i).time.values).encode('utf-8')
    data = ds.isel(time=i).values.astype('float32')
    serialized_example = serialize_example(time, data)
    writer.write(serialized_example)

In [268]:
raw_ds = tf.data.TFRecordDataset([filename])

In [269]:
features = {
    'time': tf.io.FixedLenFeature([], tf.string),
    'data': tf.io.FixedLenFeature([], tf.string)
}

In [270]:
def _parse(example_proto):
    return tf.io.parse_single_example(example_proto, features)

In [None]:
def decode(example_proto):
    dic = _parse(example_proto)
    # Get the data and normalize!

In [271]:
parsed_ds = raw_ds.map(_parse)

In [272]:
parsed_ds

<MapDataset shapes: {data: (), time: ()}, types: {data: tf.string, time: tf.string}>

In [290]:
parsed_ds = parsed_ds.shuffle(5)

In [291]:
for p in parsed_ds.take(9):
    print(p['time'])

tf.Tensor(b'2015-01-01T00:00:00.000000000', shape=(), dtype=string)
tf.Tensor(b'2015-01-01T05:00:00.000000000', shape=(), dtype=string)
tf.Tensor(b'2015-01-01T04:00:00.000000000', shape=(), dtype=string)
tf.Tensor(b'2015-01-01T06:00:00.000000000', shape=(), dtype=string)
tf.Tensor(b'2015-01-01T01:00:00.000000000', shape=(), dtype=string)


DataLossError: truncated record at 2802609

In [282]:
o = next(iter(parsed_ds))

In [283]:
o['time']

<tf.Tensor: shape=(), dtype=string, numpy=b'2015-01-01T00:00:00.000000000'>

## Shuffle across files

In [8]:
def _tensor_feature(value):
    """Returns a float_list from a float / double."""
    value = tf.io.serialize_tensor(value).numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [9]:
def serialize_example(data):
    feature = {
        'data': _tensor_feature(data)
    }
    feature = tf.train.Features(feature=feature)
    example_proto = tf.train.Example(features=feature)
    return example_proto.SerializeToString()

In [10]:
n = 10

In [11]:
data = np.arange(n*10).reshape(n, -1).astype('float32')

In [12]:
for i in range(n):
    filename = f'ttest_data{i}.tfrecord'
    writer = tf.io.TFRecordWriter(filename)
    for j in range(10):
        serialized_example = serialize_example(data[i, j])
        writer.write(serialized_example)
    writer.close()

In [13]:
filenames = sorted(glob('ttest*'))

In [62]:
filenames

['ttest_data0.tfrecord',
 'ttest_data1.tfrecord',
 'ttest_data2.tfrecord',
 'ttest_data3.tfrecord',
 'ttest_data4.tfrecord',
 'ttest_data5.tfrecord',
 'ttest_data6.tfrecord',
 'ttest_data7.tfrecord',
 'ttest_data8.tfrecord',
 'ttest_data9.tfrecord']

### Current implementation

In [14]:
def _parse(example_proto):
    return tf.io.parse_single_example(example_proto, {'data': tf.io.FixedLenFeature([], tf.string)})

In [15]:
def decode(example_proto):
    dic = _parse(example_proto)
    data = dic['data']
    return tf.io.parse_tensor(data, np.float32)

In [16]:
dataset = tf.data.TFRecordDataset(filenames).map(decode)

In [17]:
batches = dataset.shuffle(20).batch(10).as_numpy_iterator()

In [18]:
bb = [b for b in batches]

In [19]:
bb

[array([ 7.,  1., 18., 14., 21., 10., 13.,  6.,  2., 15.], dtype=float32),
 array([12., 25., 27., 20., 33., 28., 19., 16., 22.,  3.], dtype=float32),
 array([23., 24., 32., 29.,  5., 39., 11., 37., 43., 44.], dtype=float32),
 array([40., 34., 46., 52., 41., 26., 55.,  9., 30.,  4.], dtype=float32),
 array([50., 45., 36., 61., 51., 62., 35., 65., 60., 59.], dtype=float32),
 array([68., 48., 56.,  8., 58., 53., 74., 49., 66., 78.], dtype=float32),
 array([17., 64., 67., 72., 80., 42., 57., 85., 63., 83.], dtype=float32),
 array([88., 89., 91., 81.,  0., 87., 71., 96., 38., 93.], dtype=float32),
 array([99., 47., 82., 94., 84., 86., 79., 73., 95., 70.], dtype=float32),
 array([90., 77., 69., 97., 54., 98., 92., 75., 31., 76.], dtype=float32)]

## Try with interleave

In [20]:
datasets = [tf.data.TFRecordDataset(fn).map(decode).shuffle(20).batch(2) for fn in filenames]

In [21]:
dataset = tf.data.Dataset.zip(tuple(datasets))

In [22]:
batches = dataset.as_numpy_iterator()

In [23]:
b = list(batches)[0]

In [25]:
np.array(b)

array([[ 4.,  7.],
       [16., 10.],
       [27., 22.],
       [37., 33.],
       [43., 44.],
       [57., 50.],
       [69., 61.],
       [71., 77.],
       [80., 88.],
       [97., 99.]], dtype=float32)

In [87]:
# batches = dataset.shuffle(20).batch(10).as_numpy_iterator()

In [79]:
bb

[array([ 0., 11., 21., 30., 24.,  2., 13.,  1., 10., 25.], dtype=float32),
 array([ 6., 17., 22.,  4., 20., 16., 32., 18.,  9., 29.], dtype=float32),
 array([ 8.,  5., 19., 50., 26.,  3., 27., 12., 70., 31.], dtype=float32),
 array([38., 42., 52., 40., 37.,  7., 53., 61., 71., 44.], dtype=float32),
 array([60., 23., 34., 39., 74., 28., 54., 55., 56., 73.], dtype=float32),
 array([51., 76., 43., 47., 58., 67., 45., 68., 15., 48.], dtype=float32),
 array([35., 57., 36., 49., 81., 41., 91., 66., 62., 46.], dtype=float32),
 array([75., 65., 59., 64., 84., 82., 72., 69., 14., 97.], dtype=float32),
 array([79., 94., 95., 63., 80., 88., 99., 87., 78., 92.], dtype=float32),
 array([93., 96., 77., 33., 90., 86., 85., 98., 83., 89.], dtype=float32)]

In [110]:
dataset = tf.data.Dataset.from_tensor_slices(filenames)

In [34]:
dataset

<TensorSliceDataset shapes: (), types: tf.string>

In [268]:
raw_ds = tf.data.TFRecordDataset([filename])

In [269]:
features = {
    'time': tf.io.FixedLenFeature([], tf.string),
    'data': tf.io.FixedLenFeature([], tf.string)
}

In [270]:
def _parse(example_proto):
    return tf.io.parse_single_example(example_proto, features)

In [None]:
def decode(example_proto):
    dic = _parse(example_proto)
    # Get the data and normalize!

In [271]:
parsed_ds = raw_ds.map(_parse)

In [272]:
parsed_ds

<MapDataset shapes: {data: (), time: ()}, types: {data: tf.string, time: tf.string}>

In [290]:
parsed_ds = parsed_ds.shuffle(5)

In [291]:
for p in parsed_ds.take(9):
    print(p['time'])

tf.Tensor(b'2015-01-01T00:00:00.000000000', shape=(), dtype=string)
tf.Tensor(b'2015-01-01T05:00:00.000000000', shape=(), dtype=string)
tf.Tensor(b'2015-01-01T04:00:00.000000000', shape=(), dtype=string)
tf.Tensor(b'2015-01-01T06:00:00.000000000', shape=(), dtype=string)
tf.Tensor(b'2015-01-01T01:00:00.000000000', shape=(), dtype=string)


DataLossError: truncated record at 2802609

In [282]:
o = next(iter(parsed_ds))

In [283]:
o['time']

<tf.Tensor: shape=(), dtype=string, numpy=b'2015-01-01T00:00:00.000000000'>

## Implement from data generator

In [8]:
dg_train = DataGenerator(
    ds_train, args['var_dict'], args['lead_time'], batch_size=args['batch_size'], output_vars=args['output_vars'],
    data_subsample=1, norm_subsample=10, nt_in=1, dt_in=args['dt_in'], load=True, normalize=False, 
    shuffle=False
)

In [9]:
def _tensor_feature(value):
    value = tf.io.serialize_tensor(value).numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [10]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [11]:
def serialize_example(time, data):
    feature = {
        'time': _bytes_feature(time),
        'data': _tensor_feature(data)
    }
    feature = tf.train.Features(feature=feature)
    example_proto = tf.train.Example(features=feature)
    return example_proto.SerializeToString()

In [12]:
savedir = '/data/stephan/WeatherBench/tfrecords/'
!mkdir -p $savedir

In [13]:
filename = f'{savedir}test1.tfrecord'

In [322]:
!rm $filename

In [323]:
writer = tf.io.TFRecordWriter(filename)

In [324]:
for t in tqdm(dg_train.data.time):
    time = str(t.values).encode('utf-8')
    data = dg_train.data.sel(time=t).values.astype('float32')
    serialized_example = serialize_example(time, data)
    writer.write(serialized_example)

HBox(children=(FloatProgress(value=0.0, max=8760.0), HTML(value='')))




In [325]:
writer.close()

In [14]:
!ls -lh $savedir

total 3,0G
-rw-rw-r-- 1 rasp rasp 221M Jun  7 15:07 month_01.tfrecord
-rw-rw-r-- 1 rasp rasp 200M Jun  7 16:06 month_02.tfrecord
-rw-rw-r-- 1 rasp rasp 2,6G Jun  7 15:01 test1.tfrecord


Too large. 1 year = 2.6G. Should be around 100-200G. What about one month? 2.6/12 = 216M. Let's try reading it to check whether the different lengths would be a problem.

In [56]:
np.unique(ds.time.dt.year)

[autoreload of src.data_generator failed: Traceback (most recent call last):
  File "/home/rasp/miniconda3/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/home/rasp/miniconda3/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 394, in superreload
    module = reload(module)
  File "/home/rasp/miniconda3/lib/python3.7/imp.py", line 314, in reload
    return importlib.reload(module)
  File "/home/rasp/miniconda3/lib/python3.7/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 630, in _exec
  File "<frozen importlib._bootstrap_external>", line 724, in exec_module
  File "<frozen importlib._bootstrap_external>", line 860, in get_code
  File "<frozen importlib._bootstrap_external>", line 791, in source_to_code
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/home/rasp/repositor

'[1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992\n 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006\n 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018]'

In [330]:
for m in tqdm(range(1, 3)):
    data_slice = dg_train.data.sel(time=f'2015-{str(m).zfill(2)}')
    filename = f'{savedir}month_{str(m).zfill(2)}.tfrecord'
    writer = tf.io.TFRecordWriter(filename)
    for t in tqdm(data_slice.time):
        time = str(t.values).encode('utf-8')
        data = data_slice.sel(time=t).values.astype('float32')
        serialized_example = serialize_example(time, data)
        writer.write(serialized_example)

HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=744.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=672.0), HTML(value='')))





In [331]:
write.close()

In [332]:
!ls -lh $savedir

total 3,0G
-rw-rw-r-- 1 rasp rasp 221M Jun  7 15:07 month_01.tfrecord
-rw-rw-r-- 1 rasp rasp 200M Jun  7 15:07 month_02.tfrecord
-rw-rw-r-- 1 rasp rasp 2,6G Jun  7 15:01 test1.tfrecord


### Now read the data

In [22]:
features = {
    'time': tf.io.FixedLenFeature([], tf.string),
    'data': tf.io.FixedLenFeature([], tf.string)
}

In [15]:
def _parse(example_proto):
    return tf.io.parse_single_example(example_proto, features)
def decode(example_proto):
    dic = _parse(example_proto)
    time = dic['time']
    data = dic['data']
    # Normalize I guess
    return tf.io.parse_tensor(data, np.float32)
def decode_time(example_proto):
    dic = _parse(example_proto)
    time = dic['time']
    data = dic['data']
    # Normalize I guess
    return time

In [16]:
lead_time=72

In [17]:
filenames = [f'{savedir}month_{str(m).zfill(2)}.tfrecord' for m in range(1, 3)]

In [78]:
raw_ds = tf.data.TFRecordDataset(filenames)

In [79]:
parsed_ds = raw_ds.map(decode_time)

In [80]:
pp = [p for p in parsed_ds]

In [81]:
len(pp)

1416

In [82]:
parsed_ds._

TypeError: 'MapDataset' object is not subscriptable

In [35]:
batch = parsed_ds

In [36]:
batch2 = parsed_ds.skip(lead_time)

In [26]:
bb = [b for b in batch]

In [27]:
len(bb)

44

In [37]:
combined = tf.data.Dataset.zip((batch, batch2))

In [47]:
cc = combined.batch(32)

In [85]:
ccc = cc.__iter__()

In [86]:
next(ccc)

(<tf.Tensor: shape=(32,), dtype=string, numpy=
 array([b'2015-01-01T00:00:00.000000000', b'2015-01-01T01:00:00.000000000',
        b'2015-01-01T02:00:00.000000000', b'2015-01-01T03:00:00.000000000',
        b'2015-01-01T04:00:00.000000000', b'2015-01-01T05:00:00.000000000',
        b'2015-01-01T06:00:00.000000000', b'2015-01-01T07:00:00.000000000',
        b'2015-01-01T08:00:00.000000000', b'2015-01-01T09:00:00.000000000',
        b'2015-01-01T10:00:00.000000000', b'2015-01-01T11:00:00.000000000',
        b'2015-01-01T12:00:00.000000000', b'2015-01-01T13:00:00.000000000',
        b'2015-01-01T14:00:00.000000000', b'2015-01-01T15:00:00.000000000',
        b'2015-01-01T16:00:00.000000000', b'2015-01-01T17:00:00.000000000',
        b'2015-01-01T18:00:00.000000000', b'2015-01-01T19:00:00.000000000',
        b'2015-01-01T20:00:00.000000000', b'2015-01-01T21:00:00.000000000',
        b'2015-01-01T22:00:00.000000000', b'2015-01-01T23:00:00.000000000',
        b'2015-01-02T00:00:00.000000000',

In [48]:
bb = [b for b in cc]

In [49]:
len(bb)

42

In [50]:
bb[0], bb[-1]

((<tf.Tensor: shape=(32,), dtype=string, numpy=
  array([b'2015-01-01T00:00:00.000000000', b'2015-01-01T01:00:00.000000000',
         b'2015-01-01T02:00:00.000000000', b'2015-01-01T03:00:00.000000000',
         b'2015-01-01T04:00:00.000000000', b'2015-01-01T05:00:00.000000000',
         b'2015-01-01T06:00:00.000000000', b'2015-01-01T07:00:00.000000000',
         b'2015-01-01T08:00:00.000000000', b'2015-01-01T09:00:00.000000000',
         b'2015-01-01T10:00:00.000000000', b'2015-01-01T11:00:00.000000000',
         b'2015-01-01T12:00:00.000000000', b'2015-01-01T13:00:00.000000000',
         b'2015-01-01T14:00:00.000000000', b'2015-01-01T15:00:00.000000000',
         b'2015-01-01T16:00:00.000000000', b'2015-01-01T17:00:00.000000000',
         b'2015-01-01T18:00:00.000000000', b'2015-01-01T19:00:00.000000000',
         b'2015-01-01T20:00:00.000000000', b'2015-01-01T21:00:00.000000000',
         b'2015-01-01T22:00:00.000000000', b'2015-01-01T23:00:00.000000000',
         b'2015-01-02T00:00:

In [399]:
for b, c in cc.take():
    print(b, c)

tf.Tensor(
[[b'2015-01-02T14:00:00.000000000' b'2015-01-02T15:00:00.000000000']
 [b'2015-01-04T14:00:00.000000000' b'2015-01-04T15:00:00.000000000']
 [b'2015-01-03T14:00:00.000000000' b'2015-01-03T15:00:00.000000000']
 [b'2015-01-09T04:00:00.000000000' b'2015-01-09T05:00:00.000000000']
 [b'2015-01-07T20:00:00.000000000' b'2015-01-07T21:00:00.000000000']
 [b'2015-01-09T06:00:00.000000000' b'2015-01-09T07:00:00.000000000']
 [b'2015-01-05T08:00:00.000000000' b'2015-01-05T09:00:00.000000000']
 [b'2015-01-04T06:00:00.000000000' b'2015-01-04T07:00:00.000000000']
 [b'2015-01-07T16:00:00.000000000' b'2015-01-07T17:00:00.000000000']
 [b'2015-01-02T12:00:00.000000000' b'2015-01-02T13:00:00.000000000']
 [b'2015-01-03T02:00:00.000000000' b'2015-01-03T03:00:00.000000000']
 [b'2015-01-04T22:00:00.000000000' b'2015-01-04T23:00:00.000000000']
 [b'2015-01-02T08:00:00.000000000' b'2015-01-02T09:00:00.000000000']
 [b'2015-01-08T02:00:00.000000000' b'2015-01-08T03:00:00.000000000']
 [b'2015-01-07T00:00:00

DataLossError: truncated record at 208950071

In [382]:
for b in batch.take(1):
    pass

In [383]:
b

<tf.Tensor: shape=(2,), dtype=string, numpy=
array([b'2015-01-01T00:00:00.000000000', b'2015-01-01T01:00:00.000000000'],
      dtype=object)>

In [384]:
X = next(iter(parsed_ds))

In [387]:
X.numpy()

b'2015-01-01T00:00:00.000000000'

## Now the full implementation

In [6]:
!rm /data/stephan/WeatherBench/tfrecords/ERA/*

In [7]:
ds_train = ds.sel(time=slice('2015', '2015'))

In [8]:
dg_train = DataGenerator(
    ds_train, args['var_dict'], args['lead_time'], batch_size=args['batch_size'], output_vars=args['output_vars'],
    data_subsample=1, norm_subsample=10, nt_in=1, dt_in=args['dt_in'], load=True, normalize=False, shuffle=False
)

In [9]:
savedir = '/data/stephan/WeatherBench/tfrecords/ERA/'
!mkdir -p $savedir

In [10]:
dg_train.to_tfrecord(savedir)

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [12]:
dg_train.mean.to_netcdf(f'{savedir}mean.nc')
dg_train.std.to_netcdf(f'{savedir}std.nc')

In [480]:
fns = sorted(glob('/data/stephan/WeatherBench/tfrecords/ERA/2015-*.tfrecord'))

In [481]:
features = {
    'time': tf.io.FixedLenFeature([], tf.string),
    'data': tf.io.FixedLenFeature([], tf.string)
}

In [482]:
def _parse(example_proto):
    return tf.io.parse_single_example(example_proto, features)

In [596]:
def decode(example_proto):
    dic = _parse(example_proto)
    time = dic['time']
    data = dic['data']
#     return tf.io.parse_tensor(data, np.float32)
    return time

In [214]:
len(fns)

12

In [485]:
nt_in = 3
dt_in = 2

In [486]:
lead_time=72

In [487]:
fns 

['/data/stephan/WeatherBench/tfrecords/ERA/2015-01.tfrecord',
 '/data/stephan/WeatherBench/tfrecords/ERA/2015-02.tfrecord',
 '/data/stephan/WeatherBench/tfrecords/ERA/2015-03.tfrecord',
 '/data/stephan/WeatherBench/tfrecords/ERA/2015-04.tfrecord',
 '/data/stephan/WeatherBench/tfrecords/ERA/2015-05.tfrecord',
 '/data/stephan/WeatherBench/tfrecords/ERA/2015-06.tfrecord',
 '/data/stephan/WeatherBench/tfrecords/ERA/2015-07.tfrecord',
 '/data/stephan/WeatherBench/tfrecords/ERA/2015-08.tfrecord',
 '/data/stephan/WeatherBench/tfrecords/ERA/2015-09.tfrecord',
 '/data/stephan/WeatherBench/tfrecords/ERA/2015-10.tfrecord',
 '/data/stephan/WeatherBench/tfrecords/ERA/2015-11.tfrecord',
 '/data/stephan/WeatherBench/tfrecords/ERA/2015-12.tfrecord']

In [488]:
files_per_dataset = 5

In [489]:
fnss = [fns[i*files_per_dataset:i*files_per_dataset+files_per_dataset] for i in range(len(fns)//files_per_dataset)]

In [490]:
fnss

[['/data/stephan/WeatherBench/tfrecords/ERA/2015-01.tfrecord',
  '/data/stephan/WeatherBench/tfrecords/ERA/2015-02.tfrecord',
  '/data/stephan/WeatherBench/tfrecords/ERA/2015-03.tfrecord',
  '/data/stephan/WeatherBench/tfrecords/ERA/2015-04.tfrecord',
  '/data/stephan/WeatherBench/tfrecords/ERA/2015-05.tfrecord'],
 ['/data/stephan/WeatherBench/tfrecords/ERA/2015-06.tfrecord',
  '/data/stephan/WeatherBench/tfrecords/ERA/2015-07.tfrecord',
  '/data/stephan/WeatherBench/tfrecords/ERA/2015-08.tfrecord',
  '/data/stephan/WeatherBench/tfrecords/ERA/2015-09.tfrecord',
  '/data/stephan/WeatherBench/tfrecords/ERA/2015-10.tfrecord']]

In [597]:
nt_in = 3
dt_in = 2
nt_offset = (nt_in - 1) * dt_in

In [598]:
lead_time=72
cont_time=True

In [608]:
def fnc(fn):
    window_size = lead_time + nt_offset + 1
    d = tf.data.TFRecordDataset(fn)
    d = d.map(decode).window(window_size, shift=1, drop_remainder=True).flat_map(
        lambda window: window.batch(window_size))
    if cont_time:
        y_slice = slice(nt_offset+1, None)
    else:
        y_slice = -1
    if nt_in > 1:
        d = d.map(lambda window: ([window[n*dt_in] for n in range(nt_in)], window[y_slice]))
    else:
        d = d.map(lambda window: (window[0], window[y_slice]))
    return d

In [609]:
dataset = tf.data.Dataset.from_tensor_slices(fnss)

In [610]:
dataset = dataset.interleave(fnc, cycle_length=4, block_length=1, num_parallel_calls=1)

In [611]:
dataset = dataset.repeat().batch(2)
# dataset = dataset.shuffle(1000).repeat().batch(10)

In [612]:
for e in dataset.take(1):
    print(e)
    print('\n')

(<tf.Tensor: shape=(2, 3), dtype=string, numpy=
array([[b'2015-01-01T00:00:00.000000000',
        b'2015-01-01T02:00:00.000000000',
        b'2015-01-01T04:00:00.000000000'],
       [b'2015-06-01T00:00:00.000000000',
        b'2015-06-01T02:00:00.000000000',
        b'2015-06-01T04:00:00.000000000']], dtype=object)>, <tf.Tensor: shape=(2, 72), dtype=string, numpy=
array([[b'2015-01-01T05:00:00.000000000',
        b'2015-01-01T06:00:00.000000000',
        b'2015-01-01T07:00:00.000000000',
        b'2015-01-01T08:00:00.000000000',
        b'2015-01-01T09:00:00.000000000',
        b'2015-01-01T10:00:00.000000000',
        b'2015-01-01T11:00:00.000000000',
        b'2015-01-01T12:00:00.000000000',
        b'2015-01-01T13:00:00.000000000',
        b'2015-01-01T14:00:00.000000000',
        b'2015-01-01T15:00:00.000000000',
        b'2015-01-01T16:00:00.000000000',
        b'2015-01-01T17:00:00.000000000',
        b'2015-01-01T18:00:00.000000000',
        b'2015-01-01T19:00:00.000000000',
   

In [613]:
dataset = dataset.as_numpy_iterator()

In [614]:
X, y = next(dataset)

In [627]:
nt = np.random.randint(1, lead_time + 1, 2)

In [628]:
y.shape

(2, 72)

In [629]:
nt

array([23, 24])

In [630]:
y[np.arange(len(y)), nt]

array([b'2015-01-02T04:00:00.000000000', b'2015-06-02T05:00:00.000000000'],
      dtype=object)

In [623]:
y[]

IndexError: index 65 is out of bounds for axis 0 with size 2

In [620]:
X

array([[b'2015-01-01T00:00:00.000000000',
        b'2015-01-01T02:00:00.000000000',
        b'2015-01-01T04:00:00.000000000'],
       [b'2015-06-01T00:00:00.000000000',
        b'2015-06-01T02:00:00.000000000',
        b'2015-06-01T04:00:00.000000000']], dtype=object)

In [619]:
y[:, nt]

array([[b'2015-01-03T22:00:00.000000000',
        b'2015-01-02T14:00:00.000000000'],
       [b'2015-06-03T22:00:00.000000000',
        b'2015-06-02T14:00:00.000000000']], dtype=object)

In [563]:
ds_train = ds.sel(time=slice('2015', '2015'))

In [569]:
dg_train = DataGenerator(
    ds_train, args['var_dict'], args['lead_time'], batch_size=args['batch_size'], output_vars=args['output_vars'],
    data_subsample=1, norm_subsample=10, nt_in=3, dt_in=2, load=True, normalize=False, shuffle=False,
    mean=xr.open_dataarray(f'{savedir}mean.nc'), std=xr.open_dataarray(f'{savedir}std.nc')
)

In [570]:
Xref, yref = dg_train[0]

In [571]:
Xref.shape, yref.shape

((32, 32, 64, 114), (32, 32, 64, 2))

In [586]:
Xref[:5, 0, 0, np.array([0, 38, 38*2])]

array([[201331.  , 201344.16, 201337.6 ],
       [201314.56, 201314.56, 201331.  ],
       [201344.16, 201337.6 , 201373.75],
       [201314.56, 201331.  , 201383.61],
       [201337.6 , 201373.75, 201423.06]], dtype=float32)

In [573]:
yref[:5, 0, 0, 0]

array([50566.88 , 50570.164, 50583.316, 50599.754, 50593.176],
      dtype=float32)

### Ok here comes the hard part

Reading in the data in exactly the same way

In [574]:
tfrecord_files = '/data/stephan/WeatherBench/tfrecords/ERA/2015*'

In [671]:
dg_tfr = DataGenerator(
    ds_train, args['var_dict'], args['lead_time'], batch_size=args['batch_size'], output_vars=args['output_vars'],
    data_subsample=1, norm_subsample=10, nt_in=3, dt_in=2, load=False, normalize=False, shuffle=False,
    mean=xr.open_dataarray(f'{savedir}mean.nc'), std=xr.open_dataarray(f'{savedir}std.nc'),
    tfrecord_files=tfrecord_files, tfr_return_time=False, tfr_cycle_length=6, tfr_num_parallel_calls=6,
    tfr_buffer_size=100, tfr_fpds=2, cont_time=True
)

In [672]:
%lprun -f dg_tfr._get_tfrecord_item dg_tfr[0]

Timer unit: 1e-06 s

Total time: 6.94695 s
File: /home/rasp/repositories/myWeatherBench/devlog/src/data_generator.py
Function: _get_tfrecord_item at line 292

Line #      Hits         Time  Per Hit   % Time  Line Contents
   292                                               def _get_tfrecord_item(self, i):
   293         1    6853931.0 6853931.0     98.7          X, y = next(self.tfr_dataset)
   294                                           
   295         1          5.0      5.0      0.0          if self.cont_time:
   296                                                       # y will have lead_time as second dimension
   297         1          2.0      2.0      0.0              if not self.fixed_time:
   298         1          1.0      1.0      0.0                  if self.min_lead_time is None:
   299         1          1.0      1.0      0.0                      min_nt = 0
   300                                                           else:
   301                                   

In [None]:
%%time
Xtf, ytf = next(iter(dg_tfr))

In [583]:
Xtf.shape, ytf.shape

((32, 32, 64, 114), (32, 32, 64, 2))

In [588]:
Xtf[:5, 0, 0, np.array([0, 38, 38*2])]

array([[201331.  , 201344.16, 201337.6 ],
       [201314.56, 201314.56, 201331.  ],
       [201344.16, 201337.6 , 201373.75],
       [201314.56, 201331.  , 201383.61],
       [201337.6 , 201373.75, 201423.06]], dtype=float32)

In [585]:
ytf[:5, 0, 0, 0]

array([50566.88 , 50570.164, 50583.316, 50599.754, 50593.176],
      dtype=float32)

In [580]:
%%timeit 
next(iter(dg_tfr))

334 ms ± 15.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Now try training a model

In [37]:
model_ref = build_resnet(
    [64, 64, 2], [7, 3, 3], input_shape=dg_tfr.shape,
    bn_position=args['bn_position'], use_bias=args['use_bias'], l2=args['l2'], skip=args['skip'],
    dropout=args['dropout'], activation=args['activation']
)

In [38]:
model_ref.compile('adam', 'mse')

In [39]:
model_ref.fit(dg_tfr)

  ...
    to  
  ['...']
Train for 4656 steps


<tensorflow.python.keras.callbacks.History at 0x7f17da719bd0>

In [239]:

dataset = tf.data.Dataset.from_tensor_slices(np.arange(20))
dataset = dataset.shuffle(buffer_size=20)

In [240]:
dataset = dataset.batch(10)

In [241]:
bb = [b for b in dataset]

In [242]:
bb

[<tf.Tensor: shape=(10,), dtype=int64, numpy=array([17, 10,  6,  9,  7,  5, 18, 11,  0,  4])>,
 <tf.Tensor: shape=(10,), dtype=int64, numpy=array([12,  3, 15, 14, 16, 13, 19,  8,  2,  1])>]