Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset based Trainer #54

Open
redknightlois opened this issue Apr 29, 2019 · 5 comments
Open

Dataset based Trainer #54

redknightlois opened this issue Apr 29, 2019 · 5 comments

Comments

@redknightlois
Copy link
Contributor

redknightlois commented Apr 29, 2019

This example dataset based trainer also does expert signal recollection, so that is why I didnt do a PR, will let it to you to decide which parts make sense for rlkit.

class OptimizedBatchRLAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta):
    def __init__(
            self,
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            max_num_steps_before_training=1e5,
            expert_data_collector: PathCollector = None,
    ):
        super().__init__(
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )

        assert isinstance(replay_buffer, Dataset), "The replay buffers must be compatible with Pytorch Dataset to use this version."

        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.max_num_steps_before_training = max_num_steps_before_training
        self.expert_data_collector = expert_data_collector

    def _train(self):
        if self.min_num_steps_before_training > 0:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )

            self.replay_buffer.add_paths(init_expl_paths)

            self.expert_data_collector.end_epoch(-1)
            self.expl_data_collector.end_epoch(-1)

        if self.expert_data_collector is not None:
            new_expl_paths = self.expert_data_collector.collect_new_paths(
                self.max_path_length,
                min(int(self.replay_buffer.max_buffer_size * 0.5), self.max_num_steps_before_training),
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(new_expl_paths)

        dataset_loader = torch.utils.data.DataLoader(self.replay_buffer, pin_memory=True, batch_size=self.batch_size, num_workers=0)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            printout('Evaluation sampling')
            self.eval_data_collector.collect_new_paths(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
                discard_incomplete_paths=True,
            )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):

                printout('Exploration sampling')
                new_expl_paths = self.expl_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_expl_steps_per_train_loop,
                    discard_incomplete_paths=False,
                )
                gt.stamp('exploration sampling', unique=False)

                self.replay_buffer.add_paths(new_expl_paths)
                gt.stamp('data storing', unique=False)

                self.training_mode(True)

                i = 0
                with tqdm(total=self.num_trains_per_train_loop) as pbar:
                    while True:

                        for _, data in enumerate(dataset_loader, 0):
                            if i > self.num_trains_per_train_loop:
                                break  # We are done

                            observations = data[0].to(ptu.device)
                            actions = data[1].to(ptu.device)
                            rewards = data[2].to(ptu.device)
                            terminals = data[3].to(ptu.device).float()
                            next_observations = data[4].to(ptu.device)
                            env_infos = data[5]

                            train_data = dict(
                                observations=observations,
                                actions=actions,
                                rewards=rewards,
                                terminals=terminals,
                                next_observations=next_observations,
                            )

                            for key in env_infos.keys():
                                train_data[key] = env_infos[key]

                            self.trainer.train(train_data)
                            pbar.update(1)
                            i += 1

                        if i > self.num_trains_per_train_loop:
                            break

                gt.stamp('training', unique=False)
                self.training_mode(False)

                if isinstance(self.expl_data_collector, AtariPathCollectorWithEmbedder):
                    eval_policy = self.eval_data_collector.get_snapshot()['policy']
                    self.expl_data_collector.evaluate(eval_policy)

            self._end_epoch(epoch)

    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)
@nm-narasimha
Copy link

Thanks for defining this class.. Can you share an example how to use this trainer class along with DDPG and SAC?

@redknightlois
Copy link
Contributor Author

Standard examples show how to do that. There is no difference between the current and this one. I use #52 for dataset size reasons though, but for the rest is pretty straightforward.

@nm-narasimha
Copy link

Thanks.. @redknightlois , do you have a sample replay_buffer compatable with pytorch dataset class?
Is env_replay_buffer or any other class in rlkit.data_management is compatable?

Thanks,
Narasimha

@redknightlois
Copy link
Contributor Author

#52 is a pytorch dataset class.

@vitchyr
Copy link
Collaborator

vitchyr commented Jun 11, 2019

Hmmm, so it looks like the main difference is the addition of expert_data_collector. Is that correct? In that case, I'm not sure if we need to create an entirely new class for this. One option would be to add that data to the replay buffer before passing the replay buffer to the algorithm. What do you think of that? It would help separate out the algorithm from the pretraining phase.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants