# Is it worth the effort to use TFRecords?

What are the potential benefits? 

- Speed
- No need for pre-stacking vectors --> More flexible during training

In [2]:
import xarray as xr
import numpy as np
import tensorflow as tf

## Load some example data

In [1]:
DATADIR = '/local/S.Rasp/sp32fbp_andkua/'

In [3]:
aqua = xr.open_dataset(DATADIR + 'AndKua_aqua_SPCAM3.0_sp_fbp32.cam2.h1.0000-01-02-00000.nc', decode_times=False)

In [4]:
aqua

<xarray.Dataset>
Dimensions:       (crm_x: 32, crm_y: 1, crm_z: 28, ilev: 31, isccp_prs: 7, isccp_prstau: 49, isccp_tau: 7, lat: 64, lev: 30, lon: 128, tbnd: 2, time: 48)
Coordinates:
  * lat           (lat) float64 -87.86 -85.1 -82.31 -79.53 ... 82.31 85.1 87.86
  * lon           (lon) float64 0.0 2.812 5.625 8.438 ... 351.6 354.4 357.2
  * crm_x         (crm_x) float64 0.0 4.0 8.0 12.0 ... 112.0 116.0 120.0 124.0
  * crm_y         (crm_y) float64 0.0
  * crm_z         (crm_z) float64 992.6 976.3 957.5 936.2 ... 38.27 24.61 14.36
  * lev           (lev) float64 3.643 7.595 14.36 24.61 ... 957.5 976.3 992.6
  * ilev          (ilev) float64 2.255 5.032 10.16 18.56 ... 967.5 985.1 1e+03
  * isccp_prs     (isccp_prs) float64 90.0 245.0 375.0 500.0 620.0 740.0 900.0
  * isccp_tau     (isccp_tau) float64 0.15 0.8 2.45 6.5 16.2 41.5 219.5
  * isccp_prstau  (isccp_prstau) float64 90.0 90.0 90.0 ... 900.0 900.0 900.2
  * time          (time) float64 1.0 1.021 1.042 1.062 ... 1.938 1.958 1.979


## Write to TFRecords file

In [189]:
tfr_fn = '/local/S.Rasp/tmp/test.tfrecords'

In [190]:
!rm /local/S.Rasp/tmp/test.tfrecords

In [191]:
writer = tf.python_io.TFRecordWriter(tfr_fn)

In [192]:
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

In [193]:
vars = [
    'TAP', 'QAP', 'VAP', 'PS', 'SOLIN', 'SHFLX', 'LHFLX', 
    'TPHYSTND', 'PHQ', 'FSNT', 'FSNS', 'FLNT', 'FLNS', 'PRECT',
    'time', 'lat', 'lon'
]

In [194]:
from tqdm import tqdm_notebook as tqdm

In [195]:
import pdb

In [196]:
def write_sample(it, ilat, ilon):
    column = aqua.isel(time=it, lat=ilat, lon=ilon)
    feature = {v:_bytes_feature(tf.compat.as_bytes(column[v].values.tostring())) for v in vars}
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    writer.write(example.SerializeToString())

In [199]:
%load_ext line_profiler

In [203]:
%%time
aqua.load()

CPU times: user 348 ms, sys: 560 ms, total: 908 ms
Wall time: 1.27 s


<xarray.Dataset>
Dimensions:       (crm_x: 32, crm_y: 1, crm_z: 28, ilev: 31, isccp_prs: 7, isccp_prstau: 49, isccp_tau: 7, lat: 64, lev: 30, lon: 128, tbnd: 2, time: 48)
Coordinates:
  * lat           (lat) float64 -87.86 -85.1 -82.31 -79.53 ... 82.31 85.1 87.86
  * lon           (lon) float64 0.0 2.812 5.625 8.438 ... 351.6 354.4 357.2
  * crm_x         (crm_x) float64 0.0 4.0 8.0 12.0 ... 112.0 116.0 120.0 124.0
  * crm_y         (crm_y) float64 0.0
  * crm_z         (crm_z) float64 992.6 976.3 957.5 936.2 ... 38.27 24.61 14.36
  * lev           (lev) float64 3.643 7.595 14.36 24.61 ... 957.5 976.3 992.6
  * ilev          (ilev) float64 2.255 5.032 10.16 18.56 ... 967.5 985.1 1e+03
  * isccp_prs     (isccp_prs) float64 90.0 245.0 375.0 500.0 620.0 740.0 900.0
  * isccp_tau     (isccp_tau) float64 0.15 0.8 2.45 6.5 16.2 41.5 219.5
  * isccp_prstau  (isccp_prstau) float64 90.0 90.0 90.0 ... 900.0 900.0 900.2
  * time          (time) float64 1.0 1.021 1.042 1.062 ... 1.938 1.958 1.979


In [208]:
def write_all():
    for it, t in tqdm(enumerate(aqua.time[:1])):
        for ilat, lat in enumerate(aqua.lat):
            for ilon, lon in enumerate(aqua.lon):
                #print(it, ilat, ilon)
                write_sample(it, ilat, ilon)

In [209]:
%lprun -f write_sample write_all()

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

In [154]:
writer.close()

In [155]:
!ls -lh $tfr_fn

-rw-r--r-- 1 S.Rasp ls-craig 7.3M Jan 22 16:51 /local/S.Rasp/tmp/test.tfrecords


In [156]:
feature.keys()

dict_keys(['TAP', 'QAP', 'VAP', 'PS', 'SOLIN', 'SHFLX', 'LHFLX', 'TPHYSTND', 'PHQ', 'FSNT', 'FSNS', 'FLNT', 'FLNS', 'PRECT', 'time', 'lat', 'lon'])

In [77]:
feature = {
    'TAP': _bytes_feature(tf.compat.as_bytes(one_sample.TAP.values.tostring())),
    'QAP': _bytes_feature(tf.compat.as_bytes(one_sample.QAP.values.tostring()))
}

In [78]:
example = tf.train.Example(features=tf.train.Features(feature=feature))

In [79]:
example

features {
  feature {
    key: "QAP"
    value {
      bytes_list {
        value: "]\201\3775d\301\3755\370m\3525R!\3525t\263\3505$\327\3405\357~\3315.\204\3265\007U\3265Dr\3265\223\352\3315\026\321\3675>\213.6T\340\2226cv\0317(\335\2117\227\222D8\354\242\0329\226\"\3009!\201/:\027\022P:~\346]:\253T\246:\321\375\333:1\341\020;\023\223\';=\2338;\2712G;\202\274K;}\355[;"
      }
    }
  }
  feature {
    key: "TAP"
    value {
      bytes_list {
        value: "\327\310vC\225&hC\2264YC\203&nC0\324gCZ\177eC\2328eC\256%dC\304\376bCZ\321aC \230`C3\022^Ca\320[C\337=YC\274AXC\'\034ZCl\255_C\035\375hC}OrC\240\277yC\2241\177Cz\210\201C\343\355\202C+U\204C\007\024\205C\204\313\205CF\223\206C\364`\207Cs\030\210C\220\305\210C"
      }
    }
  }
}

In [80]:
for i in range(100): writer.write(example.SerializeToString())

In [81]:
writer.close()

## Read from TFRecords file

In [49]:
tf.enable_eager_execution()

In [50]:
tf.executing_eagerly() 

True

In [167]:
tfr_ds = tf.data.TFRecordDataset(tfr_fn)

In [168]:
tfr_ds.output_shapes

TensorShape([])

In [169]:
ds = tfr_ds.repeat(10)

In [170]:
def _read_from_tfrecord(example_proto):
    feature = {v: tf.FixedLenFeature([], tf.string) for v in vars}

    features = tf.parse_example([example_proto], features=feature)

    out = [tf.decode_raw(features[v], tf.float32) for v in vars]
    return out

In [171]:
ds = ds.map(_read_from_tfrecord)

In [172]:
ds.shuffle(1000)

<ShuffleDataset shapes: ((1, ?), (1, ?), (1, ?), (1, ?), (1, ?), (1, ?), (1, ?), (1, ?), (1, ?), (1, ?), (1, ?), (1, ?), (1, ?), (1, ?), (1, ?), (1, ?), (1, ?)), types: (tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32)>

In [174]:
bs = ds.batch(10)

In [175]:
x = next(iter(bs))

In [179]:
len(x), len(vars)

(17, 17)

In [181]:
vars

['TAP',
 'QAP',
 'VAP',
 'PS',
 'SOLIN',
 'SHFLX',
 'LHFLX',
 'TPHYSTND',
 'PHQ',
 'FSNT',
 'FSNS',
 'FLNT',
 'FLNS',
 'PRECT',
 'time',
 'lat',
 'lon']

In [187]:
x[-0].numpy()

array([[[246.67896, 231.9186 , 217.06544, 238.33183, 231.94052,
         229.52925, 229.34657, 228.38527, 227.2268 , 226.18579,
         225.0098 , 222.70345, 220.65616, 217.71031, 216.96951,
         218.21031, 223.66469, 231.7498 , 240.48862, 248.43716,
         254.49712, 258.68607, 261.989  , 264.54556, 265.72986,
         267.48657, 269.2063 , 270.7589 , 272.1433 , 273.44547]],

       [[246.68661, 231.91966, 217.0796 , 238.3068 , 231.94002,
         229.52675, 229.33952, 228.3665 , 227.20094, 226.17729,
         224.99347, 222.68636, 220.66388, 217.73404, 217.03984,
         218.28249, 223.6974 , 231.72603, 240.45279, 248.42583,
         254.49748, 258.67657, 261.9907 , 264.51477, 265.71988,
         267.4587 , 269.19016, 270.7647 , 272.1693 , 273.461  ]],

       [[246.694  , 231.9225 , 217.09715, 238.2795 , 231.9397 ,
         229.52489, 229.33363, 228.34792, 227.17583, 226.16989,
         224.9778 , 222.67046, 220.6744 , 217.76143, 217.11226,
         218.35725, 223.7323 , 231

In [137]:
x.shape, y.shape

(TensorShape([Dimension(10), Dimension(1), Dimension(30)]),
 TensorShape([Dimension(10), Dimension(1), Dimension(30)]))

In [93]:
x.shape

TensorShape([Dimension(10), Dimension(1), Dimension(30)])

In [67]:
bs = tfr_ds.map(_read_from_tfrecord).repeat().batch(1)

In [69]:
bs

<BatchDataset shapes: ((?, 1, ?), (?, 1, ?)), types: (tf.float32, tf.float32)>

## Quick conclusion

So I guess it theoretically works. I am confident I would figure out the shuffling, dimensions. BUT: Converting one time slice takes around 100 seconds. The means converting an entire year would