<a href="https://colab.research.google.com/github/skozh/RL/blob/master/ES_sample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Simple example: Minimize a quadratic around some solution point
# Code from https://gist.github.com/karpathy/77fbb6a8dac5395f1b73e7a89300318d
# Explanation at https://openai.com/blog/evolution-strategies/
import numpy as np

In [2]:
solution = np.array([0.5, 0.1, -0.3])

In [3]:
def fun(w):
  return -np.sum((w - solution)**2)

In [4]:
npop = 50                                # population size
sigma = 0.1                             # noise standard deviation
alpha = 0.001                         # learning rate
w = np.random.randn(3)    # Initial guess (random)

In [5]:
for i in range(300):
  N = np.random.randn(npop, 3)                                  # Random noise initialization
  R = np.zeros(npop)                                                        # Reward initialization
  for j in range(npop):
    w_try = w + sigma* N[j]                                              # Add noise to the guessed parameter
    R[j] = fun(w_try)
  if i % 50 == 0:
    print ('iter %d. w: %s, solution: %s, reward: %f' % 
           (i, str(w), str(solution), R[j]))
  A = (R - np.mean(R))/np.std(R)                                    # Standardize Rewards to Gaussian Distribution
  w = w + alpha/(npop * sigma) * np.dot(N.T, A)      # Update Parameter. 

iter 0. w: [ 0.35219536 -0.01436819 -0.70463822], solution: [ 0.5  0.1 -0.3], reward: -0.328839
iter 50. w: [ 0.49516293  0.08896741 -0.34004539], solution: [ 0.5  0.1 -0.3], reward: -0.032393
iter 100. w: [ 0.50034543  0.10111892 -0.30188364], solution: [ 0.5  0.1 -0.3], reward: -0.088718
iter 150. w: [ 0.50226112  0.09958874 -0.30119221], solution: [ 0.5  0.1 -0.3], reward: -0.049371
iter 200. w: [ 0.50394589  0.09434979 -0.2976531 ], solution: [ 0.5  0.1 -0.3], reward: -0.032263
iter 250. w: [ 0.50237112  0.09294307 -0.30613214], solution: [ 0.5  0.1 -0.3], reward: -0.006015


In [6]:
w

array([ 0.5105021 ,  0.09298527, -0.29661409])