In [1]:
from multilayer_perceptron.mlp.nn import *
import gym
env = gym.make('CartPole-v1')

In [2]:
class NN:
    def __init__(self, layer_sizes, seed=None):
        if seed is not None:
            np.random.seed(seed)
        self.l = [(np.random.randn(m, n), np.zeros((m, 1))) for m, n in zip(layer_sizes[1:], layer_sizes)]
    
    def predict(self, X):
        out = X
        for W, b in self.l:
            Z = X @ W.T + b.T
            out = np.tanh(Z)
        if out.shape[0] == 1:
            return out.item()
        return out

    def set_params(self, params):
        self.l = params

In [3]:
nn = NN([4, 4, 1])

In [4]:
nn.predict(env.observation_space.sample().reshape(1, -1))

1.0

In [5]:
env.action_space

Discrete(2)

In [6]:
env.action_space.sample()

1

In [7]:
def deep_copy(params):
    return [(np.copy(W), np.copy(b)) for (W, b) in params]

In [8]:
def params_perturbation(params, sigma=0.1, seed=None):
    """
    Obtain weights perturbation for the whole network architecture
    """
    if seed is not None:
        np.random.seed(42)
    return [(np.random.randn(*W.shape) * sigma, np.random.randn(*b.shape) * sigma) for W, b in params]

In [9]:
def combine_weights(params, delta_params):
    return [(W + dW, b + db) for ((W, b), (dW, db)) in zip(params, delta_params)]

In [10]:
def evaluate_model(nn, env, max_iter=1000, verbose=False):
    observation = env.reset()
    done = False
    i = 0
    r_sum = 0
    while not done and i < max_iter:
        observation, reward, done, _ = env.step(int(nn.predict(observation) > 0))
        i += 1
        r_sum += reward
    if verbose:
        print(f"Episode end after {i} iterations with reward = {r_sum} and done status {done}")
    return r_sum

In [11]:
n = 20
lr = 0.03
sigma = 0.1

In [12]:
def update_params(params, population, rewards, learning_rate=0.05, sigma=0.1):
    """
    Inplace update of parameters
    """
    n = len(population)
    for candidate, reward in zip(population, rewards):
        for i in range(len(params)):
            W, b = params[i]
            dW, db = candidate[i]
            W_new =  W + learning_rate / (n * sigma) * reward * dW
#             b_new = b + learning_rate / (n * sigma) * reward * db
            b_new = b
            params[i] = (W_new, b_new)

In [13]:
def generation_update(model, environment, sigma=0.1, lr=0.01, population_size=10, seed=None, normalize_rewards=True):
    original_params = deep_copy(model.l)
    
    if seed is not None:
        np.random.seed(seed)

    population = []
    rewards = []
    for i in range(population_size):
        candidate = params_perturbation(original_params, sigma=sigma)
        modified_params = combine_weights(original_params, candidate)
        model.set_params(modified_params)
        reward = evaluate_model(model, environment)
        population.append(candidate)
        rewards.append(reward)

    rewards = np.array(rewards)
    r_mean, r_std = rewards.mean(), rewards.std()
    if normalize_rewards:
        rewards = (rewards - r_mean) / (r_std + 1e-9)
        
    update_params(original_params, population, rewards, learning_rate=lr, sigma=sigma)
    model.set_params(deep_copy(original_params))
    return r_mean, r_std

In [14]:
def render_env(model, env, max_iter=None, verbose=True):
    observation = env.reset()
    done = False
    i = 0
    r_sum = 0
    while not done or (max_iter is not None and i < max_iter):
        env.render()
        observation, reward, done, _ = env.step(int(nn.predict(observation) > 0))
        i += 1
        r_sum += reward
    if verbose:
        print(f"Episode end after {i} iterations with reward = {r_sum} and done status {done}")

In [17]:
for i in range(100):
    mean_rewards, std_rewards = generation_update(nn, env, population_size=20, sigma=0.1, lr=0.03, normalize_rewards=True)
    if i % 10 == 0:
        print(i, mean_rewards, std_rewards)
    if mean_rewards > 1000:
        break

0 455.55 169.73994079178885
10 635.2 203.71808952569725
20 621.6 252.29137916306217
30 621.8 272.0070954956874
40 793.6 236.95092318874808
50 748.5 289.7128751022295
60 898.7 178.03906874616027
70 971.65 96.54805798150474
80 950.9 128.34597773206607
90 975.0 95.5557428938732


In [18]:
render_env(nn, env)

KeyboardInterrupt: 