In [6]:
"""
Taken from:
    https://gist.github.com/karpathy/77fbb6a8dac5395f1b73e7a89300318d
    
A bare bones examples of optimizing a black-box function (f) using
Natural Evolution Strategies (NES), where the parameter distribution is a 
gaussian of fixed standard deviation.
"""

import numpy as np
np.random.seed(0)

# the function we want to optimize
def f(w):
    # here we would normally:
    # ... 1) create a neural network with weights w
    # ... 2) run the neural network on the environment for some time
    # ... 3) sum up and return the total reward

    # but for the purposes of an example, lets try to minimize
    # the L2 distance to a specific solution vector. So the highest reward
    # we can achieve is 0, when the vector w is exactly equal to solution
    reward = -np.sum(np.square(solution - w))
    
    def rastrigin(*X, **kwargs):
        A = kwargs.get('A', 10)
        return A + sum([(x**2 - A * np.cos(2 * np.pi * x)) for x in X])
    
    return -rastrigin(w[0], w[1])

# hyperparameters
npop = 50 # population size
sigma = 0.51 # noise standard deviation
alpha = 0.001 # learning rate

# start the optimization
solution = np.array([0.5, 0.5])
n = len(solution)

w = np.random.randn(n) # our initial guess is random
for i in range(30000):

    # print current fitness of the most likely parameter setting
    if i % 20 == 0:
        print('iter %d. w: %s, solution: %s, reward: %f' % 
              (i, str(w), str(solution), f(w)))

    # initialize memory for a population of w's, and their rewards
    N = np.random.randn(npop, n) # samples from a normal distribution N(0,1)
    R = np.zeros(npop)
    for j in range(npop):
        w_try = w + sigma*N[j] # jitter w using gaussian of sigma 0.1
        R[j] = f(w_try) # evaluate the jittered version

    # standardize the rewards to have a gaussian distribution
    A = (R - np.mean(R)) / np.std(R)
    # perform the parameter update. The matrix multiply below
    # is just an efficient way to sum up all the rows of the noise matrix N,
    # where each row N[j] is weighted by A[j]
    w = w + alpha/(npop*sigma) * np.dot(N.T, A)

iter 0. w: [1.76405235 0.40015721], solution: [0.5 0.5], reward: -20.486190
iter 20. w: [1.75796799 0.39908414], solution: [0.5 0.5], reward: -20.805497
iter 40. w: [1.75025114 0.39658209], solution: [0.5 0.5], reward: -21.166962
iter 60. w: [1.74311885 0.39545125], solution: [0.5 0.5], reward: -21.545962
iter 80. w: [1.73601381 0.39373736], solution: [0.5 0.5], reward: -21.899098
iter 100. w: [1.72875583 0.39263115], solution: [0.5 0.5], reward: -22.283060
iter 120. w: [1.72028275 0.39006122], solution: [0.5 0.5], reward: -22.675464
iter 140. w: [1.71474041 0.38838605], solution: [0.5 0.5], reward: -22.928623
iter 160. w: [1.71040455 0.38657544], solution: [0.5 0.5], reward: -23.103399
iter 180. w: [1.70438556 0.38354407], solution: [0.5 0.5], reward: -23.319302
iter 200. w: [1.69931436 0.38229348], solution: [0.5 0.5], reward: -23.552502
iter 220. w: [1.69576049 0.38222818], solution: [0.5 0.5], reward: -23.748890
iter 240. w: [1.68947797 0.37995079], solution: [0.5 0.5], reward: -23

iter 2260. w: [1.13214347 0.19439763], solution: [0.5 0.5], reward: -1.149897
iter 2280. w: [1.12634489 0.1928934 ], solution: [0.5 0.5], reward: -0.783181
iter 2300. w: [1.12325389 0.1917828 ], solution: [0.5 0.5], reward: -0.573395
iter 2320. w: [1.11830317 0.19112628], solution: [0.5 0.5], reward: -0.309520
iter 2340. w: [1.11308943 0.18787481], solution: [0.5 0.5], reward: 0.110757
iter 2360. w: [1.1093677 0.1852446], solution: [0.5 0.5], reward: 0.422753
iter 2380. w: [1.10489368 0.18294312], solution: [0.5 0.5], reward: 0.741137
iter 2400. w: [1.09691044 0.18136225], solution: [0.5 0.5], reward: 1.146827
iter 2420. w: [1.09283402 0.1810643 ], solution: [0.5 0.5], reward: 1.316653
iter 2440. w: [1.08948081 0.17985557], solution: [0.5 0.5], reward: 1.507400
iter 2460. w: [1.08458944 0.18048325], solution: [0.5 0.5], reward: 1.641914
iter 2480. w: [1.0792283  0.18056911], solution: [0.5 0.5], reward: 1.814394
iter 2500. w: [1.07706904 0.17926808], solution: [0.5 0.5], reward: 1.9574

iter 4440. w: [0.75213339 0.0740895 ], solution: [0.5 0.5], reward: -1.501262
iter 4460. w: [0.75084848 0.07137572], solution: [0.5 0.5], reward: -1.504428
iter 4480. w: [0.74837316 0.06961754], solution: [0.5 0.5], reward: -1.608648
iter 4500. w: [0.74684547 0.069033  ], solution: [0.5 0.5], reward: -1.686763
iter 4520. w: [0.74487408 0.06806888], solution: [0.5 0.5], reward: -1.782221
iter 4540. w: [0.74561034 0.06868724], solution: [0.5 0.5], reward: -1.753346
iter 4560. w: [0.74258438 0.06914102], solution: [0.5 0.5], reward: -1.950862
iter 4580. w: [0.73938526 0.06907027], solution: [0.5 0.5], reward: -2.144922
iter 4600. w: [0.73893815 0.06864239], solution: [0.5 0.5], reward: -2.160958
iter 4620. w: [0.73507777 0.06872241], solution: [0.5 0.5], reward: -2.399123
iter 4640. w: [0.73376898 0.06828195], solution: [0.5 0.5], reward: -2.467433
iter 4660. w: [0.73340303 0.06759215], solution: [0.5 0.5], reward: -2.471729
iter 4680. w: [0.73177386 0.06643117], solution: [0.5 0.5], rewa

iter 6620. w: [0.55976778 0.01591179], solution: [0.5 0.5], reward: -9.666654
iter 6640. w: [0.55967365 0.01471447], solution: [0.5 0.5], reward: -9.661453
iter 6660. w: [0.5585709  0.01640621], solution: [0.5 0.5], reward: -9.695799
iter 6680. w: [0.55745022 0.01722196], solution: [0.5 0.5], reward: -9.725081
iter 6700. w: [0.55632726 0.01881674], solution: [0.5 0.5], reward: -9.759895
iter 6720. w: [0.5524191  0.01959367], solution: [0.5 0.5], reward: -9.843735
iter 6740. w: [0.55180898 0.01970166], solution: [0.5 0.5], reward: -9.856231
iter 6760. w: [0.54948073 0.02051561], solution: [0.5 0.5], reward: -9.905912
iter 6780. w: [0.54961168 0.01917606], solution: [0.5 0.5], reward: -9.893015
iter 6800. w: [0.54746627 0.01698998], solution: [0.5 0.5], reward: -9.915486
iter 6820. w: [0.54700558 0.01671626], solution: [0.5 0.5], reward: -9.921620
iter 6840. w: [0.54690488 0.01646992], solution: [0.5 0.5], reward: -9.921731
iter 6860. w: [0.54752668 0.01465556], solution: [0.5 0.5], rewa

iter 8740. w: [0.36695407 0.00308816], solution: [0.5 0.5], reward: -6.841262
iter 8760. w: [0.36516466 0.0020803 ], solution: [0.5 0.5], reward: -6.755080
iter 8780. w: [0.36147734 0.00107778], solution: [0.5 0.5], reward: -6.576383
iter 8800. w: [0.35900892 0.00072216], solution: [0.5 0.5], reward: -6.455127
iter 8820. w: [0.35866262 0.00117594], solution: [0.5 0.5], reward: -6.438183
iter 8840. w: [ 3.55079049e-01 -1.64053665e-04], solution: [0.5 0.5], reward: -6.259081
iter 8860. w: [0.35485495 0.00048349], solution: [0.5 0.5], reward: -6.247835
iter 8880. w: [0.35130454 0.00040348], solution: [0.5 0.5], reward: -6.067414
iter 8900. w: [0.35116235 0.00153644], solution: [0.5 0.5], reward: -6.060563
iter 8920. w: [0.34942523 0.00173186], solution: [0.5 0.5], reward: -5.971291
iter 8940. w: [0.34805903 0.00337104], solution: [0.5 0.5], reward: -5.902154
iter 8960. w: [0.34573275 0.0035856 ], solution: [0.5 0.5], reward: -5.780935
iter 8980. w: [0.34450438 0.00458709], solution: [0.5 

iter 10820. w: [0.1898057  0.00817773], solution: [0.5 0.5], reward: 3.643303
iter 10840. w: [0.18728549 0.00831221], solution: [0.5 0.5], reward: 3.790503
iter 10860. w: [0.18583857 0.0059559 ], solution: [0.5 0.5], reward: 3.881496
iter 10880. w: [0.18481849 0.00653915], solution: [0.5 0.5], reward: 3.939303
iter 10900. w: [0.18441116 0.00627133], solution: [0.5 0.5], reward: 3.963597
iter 10920. w: [0.18244746 0.00598333], solution: [0.5 0.5], reward: 4.077763
iter 10940. w: [0.18200972 0.00629857], solution: [0.5 0.5], reward: 4.102203
iter 10960. w: [0.1799698  0.00555561], solution: [0.5 0.5], reward: 4.220998
iter 10980. w: [0.17578393 0.00445556], solution: [0.5 0.5], reward: 4.461125
iter 11000. w: [0.17587208 0.00478848], solution: [0.5 0.5], reward: 4.455535
iter 11020. w: [0.17398529 0.00563024], solution: [0.5 0.5], reward: 4.560060
iter 11040. w: [0.17115087 0.00326868], solution: [0.5 0.5], reward: 4.722632
iter 11060. w: [0.16971283 0.00360652], solution: [0.5 0.5], rew

iter 12940. w: [ 0.0883776  -0.00399326], solution: [0.5 0.5], reward: 8.486487
iter 12960. w: [ 0.08750579 -0.0045504 ], solution: [0.5 0.5], reward: 8.514447
iter 12980. w: [ 0.08683487 -0.00269968], solution: [0.5 0.5], reward: 8.539177
iter 13000. w: [ 0.0849989  -0.00225306], solution: [0.5 0.5], reward: 8.599223
iter 13020. w: [ 0.08335776 -0.00327303], solution: [0.5 0.5], reward: 8.650413
iter 13040. w: [ 0.08400778 -0.00342915], solution: [0.5 0.5], reward: 8.629598
iter 13060. w: [ 0.08539724 -0.00397513], solution: [0.5 0.5], reward: 8.584261
iter 13080. w: [ 0.08586155 -0.00761291], solution: [0.5 0.5], reward: 8.560870
iter 13100. w: [ 0.08477402 -0.00775913], solution: [0.5 0.5], reward: 8.595511
iter 13120. w: [ 0.08647385 -0.00753439], solution: [0.5 0.5], reward: 8.541174
iter 13140. w: [ 0.08457245 -0.00657437], solution: [0.5 0.5], reward: 8.605338
iter 13160. w: [ 0.08331526 -0.00843627], solution: [0.5 0.5], reward: 8.639764
iter 13180. w: [ 0.08278337 -0.00622448]

iter 15120. w: [0.04263208 0.01364297], solution: [0.5 0.5], reward: 9.604659
iter 15140. w: [0.04075723 0.0136541 ], solution: [0.5 0.5], reward: 9.635264
iter 15160. w: [0.03981101 0.01476367], solution: [0.5 0.5], reward: 9.643981
iter 15180. w: [0.04093327 0.01465509], solution: [0.5 0.5], reward: 9.626828
iter 15200. w: [0.03916764 0.01529692], solution: [0.5 0.5], reward: 9.650784
iter 15220. w: [0.03742057 0.01242201], solution: [0.5 0.5], reward: 9.692865
iter 15240. w: [0.03756519 0.01200478], solution: [0.5 0.5], reward: 9.692753
iter 15260. w: [0.03754614 0.01292961], solution: [0.5 0.5], reward: 9.688464
iter 15280. w: [0.03846884 0.01315203], solution: [0.5 0.5], reward: 9.673531
iter 15300. w: [0.03852665 0.01214521], solution: [0.5 0.5], reward: 9.677704
iter 15320. w: [0.03858901 0.01266101], solution: [0.5 0.5], reward: 9.674223
iter 15340. w: [0.03667552 0.01121802], solution: [0.5 0.5], reward: 9.709361
iter 15360. w: [0.0352993  0.01078993], solution: [0.5 0.5], rew

iter 17260. w: [ 0.02111767 -0.00759956], solution: [0.5 0.5], reward: 9.900199
iter 17280. w: [ 0.01820879 -0.00751303], solution: [0.5 0.5], reward: 9.923096
iter 17300. w: [ 0.01874635 -0.00829401], solution: [0.5 0.5], reward: 9.916716
iter 17320. w: [ 0.01625581 -0.00696767], solution: [0.5 0.5], reward: 9.937990
iter 17340. w: [ 0.01818415 -0.00770769], solution: [0.5 0.5], reward: 9.922686
iter 17360. w: [ 0.01616302 -0.00805894], solution: [0.5 0.5], reward: 9.935334
iter 17380. w: [ 0.01605987 -0.00879991], solution: [0.5 0.5], reward: 9.933515
iter 17400. w: [ 0.01488893 -0.01056998], solution: [0.5 0.5], reward: 9.933895
iter 17420. w: [ 0.01517246 -0.01200924], solution: [0.5 0.5], reward: 9.925765
iter 17440. w: [ 0.01484367 -0.01261843], solution: [0.5 0.5], reward: 9.924746
iter 17460. w: [ 0.0153886  -0.01161112], solution: [0.5 0.5], reward: 9.926320
iter 17480. w: [ 0.01377646 -0.01281766], solution: [0.5 0.5], reward: 9.929794
iter 17500. w: [ 0.01012383 -0.01266989]

iter 19460. w: [ 0.00656672 -0.00312453], solution: [0.5 0.5], reward: 9.989509
iter 19480. w: [ 0.00857192 -0.00279673], solution: [0.5 0.5], reward: 9.983874
iter 19500. w: [ 0.01034469 -0.00174265], solution: [0.5 0.5], reward: 9.978174
iter 19520. w: [ 0.01088955 -0.00188295], solution: [0.5 0.5], reward: 9.975780
iter 19540. w: [ 0.01141895 -0.00284471], solution: [0.5 0.5], reward: 9.972537
iter 19560. w: [ 0.01171483 -0.00340599], solution: [0.5 0.5], reward: 9.970484
iter 19580. w: [ 0.00975855 -0.00471266], solution: [0.5 0.5], reward: 9.976707
iter 19600. w: [ 0.00960744 -0.00577252], solution: [0.5 0.5], reward: 9.975083
iter 19620. w: [ 0.01015101 -0.00588587], solution: [0.5 0.5], reward: 9.972692
iter 19640. w: [ 0.0125069  -0.00478737], solution: [0.5 0.5], reward: 9.964436
iter 19660. w: [ 0.01233322 -0.00250288], solution: [0.5 0.5], reward: 9.968595
iter 19680. w: [ 0.01213971 -0.00495666], solution: [0.5 0.5], reward: 9.965903
iter 19700. w: [ 0.01158656 -0.00699831]

iter 21620. w: [0.01972825 0.00347437], solution: [0.5 0.5], reward: 9.920489
iter 21640. w: [0.02033167 0.00359507], solution: [0.5 0.5], reward: 9.915536
iter 21660. w: [0.02008151 0.00054847], solution: [0.5 0.5], reward: 9.920041
iter 21680. w: [ 0.02025183 -0.00034851], solution: [0.5 0.5], reward: 9.918717
iter 21700. w: [0.01766447 0.0002306 ], solution: [0.5 0.5], reward: 9.938148
iter 21720. w: [ 1.47917295e-02 -3.63006366e-05], solution: [0.5 0.5], reward: 9.956624
iter 21740. w: [ 0.01332667 -0.00115426], solution: [0.5 0.5], reward: 9.964522
iter 21760. w: [ 0.01350029 -0.00106705], solution: [0.5 0.5], reward: 9.963637
iter 21780. w: [ 0.01426755 -0.00082058], solution: [0.5 0.5], reward: 9.959508
iter 21800. w: [0.01310242 0.00139199], solution: [0.5 0.5], reward: 9.965576
iter 21820. w: [0.01262903 0.00067282], solution: [0.5 0.5], reward: 9.968285
iter 21840. w: [1.38602013e-02 7.77107833e-06], solution: [0.5 0.5], reward: 9.961912
iter 21860. w: [ 0.01299922 -0.0011794

iter 23700. w: [0.01223495 0.00256456], solution: [0.5 0.5], reward: 9.969012
iter 23720. w: [0.0128377  0.00229038], solution: [0.5 0.5], reward: 9.966281
iter 23740. w: [ 0.01413395 -0.00086533], solution: [0.5 0.5], reward: 9.960245
iter 23760. w: [ 0.01409828 -0.00096748], solution: [0.5 0.5], reward: 9.960407
iter 23780. w: [ 0.01410176 -0.00175242], solution: [0.5 0.5], reward: 9.959964
iter 23800. w: [ 0.0132984  -0.00218683], solution: [0.5 0.5], reward: 9.963986
iter 23820. w: [ 0.01460929 -0.00156978], solution: [0.5 0.5], reward: 9.957198
iter 23840. w: [ 0.0155204  -0.00044731], solution: [0.5 0.5], reward: 9.952209
iter 23860. w: [0.0173189  0.00094581], solution: [0.5 0.5], reward: 9.940374
iter 23880. w: [ 0.01886739 -0.00021744], solution: [0.5 0.5], reward: 9.929450
iter 23900. w: [ 0.01782889 -0.00053458], solution: [0.5 0.5], reward: 9.936946
iter 23920. w: [ 0.01832299 -0.00147396], solution: [0.5 0.5], reward: 9.933036
iter 23940. w: [ 0.01796942 -0.00303709], solu

iter 26000. w: [0.00282911 0.02251999], solution: [0.5 0.5], reward: 9.897965
iter 26020. w: [0.00204022 0.02346976], solution: [0.5 0.5], reward: 9.890091
iter 26040. w: [0.0016328 0.0206189], solution: [0.5 0.5], reward: 9.915244
iter 26060. w: [0.00041194 0.02172453], solution: [0.5 0.5], reward: 9.906479
iter 26080. w: [0.0011684  0.02211583], solution: [0.5 0.5], reward: 9.902849
iter 26100. w: [0.00121562 0.02198833], solution: [0.5 0.5], reward: 9.903939
iter 26120. w: [0.0008761  0.02164308], solution: [0.5 0.5], reward: 9.907059
iter 26140. w: [-0.00185135  0.02350386], solution: [0.5 0.5], reward: 9.889920
iter 26160. w: [-0.00257008  0.02294499], solution: [0.5 0.5], reward: 9.894421
iter 26180. w: [-0.00299103  0.02234739], solution: [0.5 0.5], reward: 9.899309
iter 26200. w: [-0.00343573  0.02042208], solution: [0.5 0.5], reward: 9.915029
iter 26220. w: [-0.00340351  0.02120283], solution: [0.5 0.5], reward: 9.908644
iter 26240. w: [-0.00334165  0.02062787], solution: [0.5

iter 28080. w: [-0.00109773  0.01485196], solution: [0.5 0.5], reward: 9.956031
iter 28100. w: [-0.00100473  0.012476  ], solution: [0.5 0.5], reward: 9.968936
iter 28120. w: [0.00065022 0.01199532], solution: [0.5 0.5], reward: 9.971383
iter 28140. w: [-0.00061581  0.01277876], solution: [0.5 0.5], reward: 9.967545
iter 28160. w: [-0.00213181  0.0130451 ], solution: [0.5 0.5], reward: 9.965356
iter 28180. w: [-5.68342039e-05  1.27706885e-02], solution: [0.5 0.5], reward: 9.967661
iter 28200. w: [0.00163403 0.01118358], solution: [0.5 0.5], reward: 9.974667
iter 28220. w: [0.00196296 0.01124713], solution: [0.5 0.5], reward: 9.974150
iter 28240. w: [-0.00016697  0.01254492], solution: [0.5 0.5], reward: 9.968789
iter 28260. w: [-0.00151422  0.01266297], solution: [0.5 0.5], reward: 9.967749
iter 28280. w: [-0.00150085  0.01300509], solution: [0.5 0.5], reward: 9.966017
iter 28300. w: [0.00146337 0.00987016], solution: [0.5 0.5], reward: 9.980254
iter 28320. w: [0.00251119 0.0109707 ], 