Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gymfcpp/black_jack_env.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "gymfcpp/black_jack_env.h"
#include "gymfcpp/config.h"
#include "gymfcpp/gymfcpp_config.h"

#ifdef GYMFCPP_DEBUG
#include <cassert>
Expand Down
2 changes: 1 addition & 1 deletion src/gymfcpp/cliff_world_env.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "gymfcpp/cliff_world_env.h"
#include "gymfcpp/names_generator.h"
#include "gymfcpp/config.h"
#include "gymfcpp/gymfcpp_config.h"

#ifdef GYMFCPP_DEBUG
#include <cassert>
Expand Down
6 changes: 0 additions & 6 deletions src/gymfcpp/config.h

This file was deleted.

2 changes: 1 addition & 1 deletion src/gymfcpp/frozen_lake_env.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "gymfcpp/frozen_lake_env.h"
#include "gymfcpp/names_generator.h"
#include "gymfcpp/config.h"
#include "gymfcpp/gymfcpp_config.h"

#include <boost/python.hpp>

Expand Down
169 changes: 104 additions & 65 deletions src/gymfcpp/serial_vector_env_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,43 +32,24 @@ struct SerialVectorEnvWrapperConfig
uint_t seed{42};
};

namespace {

///
///
///
template<typename EnvType, typename StateAdaptor=void>
class SerialVectorEnvWrapper: private boost::noncopyable
template<typename EnvType>
class serial_vector_env_wrapper_base: private boost::noncopyable
{
public:

typedef VectorTimeStep<typename EnvType::state_type, StateAdaptor> time_step_type;

///
/// \brief VectorEnvWrapper
/// \param n_copies
/// \param version
///
SerialVectorEnvWrapper(SerialVectorEnvWrapperConfig config, obj_t p_namespace);

/// \brief serial_vector_env_wrapper_base
/// \param config
/// \param p_namespace
///
/// \brief ~VectorEnvWrapper
///
~SerialVectorEnvWrapper();

///
/// \brief make
///
void make();

///
///
///
time_step_type reset();
serial_vector_env_wrapper_base(SerialVectorEnvWrapperConfig config, obj_t p_namespace);

///
///
///
time_step_type step(const std::vector<typename EnvType::action_type>& actions);
~serial_vector_env_wrapper_base();

///
/// \brief close. Close down the environment
Expand All @@ -81,12 +62,7 @@ class SerialVectorEnvWrapper: private boost::noncopyable
///
uint_t n_copies()const noexcept{return config_.n_copies;}

private:

///
/// \brief current_state
///
time_step_type current_state_;
protected:

///
/// \brief p_namespace_ The main Python namespace
Expand All @@ -110,11 +86,9 @@ class SerialVectorEnvWrapper: private boost::noncopyable

};


template<typename EnvType, typename StateAdaptor>
SerialVectorEnvWrapper<EnvType, StateAdaptor>::SerialVectorEnvWrapper(SerialVectorEnvWrapperConfig config, obj_t p_namespace)
template<typename EnvType>
serial_vector_env_wrapper_base<EnvType>::serial_vector_env_wrapper_base(SerialVectorEnvWrapperConfig config, obj_t p_namespace)
:
current_state_(),
p_namespace_(p_namespace),
envs_(),
config_(config),
Expand All @@ -128,81 +102,146 @@ SerialVectorEnvWrapper<EnvType, StateAdaptor>::SerialVectorEnvWrapper(SerialVect

}

template<typename EnvType, typename StateAdaptor>
SerialVectorEnvWrapper<EnvType, StateAdaptor>::~SerialVectorEnvWrapper(){
template<typename EnvType>
serial_vector_env_wrapper_base<EnvType>::~serial_vector_env_wrapper_base(){

close();
}

template<typename EnvType, typename StateAdaptor>
template<typename EnvType>
void
SerialVectorEnvWrapper<EnvType, StateAdaptor>::close(){
serial_vector_env_wrapper_base<EnvType>::close(){

for(uint_t t=0; t < config_.n_copies; ++t){
envs_[t]->close();
}
}

template<typename EnvType, typename StateAdaptor>
typename SerialVectorEnvWrapper<EnvType, StateAdaptor>::time_step_type
SerialVectorEnvWrapper<EnvType, StateAdaptor>::reset(){
}



template<typename EnvType>
class SerialVectorEnvWrapper: protected serial_vector_env_wrapper_base<EnvType>
{
public:

typedef VectorTimeStep<typename EnvType::state_type> time_step_type;

///
/// \brief VectorEnvWrapper
/// \param n_copies
/// \param version
///
SerialVectorEnvWrapper(SerialVectorEnvWrapperConfig config, obj_t p_namespace);

///
/// \brief ~VectorEnvWrapper
///
~SerialVectorEnvWrapper()=default;

using serial_vector_env_wrapper_base<EnvType>::n_copies;
using serial_vector_env_wrapper_base<EnvType>::close;

///
/// \brief make
///
void make();

///
///
///
time_step_type reset();

///
///
///
time_step_type step(const std::vector<typename EnvType::action_type>& actions);

private:

///
/// \brief current_state
///
time_step_type current_state_;

};

template<typename EnvType>
SerialVectorEnvWrapper<EnvType>::SerialVectorEnvWrapper(SerialVectorEnvWrapperConfig config, obj_t p_namespace)
:
serial_vector_env_wrapper_base<EnvType>(config, p_namespace),
current_state_()
{}

template<typename EnvType>
typename SerialVectorEnvWrapper<EnvType>::time_step_type
SerialVectorEnvWrapper<EnvType>::reset(){

#ifdef GYMFCPP_DEBUG
assert(is_created_ && "Environment has not been created");
assert(this->is_created_ && "Environment has not been created");
#endif

current_state_.clear();
current_state_.reserve(config_.n_copies);
current_state_.reserve(this->config_.n_copies);

for(uint_t env=0; env<envs_.size(); ++env){
auto time_step = envs_[env]->reset();
for(uint_t env=0; env<this->envs_.size(); ++env){
auto time_step = this->envs_[env]->reset();
current_state_.add_time_step(time_step);
}

return current_state_;
}

template<typename EnvType, typename StateAdaptor>
typename SerialVectorEnvWrapper<EnvType, StateAdaptor>::time_step_type
SerialVectorEnvWrapper<EnvType, StateAdaptor>::step(const std::vector<typename EnvType::action_type>& actions){
template<typename EnvType>
typename SerialVectorEnvWrapper<EnvType>::time_step_type
SerialVectorEnvWrapper<EnvType>::step(const std::vector<typename EnvType::action_type>& actions){


#ifdef GYMFCPP_DEBUG
assert(is_created_ && "Environment has not been created");
assert(actions.size() == envs_.size() && "Invalid number of actions. Number of actions does not equal number of environments");
assert(this->is_created_ && "Environment has not been created");
assert(actions.size() == this->envs_.size() && "Invalid number of actions. Number of actions does not equal number of environments");
#endif

current_state_.clear();
current_state_.reserve(config_.n_copies);
current_state_.reserve(this->config_.n_copies);

for(uint_t env=0; env<envs_.size(); ++env){
auto time_step = envs_[env]->step(actions[env]);
for(uint_t env=0; env<this->envs_.size(); ++env){
auto time_step = this->envs_[env]->step(actions[env]);
current_state_.add_time_step(time_step);

// if the current environment finished
// reset it
if(time_step.last()){
this->envs_[env]->reset();
}
}

return current_state_;

}

template<typename EnvType, typename StateAdaptor>

template<typename EnvType>
void
SerialVectorEnvWrapper<EnvType, StateAdaptor>::make(){
SerialVectorEnvWrapper<EnvType>::make(){

if(is_created_){
if(this->is_created_){
return;
}

envs_.reserve(config_.n_copies);
for(uint_t env=0; env < config_.n_copies; ++env){
envs_.push_back(std::make_shared<EnvType>(config_.env_id, p_namespace_, false));
envs_.back()->make();
this->envs_.reserve(this->config_.n_copies);
for(uint_t env=0; env < this->config_.n_copies; ++env){
this->envs_.push_back(std::make_shared<EnvType>(this->config_.env_id, this->p_namespace_, false));
this->envs_.back()->make();
}

// reserve space for time steps
current_state_.reserve(config_.n_copies);
is_created_ = true;
current_state_.reserve(this->config_.n_copies);
this->is_created_ = true;
}


}

#endif // VECTOR_ENV_WRAPPER_H
12 changes: 12 additions & 0 deletions src/gymfcpp/time_step.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@ class TimeStep
template<typename T>
const T& get_extra(std::string name)const;

///
/// \brief info
/// \return
///
const std::unordered_map<std::string, std::any>& info()const noexcept{return extra_;}

///
/// \brief info
/// \return
///
std::unordered_map<std::string, std::any>& info()noexcept{return extra_;}

private:

///
Expand Down
37 changes: 37 additions & 0 deletions src/gymfcpp/torch_state_adaptor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "gymfcpp/gymfcpp_config.h"
#include "gymfcpp/torch_state_adaptor.h"

#ifdef USE_PYTORCH

namespace rlenvs{
namespace torch_utils{


torch_tensor_t
TorchStateAdaptor::operator()(real_t value)const{
return this->operator()(std::vector<real_t>(1, value));
}

torch_tensor_t
TorchStateAdaptor::operator()(const std::vector<real_t>& data)const{
return torch::tensor(data);
}

torch_tensor_t
TorchStateAdaptor::operator()(const std::vector<int>& data)const{

return torch::tensor(data);
};

TorchStateAdaptor::value_type
TorchStateAdaptor::stack(const std::vector<value_type>& values)const{

return torch::stack(values, 0);
}

}

}
#endif


35 changes: 35 additions & 0 deletions src/gymfcpp/torch_state_adaptor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef TORCH_STATE_ADAPTOR_H
#define TORCH_STATE_ADAPTOR_H

#include "gymfcpp/gymfcpp_config.h"

#ifdef USE_PYTORCH

#include "gymfcpp/gymfcpp_types.h"
#include <vector>

namespace rlenvs{
namespace torch_utils {

using namespace gymfcpp;

struct TorchStateAdaptor{


typedef torch_tensor_t value_type;
typedef torch_tensor_t state_type;

torch_tensor_t operator()(real_t value)const;
torch_tensor_t operator()(const std::vector<real_t>& data)const;
torch_tensor_t operator()(const std::vector<int>& data)const;

value_type stack(const std::vector<value_type>& values)const;
};



}
}

#endif
#endif // TORCH_STATE_ADAPTOR_H
Loading