In [69]:
import numpy as np
import torch
from torch.autograd import Variable
from torch.nn import Softmax, LogSoftmax, Sigmoid

class Option():
    def __init__(self, n_states, n_actions):
        self.n_states = n_states
        self.n_actions = n_actions

        # Policy parameters
        self.theta = Variable(torch.Tensor(
            np.random.rand(n_states, n_actions)), requires_grad=True)
        # Termination parameters
        self.upsilon = Variable(torch.Tensor(
            np.random.rand(n_states)), requires_grad=True)
    
    # Input : index in [0, n_states - 1]
    # Return : log pi(.|state), variable of shape (1, n_actions)
    def pi(self, state_index, T=0.1):
        state_var = self._varFromStateIndex(state_index)
        logprobs = Softmax()(torch.matmul(state_var, self.theta) / T)
        return logprobs
    
    # Input : index in [0, n_states - 1]
    # Return : beta(state), variable of shape (1)
    def beta(self, state_index):
        state_var = self._varFromStateIndex(state_index)
        return Sigmoid()(torch.matmul(state_var, self.upsilon))
    
    # Input : index in [0, n_states - 1]
    # Return : one of "left", "up", "right", or "down",
    #          index of action chosen in [0, n_actions]
    #          and one-hot variable of shape (1, n_actions)
    def pickAction(self, state_index):
        probs = self.pi(state_index).data.numpy().reshape(-1)
        action_index = np.random.choice(self.n_actions, size=1, p=probs)[0]
        action, action_one_hot = self._actionFromActionIndex(action_index)
        return action, action_index, action_one_hot
    
    # Input : index in [0, n_states - 1]
    # Return : one-hot variable of shape (1, n_states)
    def _varFromStateIndex(self, state_index):
        s = np.zeros(self.n_states)
        s[state_index] = 1
        return Variable(torch.Tensor(s)).view(1, -1)
    
    # Input : index in [0, n_actions - 1]
    # Return : one of "left", "up", "right", or "down"
    #          and one-hot variable of shape (1, n_actions)
    def _actionFromActionIndex(self, action_index):
        if action_index == 0:
            action = "left"
        elif action_index == 1:
            action = "up"
        elif action_index == 2:
            action = "right"
        elif action_index == 3:
            action = "down"
        a = np.zeros(self.n_actions)
        a[action_index] = 1
        return action, Variable(torch.Tensor(a)).view(1, -1)

In [None]:
import random

""" Agent planning using Option-Critic architecture """
class OptionCritic():
    def __init__(self, gamma=0.99, alpha_critic=0.5, alpha_theta=0.25, 
                 alpha_upsilon=0.25, n_options=4):    
        self.gamma = gamma                 # Discount factor
        self.alpha_critic = alpha_critic   # Critic lr
        self.alpha_theta = alpha_theta     # Intra-option policies lr
        self.alpha_upsilon = alpha_upsilon # Termination functions lr
        
        n_states = 13*13
        n_actions = 4
        self.options = [Option(n_states, n_actions) \
                        for _ in xrange(n_options)]
        
        self.current_option = None
        # Keep track of one hot var and index of last action taken
        self.last_action_one_hot = None
        self.last_action_index = None
        
        # Action values in the context of (state, option) pairs
        self.Q_U = np.zeros((n_states, n_options, n_actions))
        # Option values (computed from Q_U)
        self.Q = np.zeros((n_states, n_options))
        # State values (computed from Q)
        self.V = np.zeros(n_states)
        
    def epsilonGreedyPolicy(self, state_tuple, epsilon=0.01):
        state_index = self._sIdx(state_tuple)
        # If current option is None, pick a new one epsilon greedily
        if self.current_option is None:
           # Pick greedy option with probability (1 - epsilon)
            if random.uniform(0, 1) > epsilon:
                best_option_idx = np.argmax(self.Q[state_index])
                self.current_option = self.options[best_option_idx]
            # Pick random action with probability epsilon
            else:
                self.current_option = random.choice(self.options)

        # Pick action according to current option
        action, action_index, action_one_hot = \
            self.current_option.pickAction(state_index)
        # Record one hot var and index of last action taken
        self.last_action_one_hot = action_one_hot
        self.last_action_index = action_index
        return action
    
    def recordTransition(self, state, reward, next_state):
        pi = self.current_option.pi(self._sIdx(state))
        beta = self.current_option.beta(self._sIdx(next_state))
        
        # 1) Critic improvement
        # Update estimate of Q_U[state, current_option, next_state]
        self._evaluateOption(state, reward, next_state, pi, beta)
        
        # 2) Actor improvement
        # Take a gradient step for policy and termination parameters
        # of current option
        self._improveOption(state, next_state, pi, beta)
        
        # If current option ends, set current option to None
        beta = self.current_option.beta(self._sIdx(next_state)).data[0]
        if random.uniform(0, 1) < beta:
            self.current_option = None
        
    def _evaluateOption(self, state, reward, next_state, pi, beta):
        s1 = self._sIdx(state)
        s2 = self._sIdx(next_state)
        o = self._oIdx(self.current_option)
        a = self.last_action_index
        
        # Update Q_U
        beta = beta.data[0]
        target = reward + self.gamma * (1 - beta) * self.Q[s2, o] \
            + self.gamma * beta * np.max(self.Q[s2])
        self.Q_U[s1, o, a] += \
            self.alpha_critic * (target - self.Q_U[s1, o, a])
            
        # Update Q since Q_U has changed
        self.Q[s1, o] = pi.data.numpy().reshape(-1).dot(self.Q_U[s1, o])
        
        # Update V since Q has changed
        # This update is only valid if the policy over options is greedy
        self.V[s1] = np.max(self.Q[s1, o])
        
    def _improveOption(self, state, next_state, pi, beta):
        s1 = self._sIdx(state)
        s2 = self._sIdx(next_state)
        o = self._oIdx(self.current_option)
        a = self.last_action_index
        
        # 1) Policy update
        # Compute log pi(last_action_taken | state)
        logprobs = torch.log(pi)
        logprob = torch.sum(logprobs * self.last_action_one_hot)
        # Compute gradient of theta w.r.t this quantity
        logprob.backward()
        grad_theta = self.current_option.theta.grad.data
        # Take a gradient step
        self.current_option.theta.data += self.alpha_theta * \
            self.Q_U[s1, o, a] * grad_theta
        # Zero gradient
        self.current_option.theta.grad.data.zero_()
        
        # 2) Termination function update
        # Compute gradient of upsilon w.r.t beta(next_state)
        beta.backward()
        grad_upsilon = self.current_option.upsilon.grad.data
        # Take a gradient step
        self.current_option.upsilon.data += self.alpha_upsilon * \
            (self.Q[s2, o] - self.V[s2]) * grad_upsilon
        # Zero gradient
        self.current_option.upsilon.grad.data.zero_()
        
    def _sIdx(self, state):
        return state[0] * 13 + state[1]
    
    def _oIdx(self, option):
        return self.options.index(option)

In [None]:
from four_rooms import FourRoomsEnvironment
import matplotlib.pyplot as plt

env = FourRoomsEnvironment(start_loc=("random"))

def run_episode(verbose=False):
    n_steps = 0
    state = env.reset()
    while True:
        n_steps += 1
        action = agent.epsilonGreedyPolicy(state)
        if verbose:
            print("State = {}, Option = {}, Action = {}".format(
                state, agent.current_option, action))
        next_state, reward, done = env.step(action)
        agent.recordTransition(state, reward, next_state)
        state = next_state
        if done:
            return n_steps
        # If episode takes more than 1000 steps to finish reset it
        if n_steps > 1000:
            n_steps = 0
            state = env.reset()
        
n_repetitions = 5
n_episodes = 301

average_len_episodes = []

for i in xrange(n_repetitions):
    print i
    agent = OptionCritic()
    len_episodes = []
   
    for j in xrange(n_episodes):
        print j
        n_steps = run_episode()
        len_episodes.append(n_steps)
        
    average_len_episodes.append(len_episodes)
    
average_len_episodes = np.array(average_len_episodes).mean(axis=0)
    
plt.plot(xrange(n_episodes), average_len_episodes)
plt.xlabel("Episodes")
plt.ylabel("Steps per episode")
#plt.savefig("plots/option-critic_4options.png")
plt.show()

0
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276


In [None]:
def visualizeTermination(option):
    beta = []
    for i in xrange(13):
        row = []
        for j in xrange(13):
            s = agent._sIdx((i, j))
            row.append(option.beta(s).data[0])
        beta.append(row)
    beta = np.array(beta) * (env.grid >= 0)
    plt.imshow(beta)
    plt.title("Termination Probability")
    plt.colorbar()
    plt.show()

def visualizePolicy(option):
    pi = []
    for i in xrange(13):
        row = []
        for j in xrange(13):
            s = agent._sIdx((i, j))
            row.append(np.argmax(option.pi(s)))
        pi.append(row)
    pi = (np.array(pi) + 1) * (env.grid >= 0)
    plt.imshow(pi)
    plt.title("Argmax Policy (1=left, 2=up, 3=right, 4=down)")
    plt.colorbar()
    plt.show()

for o in agent.options:
    visualizeTermination(o)
    visualizePolicy(o)