In [1]:
import torch
from torch import nn
import numpy as np
import gym 
import matplotlib.pyplot as plt

In [8]:
class CrossEntropyMethod(nn.Module):
    def __init__(self, name, state_dim, action_n, lr=0.01):
        super().__init__()
        self.name = name
        self.state_dim = state_dim
        self.action_n = action_n
        self.lr = lr
        
        self.network = nn.Sequential(
            nn.Linear(self.state_dim, 100),
            nn.ReLU(),
            nn.Linear(100, self.action_n)
        )
        
        self.softmax = nn.Softmax()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        self.loss = nn.CrossEntropyLoss()
#         self.policy = np.ones((self.state_n, self.action_n)) / self.action_n
    
    def forward(self, _input):
        return self.network(_input)
    
    def get_action(self, state):
        state = torch.FloatTensor(state)
        logits = self.network(state)
        action_prob = self.softmax(logits).detach().numpy()
        action = np.random.choice(self.action_n, p=action_prob)
        return action

    def update_policy(self, elite_trajectories):
        elite_states = []
        elite_actions = []
        for trajectory in elite_trajectories:
            elite_states.extend(trajectory['states'])
            elite_actions.extend(trajectory['actions'])

        elite_states = torch.FloatTensor(elite_states)
        elite_actions = torch.LongTensor(elite_actions)
        
        loss = self.loss(self.forward(elite_states), elite_actions)
        loss.backward()
        
        self.optimizer.step()
        self.optimizer.zero_grad()

    def save_policy(self):
        np.save(f'{env.spec.id}-{self.name}', self.policy)
    def load_policy(self):
        fname = f'{env.spec.id}-{self.name}.npy'
        if os.path.exists(fname):
            self.policy = np.load(fname)

In [19]:
def get_trajectory(env, agent, trajectory_len, visualize=False):
    trajectory = {
        'states':[], 
        'actions':[],
        'total_reward': 0}
    state = env.reset()
    trajectory['states'].append(state)
    for _ in range(trajectory_len):
        action = agent.get_action(state)
        trajectory['actions'].append(action)
        state, reward, done, _ = env.step(action)
        trajectory['total_reward'] += reward
        if done:
            break

        if visualize:
            env.render()

        trajectory['states'].append(state)
    return trajectory

In [15]:
def get_elite_trajectories(trajectories, q_param):
    total_rewards = [trajectory['total_reward'] for trajectory in trajectories]
    quantile = np.quantile(total_rewards, q=q_param) 
    return [trajectory for trajectory in trajectories if trajectory['total_reward'] > quantile]

In [16]:
env = gym.make('CartPole-v1')

In [17]:
state_dim = 4
action_n = 2

agent = CrossEntropyMethod('test', state_dim, action_n=action_n)
episode_n = 100
trajectory_n = 20
trajectory_len = 500
q_param = 0.8

In [18]:
for _ in range(episode_n):
    trajectories = [get_trajectory(env, agent, trajectory_len) for _ in range(trajectory_n)]
    
    mean_total_reward = np.mean([trajectory['total_reward'] for trajectory in trajectories])
    print(mean_total_reward)
    
    elite_trajectories = get_elite_trajectories(trajectories, q_param)
    
    if len(elite_trajectories)>0:
        agent.update_policy(elite_trajectories)



20.25
24.4
27.75
26.65
34.15
35.25
45.0
43.85
41.1
59.75
50.85
47.05
46.15
64.65
51.85
61.05
60.35
59.65
54.6
52.55
68.8
65.1
76.4
72.4
89.65
82.1
68.1
71.4
65.2
74.75
74.95
73.6
76.15
94.0
88.5
93.5
107.4
96.6
102.8
84.4
88.6
92.5
99.95
95.65
93.85
103.0
110.55
118.65
147.9
153.1
176.4
231.25
206.6
220.15
192.5
234.3
252.7
279.95
297.0
294.05
357.4
379.6
350.55
333.7
341.2
332.45
385.7
337.95
425.8
384.4
401.75
374.85
370.85
384.05
386.65
403.35
434.15
361.95
404.1
384.8
388.55
378.05
399.3
419.2
358.45
390.25
410.9
346.3
420.7
392.3
418.0
431.65
388.2
336.65
390.15
399.25
399.9
386.2
450.05
378.75


In [23]:
tjs = get_trajectory(env, agent, trajectory_len, trajectory_n)

If you want to render in human mode, initialize the environment in this way: gym.make('EnvName', render_mode='human') and don't call the render method.
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  "You are calling render method, "


In [22]:

tjs

{'states': [array([ 0.02262811,  0.04225926, -0.00265484,  0.00147396], dtype=float32),
  array([ 0.02347329,  0.23741919, -0.00262536, -0.2920454 ], dtype=float32),
  array([ 0.02822168,  0.04233476, -0.00846627, -0.00019165], dtype=float32),
  array([ 0.02906837, -0.15266475, -0.0084701 ,  0.2898081 ], dtype=float32),
  array([ 0.02601507,  0.04257695, -0.00267394, -0.00553413], dtype=float32),
  array([ 0.02686661,  0.23773715, -0.00278462, -0.2990595 ], dtype=float32),
  array([ 0.03162136,  0.04265499, -0.00876581, -0.00725611], dtype=float32),
  array([ 0.03247446, -0.15234016, -0.00891093,  0.28264827], dtype=float32),
  array([ 0.02942766,  0.04290776, -0.00325797, -0.01283176], dtype=float32),
  array([ 0.03028581, -0.15216732, -0.0035146 ,  0.27882147], dtype=float32),
  array([ 0.02724246,  0.04300459,  0.00206183, -0.01496789], dtype=float32),
  array([ 0.02810255, -0.15214686,  0.00176247,  0.27836487], dtype=float32),
  array([ 0.02505962,  0.0429499 ,  0.00732976, -0.013