Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling entire trajectorie for a single env from RolloutStorage #4

Closed
offline-rl-neurips opened this issue Oct 13, 2020 · 2 comments
Closed

Comments

@offline-rl-neurips
Copy link

Hi, If I needed to sample entire trajectories for a single env from the RolloutStorage, what would be an easy way to do so?
For example, would it be easier to adapt the recurrent_generator -- any hints would be really appreciated (haven't dealt much with PPO, so this may be a really stupid question)!

Context: I want to randomly sample 2 trajectories (possibly from different envs) and add an auxiliary loss which depends on these two trajectories.

@agarwl
Copy link

agarwl commented Oct 13, 2020

Would this work?

class TrajStorage(object):
    def __init__(self, rollouts):
      trajs = []
      num_processes = rollouts.obs.shape[1]
      for env_index in range(num_processes):
        env_masks = rollouts.masks[:, env_index, 0]
        env_obs = rollouts.obs[:, env_index]
        env_actions = rollouts.actions[:, env_index]

        indices = np.where(1 - env_masks)
        prev_index = 0
        for index in indices:
          obs = env_obs[prev_index: index]
          actions = env_actions[prev_index: index]
          prev_index = index 
          trajs.append((obs, actions))

      self.trajs = trajs
      self.num_trajs = len(trajs)

    def sample_trajs(self, trajs):
        idx1, idx2 = np.random.randint(0, self.num_trajs, 2)
        return self.trajs[idx1], self.trajs[idx2]

@rraileanu
Copy link
Owner

Hi! Yes, the above looks right to me!

The only issue is that you might end up with some partial trajectories since the first observation in rollouts might be a continuation of a trajectory collected during the previous update (so it's not necessarily the initial observation in an environment). But I think this can be fixed by removing the first trajectory that you are adding i.e. start from prev_index = indices[0] rather than prev_index = 0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants