In [1]:
import tensorflow as tf
import os
import time
import alexnet_model
 
imageWidth = 227
imageHeight = 227
imageDepth = 3
batch_size = 128
resize_min = 256
 
# Parse TFRECORD and distort the image for train
def _parse_function(example_proto):
    features = {"image": tf.FixedLenFeature([], tf.string, default_value=""),
                "height": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "width": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "channels": tf.FixedLenFeature([1], tf.int64, default_value=[3]),
                "colorspace": tf.FixedLenFeature([], tf.string, default_value=""),
                "img_format": tf.FixedLenFeature([], tf.string, default_value=""),
                "label": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "bbox_xmin": tf.VarLenFeature(tf.float32),
                "bbox_xmax": tf.VarLenFeature(tf.float32),
                "bbox_ymin": tf.VarLenFeature(tf.float32),
                "bbox_ymax": tf.VarLenFeature(tf.float32),
                "text": tf.FixedLenFeature([], tf.string, default_value=""),
                "filename": tf.FixedLenFeature([], tf.string, default_value="")
               }
    parsed_features = tf.parse_single_example(example_proto, features)
    
    xmin = tf.expand_dims(parsed_features["bbox_xmin"].values, 0)
    xmax = tf.expand_dims(parsed_features["bbox_xmax"].values, 0)
    ymin = tf.expand_dims(parsed_features["bbox_ymin"].values, 0)
    ymax = tf.expand_dims(parsed_features["bbox_ymax"].values, 0)
    
    bbox = tf.concat(axis=0, values=[ymin, xmin, ymax, xmax])
    bbox = tf.expand_dims(bbox, 0)
    bbox = tf.transpose(bbox, [0, 2, 1])
    
    height = parsed_features["height"]
    width = parsed_features["width"]
    channels = parsed_features["channels"]
 
    bbox_begin, bbox_size, bbox_for_draw = tf.image.sample_distorted_bounding_box(
        tf.concat(axis=0, values=[height, width, channels]),
        bounding_boxes=bbox,
        min_object_covered=0.1,
        use_image_if_no_bounding_boxes=True)
 
    # Reassemble the bounding box in the format the crop op requires.
    offset_y, offset_x, _ = tf.unstack(bbox_begin)
    target_height, target_width, _ = tf.unstack(bbox_size)
    crop_window = tf.cast(tf.stack([offset_y, offset_x, target_height, target_width]), tf.int32)
    
    # Use the fused decode and crop op here, which is faster than each in series.
    cropped = tf.image.decode_and_crop_jpeg(parsed_features["image"], crop_window, channels=3)
 
    # Flip to add a little more random distortion in.
    cropped = tf.image.random_flip_left_right(cropped)
    
    image_train = tf.image.resize_images(cropped, [imageHeight, imageWidth], 
                                         method=tf.image.ResizeMethod.BILINEAR,align_corners=False)
    
    image_train = tf.cast(image_train, tf.uint8)
    image_train = tf.image.convert_image_dtype(image_train, tf.float32)
    return image_train, parsed_features["label"][0], parsed_features["text"], parsed_features["filename"]
 
with tf.device('/cpu:0'):
    train_files_names = os.listdir('/home/Irving/AI/Alex/Data/train_tf/')
    train_files = ['/home/Irving/AI/Alex/Data/train_tf/'+item for item in train_files_names]
    dataset_train = tf.data.TFRecordDataset(train_files)
    dataset_train = dataset_train.map(_parse_function, num_parallel_calls=6)
    dataset_train = dataset_train.repeat(10)
    dataset_train = dataset_train.batch(batch_size)
    dataset_train = dataset_train.prefetch(batch_size)
    iterator = tf.data.Iterator.from_structure(dataset_train.output_types, dataset_train.output_shapes)
    next_images, next_labels, next_text, next_filenames = iterator.get_next()
    train_init_op = iterator.make_initializer(dataset_train)
 
def _parse_test_function(example_proto):
    features = {"image": tf.FixedLenFeature([], tf.string, default_value=""),
                "height": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "width": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "channels": tf.FixedLenFeature([1], tf.int64, default_value=[3]),
                "colorspace": tf.FixedLenFeature([], tf.string, default_value=""),
                "img_format": tf.FixedLenFeature([], tf.string, default_value=""),
                "label": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "bbox_xmin": tf.VarLenFeature(tf.float32),
                "bbox_xmax": tf.VarLenFeature(tf.float32),
                "bbox_ymin": tf.VarLenFeature(tf.float32),
                "bbox_ymax": tf.VarLenFeature(tf.float32),
                "text": tf.FixedLenFeature([], tf.string, default_value=""),
                "filename": tf.FixedLenFeature([], tf.string, default_value="")
               }
    parsed_features = tf.parse_single_example(example_proto, features)
    image_decoded = tf.image.decode_jpeg(parsed_features["image"], channels=3)
    shape = tf.shape(image_decoded)
    height, width = shape[0], shape[1]
    resized_height, resized_width = tf.cond(height<width,
        lambda: (resize_min, tf.cast(tf.multiply(tf.cast(width, tf.float64),tf.divide(resize_min,height)), tf.int32)),
        lambda: (tf.cast(tf.multiply(tf.cast(height, tf.float64),tf.divide(resize_min,width)), tf.int32), resize_min))
    image_resized = tf.image.resize_images(image_decoded, [resized_height, resized_width])
    image_resized = tf.cast(image_resized, tf.uint8)
    image_resized = tf.image.convert_image_dtype(image_resized, tf.float32)
    
    # calculate how many to be center crop
    shape = tf.shape(image_resized)  
    height, width = shape[0], shape[1]
    amount_to_be_cropped_h = (height - imageHeight)
    crop_top = amount_to_be_cropped_h // 2
    amount_to_be_cropped_w = (width - imageWidth)
    crop_left = amount_to_be_cropped_w // 2
    image_valid = tf.slice(image_resized, [crop_top, crop_left, 0], [imageHeight, imageWidth, -1])
    return image_valid, parsed_features["label"][0], parsed_features["text"], parsed_features["filename"]
 
with tf.device('/cpu:0'):
    valid_files_names = os.listdir('/home/Irving/AI/Alex/Data/valid_tf/')
    valid_files = ['/home/Irving/AI/Alex/Data/valid_tf/'+item for item in valid_files_names]
    dataset_valid = tf.data.TFRecordDataset(valid_files)
    dataset_valid = dataset_valid.map(_parse_test_function, num_parallel_calls=6)
    dataset_valid = dataset_valid.batch(batch_size)
    dataset_valid = dataset_valid.prefetch(batch_size)
    iterator_valid = tf.data.Iterator.from_structure(dataset_valid.output_types, dataset_valid.output_shapes)
    next_valid_images, next_valid_labels, next_valid_text, next_valid_filenames = iterator_valid.get_next()
    valid_init_op = iterator_valid.make_initializer(dataset_valid)
 
global_step = tf.Variable(0, trainable=False)
epoch_steps = int(1281167/batch_size)
boundaries = [50000,80000,100000]
values = [0.00005,0.00001,0.000005,0.000001]
learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
lr_summary = tf.summary.scalar('learning_rate', learning_rate)
 
result = alexnet_model.inference(next_images, dropout_rate=0.5, wd=0.00005)
output_result_scores = tf.nn.softmax(result)
output_result = tf.argmax(output_result_scores, 1)
 
#Calculate the cross entropy loss
cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=next_labels, logits=result)
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
tf.add_to_collection('losses', cross_entropy_mean)
 
#Add the l2 weights to the loss
loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
loss_summary = tf.summary.scalar('loss', loss)
 
#Define the optimizer
#opt_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
opt_op = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step) 

 
#Get the inference logits by the model for the validation images
result_valid = alexnet_model.inference(next_valid_images, dropout_rate=0.5, wd=None)
output_valid_scores = tf.nn.softmax(result_valid)
output_valid_result = tf.argmax(output_valid_scores, 1)
accuracy_valid_batch = tf.reduce_mean(tf.cast(tf.equal(next_valid_labels, tf.argmax(output_valid_scores, 1)), tf.float32))
accuracy_valid_top_5 = tf.reduce_mean(tf.cast(tf.nn.in_top_k(output_valid_scores, next_valid_labels, k=5), tf.float32))
acc_1_summary = tf.summary.scalar('accuracy_valid_top_1', accuracy_valid_batch)
acc_2_summary = tf.summary.scalar('accuracy_valid_top_5', accuracy_valid_top_5)
 
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
 
with tf.Session() as sess:
    #saver.restore(sess, "model/model.ckpt-5000")
    sess.run(tf.global_variables_initializer())
    sess.run([global_step, train_init_op, valid_init_op])
    total_loss = 0.0
    epoch = 0
    starttime = time.time()
    while(True):
        try:
            loss_t, lr, step, _ = sess.run([loss, learning_rate, global_step, opt_op])
            total_loss += loss_t
            
            if step%100==0:
                print("step: %i, Learning_rate: %f, Time: %is Loss: %f"%(step, lr, int(time.time()-starttime), total_loss/100))
                total_loss = 0.0
                starttime = time.time()
            
            if step%5000==0:
                save_path = saver.save(sess, "/home/Irving/AI/Alex/Data/model/model.ckpt", global_step=global_step)
                truepredict = 0.0
                truepredict_top5 = 0.0
                valid_count = 0
                while(True):
                    try:
                        acc_valid_1, acc_valid_5, valid_result_t = sess.run([accuracy_valid_batch, accuracy_valid_top_5, output_valid_result])
                        truepredict += acc_valid_1
                        truepredict_top5 += acc_valid_5
                        valid_count += 1
                    except tf.errors.OutOfRangeError:
                        print("valid accuracy of top 1: %f" % (truepredict/valid_count))
                        print("valid accuracy of top 5: %f" % (truepredict_top5/valid_count))
                        break
                starttime = time.time()
                sess.run([valid_init_op])
          
        except tf.errors.OutOfRangeError:
            break


W0823 09:42:57.134424 140020895729408 deprecation.py:323] From <ipython-input-1-f8954d892f91>:47: sample_distorted_bounding_box (from tensorflow.python.ops.image_ops_impl) is deprecated and will be removed in a future version.
Instructions for updating:
`seed2` arg is deprecated.Use sample_distorted_bounding_box_v2 instead.
W0823 09:42:57.179473 140020895729408 deprecation.py:323] From <ipython-input-1-f8954d892f91>:75: DatasetV1.output_types (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.data.get_output_types(dataset)`.
W0823 09:42:57.179893 140020895729408 deprecation.py:323] From <ipython-input-1-f8954d892f91>:75: DatasetV1.output_shapes (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.data.get_output_shapes(dataset)`.
W0823 09:42:57.182478 140020895729408 deprecation.py:323] From /home/Irv

step: 100, Learning_rate: 0.000050, Time: 9s Loss: 13.847442


KeyboardInterrupt: 