In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from read_data import *

classes_dict = {
    'by_country' : 4,
    'by_style' : 7,
    'by_product' : 2
}

## Set the classification problem

In [None]:
problem_type = 'by_style' 

## Adjust Hyperparameters

In [None]:
EPOCHS = 1
BATCH_SIZE = 8
VALIDATION_BATCH = 16

IMG_SIZE = 150
CLASSES = classes_dict[problem_type]

## Build the model and deploy it on a device

In [None]:
with tf.device('/cpu:0'):
    
    # decide the dataset input type
    is_training = tf.placeholder(tf.bool, name="is_training")
    
    # load training data from input queues     
    images_trn, labels_trn = inputs(problem_type, BATCH_SIZE, EPOCHS)
    
    # load validation data from feed dictionary
    images_val = tf.placeholder(tf.uint8, shape=[VALIDATION_BATCH, IMG_SIZE, IMG_SIZE, 3])
    labels_val = tf.placeholder(tf.int32, shape=[VALIDATION_BATCH,])
    
    # choose the input
    images = tf.cond(is_training, lambda: images_trn, lambda: images_val)
    labels = tf.cond(is_training, lambda: labels_trn, lambda: labels_val)

    # normalize the images     
    images = (tf.cast(images, tf.float32) / 255.0)
    # encode labels using one hot
    labels = tf.one_hot(labels, CLASSES)

## Create the session and start the threads for input queues

In [None]:
# create the veriable initializers
init_op = tf.group(tf.global_variables_initializer(),
                   tf.local_variables_initializer())

# create a session for running operations in the graph.
sess = tf.Session()

# initialize the variables
sess.run(init_op)

# start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

## Load validation data

In [None]:
# load validation data
images_validation = np.load(os.path.join('data_' + problem_type, 'testing_data.dat'))
labels_validation = np.load(os.path.join('data_' + problem_type, 'testing_labels.dat'))

## Training loop

In [None]:
try:
    step = 0
    
    # feed data until the epoch limit is reached     
    while not coord.should_stop():

        step += 1
        print("Step " + str(step), flush=True)
        
        
        img, label = sess.run([images, labels], feed_dict={
            is_training : False,
            images_val : images_validation[:16],
            labels_val : labels_validation[:16]
        })
        
        plt.imshow(img[0])
        plt.title(label[0])
        plt.show()
        
        
except tf.errors.OutOfRangeError:
    
    print('\nDone training -- epoch limit reached\n')
    
finally:
    
    # when done, ask the threads to stop
    coord.request_stop()

    # wait for threads to finish
    coord.join(threads)
    sess.close()