In [1]:
from lab12_util import *

DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
DEST_DIRECTORY = 'dataset/cifar10'
DATA_DIRECTORY = DEST_DIRECTORY + '/cifar-10-batches-bin'
IMAGE_HEIGHT = 32
IMAGE_WIDTH = 32
IMAGE_DEPTH = 3
IMAGE_SIZE_CROPPED = 24
BATCH_SIZE = 128
NUM_CLASSES = 10 
LABEL_BYTES = 1
IMAGE_BYTES = 32 * 32 * 3
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000

# download it
maybe_download_and_extract(DEST_DIRECTORY, DATA_URL)

>> Done


In [2]:
from tensorflow.contrib.data import FixedLengthRecordDataset, Iterator

def cifar10_record_distort_parser(record):
    ''' Parse the record into label, cropped and distorted image
    -----
   Args:
        record: 
            a record containing label and image.
    Returns:
        label: 
            the label in the record.
        image: 
            the cropped and distorted image in the record.
    '''
    print("distort parser")
    height = IMAGE_SIZE_CROPPED
    width = IMAGE_SIZE_CROPPED
    
    
    record_uint8 = tf.decode_raw(record, tf.uint8)
    label = tf.cast(
      tf.strided_slice(record_uint8, [0], [LABEL_BYTES]), tf.int32)
    
    depth_major = tf.reshape(
      tf.strided_slice(record_uint8, [LABEL_BYTES],
                       [LABEL_BYTES + IMAGE_BYTES]),
      [IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH])    
    
    image = tf.transpose(depth_major, [1, 2, 0])
    float_image = tf.cast(image, tf.float32)
    
    distorted_image = tf.random_crop(float_image, [height, width, 3])
    distorted_image = tf.image.random_flip_left_right(distorted_image)
    distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
    distorted_image = tf.image.random_contrast(
      distorted_image, lower=0.2, upper=1.8)
    # standardization: subtract off the mean and divide by the variance of the pixels
    distorted_image = tf.image.per_image_standardization(distorted_image)
    # Set the shapes of tensors.
#     print(distorted_image)BATCH_SIZE
    distorted_image.set_shape([height, width, 3])
    
    return label, distorted_image

In [3]:
def cifar10_record_crop_parser(record):
    ''' Parse the record into label, cropped image
    -----
    Args:
        record: 
            a record containing label and image.
    Returns:
        label: 
            the label in the record.
        image: 
            the cropped image in the record.
    '''
    # TODO2
    # image preprocessing for training
    print("crop parser")
    height = IMAGE_SIZE_CROPPED
    width = IMAGE_SIZE_CROPPED
    
    record_uint8 = tf.decode_raw(record, tf.uint8)
    label = tf.cast(tf.strided_slice(record_uint8, [0], [LABEL_BYTES]), tf.int32)
    depth_major = tf.reshape(
      tf.strided_slice(record_uint8, [LABEL_BYTES],
                       [LABEL_BYTES + IMAGE_BYTES]),
      [IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH])
    image = tf.transpose(depth_major, [1, 2, 0])
    image = tf.cast(image , tf.float32)
    
    image = tf.image.resize_image_with_crop_or_pad(image,IMAGE_SIZE_CROPPED,IMAGE_SIZE_CROPPED)
    
    image = tf.image.per_image_standardization(image)
    
    image.set_shape([IMAGE_SIZE_CROPPED,IMAGE_SIZE_CROPPED,3])    
    return label,image

In [4]:
def cifar10_iterator(filenames, batch_size, cifar10_record_parser):
    ''' Create a dataset and return a tf.contrib.data.Iterator 
    which provides a way to extract elements from this dataset.
    -----
    Args:
        filenames: 
            a tensor of filenames.
        batch_size: 
            batch size.
    Returns:
        iterator: 
            an Iterator providing a way to extract elements from the created dataset.
        output_types: 
            the output types of the created dataset.
        output_shapes: 
            the output shapes of the created dataset.
    '''
    record_bytes = LABEL_BYTES + IMAGE_BYTES
    dataset = FixedLengthRecordDataset(filenames, record_bytes)
    # TODO3
    # tips: use dataset.map with cifar10_record_parser(record)
    #       output_types = dataset.output_types
    #       output_shapes = dataset.output_shapes
    
    
    for f in training_files:
        if not tf.gfile.Exists(f):
            raise ValueError('Failed to find file: ' + f)
    # create a queue that produces filenames to read
    # (4) filename queue
    

    # ensure a level of mixing of elements.
    dataset = dataset.map(cifar10_record_parser , num_threads = 16)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_initializable_iterator()
    return   iterator , dataset.output_types , dataset.output_shapes

In [5]:
tf.reset_default_graph()

training_files = [
    os.path.join(DATA_DIRECTORY, 'data_batch_%d.bin' % i) for i in range(1, 6)]
testing_files = [os.path.join(DATA_DIRECTORY, 'test_batch.bin')]

filenames_train = tf.constant(training_files)
filenames_test = tf.constant(testing_files)

iterator_train, types, shapes = cifar10_iterator(filenames_train, BATCH_SIZE,
                                                 cifar10_record_distort_parser)
iterator_test, _, _ = cifar10_iterator(filenames_test, BATCH_SIZE,
                                       cifar10_record_crop_parser)


print("handle")
# use to handle training and testing
iter_train_handle = iterator_train.string_handle()
iter_test_handle = iterator_test.string_handle()
handle = tf.placeholder(tf.string, shape=[])
iterator = Iterator.from_string_handle(handle, types, shapes)
labels_images_pairs = iterator.get_next()

print("CNN")
# CNN model
model = CNN_Model(
    batch_size=BATCH_SIZE,
    num_classes=NUM_CLASSES,
    num_training_example=NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN,
    num_epoch_per_decay=350.0,
    init_lr=0.1,
    moving_average_decay=0.9999)

with tf.device('/gpu:0'):
    labels, images = labels_images_pairs
    labels = tf.reshape(labels, [BATCH_SIZE])
    images = tf.reshape(
      images, [BATCH_SIZE, IMAGE_SIZE_CROPPED, IMAGE_SIZE_CROPPED, IMAGE_DEPTH])
    
with tf.variable_scope('model'):
    logits = model.inference(images)
# train
global_step = tf.contrib.framework.get_or_create_global_step()
total_loss = model.loss(logits, labels)
train_op = model.train(total_loss, global_step)
# test
top_k_op = tf.nn.in_top_k(logits, labels, 1)

distort parser
crop parser
handle
CNN


In [6]:
%%time
# TODO4:
# 1. train the CNN model 10 epochs
# 2. show the loss per epoch
# 3. get the accuracy of this 10-epoch model
# 4. measure the time using '%%time' instruction
# tips:
# use placeholder handle to determine if training or testing. 
NUM_EPOCH = 10
NUM_BATCH_PER_EPOCH = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN // BATCH_SIZE
print("TRAIN")
# train
saver = tf.train.Saver()
with tf.Session() as sess:
    
    handle_train = sess.run(iter_train_handle)
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    loss = []iter_train_handle
    
    for i in range(NUM_EPOCH):
        
        _loss = []
        sess.run(iterator_train.initializer)
        for _ in range(NUM_BATCH_PER_EPOCH):
            l, _ = sess.run([total_loss, train_op], feed_dict={handle:handle_train})
            _loss.append(l)
                
        loss_this_epoch = np.sum(_loss)
        gs = global_step.eval()
        print('loss of epoch %d: %f' % (gs / NUM_BATCH_PER_EPOCH, loss_this_epoch))
        loss.append(loss_this_epoch)
    
    coord.request_stop()
    coord.join(threads)
    save_path = saver.save(sess ,"checkpoint/model.ckpt" )

variables_to_restore = model.ema.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
    # Restore variables from disk.
    ckpt = tf.train.get_checkpoint_state('checkpoint')
    saver.restore(sess, ckpt.model_checkpoint_path)
    iterator_test_handle = sess.run(iterator_test.string_handle())
    sess.run(iterator_test.initializer)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    num_iter = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL // BATCH_SIZE
    total_sample_count = num_iter * BATCH_SIZE
    print(num_iter, total_sample_count)
    true_count = 0
    for _ in range(num_iter):
        predictions = sess.run([top_k_op], feed_dict={handle: iterator_test_handle})
        true_count += np.sum(predictions)
    print('Accurarcy: %d/%d = %f' % (true_count, total_sample_count,
                                     true_count / total_sample_count))
    coord.request_stop()
    coord.join(threads)

TRAIN
loss of epoch 1: 1519.314087
loss of epoch 2: 1190.297974
loss of epoch 3: 967.729614
loss of epoch 4: 810.423828
loss of epoch 5: 696.959595
loss of epoch 6: 616.739319
loss of epoch 7: 554.492676
loss of epoch 8: 509.921326
loss of epoch 9: 475.377686
loss of epoch 10: 449.348511
INFO:tensorflow:Restoring parameters from checkpoint\model.ckpt
78 9984
Accurarcy: 7561/9984 = 0.757312
Wall time: 1min 43s
