In [None]:
file = '/Users/liuyouru/Downloads/cifar-10-batches-py/cifar-10-batches-bin/data_batch_1.bin'

In [None]:
import os
import tensorflow as tf
import numpy as np

In [None]:
slim = tf.contrib.slim

In [None]:
MAX_STEP = 50000
BUFFER_SIZE = 256

IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224

BATCH_SIZE = 8

LEARNING_RATE = 0.001

In [None]:
filename = [file]

dataset = tf.data.FixedLengthRecordDataset(filename,32*32*3+1)

def parse(bin_example):
    
    decoded = tf.decode_raw(bin_example,out_type=tf.uint8)
    x = tf.reshape(decoded[1:],shape=(3,32,32))
    x = tf.transpose(x,perm=[1,2,0])
    x = tf.image.resize_images(x,[IMAGE_HEIGHT,IMAGE_WIDTH])
    x = tf.to_float(x)
    
    y = decoded[0]
    y = tf.to_int32(y)
    
    return x,y
    
dataset = dataset.map(parse)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.shuffle(BUFFER_SIZE)
dataset = dataset.repeat(-1)

iterator = dataset.make_one_shot_iterator()

next_batch = iterator.get_next()

In [None]:
x_op,y_op = next_batch

In [None]:
def vgg_16(inputs,num_classes,keep_prob = 0.5):
    
    end_point = {}
    
    with slim.arg_scope([slim.conv2d,slim.fully_connected],activation_fn=tf.nn.relu):
        
        # block1
        net = slim.repeat(inputs,2,slim.conv2d,64,[3,3],scope='conv1')
        net = slim.max_pool2d(net,[2,2],scope='pool1')
        
        end_point['block1'] = net
        
        # block2
        net = slim.repeat(net,2,slim.conv2d,128,[3,3],scope='conv2')
        net = slim.max_pool2d(net,[2,2],scope='pool2')
        
        end_point['block2'] = net
        
        # block3
        net = slim.repeat(net,3,slim.conv2d,256,[3,3],scope='conv3')
        net = slim.max_pool2d(net,[2,2],scope='pool3')
        
        end_point['block3'] = net
        
        # block4
        net = slim.repeat(net,3,slim.conv2d,512,[3,3],scope='conv4')
        net = slim.max_pool2d(net,[2,2],scope='pool4')
        
        end_point['block4'] = net
        
        # block5
        net = slim.repeat(net,3,slim.conv2d,512,[3,3],scope='conv5')
        net = slim.max_pool2d(net,[2,2],scope='pool5')
        
        end_point['block5'] = net
        
        net = slim.flatten(net,scope='flatten')
        
        net = slim.fully_connected(net,4096,scope='fc6')
        end_point['fc6'] = net
        net = slim.dropout(net,keep_prob=keep_prob,scope='fc6_drop')
        
        net = slim.fully_connected(net,4096,scope='fc7')
        end_point['fc7'] = net
        net = slim.dropout(net,keep_prob=keep_prob,scope='fc7_drop')
        
        net = slim.fully_connected(net,num_classes,activation_fn=None, scope='fc8')
        
        end_point['fc8'] = net
        
        return net,end_point

In [None]:
logtis,end_points = vgg_16(x_op,10)

In [None]:
logtis

In [None]:
end_points

In [None]:
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.to_int64(y_op),
                                                        logits=logtis,
                                                        name='cross_entropy')
loss = tf.reduce_mean(losses)

loss_summary = tf.summary.scalar('LOSS',loss)

correct_pred = tf.equal(tf.argmax(logtis,1),tf.to_int64(y_op))
accuracy_op = tf.reduce_mean(tf.cast(correct_pred,tf.float32))

accuracy_summary = tf.summary.scalar('ACCURACY',accuracy_op)

In [None]:
optimizer = tf.train.GradientDescentOptimizer(learning_rate=LEARNING_RATE)

train_op = optimizer.minimize(loss)

In [None]:
summary_op = tf.summary.merge_all(key=tf.GraphKeys.SUMMARIES)

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()

In [None]:
with tf.Session() as sess:
    
    summary_writer = tf.summary.FileWriter('/Users/liuyouru/Downloads/cifar-10-batches-py/logs',
                                          tf.get_default_graph())
    
    init_op.run()
    
    for step in range(MAX_STEP):
        _, cur_loss, cur_accuracy = sess.run([train_op,loss,accuracy_op])
        
        if step % 10 == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str,step)
            print('step = ',step,'loss = ',cur_loss,'accuracy',cur_accuracy)
            
        if step % 1000 == 0:
            saver.save(sess,'/Users/liuyouru/Downloads/cifar-10-batches-py/logs/model.ckpt')
            
    summary_writer.close()
        