In [1]:
import multiprocessing as mp
import numpy as np
import gym
env = gym.make('Marvin-v0')

In [2]:
class NN:
    def __init__(self, layer_sizes, seed=None):
        if seed is not None:
            np.random.seed(seed)
        self.weights = [np.zeros((m, n)) * 1e-3 for m, n in zip(layer_sizes[1:], layer_sizes)]
    
    def predict(self, X):
        out = X
        for W in self.weights:
            Z = out @ W.T
            out = np.tanh(Z)
        if out.shape[0] == 1 and len(out.shape) == 1:
            return out.item()
        return out

    def set_weights(self, weights, copy=False):
        if copy:
            self.weights = [np.copy(l) for l in weights]
        else:
            self.weights = weights
        
    def get_weights(self, copy=False):
        if copy:
            return [np.copy(l) for l in self.weights]
        return self.weights
    
    def sample_like(self, sigma=1.0):
        return [np.random.randn(*l.shape) * sigma for l in self.weights]

In [21]:
def sample_like(weights, sigma=1):
    """
    Create a sample of the same shapes as the input
    @param weights: list of np.arrays
    """
    return [np.random.randn(*l.shape) * sigma for l in weights]

def combine_weights(params, delta_params, sigma):
    return [W + dW * sigma for W, dW in zip(params, delta_params)]
    
def update_params(params, population, rewards, lr=0.05, sigma=0.1):
    """
    Inplace update of parameters
    """
    n = len(population)
    for i in range(len(params)):
        W = params[i]
        dW_accum = np.zeros_like(W)

        for candidate, reward in zip(population, rewards):
            dW = candidate[0][i]
            dW_accum += reward * dW

        W_new = W + lr / (n * sigma) * dW_accum
        params[i] = W_new

    return params

def evaluate_weights(weights):
    global nn, env
    
    nn.set_weights(weights)

    observation = env.reset()
    done = False
    i = 0
    r_sum = 0
    while not done and i < 1500:
        action = nn.predict(observation)
        observation, reward, done, _ = env.step(action)
        i += 1
        r_sum += reward
    return r_sum

class ESMPSolver:
    def __init__(self, model, environment, population_size=30, max_episode_len=1500,
                 lr=0.05, lr_decay=0.999, sigma=0.1, verbose=False):
        self.model = model
        self.env = environment
        self.population_size = population_size
        self.max_episode_len = max_episode_len
        self.lr = lr
        self.lr_decay = lr_decay
        self.sigma = sigma
        self.verbose = verbose
        self.pool = mp.Pool(mp.cpu_count())
    
    def solve(self, weights=None, fitness_fn=None, n_generations=100, seed=None):
        """
        If weights is none, simple MLP is assumed, otherwise this should be the list of weights matrices from some model
        """
        if weights is None:
            weights = self.model.get_weights(copy=True)
        if fitness_fn is None:
            fitness_fn = self.evaluate_model

        if seed is not None:
            np.random.seed(seed)

        lr = self.lr
        for generation in range(n_generations):
    
            population = [[sample_like(weights)] for _ in range(self.population_size)]
            rewards = self.pool.starmap(evaluate_weights, population)

            rewards = np.array(rewards)
            r_mean, r_std = rewards.mean(), rewards.std()
            rewards = (rewards - r_mean) / r_std
            
            update_params(weights, population, rewards, lr=lr, sigma=self.sigma)
        
        
            lr = lr * self.lr_decay
            if self.verbose and (generation % int(self.verbose) == 0):
                print(f'[{generation}]: E[R]={r_mean:.4f}, std(R)={r_std:.4f} | lr={lr:.4f}')
        return weights


In [24]:
np.random.seed(42)
nn = NN([24, 24, 4])
es = ESMPSolver(nn, env, population_size=50, verbose=5)
weights = es.solve(n_generations=40)

[0]: E[R]=-123.8335, std(R)=20.7160 | lr=0.0500
[5]: E[R]=-113.8330, std(R)=36.1704 | lr=0.0497
[10]: E[R]=-123.9507, std(R)=22.4088 | lr=0.0495
[15]: E[R]=-121.3969, std(R)=16.5678 | lr=0.0492
[20]: E[R]=-123.3158, std(R)=18.7304 | lr=0.0490
[25]: E[R]=-127.4941, std(R)=24.0309 | lr=0.0487
[30]: E[R]=-129.4559, std(R)=24.4919 | lr=0.0485
[35]: E[R]=-128.5757, std(R)=21.2876 | lr=0.0482
