-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
23482e6
commit 1203bec
Showing
5 changed files
with
99 additions
and
143 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,92 +1,60 @@ | ||
from time import perf_counter | ||
import time | ||
from itertools import count | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from lagom import Logger | ||
from lagom import BaseEngine | ||
from lagom.transform import describe | ||
from lagom.utils import color_str | ||
from lagom.envs.wrappers import get_wrapper | ||
|
||
|
||
class Engine(BaseEngine): | ||
def train(self, n=None, **kwargs): | ||
train_logs = [] | ||
eval_logs = [] | ||
eval_togo = 0 | ||
dump_togo = 0 | ||
num_episode = 0 | ||
train_logs, eval_logs = [], [] | ||
checkpoint_count = 0 | ||
observation, _ = self.env.reset() | ||
for i in count(): | ||
if i >= self.config['train.timestep']: | ||
for iteration in count(): | ||
if self.agent.total_timestep >= self.config['train.timestep']: | ||
break | ||
if i < self.config['replay.init_size']: | ||
action = [self.env.action_space.sample()] | ||
t0 = time.perf_counter() | ||
|
||
if iteration < self.config['replay.init_trial']: | ||
[traj] = self.runner(self.random_agent, self.env, 1) | ||
else: | ||
action = self.agent.choose_action(observation, mode='stochastic')['action'] | ||
next_observation, reward, step_info = self.env.step(action) | ||
eval_togo += 1 | ||
dump_togo += 1 | ||
if step_info[0].last: # [0] due to single environment | ||
start_time = perf_counter() | ||
self.replay.add(observation[0], action[0], reward[0], step_info[0]['last_observation'], step_info[0].terminal) | ||
[traj] = self.runner(self.agent, self.env, 1, mode='train') | ||
self.replay.add(traj) | ||
# Number of gradient updates = collected episode length | ||
out_agent = self.agent.learn(D=None, replay=self.replay, T=traj.T) | ||
|
||
logger = Logger() | ||
logger('train_iteration', iteration+1) | ||
logger('num_seconds', round(time.perf_counter() - t0, 1)) | ||
[logger(key, value) for key, value in out_agent.items()] | ||
logger('episode_return', sum(traj.rewards)) | ||
logger('episode_horizon', traj.T) | ||
logger('accumulated_trained_timesteps', self.agent.total_timestep) | ||
train_logs.append(logger.logs) | ||
if iteration == 0 or (iteration+1) % self.config['log.freq'] == 0: | ||
logger.dump(keys=None, index=0, indent=0, border='-'*50) | ||
if self.agent.total_timestep >= int(self.config['train.timestep']*(checkpoint_count/(self.config['checkpoint.num'] - 1))): | ||
self.agent.checkpoint(self.logdir, iteration + 1) | ||
checkpoint_count += 1 | ||
|
||
# updates in the end of episode, for each time step | ||
out_agent = self.agent.learn(D=None, replay=self.replay, episode_length=step_info[0]['episode']['horizon']) | ||
num_episode += 1 | ||
if (i+1) >= int(self.config['train.timestep']*(checkpoint_count/(self.config['checkpoint.num'] - 1))): | ||
self.agent.checkpoint(self.logdir, num_episode) | ||
checkpoint_count += 1 | ||
logger = Logger() | ||
logger('num_seconds', round(perf_counter() - start_time, 1)) | ||
logger('accumulated_trained_timesteps', i + 1) | ||
logger('accumulated_trained_episodes', num_episode) | ||
[logger(key, value) for key, value in out_agent.items()] | ||
logger('episode_return', step_info[0]['episode']['return']) | ||
logger('episode_horizon', step_info[0]['episode']['horizon']) | ||
train_logs.append(logger.logs) | ||
if dump_togo >= self.config['log.freq']: | ||
dump_togo %= self.config['log.freq'] | ||
logger.dump(keys=None, index=0, indent=0, border='-'*50) | ||
if eval_togo >= self.config['eval.freq']: | ||
eval_togo %= self.config['eval.freq'] | ||
eval_logs.append(self.eval(accumulated_trained_timesteps=(i+1), | ||
accumulated_trained_episodes=num_episode)) | ||
else: | ||
self.replay.add(observation[0], action[0], reward[0], next_observation[0], step_info[0].terminal) | ||
observation = next_observation | ||
if checkpoint_count < self.config['checkpoint.num']: | ||
self.agent.checkpoint(self.logdir, num_episode) | ||
checkpoint_count += 1 | ||
if self.agent.total_timestep >= int(self.config['train.timestep']*(len(eval_logs)/(self.config['eval.num'] - 1))): | ||
eval_logs.append(self.eval(n=len(eval_logs))) | ||
return train_logs, eval_logs | ||
|
||
def eval(self, n=None, **kwargs): | ||
start_time = perf_counter() | ||
returns = [] | ||
horizons = [] | ||
for _ in range(self.config['eval.num_episode']): | ||
observation = self.eval_env.reset() | ||
for _ in range(self.eval_env.spec.max_episode_steps): | ||
with torch.no_grad(): | ||
action = self.agent.choose_action(observation, mode='eval')['action'] | ||
next_observation, reward, done, info = self.eval_env.step(action) | ||
if done[0]: # [0] single environment | ||
returns.append(info[0]['episode']['return']) | ||
horizons.append(info[0]['episode']['horizon']) | ||
break | ||
observation = next_observation | ||
logger = Logger() | ||
logger('num_seconds', round(perf_counter() - start_time, 1)) | ||
logger('accumulated_trained_timesteps', kwargs['accumulated_trained_timesteps']) | ||
logger('accumulated_trained_episodes', kwargs['accumulated_trained_episodes']) | ||
logger('online_return', describe(returns, axis=-1, repr_indent=1, repr_prefix='\n')) | ||
logger('online_horizon', describe(horizons, axis=-1, repr_indent=1, repr_prefix='\n')) | ||
t0 = time.perf_counter() | ||
with torch.no_grad(): | ||
D = self.runner(self.agent, self.eval_env, 10, mode='eval') | ||
|
||
monitor_env = get_wrapper(self.eval_env, 'VecMonitor') | ||
logger('running_return', describe(monitor_env.return_queue, axis=-1, repr_indent=1, repr_prefix='\n')) | ||
logger('running_horizon', describe(monitor_env.horizon_queue, axis=-1, repr_indent=1, repr_prefix='\n')) | ||
logger = Logger() | ||
logger('eval_iteration', n+1) | ||
logger('num_seconds', round(time.perf_counter() - t0, 1)) | ||
logger('accumulated_trained_timesteps', self.agent.total_timestep) | ||
logger('online_return', describe([sum(traj.rewards) for traj in D], axis=-1, repr_indent=1, repr_prefix='\n')) | ||
logger('online_horizon', describe([traj.T for traj in D], axis=-1, repr_indent=1, repr_prefix='\n')) | ||
logger('running_return', describe(self.eval_env.return_queue, axis=-1, repr_indent=1, repr_prefix='\n')) | ||
logger('running_horizon', describe(self.eval_env.horizon_queue, axis=-1, repr_indent=1, repr_prefix='\n')) | ||
logger.dump(keys=None, index=0, indent=0, border=color_str('+'*50, color='green')) | ||
return logger.logs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.