In [1]:
import gym

class ActionSamplingEnvironment:
    def __init__(self, env: gym.Env):
        self._env = env

    @property
    def metadata(self):
        return self._env.metadata
    
    @property
    def reward_range(self):
        return self._env.reward_range
    
    @property
    def spec(self):
        return self._env.spec

    @property
    def action_space(self):
        return self._env.action_space

    @property
    def observation_space(self):
        return self._env.observation_space

    def step(self, action_dist):
        """Run one timestep of the environment's dynamics. When end of
        episode is reached, you are responsible for calling `reset()`
        to reset this environment's state.

        Accepts an action and returns a tuple (observation, reward, done, info).

        Args:
            action_dist (object): a distribution over actions, provided by the agent

        Returns:
            action (object): an action sampled from the provided distribution
            observation (object): agent's observation of the current environment
            reward (float) : amount of reward returned after previous action
            done (bool): whether the episode has ended, in which case further step() calls will return undefined results
            info (dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning)
        """
        action = action_dist.sample()
        next_obs, reward, done, info = self._env.step(action.numpy())
        return action, next_obs, reward, done, info

    def reset(self):
        """Resets the state of the environment and returns an initial observation.

        Returns:
            observation (object): the initial observation.
        """
        return self._env.reset()

    def render(self, mode='human'):
        self._env.render(mode=mode)

    def close(self):
        """Override close in your subclass to perform any necessary cleanup.

        Environments will automatically close() themselves when
        garbage collected or when the program exits.
        """
        self._env.close()

    def seed(self, seed=None):
        """Sets the seed for this env's random number generator(s).

        Note:
            Some environments use multiple pseudorandom number generators.
            We want to capture all such seeds used in order to ensure that
            there aren't accidental correlations between multiple generators.

        Returns:
            list<bigint>: Returns the list of seeds used in this env's random
              number generators. The first value in the list should be the
              "main" seed, or the value which a reproducer should pass to
              'seed'. Often, the main seed equals the provided 'seed', but
              this won't be true if seed=None, for example.
        """
        return self._env.seed(seed=seed)

    @property
    def unwrapped(self):
        """Completely unwrap this env.

        Returns:
            gym.Env: The base non-wrapped gym.Env instance
        """
        return self._env.unwrapped

    def __str__(self):
        return self._env.__str__()

    def __enter__(self):
        """Support with-statement for the environment. """
        return self._env.__enter__()

    def __exit__(self, *args):
        """Support with-statement for the environment. """
        return self._env.__exit__(*args)

In [4]:
try:
    import cPickle as pickle
except:
    import pickle


class EnvironmentRecorder(ActionSamplingEnvironment):
    def __init__(self, env: gym.Env, experience_path):
        super().__init__(env)
        self._experience_path = experience_path
        self._fout = open(experience_path, 'wb')
        self._pickler = pickle.Pickler(self._fout)
        self._obs = None
    
    def reset(self):
        self._obs = super().reset()
        return self._obs
    
    def step(self, action_dist):
        action, next_obs, reward, done, info = super().step(action_dist)
        self._pickler.dump((self._obs, action_dist, action, reward, next_obs, done))
        self._obs = next_obs
        return action, next_obs, reward, done, info


In [5]:
from spinup import ppo_pytorch as ppo
import torch
import gym

env_fn = lambda : EnvironmentRecorder(gym.make('CartPole-v1'), 'cartpole_ppo_learner.pkl')

ac_kwargs = dict(hidden_sizes=[32,32], activation=torch.nn.ReLU)

logger_kwargs = dict(output_dir='cartpole', exp_name='cartpole_ppo_learner')

ppo(env_fn=env_fn, ac_kwargs=ac_kwargs, steps_per_epoch=500, epochs=100, logger_kwargs=logger_kwargs)

[32;1mLogging data to cartpole/progress.txt[0m
[36;1mSaving config:
[0m
{
    "ac_kwargs":	{
        "activation":	"ReLU",
        "hidden_sizes":	[
            32,
            32
        ]
    },
    "actor_critic":	"MLPActorCritic",
    "clip_ratio":	0.2,
    "env_fn":	"<function <lambda> at 0x7f6a0443e0d0>",
    "epochs":	100,
    "exp_name":	"cartpole_ppo_learner",
    "gamma":	0.99,
    "lam":	0.97,
    "logger":	{
        "<spinup.utils.logx.EpochLogger object at 0x7f69d384d5c0>":	{
            "epoch_dict":	{},
            "exp_name":	"cartpole_ppo_learner",
            "first_row":	true,
            "log_current_row":	{},
            "log_headers":	[],
            "output_dir":	"cartpole",
            "output_file":	{
                "<_io.TextIOWrapper name='cartpole/progress.txt' mode='w' encoding='UTF-8'>":	{
                    "mode":	"w"
                }
            }
        }
    },
    "logger_kwargs":	{
        "exp_name":	"cartpole_ppo_learner",
        "output_



---------------------------------------
|             Epoch |               0 |
|      AverageEpRet |            23.2 |
|          StdEpRet |            12.2 |
|          MaxEpRet |              57 |
|          MinEpRet |               9 |
|             EpLen |            23.2 |
|      AverageVVals |          0.0356 |
|          StdVVals |          0.0271 |
|          MaxVVals |           0.108 |
|          MinVVals |         -0.0608 |
| TotalEnvInteracts |             500 |
|            LossPi |          0.0178 |
|             LossV |             273 |
|       DeltaLossPi |         -0.0398 |
|        DeltaLossV |           -41.3 |
|           Entropy |            0.68 |
|                KL |          0.0132 |
|          ClipFrac |            0.29 |
|          StopIter |              79 |
|              Time |            1.63 |
---------------------------------------
[32;1mEarly stopping at step 57 due to reaching max kl.[0m
---------------------------------------
|             Epoch

In [6]:
fin = open('cartpole_ppo_learner.pkl', 'rb')
unpickler = pickle.Unpickler(fin)

In [7]:
obs, action_dist, action, reward, next_obs, done = unpickler.load()

In [8]:
obs

array([-0.01171707, -0.03534962, -0.04228413,  0.03623856])

In [9]:
action_dist

Categorical(probs: torch.Size([2]), logits: torch.Size([2]))

In [10]:
action

tensor(1)

In [11]:
reward

1.0

In [12]:
next_obs

array([-0.01242406,  0.16035239, -0.04155936, -0.26947989])

In [13]:
done

False