In [0]:
import csv
import numpy as np
import tensorflow as tf
from keras import backend as K
from keras.models import Model
from keras.layers import Input, Dense, Conv2D, GlobalAveragePooling2D
from keras.datasets import cifar10, mnist, fashion_mnist
from keras.utils import to_categorical
from keras.callbacks import ModelCheckpoint
from utils import StateSpace, Controller


p_session = tf.Session()
K.set_session(p_session)

state_space = StateSpace()
state_space.add_state(name='kernel', values=[1, 3])
state_space.add_state(name='filters', values=[32, 48, 64])


(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
X_train = X_train.astype('float32') / 255.
X_test = X_test.astype('float32') / 255.
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

if K.image_data_format() == 'channels_first':
    X_train = X_train.reshape(X_train.shape[0], 1, 28, 28)
    X_test = X_test.reshape(X_test.shape[0], 1, 28, 28)
    print(X_train.shape)
    input_shape = (1, 28, 28)
else:
    X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
    X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)
    input_shape = (28, 28, 1)
    

acc = 0.0
total_reward = 0.0

class Agent:
    def __init__(self, data):
        self.data = data
        self.beta_bias = 0.8
        self.moving_acc = 0.0
    def get_rewards(self, Network, actions):
        with tf.Session(graph=tf.Graph()) as session:
            K.set_session(session)
            kernel_1, filters_1, kernel_2, filters_2, kernel_3, filters_3, kernel_4, filters_4 = actions
            inp = Input(shape=(28, 28, 1))
            x = Conv2D(filters_1, (kernel_1, kernel_1), strides=(2, 2), padding='same', activation='relu')(acc)
            x = Conv2D(filters_2, (kernel_2, kernel_2), strides=(1, 1), padding='same', activation='relu')(x)
            x = Conv2D(filters_3, (kernel_3, kernel_3), strides=(2, 2), padding='same', activation='relu')(x)
            x = Conv2D(filters_4, (kernel_4, kernel_4), strides=(1, 1), padding='same', activation='relu')(x)
            x = GlobalAveragePooling2D()(x)
            x = Dense(10, activation='softmax')(x)
            model = Model(inp, x)
            model.compile('adam', 'categorical_crossentropy', metrics=['accuracy'])
            X_train, y_train, X_val, y_val = self.data
            model.fit(X_train, y_train, batch_size=256, epochs=50,verbose=1, validation_data=(X_val, y_val))
            loss, acc = model.evaluate(X_val, y_val, batch_size=256)
            reward = (acc - self.moving_acc)
            self.moving_acc = 0.8 * self.moving_acc + (1 - 0.8) * acc
            self.moving_acc = self.moving_acc / (1 - self.beta_bias)
            self.beta_bias = 0
            reward = np.clip(reward, -0.1, 0.1)
            print()
        session.close()
        return reward, acc


with p_session.as_default():
    controller = Controller(p_session, 4, state_space,reg_param=1e-2,exploration=0.9,controller_cells=32, embedding_dim=20)

manager = Agent([X_train, y_train, X_test, y_test])

state = state_space.get_random_state_space(4)
print("Initial Random State : ", state_space.parse_state_space_list(state))


accuracy_plot = []
reward_plot = []
loss_plot = []

for trial in range(20):
    with p_session.as_default():
        K.set_session(p_session)
        actions = controller.get_action(state) 

    print("Predicted actions : ", state_space.parse_state_space_list(actions))
    reward, acc = manager.get_rewards(Network, state_space.parse_state_space_list(actions))
    print("Rewards : ", reward, "Accuracy : ", acc)
    reward_plot.append(reward)
    accuracy_plot.append(acc)
    with p_session.as_default():
        K.set_session(p_session)
        total_reward += reward
        print("Total reward : ", total_reward)
        state = actions
        controller.store_rollout(state, reward)
        loss = controller.train_step()
        loss_plot.append(loss)
        print("Trial %d: Controller loss : %0.6f" % (trial + 1, loss))
print("Total Reward : ", total_reward)

In [0]:
import matplotlib.pyplot as plt
import os
os.makedirs('plots')

plt.xticks(range(1,21,2))
plt.plot(accuracy_plot)
plt.savefig('plots/accuracy.png')

plt.xticks(range(1,21,2))
plt.plot(loss_plot)
plt.savefig('plots/loss.png')

plt.xticks(range(1,21,2))
plt.plot(reward_plot)
plt.savefig('plots/reward.png')