## 存储



In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='2'
import tensorflow as tf
from tensorflow.python_io import TFRecordWriter
from tqdm import tqdm
import tensorflow.contrib.eager as tfe

tfe.enable_eager_execution()

output_dir = './data'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)


In [2]:
def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def int_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def feature_to_example(img, label):
    img = img.tostring() # numpy to string
    return tf.train.Example(
        features=tf.train.Features(
            feature={
                'img': bytes_feature(img),
                'label': int_feature(label)
            }
        )
    )

def create_tfrecords(output_dir, split='train'):
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() # return numpy arrays
    x, y = (x_train, y_train) if split=='train' else (x_test, y_test)
    
    output_path = os.path.join(output_dir, "mnist_{}.tfrecord".format(split))
    with TFRecordWriter(output_path) as writer:
        for (img, label) in tqdm(zip(x, y)):
            mnist_example = feature_to_example(img, label)
            writer.write(mnist_example.SerializeToString())
            
create_tfrecords(output_dir, split='train')
create_tfrecords(output_dir, split='val')

60000it [00:02, 21008.05it/s]
10000it [00:00, 20997.06it/s]


## 读取

In [3]:
def parse_example(example_proto):
    feature_dict = {
        'img': tf.FixedLenFeature([], tf.string),
        'label': tf.FixedLenFeature([], tf.int64)
    }
    
    features = tf.parse_single_example(example_proto, features=feature_dict)
    
    # origional format is uint8
    img = tf.decode_raw(features['img'], out_type=tf.uint8)
    img = tf.cast(img, tf.float32)
    
    label = features['label']
    label = tf.cast(label, tf.int32)
    
    return (img, label)

def load_data(data_dir, split='train', batch_size=64, epochs=1):
    data_path = os.path.join(data_dir, "mnist_{}.tfrecord".format(split))
    dataset = tf.data.TFRecordDataset(data_path)
    
    if split == 'train':
        dataset = dataset.shuffle(60000, reshuffle_each_iteration=True)
    else:
        dataset = dataset.shuffle(10000, reshuffle_each_iteration=True)
        
    dataset.repeat(epochs)
    
    dataset = dataset.map(parse_example)
    
    dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
    
    #dataset = dataset.prefetch(1) # not allowed in eager mode

    return dataset

在使用TensorFlow过程中可能会遇到的其他数据格式，建议参考PSCAL数据集。该数据集中包含了几乎所有常见的数据形式  
[https://github.com/tensorflow/models/blob/master/research/object_detection/dataset_tools/create_pascal_tf_record.py](https://github.com/tensorflow/models/blob/master/research/object_detection/dataset_tools/create_pascal_tf_record.py)

```python
example = tf.train.Example(features=tf.train.Features(feature={
      'image/height': dataset_util.int64_feature(height),
      'image/width': dataset_util.int64_feature(width),
      'image/filename': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
      'image/source_id': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
      'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
      'image/encoded': dataset_util.bytes_feature(encoded_jpg),
      'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
      'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
      'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
      'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
      'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
      'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
      'image/object/class/label': dataset_util.int64_list_feature(classes),
      'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
      'image/object/truncated': dataset_util.int64_list_feature(truncated),
      'image/object/view': dataset_util.bytes_list_feature(poses),
  }))
  return example
```
其中`dataset_util`相关代买如下
```python
def int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def int64_list_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def bytes_feature(value): 
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def bytes_list_feature(value): 
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def float_list_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

```


## Eager Execution

eager 模式中不需要iterator

In [4]:
train_ds = load_data(output_dir, split='train')
test_ds = load_data(output_dir, split='val')

#iterator = dataset.make_one_shot_iterator()

In [5]:
class MNISTModel(tf.keras.Model):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self._input_shape = [-1, 28, 28, 1]
        self.conv1 = tf.layers.Conv2D(32, 5,
                                  padding='same',
                                  activation=tf.nn.relu)
        self.max_pool2d = tf.layers.MaxPooling2D((2, 2), (2, 2), padding='same')
        self.conv2 = tf.layers.Conv2D(64, 5,
                                      padding='same',
                                      activation=tf.nn.relu)
        self.fc1 = tf.layers.Dense(750, activation=tf.nn.relu)
        self.dropout = tf.layers.Dropout(0.5)
        self.fc2 = tf.layers.Dense(10)
    
    def call(self, x):
        x = tf.reshape(x, self._input_shape)
        x = self.max_pool2d(self.conv1(x))
        x = self.max_pool2d(self.conv2(x))
        x = tf.layers.flatten(x)
        x = self.dropout(self.fc1(x))
        return self.fc2(x)

In [6]:
def loss_fn(model, x, y):
    return tf.reduce_mean(
      tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=model(x), labels=y))

def get_accuracy(model, x, y_true):
    logits = model(x)
    prediction = tf.argmax(logits, 1)
    equality = tf.equal(prediction, tf.cast(y_true, tf.int64))
    accuracy = tf.reduce_mean(tf.cast(equality, tf.float32))
    return accuracy

In [7]:
model = MNISTModel()
optimizer = tf.train.AdamOptimizer()
epochs = 1000
for (batch, (images, labels)) in enumerate(train_ds):
    with tfe.GradientTape() as tape:
        loss = loss_fn(model, images, labels)
    
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(zip(grads, model.variables), global_step=tf.train.get_or_create_global_step())
    if batch % 10 == 0:
        acc = get_accuracy(model, images, labels).numpy()
        print("Iteration {}, loss: {:.3f}, train accuracy: {:.2f}%".format(batch, loss_fn(model, images, labels).numpy(), acc*100))
    if batch > epochs:
        break 

Iteration 0, loss: 83.327, train accuracy: 17.19%
Iteration 10, loss: 1.934, train accuracy: 53.12%
Iteration 20, loss: 0.938, train accuracy: 75.00%
Iteration 30, loss: 0.478, train accuracy: 82.81%
Iteration 40, loss: 0.215, train accuracy: 93.75%
Iteration 50, loss: 0.275, train accuracy: 93.75%
Iteration 60, loss: 0.274, train accuracy: 89.06%
Iteration 70, loss: 0.441, train accuracy: 84.38%
Iteration 80, loss: 0.305, train accuracy: 92.19%
Iteration 90, loss: 0.067, train accuracy: 98.44%
Iteration 100, loss: 0.187, train accuracy: 93.75%
Iteration 110, loss: 0.139, train accuracy: 98.44%
Iteration 120, loss: 0.159, train accuracy: 96.88%
Iteration 130, loss: 0.183, train accuracy: 93.75%
Iteration 140, loss: 0.193, train accuracy: 96.88%
Iteration 150, loss: 0.103, train accuracy: 98.44%
Iteration 160, loss: 0.111, train accuracy: 95.31%
Iteration 170, loss: 0.051, train accuracy: 98.44%
Iteration 180, loss: 0.181, train accuracy: 93.75%
Iteration 190, loss: 0.158, train accurac

In [8]:
avg_acc = 0
test_epochs = 20
for (batch, (images, labels)) in enumerate(test_ds):
    avg_acc += get_accuracy(model, images, labels).numpy()
    if batch % 100 == 0 and batch != 0:
        print("Iteration:{}, Average test accuracy: {:.2f}%".format(batch, (avg_acc/batch)*100))
print("Final test accuracy: {:.2f}%".format(avg_acc/batch * 100))


Iteration:100, Average test accuracy: 98.92%
Final test accuracy: 98.56%
