In [1]:
import os
import time
import numpy
import tensorflow as tf
import input_data
import vw_c3d_newnetwork
import vw_c3d_tools
import math
import numpy as np

BATCH_SIZE = 16
gpu_num = 1
#MAX_STEPS = 10000
NUM_FRAMES_PER_CLIP = 16
N_CLASSE = 9
CROP_SIZE = 112
CHANNELS = 3
MAX_EPOCHS = 100
INIT_LEARNINGRATE = 3e-3

train_log_dir = './logs//train//'
val_log_dir = './logs//val//'
model_dir = './models/'
model_filename = './models/'
is_finetune = False

def placeholder_inputs(batch_size):
    images_placeholder = tf.placeholder(tf.float32, shape=(BATCH_SIZE,
                                                           NUM_FRAMES_PER_CLIP,
                                                           CROP_SIZE,
                                                           CROP_SIZE,
                                                           CHANNELS))
    labels_placeholder = tf.placeholder(tf.int64, shape=[BATCH_SIZE, N_CLASSE])
    return images_placeholder, labels_placeholder

def train():
    with tf.Graph().as_default():
        #建立读取数据的pipeline，在训练过程中迭代器可以自动读取数据
        train_images_batch, train_labels_batch, _, _, _, train_total_num = input_data.read_clip_and_label(
                                                      filename='list/train.list',
                                                      batch_size=BATCH_SIZE * gpu_num,
                                                      num_frames_per_clip=NUM_FRAMES_PER_CLIP,
                                                      crop_size=CROP_SIZE,
                                                      shuffle=True
                                                      )
        val_images_batch, val_labels_batch, _, _, _, val_total_num= input_data.read_clip_and_label(
                                                      filename='list/test.list',
                                                      batch_size=BATCH_SIZE * gpu_num,
                                                      num_frames_per_clip=NUM_FRAMES_PER_CLIP,
                                                      crop_size=CROP_SIZE,
                                                      shuffle=True
                                                      )
 

        #定义学习率
        global_step = tf.Variable(0, name='global_step', trainable=False) 
        initial_learning_rate = INIT_LEARNINGRATE       #初始学习率
        learning_rate_decay_rate = 0.90                 #学习率衰减率
        step_of_epoch = math.floor(train_total_num / BATCH_SIZE)    #迭代完一次所有样本需要的步数
        learning_rate_decay_steps = step_of_epoch       #学习率衰减一次所需要的步数
        learning_rate = tf.train.exponential_decay(initial_learning_rate,
                    global_step, 
                    learning_rate_decay_steps,
                    learning_rate_decay_rate,staircase=True)              
        
        #将数据映射成tenor形式,并进行one-hot编码
        train_images_batch = tf.cast(train_images_batch,dtype=tf.float32)
        train_labels_batch = tf.one_hot(train_labels_batch, depth= N_CLASSE)
        train_labels_batch = tf.cast(train_labels_batch,dtype=tf.int32)
        val_images_batch = tf.cast(val_images_batch,dtype=tf.float32)
        val_labels_batch = tf.one_hot(val_labels_batch, depth= N_CLASSE)
        val_labels_batch = tf.cast(val_labels_batch,dtype=tf.int32)
        
        #构建logits、loss、accuracy、train_op
        logits = vw_c3d_newnetwork.C3D_MODEL(train_images_batch,N_CLASSE)
        loss = vw_c3d_tools.loss(logits, train_labels_batch)
        accuracy = vw_c3d_tools.accuracy(logits,train_labels_batch)
        train_op = vw_c3d_tools.optimize(loss,learning_rate,global_step)
        
        #构建数据和标签的place_holder 
        images_placeholder, labels_placeholder = placeholder_inputs(BATCH_SIZE * gpu_num)
        
        #准备存储模型
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        saver = tf.train.Saver(tf.global_variables())
        
        #设置绘图的op
        summary_op = tf.summary.merge_all()

        #初始化所有变量，并开启会话
        init = tf.global_variables_initializer()
        sess = tf.Session()
        sess.run(init)
        if is_finetune:
            saver.restore(sess, model_filename)
        
        #tensoboard绘图的writer
        tra_summary_writer = tf.summary.FileWriter(train_log_dir, sess.graph)
        val_summary_writer = tf.summary.FileWriter(val_log_dir, sess.graph)
        
        #打印一些基本信息
        MAX_STEPS = int(MAX_EPOCHS * step_of_epoch)
        print("MAX_STEPS: ",MAX_STEPS)
        print("MAX_EPOCHS: ", MAX_EPOCHS)
        print("step_of_epoch", step_of_epoch)
        print("Train samples: ", train_total_num)
        print("Val samples", val_total_num)
        
        for step in range(MAX_STEPS):
            start_time = time.time()
            tra_images,tra_labels = sess.run([train_images_batch,train_labels_batch])
            _, train_loss, train_accuracy = sess.run([train_op,loss,accuracy], feed_dict ={
                                                          images_placeholder: tra_images,
                                                          labels_placeholder: tra_labels
                                                          })
            duration = time.time() - start_time
            print('Step %d: %.3f sec' % (step, duration))
            if step % 10 == 0 or (step + 1) == MAX_STEPS:
                print ('Step: %d, train_loss: %.4f, train_accuracy: %.4f%%' % (step, train_loss, train_accuracy))
                summary_str = sess.run(summary_op)
                tra_summary_writer.add_summary(summary_str, step)
            
            if step % 50 == 0 or (step + 1) == MAX_STEPS:
                val_images,val_labels=sess.run([val_images_batch,val_labels_batch])
                val_loss, val_acc = sess.run([loss, accuracy], feed_dict={
                                                      images_placeholder: val_images,
                                                      labels_placeholder: val_labels})
                summary_str = sess.run(summary_op)
                val_summary_writer.add_summary(summary_str, step)
                print('**  Step %d, val_loss = %.2f, val_accuracy = %.2f%%  **' %(step, val_loss, val_acc))
            if step % (step_of_epoch*3) == 0 or (step + 1) == MAX_STEPS:
                checkpoint_path = os.path.join(model_dir, 'model')
                saver.save(sess, checkpoint_path, global_step=step)
        print("Done!")

In [2]:
train()

MAX_STEPS:  7800
MAX_EPOCHS:  100
step_of_epoch 78
Train samples:  1260
Val samples 538
Step 0: 2.527 sec
Step: 0, train_loss: 5.5642, train_accuracy: 6.2500%
**  Step 0, val_loss = 5.21, val_accuracy = 6.25%  **
Step 1: 1.124 sec
Step 2: 1.128 sec
Step 3: 0.981 sec
Step 4: 1.006 sec
Step 5: 1.001 sec
Step 6: 1.067 sec
Step 7: 1.045 sec
Step 8: 0.991 sec
Step 9: 0.992 sec
Step 10: 0.992 sec
Step: 10, train_loss: 6.7070, train_accuracy: 43.7500%
Step 11: 1.022 sec
Step 12: 1.018 sec
Step 13: 1.153 sec
Step 14: 1.017 sec
Step 15: 1.027 sec
Step 16: 1.078 sec
Step 17: 1.040 sec
Step 18: 1.016 sec
Step 19: 1.127 sec
Step 20: 1.019 sec
Step: 20, train_loss: 7.7721, train_accuracy: 56.2500%
Step 21: 1.057 sec
Step 22: 1.014 sec
Step 23: 1.191 sec
Step 24: 1.133 sec
Step 25: 1.193 sec
Step 26: 1.004 sec
Step 27: 1.158 sec
Step 28: 1.082 sec
Step 29: 0.993 sec
Step 30: 1.053 sec
Step: 30, train_loss: 7.6158, train_accuracy: 75.0000%
Step 31: 1.061 sec
Step 32: 1.031 sec
Step 33: 1.197 sec
Step

KeyboardInterrupt: 