In [None]:
# -*- coding: utf-8 -*-


import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt


def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict


height = 32
width = 32
channels = 3
n_inputs = height * width * channels

conv1_fmaps = 32
conv1_ksize = 3
conv1_stride = 1
conv1_pad = "SAME"

conv2_fmaps = 64
conv2_ksize = 3
conv2_stride = 1
conv2_pad = "SAME"

pool3_fmaps = conv2_fmaps

conv4_fmaps = 64
conv4_ksize = 3
conv4_stride = 2
conv4_pad = "SAME"

pool5_fmaps = conv4_fmaps

n_fc1 = 64
n_outputs = 10

tf.reset_default_graph()

with tf.name_scope("inputs"):
    X = tf.placeholder(tf.float32, shape=[None, n_inputs], name="X")
    X_reshaped = tf.reshape(X, shape=[-1, height, width, channels])
    y = tf.placeholder(tf.int32, shape=[None], name="y")

conv1 = tf.layers.conv2d(X_reshaped, filters=conv1_fmaps, kernel_size=conv1_ksize,
                         strides=conv1_stride, padding=conv1_pad,
                         activation=tf.nn.elu, name="conv1")
conv2 = tf.layers.conv2d(conv1, filters=conv2_fmaps, kernel_size=conv2_ksize,
                         strides=conv2_stride, padding=conv2_pad,
                         activation=tf.nn.elu, name="conv2")

pool3 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID")

conv4 = tf.layers.conv2d(pool3, filters=conv4_fmaps, kernel_size=conv4_ksize,
         strides=conv4_stride, padding=conv4_pad,
         activation=tf.nn.elu, name="conv4")

with tf.name_scope("pool5"):
    pool5 = tf.nn.max_pool(conv4, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID")
    pool5_flat = tf.reshape(pool5, shape=[-1, pool5_fmaps * 16])


with tf.name_scope("fc1"):
    fc1 = tf.layers.dense(pool5_flat, n_fc1, activation=tf.nn.relu, name="fc1")

with tf.name_scope("output"):
    logits = tf.layers.dense(fc1, n_outputs, name="output")
    Y_proba = tf.nn.softmax(logits, name="Y_proba")

with tf.name_scope("train"):
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y)
    loss = tf.reduce_mean(xentropy)
    optimizer = tf.train.AdamOptimizer()
    training_op = optimizer.minimize(loss)

with tf.name_scope("eval"):
    correct = tf.nn.in_top_k(logits, y, 1)
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

with tf.name_scope("init_and_save"):
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    
# load data
dat = unpickle("C:/Users/David/Documents/College2/cst495/finalproject/datasets/cifar-10-python/cifar-10-batches-py/data_batch_1")
dat.update(unpickle("C:/Users/David/Documents/College2/cst495/finalproject/datasets/cifar-10-python/cifar-10-batches-py/data_batch_2"))
dat.update(unpickle("C:/Users/David/Documents/College2/cst495/finalproject/datasets/cifar-10-python/cifar-10-batches-py/data_batch_2"))
dat.update(unpickle("C:/Users/David/Documents/College2/cst495/finalproject/datasets/cifar-10-python/cifar-10-batches-py/data_batch_2"))

test = unpickle("C:/Users/David/Documents/College2/cst495/finalproject/datasets/cifar-10-python/cifar-10-batches-py/test_batch")
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")

num_examples = len(dat[b'data'])
dat_reshaped, test_reshaped = [],[]
for img in dat[b'data']:
    dat_reshaped.append(np.transpose(np.reshape(img,(3, 32,32)), (1,2,0)))
for img in test[b'data']:
    test_reshaped.append(np.transpose(np.reshape(img,(3, 32,32)), (1,2,0)))

n_epochs = 10
batch_size = 500

with tf.Session() as sess:
    init.run()
    for epoch in range(n_epochs):
        for iteration in range(num_examples // batch_size):
            X_batch = dat[b'data'][iteration*batch_size:iteration*(batch_size+1)]
            y_batch = dat[b'labels'][iteration*batch_size:iteration*(batch_size+1)]
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
        acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
        acc_test  = accuracy.eval(feed_dict={X: test[b'data'], y: test[b'labels']})
        print(epoch, "Train accuracy:", acc_train)
        print(epoch, "Test accuracy:", acc_test)
        