In [1]:
import numpy as np
from typing import Callable

In [2]:
def reward(params):
    return -(np.power(params[0] - 1.0, 2) + 0.5*np.power(params[1] - 2.0, 2) + 0.25*np.power(params[2] + 1.0, 2))

In [4]:
def es_fit(reward: Callable, num_params: int, generations: int, num_populations: int, learning_rate: float, std_dev: float, seed: int = None):
    if seed != None:
        np.random.seed(seed)
    params = np.zeros(num_params)
    for generation in range(1, generations + 1):
        new_params = np.random.normal(params, std_dev, (num_populations, num_params))
        rewards = np.array([reward(new_param) for new_param in new_params])
        advantages = (rewards - np.mean(rewards)) / np.std(rewards)
        pot_params = params + learning_rate/(num_populations * std_dev**2) * np.dot(new_params.T, advantages)
        if reward(pot_params) > reward(params):
            params = pot_params
    return params


In [5]:
params_es = es_fit(reward, 3, 200, 64, 0.01, 0.1, 42)
reward_es = reward(params_es)
print(f"Estimated params: {params_es}, reward: {reward_es}")

Generation 1 - Params: [ 0.04566448  0.06538532 -0.03684726], reward: -3.014039066072064
Generation 2 - Params: [ 0.10288932  0.14198564 -0.06712199], reward: -2.7484816008405013
Generation 3 - Params: [ 0.17313423  0.22485549 -0.0831433 ], reward: -2.469432572159323
Generation 4 - Params: [ 0.23292273  0.29916218 -0.11460732], reward: -2.230812233234545
Generation 5 - Params: [ 0.28794425  0.37031912 -0.13292347], reward: -2.0229087110758903
Generation 6 - Params: [ 0.3521767   0.43413433 -0.14389151], reward: -1.828873119428116
Generation 7 - Params: [ 0.396962    0.51205119 -0.1392606 ], reward: -1.6558687501162785
Generation 8 - Params: [ 0.44904523  0.58064013 -0.13860682], reward: -1.4963419289081512
Generation 9 - Params: [ 0.50218648  0.66057864 -0.17284738], reward: -1.3158884483265796
Generation 10 - Params: [ 0.55591233  0.73711911 -0.18990651], reward: -1.1587107876108809
Generation 11 - Params: [ 0.6147349   0.81275674 -0.20585403], reward: -1.0108694283994508
Generation 1