In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import pandas as pd
import sklearn
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras

print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
    print(module.__name__, module.__version__)

2.3.1
sys.version_info(major=3, minor=6, micro=7, releaselevel='final', serial=0)
matplotlib 3.3.2
numpy 1.18.4
pandas 1.1.4
sklearn 0.23.2
tensorflow 2.3.1
tensorflow.keras 2.4.0


## Read data from CSV files

In [5]:
source_dir = "./generate_csv"
print(os.listdir(source_dir))

def get_filenames_by_prefix(source_dir, prefix_name):
    all_files = os.listdir(source_dir)
    results = []
    for filename in all_files:
        if filename.startswith(prefix_name):
            results.append(os.path.join(source_dir, filename))
    return results

train_filenames = get_filenames_by_prefix(source_dir, "train")
valid_filenames = get_filenames_by_prefix(source_dir, "valid")
test_filenames = get_filenames_by_prefix(source_dir, "test")

['test_00.csv', 'test_01.csv', 'test_02.csv', 'test_03.csv', 'test_04.csv', 'test_05.csv', 'test_06.csv', 'test_07.csv', 'test_08.csv', 'test_09.csv', 'train_00.csv', 'train_01.csv', 'train_02.csv', 'train_03.csv', 'train_04.csv', 'train_05.csv', 'train_06.csv', 'train_07.csv', 'train_08.csv', 'train_09.csv', 'train_10.csv', 'train_11.csv', 'train_12.csv', 'train_13.csv', 'train_14.csv', 'train_15.csv', 'train_16.csv', 'train_17.csv', 'train_18.csv', 'train_19.csv', 'valid_00.csv', 'valid_01.csv', 'valid_02.csv', 'valid_03.csv', 'valid_04.csv', 'valid_05.csv', 'valid_06.csv', 'valid_07.csv', 'valid_08.csv', 'valid_09.csv']


In [9]:
def parse_csv_line(line, n_fields = 9):
    defs = [tf.constant(np.nan)] * n_fields
    parsed_fields = tf.io.decode_csv(line, record_defaults=defs)
    x = tf.stack(parsed_fields[0:-1])
    y = tf.stack(parsed_fields[-1:])
    return x, y

def csv_reader_dataset(filenames, n_readers=5,
                       batch_size=32, n_parse_threads=5,
                       shuffle_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()
    dataset = dataset.interleave(
        lambda filename: tf.data.TextLineDataset(filename).skip(1),
        cycle_length = n_readers
    )
    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(parse_csv_line,
                          num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

batch_size = 32
train_set = csv_reader_dataset(train_filenames, batch_size = batch_size)
valid_set = csv_reader_dataset(valid_filenames, batch_size = batch_size)
test_set = csv_reader_dataset(test_filenames, batch_size = batch_size)

In [11]:
def serialize_examples(x, y):
    """converts x, y to tf.train.Example an serialize"""
    input_features = tf.train.FloatList(value = x)
    label = tf.train.FloatList(value = y)
    features = tf.train.Features(
        feature = {
            "input_features": tf.train.Feature(float_list=input_features),
            "label": tf.train.Feature(float_list=label)
        }
    )
    example = tf.train.Example(features=features)
    return example.SerializePartialToString()

def csv_dataset_to_tfrecords(
    base_filename, dataset, n_shards, steps_per_shard, compression_type=None):
    options = tf.io.TFRecordOptions(compression_type=compression_type)
    all_filenames = []
    for shard_id in range(n_shards):
        filename_fullpath = '{}_{:05d}-of-{:05d}'.format(
            base_filename, shard_id, n_shards)
        with tf.io.TFRecordWriter(filename_fullpath, options) as writer:
            for x_batch, y_batch in dataset.take(steps_per_shard):
                for x_example, y_example in zip(x_batch, y_batch):
                    writer.write(serialize_examples(x_example, y_example))
        all_filenames.append(filename_fullpath)
    return all_filenames

### Write tfrecords

In [14]:
n_shards = 20   # 分成几个    一共有多少个 batch   一个 shard 上有多少个 batch
train_steps_per_shard = 11610 // batch_size // n_shards
valid_steps_per_shard = 3880 // batch_size // n_shards
test_steps_per_shard = 5170 // batch_size // n_shards

output_dir = "generate_tfrecords"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

train_basename = os.path.join(output_dir, "train")
valid_basename = os.path.join(output_dir, "valid")
test_basename = os.path.join(output_dir, "test")

train_tfrecord_filenames = csv_dataset_to_tfrecords(
    train_basename, train_set, n_shards, train_steps_per_shard, None)
valid_tfrecord_filenames = csv_dataset_to_tfrecords(
    valid_basename, valid_set, n_shards, valid_steps_per_shard, None)
test_tfrecord_fielnames = csv_dataset_to_tfrecords(
    test_basename, test_set, n_shards, test_steps_per_shard, None)

### Write zipped tfrecords 

In [15]:
n_shards = 20
train_steps_per_shard = 11610 // batch_size // n_shards
valid_steps_per_shard = 3880 // batch_size // n_shards
test_steps_per_shard = 5170 // batch_size // n_shards

output_dir = "generate_tfrecords_zip"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

train_basename = os.path.join(output_dir, "train")
valid_basename = os.path.join(output_dir, "valid")
test_basename = os.path.join(output_dir, "test")

train_tfrecord_filenames = csv_dataset_to_tfrecords(
    train_basename, train_set, n_shards, train_steps_per_shard,
    compression_type = "GZIP")
valid_tfrecord_filenames = csv_dataset_to_tfrecords(
    valid_basename, valid_set, n_shards, valid_steps_per_shard,
    compression_type = "GZIP")
test_tfrecord_fielnames = csv_dataset_to_tfrecords(
    test_basename, test_set, n_shards, test_steps_per_shard,
    compression_type = "GZIP")

## Read tfrecords

In [16]:
# 要想解析一个 example, 需要定义一个每个 field 的字典
expected_features = {
    "input_features": tf.io.FixedLenFeature([8], dtype=tf.float32),
    "label": tf.io.FixedLenFeature([1], dtype=tf.float32)
}

# 要定义一个 map 函数
def parse_example(serialized_example):
    example = tf.io.parse_single_example(serialized_example, expected_features)
    return example["input_features"], example["label"]

# 读取文件
def tfrecords_reader_dataset(filenames, n_readers=5, batch_size=32, 
                             n_parse_threads=5, shuffle_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()
    dataset = dataset.interleave(
        lambda filename: tf.data.TFRecordDataset(filename, compression_type = "GZIP"),
        cycle_length = n_readers
    )
    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(parse_example, num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

tfrecords_train = tfrecords_reader_dataset(train_tfrecord_filenames, batch_size = 3)
for x_batch, y_batch in tfrecords_train.take(10):
    print(x_batch)
    print(y_batch)

tf.Tensor(
[[-1.0775077  -0.4487407  -0.5680568  -0.14269263 -0.09666677  0.12326469
  -0.31448638 -0.4818959 ]
 [ 0.4240821   0.91296333 -0.04437482 -0.15297213 -0.24727628 -0.10539167
   0.86126745 -1.335779  ]
 [ 0.4240821   0.91296333 -0.04437482 -0.15297213 -0.24727628 -0.10539167
   0.86126745 -1.335779  ]], shape=(3, 8), dtype=float32)
tf.Tensor(
[[0.978]
 [3.955]
 [3.955]], shape=(3, 1), dtype=float32)
tf.Tensor(
[[ 0.63636464 -1.0895426   0.09260903 -0.20538124  1.2025671  -0.03630123
  -0.6784102   0.18223535]
 [ 0.8015443   0.27216142 -0.11624393 -0.20231152 -0.5430516  -0.02103962
  -0.5897621  -0.08241846]
 [ 2.5150437   1.0731637   0.5574401  -0.17273512 -0.6129126  -0.01909157
  -0.5710993  -0.02749031]], shape=(3, 8), dtype=float32)
tf.Tensor(
[[2.429  ]
 [3.226  ]
 [5.00001]], shape=(3, 1), dtype=float32)
tf.Tensor(
[[-1.0591781   1.3935647  -0.02633197 -0.1100676  -0.6138199  -0.09695935
   0.3247131  -0.03747724]
 [-1.1157656   0.99306357 -0.334192   -0.06535219 -0.3

In [17]:
# 生成训练集
batch_size = 32
tfrecords_train_set = tfrecords_reader_dataset(
    train_tfrecord_filenames, batch_size = batch_size)
tfrecords_valid_set = tfrecords_reader_dataset(
    valid_tfrecord_filenames, batch_size = batch_size)
tfrecords_test_set = tfrecords_reader_dataset(
    test_tfrecord_fielnames, batch_size = batch_size)