Skip to content

Commit

Permalink
Merge pull request #155 from zuoxingdong/add_ddpg2
Browse files Browse the repository at this point in the history
update VPG
  • Loading branch information
zuoxingdong committed May 6, 2019
2 parents bd32d9e + 6a5c37a commit 8da7574
Show file tree
Hide file tree
Showing 84 changed files with 22 additions and 22 deletions.
18 changes: 9 additions & 9 deletions baselines/vpg/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from lagom import BaseAgent
from lagom.utils import pickle_dump
from lagom.utils import tensorify
from lagom.utils import numpify
from lagom.envs import flatdim
from lagom.envs.wrappers import get_wrapper
from lagom.networks import Module
Expand Down Expand Up @@ -63,8 +65,7 @@ def __init__(self, config, env, device, **kwargs):
self.lr_scheduler = linear_lr_scheduler(self.optimizer, config['train.timestep'], min_lr=1e-8)

def choose_action(self, obs, **kwargs):
if not torch.is_tensor(obs):
obs = torch.from_numpy(np.asarray(obs)).float().to(self.device)
obs = tensorify(obs, self.device)
out = {}
features = self.feature_network(obs)

Expand All @@ -74,7 +75,7 @@ def choose_action(self, obs, **kwargs):

action = action_dist.sample()
out['action'] = action
out['raw_action'] = action.detach().cpu().numpy()
out['raw_action'] = numpify(action, 'float')
out['action_logprob'] = action_dist.log_prob(action.detach())

V = self.V_head(features)
Expand All @@ -87,17 +88,17 @@ def learn(self, D, **kwargs):
entropies = [torch.cat(traj.get_all_info('entropy')) for traj in D]
Vs = [torch.cat(traj.get_all_info('V')) for traj in D]

last_observations = torch.from_numpy(np.concatenate([traj.last_observation for traj in D], 0)).float()
with torch.no_grad():
last_Vs = self.V_head(self.feature_network(last_observations.to(self.device))).squeeze(-1)
last_observations = tensorify(np.concatenate([traj.last_observation for traj in D], 0), self.device)
last_Vs = self.V_head(self.feature_network(last_observations)).squeeze(-1)
Qs = [bootstrapped_returns(self.config['agent.gamma'], traj, last_V)
for traj, last_V in zip(D, last_Vs)]
As = [gae(self.config['agent.gamma'], self.config['agent.gae_lambda'], traj, V, last_V)
for traj, V, last_V in zip(D, Vs, last_Vs)]

# Metrics -> Tensor, device
logprobs, entropies, Vs = map(lambda x: torch.cat(x).squeeze(), [logprobs, entropies, Vs])
Qs, As = map(lambda x: torch.from_numpy(np.concatenate(x).copy()).to(self.device), [Qs, As])
Qs, As = map(lambda x: tensorify(np.concatenate(x).copy(), self.device), [Qs, As])
if self.config['agent.standardize_adv']:
As = (As - As.mean())/(As.std() + 1e-8)

Expand Down Expand Up @@ -128,9 +129,8 @@ def learn(self, D, **kwargs):
out['entropy_loss'] = entropy_loss.mean().item()
out['policy_entropy'] = -entropy_loss.mean().item()
out['value_loss'] = value_loss.mean().item()
Vs_numpy = Vs.detach().cpu().numpy().squeeze()
out['V'] = describe(Vs_numpy, axis=-1, repr_indent=1, repr_prefix='\n')
out['explained_variance'] = ev(y_true=Qs.detach().cpu().numpy(), y_pred=Vs.detach().cpu().numpy())
out['V'] = describe(numpify(Vs, 'float').squeeze(), axis=-1, repr_indent=1, repr_prefix='\n')
out['explained_variance'] = ev(y_true=numpify(Qs, 'float'), y_pred=numpify(Vs, 'float'))
return out

def checkpoint(self, logdir, num_iter):
Expand Down
Binary file modified baselines/vpg/logs/default/0/1500925526/agent_1.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/1500925526/agent_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/1500925526/agent_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/1500925526/obs_moments_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/1500925526/obs_moments_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/1500925526/train_logs.pkl
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/1770966829/agent_1.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/1770966829/agent_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/1770966829/agent_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/1770966829/obs_moments_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/1770966829/obs_moments_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/1770966829/train_logs.pkl
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/2054191100/agent_1.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/2054191100/agent_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/2054191100/agent_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/2054191100/obs_moments_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/2054191100/obs_moments_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/0/2054191100/train_logs.pkl
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/1500925526/agent_1.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/1500925526/agent_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/1500925526/agent_500.pth
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"initial_reset_timestamp": 1556978394.6095722, "timestamps": [1556978402.951567], "episode_lengths": [1000], "episode_rewards": [3248.4540254425083], "episode_types": ["t", "t"]}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"stats": "openaigym.episode_batch.0.347927.stats.json", "videos": [["openaigym.video.0.347927.video000000.mp4", "openaigym.video.0.347927.video000000.meta.json"], ["openaigym.video.0.347927.video000001.mp4", "openaigym.video.0.347927.video000001.meta.json"]], "env_info": {"gym_version": "0.12.1", "env_id": "Hopper-v3"}}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"episode_id": 0, "content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 3.4.4-0ubuntu0.18.04.1 Copyright (c) 2000-2018 the FFmpeg developers\\nbuilt with gcc 7 (Ubuntu 7.3.0-16ubuntu3)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.18.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --enable-gpl --disable-stripping --enable-avresample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librubberband --enable-librsvg --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-omx --enable-openal --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-chromaprint --enable-frei0r --enable-libopencv --enable-libx264 --enable-shared\\nlibavutil 55. 78.100 / 55. 78.100\\nlibavcodec 57.107.100 / 57.107.100\\nlibavformat 57. 83.100 / 57. 83.100\\nlibavdevice 57. 10.100 / 57. 10.100\\nlibavfilter 6.107.100 / 6.107.100\\nlibavresample 3. 7. 0 / 3. 7. 0\\nlibswscale 4. 8.100 / 4. 8.100\\nlibswresample 2. 9.100 / 2. 9.100\\nlibpostproc 54. 7.100 / 54. 7.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-r", "125", "-f", "rawvideo", "-s:v", "500x500", "-pix_fmt", "rgb24", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "/home/zuo/Code/tmp/lagom/baselines/vpg/logs/default/1/1500925526/anim/openaigym.video.0.3824624.video000000.mp4"]}}
{"episode_id": 0, "content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 3.4.4-0ubuntu0.18.04.1 Copyright (c) 2000-2018 the FFmpeg developers\\nbuilt with gcc 7 (Ubuntu 7.3.0-16ubuntu3)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.18.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --enable-gpl --disable-stripping --enable-avresample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librubberband --enable-librsvg --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-omx --enable-openal --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-chromaprint --enable-frei0r --enable-libopencv --enable-libx264 --enable-shared\\nlibavutil 55. 78.100 / 55. 78.100\\nlibavcodec 57.107.100 / 57.107.100\\nlibavformat 57. 83.100 / 57. 83.100\\nlibavdevice 57. 10.100 / 57. 10.100\\nlibavfilter 6.107.100 / 6.107.100\\nlibavresample 3. 7. 0 / 3. 7. 0\\nlibswscale 4. 8.100 / 4. 8.100\\nlibswresample 2. 9.100 / 2. 9.100\\nlibpostproc 54. 7.100 / 54. 7.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-r", "125", "-f", "rawvideo", "-s:v", "500x500", "-pix_fmt", "rgb24", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "/home/zuo/Code/tmp/lagom/baselines/vpg/logs/default/1/1500925526/anim/openaigym.video.0.347927.video000000.mp4"]}}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"episode_id": 1, "content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 3.4.4-0ubuntu0.18.04.1 Copyright (c) 2000-2018 the FFmpeg developers\\nbuilt with gcc 7 (Ubuntu 7.3.0-16ubuntu3)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.18.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --enable-gpl --disable-stripping --enable-avresample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librubberband --enable-librsvg --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-omx --enable-openal --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-chromaprint --enable-frei0r --enable-libopencv --enable-libx264 --enable-shared\\nlibavutil 55. 78.100 / 55. 78.100\\nlibavcodec 57.107.100 / 57.107.100\\nlibavformat 57. 83.100 / 57. 83.100\\nlibavdevice 57. 10.100 / 57. 10.100\\nlibavfilter 6.107.100 / 6.107.100\\nlibavresample 3. 7. 0 / 3. 7. 0\\nlibswscale 4. 8.100 / 4. 8.100\\nlibswresample 2. 9.100 / 2. 9.100\\nlibpostproc 54. 7.100 / 54. 7.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-r", "125", "-f", "rawvideo", "-s:v", "500x500", "-pix_fmt", "rgb24", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "/home/zuo/Code/tmp/lagom/baselines/vpg/logs/default/1/1500925526/anim/openaigym.video.0.3824624.video000001.mp4"]}}
{"episode_id": 1, "content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 3.4.4-0ubuntu0.18.04.1 Copyright (c) 2000-2018 the FFmpeg developers\\nbuilt with gcc 7 (Ubuntu 7.3.0-16ubuntu3)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.18.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --enable-gpl --disable-stripping --enable-avresample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librubberband --enable-librsvg --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-omx --enable-openal --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-chromaprint --enable-frei0r --enable-libopencv --enable-libx264 --enable-shared\\nlibavutil 55. 78.100 / 55. 78.100\\nlibavcodec 57.107.100 / 57.107.100\\nlibavformat 57. 83.100 / 57. 83.100\\nlibavdevice 57. 10.100 / 57. 10.100\\nlibavfilter 6.107.100 / 6.107.100\\nlibavresample 3. 7. 0 / 3. 7. 0\\nlibswscale 4. 8.100 / 4. 8.100\\nlibswresample 2. 9.100 / 2. 9.100\\nlibpostproc 54. 7.100 / 54. 7.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-r", "125", "-f", "rawvideo", "-s:v", "500x500", "-pix_fmt", "rgb24", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "/home/zuo/Code/tmp/lagom/baselines/vpg/logs/default/1/1500925526/anim/openaigym.video.0.347927.video000001.mp4"]}}
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/1500925526/obs_moments_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/1500925526/obs_moments_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/1500925526/train_logs.pkl
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/1770966829/agent_1.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/1770966829/agent_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/1770966829/agent_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/1770966829/obs_moments_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/1770966829/obs_moments_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/1770966829/train_logs.pkl
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/2054191100/agent_1.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/2054191100/agent_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/2054191100/agent_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/2054191100/obs_moments_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/2054191100/obs_moments_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/1/2054191100/train_logs.pkl
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/1500925526/agent_1.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/1500925526/agent_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/1500925526/agent_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/1500925526/obs_moments_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/1500925526/obs_moments_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/1500925526/train_logs.pkl
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/1770966829/agent_1.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/1770966829/agent_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/1770966829/agent_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/1770966829/obs_moments_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/1770966829/obs_moments_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/1770966829/train_logs.pkl
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/2054191100/agent_1.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/2054191100/agent_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/2054191100/agent_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/2054191100/obs_moments_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/2054191100/obs_moments_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/2/2054191100/train_logs.pkl
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/1500925526/agent_1.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/1500925526/agent_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/1500925526/agent_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/1500925526/obs_moments_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/1500925526/obs_moments_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/1500925526/train_logs.pkl
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/1770966829/agent_1.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/1770966829/agent_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/1770966829/agent_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/1770966829/obs_moments_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/1770966829/obs_moments_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/1770966829/train_logs.pkl
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/2054191100/agent_1.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/2054191100/agent_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/2054191100/agent_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/2054191100/obs_moments_1000.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/2054191100/obs_moments_500.pth
Binary file not shown.
Binary file modified baselines/vpg/logs/default/3/2054191100/train_logs.pkl
Binary file not shown.
Binary file modified baselines/vpg/logs/default/result.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 9 additions & 9 deletions baselines/vpg/logs/default/source_files/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from lagom import BaseAgent
from lagom.utils import pickle_dump
from lagom.utils import tensorify
from lagom.utils import numpify
from lagom.envs import flatdim
from lagom.envs.wrappers import get_wrapper
from lagom.networks import Module
Expand Down Expand Up @@ -63,8 +65,7 @@ def __init__(self, config, env, device, **kwargs):
self.lr_scheduler = linear_lr_scheduler(self.optimizer, config['train.timestep'], min_lr=1e-8)

def choose_action(self, obs, **kwargs):
if not torch.is_tensor(obs):
obs = torch.from_numpy(np.asarray(obs)).float().to(self.device)
obs = tensorify(obs, self.device)
out = {}
features = self.feature_network(obs)

Expand All @@ -74,7 +75,7 @@ def choose_action(self, obs, **kwargs):

action = action_dist.sample()
out['action'] = action
out['raw_action'] = action.detach().cpu().numpy()
out['raw_action'] = numpify(action, 'float')
out['action_logprob'] = action_dist.log_prob(action.detach())

V = self.V_head(features)
Expand All @@ -87,17 +88,17 @@ def learn(self, D, **kwargs):
entropies = [torch.cat(traj.get_all_info('entropy')) for traj in D]
Vs = [torch.cat(traj.get_all_info('V')) for traj in D]

last_observations = torch.from_numpy(np.concatenate([traj.last_observation for traj in D], 0)).float()
with torch.no_grad():
last_Vs = self.V_head(self.feature_network(last_observations.to(self.device))).squeeze(-1)
last_observations = tensorify(np.concatenate([traj.last_observation for traj in D], 0), self.device)
last_Vs = self.V_head(self.feature_network(last_observations)).squeeze(-1)
Qs = [bootstrapped_returns(self.config['agent.gamma'], traj, last_V)
for traj, last_V in zip(D, last_Vs)]
As = [gae(self.config['agent.gamma'], self.config['agent.gae_lambda'], traj, V, last_V)
for traj, V, last_V in zip(D, Vs, last_Vs)]

# Metrics -> Tensor, device
logprobs, entropies, Vs = map(lambda x: torch.cat(x).squeeze(), [logprobs, entropies, Vs])
Qs, As = map(lambda x: torch.from_numpy(np.concatenate(x).copy()).to(self.device), [Qs, As])
Qs, As = map(lambda x: tensorify(np.concatenate(x).copy(), self.device), [Qs, As])
if self.config['agent.standardize_adv']:
As = (As - As.mean())/(As.std() + 1e-8)

Expand Down Expand Up @@ -128,9 +129,8 @@ def learn(self, D, **kwargs):
out['entropy_loss'] = entropy_loss.mean().item()
out['policy_entropy'] = -entropy_loss.mean().item()
out['value_loss'] = value_loss.mean().item()
Vs_numpy = Vs.detach().cpu().numpy().squeeze()
out['V'] = describe(Vs_numpy, axis=-1, repr_indent=1, repr_prefix='\n')
out['explained_variance'] = ev(y_true=Qs.detach().cpu().numpy(), y_pred=Vs.detach().cpu().numpy())
out['V'] = describe(numpify(Vs, 'float').squeeze(), axis=-1, repr_indent=1, repr_prefix='\n')
out['explained_variance'] = ev(y_true=numpify(Qs, 'float'), y_pred=numpify(Vs, 'float'))
return out

def checkpoint(self, logdir, num_iter):
Expand Down

0 comments on commit 8da7574

Please sign in to comment.