# tfrecord的学习使用

In [1]:
import matplotlib as mlt
import matplotlib.pyplot as plt
%matplotlib inline

import numpy as np
import pandas as pd

import sys
import os
import time

import tensorflow as tf
from tensorflow import keras
print(tf.__version__)

2.0.0


tfrecord的文件格式
tfrecord中存储的是:   
-> tf.train.Example   
  -> tf.train.Features -> {"key": tf.train.feature}    
    -> tf.train.feature -> tf.train.ByteList(存储文本)/tf.train.FloatList(存储浮点数)/tf.train.Int64List(存储整型)

In [3]:
favorite_books = [name.encode("utf-8") for name in 
                  ["deep learning", "machine learning"]]
favorite_books_bytelist = tf.train.BytesList(value=favorite_books)
print(favorite_books_bytelist)

hours_floatlist = tf.train.FloatList(value=[1.2, 9.4, 42.7, 16.8])
print(hours_floatlist)

age_list = tf.train.Int64List(value=[16])
print(age_list)

features = tf.train.Features(feature = {
    "favorite_books": tf.train.Feature(bytes_list=favorite_books_bytelist),
    "hour": tf.train.Feature(float_list=hours_floatlist),
    "age": tf.train.Feature(int64_list=age_list)
})
print(features)

value: "deep learning"
value: "machine learning"

value: 1.2000000476837158
value: 9.399999618530273
value: 42.70000076293945
value: 16.799999237060547

value: 16

feature {
  key: "age"
  value {
    int64_list {
      value: 16
    }
  }
}
feature {
  key: "favorite_books"
  value {
    bytes_list {
      value: "deep learning"
      value: "machine learning"
    }
  }
}
feature {
  key: "hour"
  value {
    float_list {
      value: 1.2000000476837158
      value: 9.399999618530273
      value: 42.70000076293945
      value: 16.799999237060547
    }
  }
}



In [4]:
example = tf.train.Example(features=features)
print(example)

serialized_example = example.SerializeToString()
print(serialized_example)

features {
  feature {
    key: "age"
    value {
      int64_list {
        value: 16
      }
    }
  }
  feature {
    key: "favorite_books"
    value {
      bytes_list {
        value: "deep learning"
        value: "machine learning"
      }
    }
  }
  feature {
    key: "hour"
    value {
      float_list {
        value: 1.2000000476837158
        value: 9.399999618530273
        value: 42.70000076293945
        value: 16.799999237060547
      }
    }
  }
}

b'\nc\n5\n\x0efavorite_books\x12#\n!\n\rdeep learning\n\x10machine learning\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01\x10\n\x1c\n\x04hour\x12\x14\x12\x12\n\x10\x9a\x99\x99?ff\x16A\xcd\xcc*Bff\x86A'


In [5]:
output_dir = "tfrecord-basic"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
file_name = "test.tfrecords"
filename_fullpath = os.path.join(output_dir, file_name)
with tf.io.TFRecordWriter(filename_fullpath) as writter:
    for i in range(3):
        writter.write(serialized_example)

利用tf.data api读取tfrecord文件

In [6]:
dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialize_example_tensor in dataset:
    print(serialize_example_tensor)

tf.Tensor(b'\nc\n5\n\x0efavorite_books\x12#\n!\n\rdeep learning\n\x10machine learning\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01\x10\n\x1c\n\x04hour\x12\x14\x12\x12\n\x10\x9a\x99\x99?ff\x16A\xcd\xcc*Bff\x86A', shape=(), dtype=string)
tf.Tensor(b'\nc\n5\n\x0efavorite_books\x12#\n!\n\rdeep learning\n\x10machine learning\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01\x10\n\x1c\n\x04hour\x12\x14\x12\x12\n\x10\x9a\x99\x99?ff\x16A\xcd\xcc*Bff\x86A', shape=(), dtype=string)
tf.Tensor(b'\nc\n5\n\x0efavorite_books\x12#\n!\n\rdeep learning\n\x10machine learning\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01\x10\n\x1c\n\x04hour\x12\x14\x12\x12\n\x10\x9a\x99\x99?ff\x16A\xcd\xcc*Bff\x86A', shape=(), dtype=string)


从csv文件中生成tfrecord

In [7]:
source_dir = 'generate csv'
def get_data_by_prefix(source_dir, prefix):
    results = []
    for file in os.listdir(source_dir):
        if file.startswith(prefix):
            results.append(os.path.join(source_dir, file))
    return results

train_filenames = get_data_by_prefix(source_dir, prefix="train")
val_filenames = get_data_by_prefix(source_dir, prefix="val")
test_filenames = get_data_by_prefix(source_dir, prefix="test")

print("train:", train_filenames)

train: ['generate csv\\train_00.csv', 'generate csv\\train_01.csv', 'generate csv\\train_02.csv', 'generate csv\\train_03.csv', 'generate csv\\train_04.csv', 'generate csv\\train_05.csv', 'generate csv\\train_06.csv', 'generate csv\\train_07.csv', 'generate csv\\train_08.csv', 'generate csv\\train_09.csv', 'generate csv\\train_10.csv', 'generate csv\\train_11.csv', 'generate csv\\train_12.csv', 'generate csv\\train_13.csv', 'generate csv\\train_14.csv', 'generate csv\\train_15.csv', 'generate csv\\train_16.csv', 'generate csv\\train_17.csv', 'generate csv\\train_18.csv', 'generate csv\\train_19.csv']


In [10]:
 def parse_csv_line(line, n_fields=9):
        defs = [tf.constant(np.nan)] * n_fields
        parse_line = tf.io.decode_csv(line, record_defaults=defs)
        x = tf.stack(parse_line[:-1])
        y = tf.stack(parse_line[-1:])
        return x, y
    
def csv_reader_dataset(filenames, n_readers=5, batch_size=32, n_parse_thread=5,
                      shuffler_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()
    dataset = dataset.interleave(
        lambda filename: tf.data.TextLineDataset(filename).skip(1),
                                cycle_length=n_readers)
    dataset.shuffle(shuffler_buffer_size)
    dataset = dataset.map(parse_csv_line, num_parallel_calls=n_parse_thread)
    dataset = dataset.batch(batch_size)
    return dataset

train_set = csv_reader_dataset(train_filenames)
val_set = csv_reader_dataset(val_filenames)
test_set = csv_reader_dataset(test_filenames)

In [13]:
def serialize_example(x, y):
    """将从csv文件中读取的样本转化成tf.data.Example格式并且将其序列化"""
    input_features = tf.train.FloatList(value=x)
    label = tf.train.FloatList(value=y)
    features = tf.train.Features(feature={
        "input_features": tf.train.Feature(float_list=input_features),
        "label": tf.train.Feature(float_list=label) 
    })
    example = tf.train.Example(features=features)
    return example.SerializeToString()

def csv_dataset_to_tfrecord(base_filenames, dataset, n_shards, step_per_shard,
                           compression_type=None):
    options = tf.io.TFRecordOptions(compression_type=compression_type)
    all_flienames = []
    for shard in range(n_shards):
        filename_fullpath = "{}{:05d}-of-{:05d}".format(base_filenames, shard, n_shards)
        with tf.io.TFRecordWriter(os.path.join(filename_fullpath), options) as writter:
            for x_batch, y_batch in dataset.take(step_per_shard):
                for x_example, y_example in zip(x_batch, y_batch):
                    writer.writer(serialize_example(x_example, y_example))
        all_flienames.append(filename_fullpath)
    return all_flienames

ERROR! Session/line number was not unique in database. History logging moved to new session 121


In [None]:
n_shards = 20
batch_size = 32
train_step_per_shard = 11610 // batch_size // n_shards
val_step_per_shard = 3880 // batch_size // n_shards
test_step_per_shard = 5170 // batch_size // n_shards

output_dir = "generate_tfrecords"
if os.path.exists(output_dir):
    os.mkdir(output_dir)

train_basename = os.path.join(output_dir, "train")
val_basename = os.path.join(output_dir, "val")
test_basename = os.path.join(output_dir, "test")

train_tgrecords_filenames = csv_dataset_to_tfrecord(base_filenames=train_basename,
                                                    dataset=train_set,
                                                    n_shards=n_shards,
                                                    step_per_shard=train_step_per_shard,
                                                    compression_type=None)
val_tgrecords_filenames = csv_dataset_to_tfrecord(base_filenames=val_basename,
                                                    dataset=val_set,
                                                    n_shards=n_shards,
                                                    step_per_shard=val_step_per_shard,
                                                    compression_type=None)
test_tgrecords_filenames = csv_dataset_to_tfrecord(base_filenames=test_filenames,
                                                    dataset=test_set,
                                                    n_shards=n_shards,
                                                    step_per_shard=test_step_per_shard,
                                                    compression_type=None)

In [9]:
expret_features = {
    "input_features": tf.io.FixedLenFeature([8], dtype=tf.float32),
    "label": tf.io.FixedLenFeature([1], dtype=tf.float32)
}

def parse_example(serialized_example):
    example = tf.io.serialize_single_sparse(serialize_example, expret_features)
    return example["input_features"], example["label"]

def tfrecord_reader_dataset(filenames, n_readers=5, batch_size=32, n_parse_thread=5,
                      shuffler_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()
    dataset = dataset.interleave(
        lambda filename: tf.data.TFRecordDataset(filename),
                                cycle_length=n_readers)
    dataset.shuffle(shuffler_buffer_size)
    dataset = dataset.map(parse_exampl, num_parallel_calls=n_parse_thread)
    dataset = dataset.batch(batch_size)
    return dataset

ERROR! Session/line number was not unique in database. History logging moved to new session 115


In [9]:
train_set = tfrecord-tfrecord_reader_dataset(train_tgrecords_filenames)
val_set = tfrecord_reader_dataset(val_tgrecords_filenames)
test_set = tfrecord_reader_dataset(test_tgrecords_filenames)
model = keras.models.Sequential([
    keras.layers.Input(shape=[8]),
    keras.layers.Dense(300, activation='relu'),
    keras.layers.Dense(300, activation="relu"),
    keras.layers.Dense(1)
])
model.compile("adam",loss='mse')
history = model.fit(train_set,
                epochs=10,
                steps_per_epoch = 11160 // 32,
                validation_steps = 3870 // 32,
                validation_data=val_set)

NameError: name 'train_files' is not defined

ERROR! Session/line number was not unique in database. History logging moved to new session 118
