In [None]:
import tensorflow as tf
import numpy as np
import random

In [None]:
class Corridor:
 
    def reset(self):
        self.current_state = 3
        self.steps = 0
        self.s6_visits = 0
        return self.vectorize(self.current_state)

    def step(self, action):
        # Left
        if action == 0:
            self.current_state -= 1
        # Right
        if action == 1 and self.current_state != 6:
            # Move from state 5 to 6
            if self.current_state == 5:
                self.s6_visits += 1
            self.current_state += 1
        self.steps += 1

        if self.current_state == 0:
            if self.s6_visits >= 2:
                return self.vectorize(self.current_state), 1.0, True
            return self.vectorize(self.current_state), 0.01, True

        if self.steps == 20:
            return self.vectorize(self.current_state), 0.0, True

        return self.vectorize(self.current_state), 0.0, False
    
    def vectorize(self, state):
        vector = np.zeros(7)
        vector[state] = 1
        return vector

In [None]:
class HDQN:

    def __init__(self):
        self.memory = []
        self.epsilon = 1.0
        self.epsilon_decay = 0.9997
        self.batch_size = 64
        self.discount_rate = 0.99
        self.learning_rate = 0.001
        self.tau = 0.001
        self.model = self.build_model()
        self.compile_model(self.model)
        self.target_model = self.build_model()
        self.target_model.set_weights(self.model.get_weights())

    def build_model(self):
        model = tf.keras.Sequential()
        model.add(tf.keras.layers.Dense(16, input_shape=(7, ), activation='relu'))
        model.add(tf.keras.layers.Dense(32, activation='relu'))
        model.add(tf.keras.layers.Dense(7, activation='linear'))
        return model

    def compile_model(self, model):
        model.compile(loss="huber_loss",
                    optimizer=tf.keras.optimizers.RMSprop(lr=self.learning_rate, clipnorm=1.0))
        model.summary()

    def select_option(self, state):
        if np.random.rand() < self.epsilon:
            applicable = np.delete(np.arange(7), np.argmax(state))
            return np.random.choice(applicable)
        state = np.reshape(state, (1, 7))
        pred = self.model.predict(state)[0]
        pred[np.argmax(state[0])] = np.NINF
        return np.argmax(pred)

    def store_transition(self, state, option, reward, next_state, done):
        self.memory.append((state, option, reward, next_state, done))
        if len(self.memory) > 100000:
            self.memory = self.memory[-100000:]

    def replay(self):
        # state, option, reward, next_state, done
        batch = random.sample(self.memory, self.batch_size)

        x = [transition[0] for transition in batch]
        x = np.reshape(x, (self.batch_size, 7))
        y = self.model.predict(x)
        next_x = [transition[3] for transition in batch]
        next_x = np.reshape(next_x, (self.batch_size, 7))
        next_y = self.target_model.predict(next_x)

        for i, transition in enumerate(batch):
            option = transition[1]
            reward = transition[2]
            done = transition[4]
            if done:
                y[i, option] = reward
            else:
                y[i, option] = reward + self.discount_rate * np.amax(next_y[i])
            
        self.model.fit(x, y, epochs=1, verbose=0)
        self.update_target_model()

    def update_target_model(self):
        model_weights = self.model.get_weights()
        target_weights = self.target_model.get_weights()
        for i in range(len(model_weights)):
            target_weights[i] = target_weights[i] + self.tau * (model_weights[i] - target_weights[i])
        self.target_model.set_weights(target_weights)

In [None]:
def select_action(state, option):
    return 1 if np.argmax(state) < option else 0

In [None]:
def train(start, runs):
    if start == 0:
        rewards = np.zeros((runs, 10000))
        terminal_steps = np.zeros((runs, 10000))
    else:
        rewards = np.load(f"./drive/My Drive/corridor/hdqn_rewards.npy")
        terminal_steps = np.load(f"./drive/My Drive/corridor/hdqn_steps.npy")
   
    for run in range(start, runs):
        print("\nRun " + str(run))

        env = Corridor()
        meta = HDQN()
        steps = 0

        for episode in range(10000):
            # print("\nEpisode " + str(episode))

            done = False
            meta_state = env.reset()
            episode_reward = 0

            while not done:
                option = meta.select_option(meta_state)
                reached = np.argmax(meta_state) == option
                
                steps += 1
                option_reward = 0

                state = meta_state
                while not done and not reached:
                    action = select_action(state, option)
                    next_state, reward, done = env.step(action)
                    
                    option_reward += reward
                    episode_reward += reward
                    
                    reached = np.argmax(next_state) == option
                    state = next_state

                meta.store_transition(meta_state, option, option_reward, state, done)
                meta_state = state

                if len(meta.memory) >= 2000:
                    meta.replay()
                    if meta.epsilon > 0.01:
                        meta.epsilon *= meta.epsilon_decay

            rewards[run, episode] = episode_reward
            terminal_steps[run, episode] = steps

            # print("reward =  " + str(episode_reward))
            # print("terminal step =  " + str(steps))
            # print("epsilon = " + str(meta.epsilon))

            if episode % 100 == 99:
                np.save(f"./drive/My Drive/corridor/hdqn_rewards", rewards)
                np.save(f"./drive/My Drive/corridor/hdqn_steps", terminal_steps)
        
        meta.model.save(f"./drive/My Drive/corridor/hdqn_{run}_{episode}.h5")
        meta.target_model.save(f"./drive/My Drive/corridor/hdqn_target_{run}_{episode}.h5")

In [None]:
train(0, 10)