In [None]:
import gym
from matplotlib import animation
import matplotlib.pyplot as plt
import numpy as np

In [None]:
%matplotlib inline

In [None]:
ENV = gym.make('CartPole-v1')
ENV.reset()

In [None]:
def initialize_random_weights(mean, std):
    return np.random.normal(mean, std, 4)

In [None]:
def sigmoid(weights, observation):
    weighted_sum = sum([weights[i] * observation[i] for i in range(len(weights))])
    return 1.0 / (1 + np.exp(-weighted_sum)) 

def grad_log_sigmoid(weights, observation, action):
    if action == 1:
        return observation * (1 - sigmoid(weights, observation))
    else:
        return - observation * sigmoid(weights, observation)

def get_action(weights, observation):
    prob_one = sigmoid(weights, observation)
    return int(np.random.random() <= prob_one)

In [None]:
def one_cartpole_run(weights):
    observation = ENV.reset()
    cum_reward = 0
    grad_log_sum = np.zeros(4)
    for t in range(1000):
        action = get_action(weights, observation)
        observation, reward, done, info = ENV.step(action)
        cum_reward += reward
        grad_log_sum += grad_log_sigmoid(weights, observation, action)
        if done:
            break
    return cum_reward, grad_log_sum

In [None]:
def record_cartpole_run(weights):
    observation = ENV.reset()
    
    all_observations = np.zeros((1, 4))
    all_observations[0, :] = observation
    
    all_actions = []
    for t in range(1000):
        action = get_action(weights, observation)
        all_actions.append(action)
        
        observation, reward, done, info = ENV.step(action)
        all_observations = np.vstack((all_observations, observation))
        
        if done:
            break

    return all_observations[:-1, :], all_actions 

In [None]:
def get_grad_reward(weights, obs, actions):
    grad_reward = np.zeros(4)
    for i in xrange(len(actions)):
        grad_reward += grad_log_sigmoid(weights, obs[i, :], actions[i]) * (len(actions) - i)
    return grad_reward

In [None]:
batch_n = 10 ** 2
grad_sample = 100
weights = initialize_random_weights(0, 1)
learning_rate = 10 ** (-2)
beta = 1
current_score = [0] * (batch_n * grad_sample)

for i in range(batch_n):
    if i % 10 == 0:
        print i, weights
        
#     learning_rate *= 0.97
    
    avg_grad_log_sum = np.zeros(4)
    for k in range(grad_sample):
#         cum_reward, grad_log_sum = one_cartpole_run(weights)
#         avg_grad_log_sum -= grad_log_sum * cum_reward
        obs, actions = record_cartpole_run(weights)
        avg_grad_log_sum += get_grad_reward(weights, obs, actions)
        current_score[i * grad_sample + k] = len(actions)
        
    avg_grad_log_sum /= grad_sample
#     avg_grad_log_sum -= beta * weights
#     print np.clip(avg_grad_log_sum, -5, 5), cum_reward
    
    weights += learning_rate * (avg_grad_log_sum - beta * weights)
#     print avg_grad_log_sum - beta * weights

print weights, current_score[-1]

In [None]:
sum(current_score[-100:]) / 100

In [None]:
plt.plot(range(len(current_score)), current_score)
plt.plot(range(len(current_score)), current_score)
plt.show()

## Troubleshooting gradient estimation

In [None]:
def log_prob_actions(weights, obs, actions):
    log_prob = 0
    for i in range(len(actions)):
        prob_one = sigmoid(weights, obs[i, :])
        if actions[i] == 1:
            log_prob += np.log(prob_one)
        else:
            log_prob += np.log(1 - prob_one)
    return log_prob

def grad_log_prob_actions(weights, obs, actions):
    grad_log_sum = np.zeros(4)
    for i in range(len(actions)):
        grad_log_sum += grad_log_sigmoid(weights, obs[i, :], actions[i])
    return grad_log_sum    

In [None]:
weights = [1, 1, 1, 1]
obs, actions = record_cartpole_run(weights)

In [None]:
print log_prob_actions(weights, obs, actions)
t = log_prob_actions(weights, obs, actions)
d = 0.0001
print grad_log_prob_actions(weights, obs, actions)
print (log_prob_actions([1 + d, 1, 1, 1], obs, actions) - t) / d
print (log_prob_actions([1, 1 + d, 1, 1], obs, actions) - t) / d
print (log_prob_actions([1, 1, 1 + d, 1], obs, actions) - t) / d
print (log_prob_actions([1, 1, 1, 1 + d], obs, actions) - t) / d