# Main training function

This file contains training and testing flow.

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import numpy as np
import tensorflow as tf
import input_data
import model
import model_layer
import subprocess

In [None]:
N_CLASSES = 3
IMG_W = 224  # resize the image. if input image is too large, training will be very slow.
IMG_H = 224
BATCH_SIZE = 64
CAPACITY = 2000
MAX_STEP = 7000
learning_rate = 0.001 

## Define training flow

Every 50 steps, it will show current loss and accuracy in terminal.
Every 2000 steps, it will save current weight and parameters.

In [None]:
def run_training():
    
    train_dir = '../pattern/2500/00627train'
    logs_train_dir = 'logs'
    
    train, train_label = input_data.get_files(train_dir)
    
    train_batch, train_label_batch = input_data.get_batch(train,
                                                          train_label,
                                                          IMG_W,
                                                          IMG_H,
                                                          BATCH_SIZE, 
                                                          CAPACITY)      
    train_logits = model_layer.inference(train_batch, BATCH_SIZE, N_CLASSES, 1)
    train_loss = model.losses(train_logits, train_label_batch)        
    train_op = model.trainning(train_loss, learning_rate)
    train__acc = model.evaluation(train_logits, train_label_batch)
       
    #summary_op = tf.summary.merge_all()
    sess = tf.Session()
    #train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)
    saver = tf.train.Saver()
    
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    try:
        for step in np.arange(MAX_STEP):
            if coord.should_stop():
                    break
            _, tra_loss, tra_acc = sess.run([train_op, train_loss, train__acc])
               
            if step % 50 == 0:
                print('Step %d, train loss = %.2f, train accuracy = %.2f%%' %(step, tra_loss, tra_acc*100.0))
                #summary_str = sess.run(summary_op)
                #train_writer.add_summary(summary_str, step)
            
            if step % 2000 == 0 or (step + 1) == MAX_STEP:
                checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
                
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        coord.request_stop()
        
    coord.join(threads)
    sess.close()


## Define testing flow

Test all images in test_dir with saved model.

1. While testing, it will copy error images to relative directory. For example, if the GT is 'nail' and the model's prediction is 'smear', this error image will be copied to 'result/error/nail_error/smear_nail/'.
2. Note that this isn't a good method that repeat loading model for each test image. A better way is load once and test all.(The only reason I wrote in this way is I copied and modified it from another code. And when I modified, I didn't have enough time to improve it. So it's still waitting for someone to improve it XD)

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
LABELS = { 'nail':0, 'scratch':1 ,'smear':2 }
GROUNDTRUTH = -1

In [None]:
def test_all_img_in_dirs():

    test_dir = '../pattern/2500/00627test'
    test, test_label = input_data.get_files(test_dir)

    global GROUNDTRUTH    
    count=0
    nail_error=0
    nail_cor=0
    scratch_error=0
    scratch_cor=0
    smear_error=0
    smear_cor=0

    n = len(test)
    for index in range(n):       
        #ind = np.random.randint(0, n)
        img_dir = test[index]
        print("#----------------------------")
        print('image: ' + img_dir)
        for key,value in LABELS.items():
            if key in img_dir:
                GROUNDTRUTH = value
        image = Image.open(img_dir)
        #plt.imshow(image)
        image_array = np.array(image)
        
 
        with tf.Graph().as_default():
            BATCH_SIZE = 1
            N_CLASSES = 3
    
            image = tf.cast(image_array, tf.float32)
            image = tf.image.resize_image_with_crop_or_pad(image, 224, 224)
            image = tf.image.per_image_standardization(image)
            image = tf.reshape(image, [1, 224, 224, 3])
            logit = model_layer.inference(image, 1, N_CLASSES, 0)
        
            logit = tf.nn.softmax(logit)

            logs_test_dir = 'logs' 
            #logs_test_dir = 'logs_stru6_0.9521_newdata' 
            #logs_test_dir = 'logs_stru6_noBN_0.911311' 
                       
            saver = tf.train.Saver()
        
            with tf.Session() as sess:
            
                print("Reading checkpoints...")
                ckpt = tf.train.get_checkpoint_state(logs_test_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    print('Loading success, global_step is %s' % global_step)
                else:
                    print('No checkpoint file found')
            
                prediction = sess.run(logit)
                max_index = np.argmax(prediction)
                print('%.6f %.6f %.6f' %(prediction[:, 0],prediction[:, 1],prediction[:, 2]))
                if max_index==0:
                    print('This is a nail defect with possibility %.6f' %prediction[:, 0])
                elif max_index == 1:
                    print('This is a scratch defect with possibility %.6f' %prediction[:, 1])
                else:
                    print('This is a smear defect with possibility %.6f' %prediction[:, 2])

                #global count, nail_error, scratch_error, smear_error, nail_cor, scratch_cor, smear_cor
                if max_index == GROUNDTRUTH:
                    count += 1
                    if max_index == 0:
                        nail_cor += 1
                    elif max_index == 1:
                        scratch_cor += 1
                    elif max_index == 2:
                        smear_cor += 1
                elif GROUNDTRUTH == 0:
                    nail_error += 1
                    #print('nail_error_image: ' + img_dir)
                    if max_index == 1: # scratch_nail
                        cp_wrong_sample('scratch', 'nail', img_dir)
                    elif max_index == 2: # smear_nail
                        cp_wrong_sample('smear', 'nail', img_dir)
                elif GROUNDTRUTH == 1:
                    scratch_error += 1
                    #print('scratch_error_image: ' + img_dir)
                    if max_index == 0: # nail_scratch
                        cp_wrong_sample('nail', 'scratch', img_dir)
                    elif max_index == 2: # smear_scratch
                        cp_wrong_sample('smear', 'scratch', img_dir)
                elif GROUNDTRUTH == 2:
                    smear_error += 1
                    #print('smear_error_image: ' + img_dir)
                    if max_index == 0: # nail_smear
                        cp_wrong_sample('nail', 'smear', img_dir)
                    elif max_index == 1: # scratch_smear
                        cp_wrong_sample('scratch', 'smear', img_dir)

    print("accuracy: %f" %(float(count/n)))
    print("nail    ---error: %d, ---correct: %d"    % (nail_error, nail_cor))
    print("scratch ---error: %d, ---correct: %d"    % (scratch_error, scratch_cor))
    print("smear   ---error: %d, ---correct: %d"    % (smear_error, smear_cor))


def cp_wrong_sample(err, cor, img_file):
    commend = 'cp ' + img_file + ' result/error/' + cor + '_error/' + err + '_' + cor +'/'
    subprocess.call(commend, shell=True)


## Start to do training and inference !

In [None]:
run_training()

test_all_img_in_dirs()