# Baseline is taken from https://github.com/alirezamika/bipedal-es

Нужно переписать на стратегию актор-критик

In [18]:
# rewritten

import numpy as np


class Model(object):

    def __init__(self):
        self.actor_weights = [np.zeros(shape=(24, 16)), np.zeros(shape=(16, 16)), np.zeros(shape=(16, 4))]
        self.critic_weights = [np.zeros(shape=(24, 12)), np.zeros(shape=(12, 4)), np.zeros(shape=(4, 1))]

    def predict(self, inp):
        #print(inp.shape)
        out = np.expand_dims(inp.flatten(), 0)
        out = out / np.linalg.norm(out)
        for layer in self.actor_weights:
            out = np.dot(out, layer)
#         noise = np.expand_dims(inp.flatten(), 0)
#         for layer in self.critic_weights:
#             noise = np.dot(noise, layer)
#         out = out + np.random.normal(size=out.shape) * noise
        return out[0]

    def get_weights(self):
        return np.stack([self.actor_weights, self.critic_weights], axis=0)

    def set_weights(self, weights):
        self.actor_weights = weights[0]
        self.critic_weights = weights[1]

In [19]:
np.stack([np.ones((1,)), np.zeros((1,))], axis=0)

array([[ 1.],
       [ 0.]])

In [23]:
import random
import cPickle as pickle
import numpy as np
from evostra import EvolutionStrategy
import gym


class Agent:

    AGENT_HISTORY_LENGTH = 1
    POPULATION_SIZE = 20
    EPS_AVG = 1
    SIGMA = 0.1
    LEARNING_RATE = 0.01
    INITIAL_EXPLORATION = 1.0
    FINAL_EXPLORATION = 0.0
    EXPLORATION_DEC_STEPS = 1000000

    def __init__(self):
        self.env = gym.make('BipedalWalker-v2')
        self.model = Model()
        self.es = EvolutionStrategy(self.model.get_weights(), self.get_reward, self.POPULATION_SIZE, self.SIGMA, self.LEARNING_RATE)
        self.exploration = self.INITIAL_EXPLORATION


    def get_predicted_action(self, sequence):
        prediction = self.model.predict(np.array(sequence))
        return prediction


    def load(self, filename='weights.pkl'):
        with open(filename,'rb') as fp:
            self.model.set_weights(pickle.load(fp))
        self.es.weights = self.model.get_weights()


    def save(self, filename='weights.pkl'):
        with open(filename, 'wb') as fp:
            pickle.dump(self.es.get_weights(), fp)


    def play(self, episodes, render=True):
        self.model.set_weights(self.es.weights)
        for episode in xrange(episodes):
            total_reward = 0
            observation = self.env.reset()
            done = False
            while not done:
                if render:
                    self.env.render()
                action = self.get_predicted_action(observation)
                observation, reward, done, _ = self.env.step(action)
                total_reward += reward
            print "total reward:", total_reward


    def train(self, iterations):
        self.es.run(iterations, print_step=1)


    def get_reward(self, weights):
        total_reward = 0.0
        self.model.set_weights(weights)

        for episode in xrange(self.EPS_AVG):
            observation = self.env.reset()
            done = False
            while not done:
                self.exploration = max(self.FINAL_EXPLORATION, self.exploration - self.INITIAL_EXPLORATION/self.EXPLORATION_DEC_STEPS)
                if random.random() < self.exploration:
                    action = self.env.action_space.sample()
                else:
                    action = self.get_predicted_action(observation)
                observation, reward, done, _ = self.env.step(action)
                total_reward += reward
                #print(sequence)

        return total_reward/self.EPS_AVG

In [24]:
agent = Agent()

In [25]:
agent.train(1000)

iter 0. reward: -98.749470
iter 1. reward: -76.182088
iter 2. reward: -101.655767
iter 3. reward: -108.785070
iter 4. reward: -112.002161
iter 5. reward: -112.182596
iter 6. reward: -124.277723
iter 7. reward: -97.477678
iter 8. reward: -71.123917
iter 9. reward: -67.063837
iter 10. reward: -102.752797
iter 11. reward: -101.662820
iter 12. reward: -102.387277
iter 13. reward: -117.643567
iter 14. reward: -71.019267
iter 15. reward: -101.076937
iter 16. reward: -64.083642
iter 17. reward: -100.204420
iter 18. reward: -99.687220
iter 19. reward: -99.165956
iter 20. reward: -104.743435
iter 21. reward: -63.409351
iter 22. reward: -106.851126
iter 23. reward: -105.094320
iter 24. reward: -54.989421
iter 25. reward: -53.593474
iter 26. reward: -98.138456
iter 27. reward: -109.923130
iter 28. reward: -102.149149
iter 29. reward: -99.059544
iter 30. reward: -96.836319
iter 31. reward: -99.764193
iter 32. reward: -118.369020
iter 33. reward: -114.254315
iter 34. reward: -115.902191
iter 35. re

iter 285. reward: -92.046187
iter 286. reward: -91.810150
iter 287. reward: -91.847225
iter 288. reward: -3.645253
iter 289. reward: -22.671050
iter 290. reward: -20.484860
iter 291. reward: -31.905810
iter 292. reward: -40.991029
iter 293. reward: -27.417733
iter 294. reward: -39.710099
iter 295. reward: -25.653333
iter 296. reward: -14.576693
iter 297. reward: -4.702676
iter 298. reward: -10.528379
iter 299. reward: -3.705623
iter 300. reward: -13.349223
iter 301. reward: -14.320683
iter 302. reward: -13.109111
iter 303. reward: -91.854850
iter 304. reward: -91.852677
iter 305. reward: -99.656133
iter 306. reward: -92.017838
iter 307. reward: -91.998660
iter 308. reward: -91.990054
iter 309. reward: -99.854183
iter 310. reward: -111.956172
iter 311. reward: -99.852021
iter 312. reward: -99.747368
iter 313. reward: -91.893680
iter 314. reward: -91.839363
iter 315. reward: -91.923892
iter 316. reward: -91.900967
iter 317. reward: -2.487101
iter 318. reward: -91.951752
iter 319. reward:

iter 567. reward: -91.660883
iter 568. reward: -98.946684
iter 569. reward: -111.659142
iter 570. reward: -91.826987
iter 571. reward: -91.978096
iter 572. reward: -92.005530
iter 573. reward: -3.381182
iter 574. reward: -91.691942
iter 575. reward: -3.695362
iter 576. reward: -91.965670
iter 577. reward: -92.015316
iter 578. reward: -91.856121
iter 579. reward: -91.871765
iter 580. reward: -99.905950
iter 581. reward: -112.298527
iter 582. reward: -93.763726
iter 583. reward: -91.914047
iter 584. reward: -100.308228
iter 585. reward: -111.723059
iter 586. reward: -101.050483
iter 587. reward: -92.146750
iter 588. reward: -93.382759
iter 589. reward: -91.992268
iter 590. reward: -92.061249
iter 591. reward: -99.726518
iter 592. reward: -91.863340
iter 593. reward: -93.603705
iter 594. reward: -93.762032
iter 595. reward: -92.057775
iter 596. reward: -91.979203
iter 597. reward: -92.103903
iter 598. reward: -91.790628
iter 599. reward: -91.991294
iter 600. reward: -91.907367
iter 601. r

KeyboardInterrupt: 

In [27]:
agent.play(1)

ArgumentError: argument 2: <type 'exceptions.TypeError'>: wrong type