diff --git a/src/gymfcpp/black_jack_env.cpp b/src/gymfcpp/black_jack_env.cpp index 0a04a348..f9f03ef7 100644 --- a/src/gymfcpp/black_jack_env.cpp +++ b/src/gymfcpp/black_jack_env.cpp @@ -1,5 +1,5 @@ #include "gymfcpp/black_jack_env.h" -#include "gymfcpp/config.h" +#include "gymfcpp/gymfcpp_config.h" #ifdef GYMFCPP_DEBUG #include diff --git a/src/gymfcpp/cliff_world_env.cpp b/src/gymfcpp/cliff_world_env.cpp index b2868109..0a8009fd 100644 --- a/src/gymfcpp/cliff_world_env.cpp +++ b/src/gymfcpp/cliff_world_env.cpp @@ -1,6 +1,6 @@ #include "gymfcpp/cliff_world_env.h" #include "gymfcpp/names_generator.h" -#include "gymfcpp/config.h" +#include "gymfcpp/gymfcpp_config.h" #ifdef GYMFCPP_DEBUG #include diff --git a/src/gymfcpp/config.h b/src/gymfcpp/config.h deleted file mode 100644 index fec994e4..00000000 --- a/src/gymfcpp/config.h +++ /dev/null @@ -1,6 +0,0 @@ -#ifndef CONFIG_H -#define CONFIG_H - -#define GYMFCPP_DEBUG - -#endif // CONFIG_H diff --git a/src/gymfcpp/frozen_lake_env.cpp b/src/gymfcpp/frozen_lake_env.cpp index d3dde617..409eb267 100644 --- a/src/gymfcpp/frozen_lake_env.cpp +++ b/src/gymfcpp/frozen_lake_env.cpp @@ -1,6 +1,6 @@ #include "gymfcpp/frozen_lake_env.h" #include "gymfcpp/names_generator.h" -#include "gymfcpp/config.h" +#include "gymfcpp/gymfcpp_config.h" #include diff --git a/src/gymfcpp/serial_vector_env_wrapper.h b/src/gymfcpp/serial_vector_env_wrapper.h index fdb1b7a7..c11d615a 100644 --- a/src/gymfcpp/serial_vector_env_wrapper.h +++ b/src/gymfcpp/serial_vector_env_wrapper.h @@ -32,43 +32,24 @@ struct SerialVectorEnvWrapperConfig uint_t seed{42}; }; +namespace { -/// -/// -/// -template -class SerialVectorEnvWrapper: private boost::noncopyable +template +class serial_vector_env_wrapper_base: private boost::noncopyable { public: - typedef VectorTimeStep time_step_type; - - /// - /// \brief VectorEnvWrapper - /// \param n_copies - /// \param version /// - SerialVectorEnvWrapper(SerialVectorEnvWrapperConfig config, obj_t p_namespace); - + /// \brief serial_vector_env_wrapper_base + /// \param config + /// \param p_namespace /// - /// \brief ~VectorEnvWrapper - /// - ~SerialVectorEnvWrapper(); - - /// - /// \brief make - /// - void make(); - - /// - /// - /// - time_step_type reset(); + serial_vector_env_wrapper_base(SerialVectorEnvWrapperConfig config, obj_t p_namespace); /// /// /// - time_step_type step(const std::vector& actions); + ~serial_vector_env_wrapper_base(); /// /// \brief close. Close down the environment @@ -81,12 +62,7 @@ class SerialVectorEnvWrapper: private boost::noncopyable /// uint_t n_copies()const noexcept{return config_.n_copies;} -private: - - /// - /// \brief current_state - /// - time_step_type current_state_; +protected: /// /// \brief p_namespace_ The main Python namespace @@ -110,11 +86,9 @@ class SerialVectorEnvWrapper: private boost::noncopyable }; - -template -SerialVectorEnvWrapper::SerialVectorEnvWrapper(SerialVectorEnvWrapperConfig config, obj_t p_namespace) +template +serial_vector_env_wrapper_base::serial_vector_env_wrapper_base(SerialVectorEnvWrapperConfig config, obj_t p_namespace) : - current_state_(), p_namespace_(p_namespace), envs_(), config_(config), @@ -128,81 +102,146 @@ SerialVectorEnvWrapper::SerialVectorEnvWrapper(SerialVect } -template -SerialVectorEnvWrapper::~SerialVectorEnvWrapper(){ +template +serial_vector_env_wrapper_base::~serial_vector_env_wrapper_base(){ close(); } -template +template void -SerialVectorEnvWrapper::close(){ +serial_vector_env_wrapper_base::close(){ for(uint_t t=0; t < config_.n_copies; ++t){ envs_[t]->close(); } } -template -typename SerialVectorEnvWrapper::time_step_type -SerialVectorEnvWrapper::reset(){ +} + + + +template +class SerialVectorEnvWrapper: protected serial_vector_env_wrapper_base +{ +public: + + typedef VectorTimeStep time_step_type; + + /// + /// \brief VectorEnvWrapper + /// \param n_copies + /// \param version + /// + SerialVectorEnvWrapper(SerialVectorEnvWrapperConfig config, obj_t p_namespace); + + /// + /// \brief ~VectorEnvWrapper + /// + ~SerialVectorEnvWrapper()=default; + + using serial_vector_env_wrapper_base::n_copies; + using serial_vector_env_wrapper_base::close; + + /// + /// \brief make + /// + void make(); + + /// + /// + /// + time_step_type reset(); + + /// + /// + /// + time_step_type step(const std::vector& actions); + +private: + + /// + /// \brief current_state + /// + time_step_type current_state_; + +}; + +template +SerialVectorEnvWrapper::SerialVectorEnvWrapper(SerialVectorEnvWrapperConfig config, obj_t p_namespace) + : + serial_vector_env_wrapper_base(config, p_namespace), + current_state_() +{} + +template +typename SerialVectorEnvWrapper::time_step_type +SerialVectorEnvWrapper::reset(){ #ifdef GYMFCPP_DEBUG - assert(is_created_ && "Environment has not been created"); + assert(this->is_created_ && "Environment has not been created"); #endif current_state_.clear(); - current_state_.reserve(config_.n_copies); + current_state_.reserve(this->config_.n_copies); - for(uint_t env=0; envreset(); + for(uint_t env=0; envenvs_.size(); ++env){ + auto time_step = this->envs_[env]->reset(); current_state_.add_time_step(time_step); } return current_state_; } -template -typename SerialVectorEnvWrapper::time_step_type -SerialVectorEnvWrapper::step(const std::vector& actions){ +template +typename SerialVectorEnvWrapper::time_step_type +SerialVectorEnvWrapper::step(const std::vector& actions){ #ifdef GYMFCPP_DEBUG - assert(is_created_ && "Environment has not been created"); - assert(actions.size() == envs_.size() && "Invalid number of actions. Number of actions does not equal number of environments"); + assert(this->is_created_ && "Environment has not been created"); + assert(actions.size() == this->envs_.size() && "Invalid number of actions. Number of actions does not equal number of environments"); #endif current_state_.clear(); - current_state_.reserve(config_.n_copies); + current_state_.reserve(this->config_.n_copies); - for(uint_t env=0; envstep(actions[env]); + for(uint_t env=0; envenvs_.size(); ++env){ + auto time_step = this->envs_[env]->step(actions[env]); current_state_.add_time_step(time_step); + + // if the current environment finished + // reset it + if(time_step.last()){ + this->envs_[env]->reset(); + } } return current_state_; } -template + +template void -SerialVectorEnvWrapper::make(){ +SerialVectorEnvWrapper::make(){ - if(is_created_){ + if(this->is_created_){ return; } - envs_.reserve(config_.n_copies); - for(uint_t env=0; env < config_.n_copies; ++env){ - envs_.push_back(std::make_shared(config_.env_id, p_namespace_, false)); - envs_.back()->make(); + this->envs_.reserve(this->config_.n_copies); + for(uint_t env=0; env < this->config_.n_copies; ++env){ + this->envs_.push_back(std::make_shared(this->config_.env_id, this->p_namespace_, false)); + this->envs_.back()->make(); } // reserve space for time steps - current_state_.reserve(config_.n_copies); - is_created_ = true; + current_state_.reserve(this->config_.n_copies); + this->is_created_ = true; } + } #endif // VECTOR_ENV_WRAPPER_H diff --git a/src/gymfcpp/time_step.h b/src/gymfcpp/time_step.h index 6c8bcd7c..ca083633 100644 --- a/src/gymfcpp/time_step.h +++ b/src/gymfcpp/time_step.h @@ -114,6 +114,18 @@ class TimeStep template const T& get_extra(std::string name)const; + /// + /// \brief info + /// \return + /// + const std::unordered_map& info()const noexcept{return extra_;} + + /// + /// \brief info + /// \return + /// + std::unordered_map& info()noexcept{return extra_;} + private: /// diff --git a/src/gymfcpp/torch_state_adaptor.cpp b/src/gymfcpp/torch_state_adaptor.cpp new file mode 100644 index 00000000..0dcb6b2b --- /dev/null +++ b/src/gymfcpp/torch_state_adaptor.cpp @@ -0,0 +1,37 @@ +#include "gymfcpp/gymfcpp_config.h" +#include "gymfcpp/torch_state_adaptor.h" + +#ifdef USE_PYTORCH + +namespace rlenvs{ +namespace torch_utils{ + + +torch_tensor_t +TorchStateAdaptor::operator()(real_t value)const{ + return this->operator()(std::vector(1, value)); +} + +torch_tensor_t +TorchStateAdaptor::operator()(const std::vector& data)const{ + return torch::tensor(data); +} + +torch_tensor_t +TorchStateAdaptor::operator()(const std::vector& data)const{ + + return torch::tensor(data); +}; + +TorchStateAdaptor::value_type +TorchStateAdaptor::stack(const std::vector& values)const{ + + return torch::stack(values, 0); +} + +} + +} +#endif + + diff --git a/src/gymfcpp/torch_state_adaptor.h b/src/gymfcpp/torch_state_adaptor.h new file mode 100644 index 00000000..4592b6b5 --- /dev/null +++ b/src/gymfcpp/torch_state_adaptor.h @@ -0,0 +1,35 @@ +#ifndef TORCH_STATE_ADAPTOR_H +#define TORCH_STATE_ADAPTOR_H + +#include "gymfcpp/gymfcpp_config.h" + +#ifdef USE_PYTORCH + +#include "gymfcpp/gymfcpp_types.h" +#include + +namespace rlenvs{ +namespace torch_utils { + +using namespace gymfcpp; + +struct TorchStateAdaptor{ + + + typedef torch_tensor_t value_type; + typedef torch_tensor_t state_type; + + torch_tensor_t operator()(real_t value)const; + torch_tensor_t operator()(const std::vector& data)const; + torch_tensor_t operator()(const std::vector& data)const; + + value_type stack(const std::vector& values)const; +}; + + + +} +} + +#endif +#endif // TORCH_STATE_ADAPTOR_H diff --git a/src/gymfcpp/vector_time_step.h b/src/gymfcpp/vector_time_step.h index fb31d1bf..8ef781e5 100644 --- a/src/gymfcpp/vector_time_step.h +++ b/src/gymfcpp/vector_time_step.h @@ -3,6 +3,7 @@ #include "gymfcpp/gymfcpp_types.h" #include "gymfcpp/time_step.h" +#include "gymfcpp/time_step_type.h" #include @@ -10,7 +11,7 @@ namespace rlenvs{ using namespace gymfcpp; -template +template class VectorTimeStep { @@ -47,18 +48,98 @@ class VectorTimeStep /// bool empty()const noexcept{return time_steps_.empty();} + /// + /// + /// + template + typename AdaptorType::value_type stack_states()const; + + /// + /// + /// + template + typename AdaptorType::value_type stack_rewards()const; + + /// + /// \brief stack_time_step_types + /// \return + /// + std::vector stack_time_step_types()const; + private: + /// + /// \brief time_steps_ + /// std::vector> time_steps_; }; -template +template void -VectorTimeStep::add_time_step(const TimeStep& step){ +VectorTimeStep::add_time_step(const TimeStep& step){ time_steps_.push_back(step); } +template +template +typename AdaptorType::value_type +VectorTimeStep::stack_states()const{ + + std::vector states; + states.reserve(time_steps_.size()); + + AdaptorType adaptor; + + std::for_each(time_steps_.begin(), time_steps_.end(), + [&](const auto& step){ + + states.push_back(adaptor(step.observation())); + + }); + + return adaptor.stack(states); +} + +template +template +typename AdaptorType::value_type +VectorTimeStep::stack_rewards()const{ + + std::vector rewards; + rewards.reserve(time_steps_.size()); + + AdaptorType adaptor; + + std::for_each(time_steps_.begin(), time_steps_.end(), + [&](const auto& step){ + + rewards.push_back(step.reward()); + + }); + + return adaptor(rewards); +} + +template +std::vector +VectorTimeStep::stack_time_step_types()const{ + + std::vector step_types; + step_types.reserve(time_steps_.size()); + + + + std::for_each(time_steps_.begin(), time_steps_.end(), + [&](const auto& step){ + + step_types.push_back(step.type()); + + }); + + return step_types; +} + } #endif // VECTOR_TIME_STEP_H diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9f04b1ca..36648df5 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -40,4 +40,5 @@ ADD_SUBDIRECTORY(test_tiled_cart_pole) ADD_SUBDIRECTORY(test_numpy_cpp_utils) ADD_SUBDIRECTORY(test_grid_world) ADD_SUBDIRECTORY(test_vector_env_wrapper) +ADD_SUBDIRECTORY(test_vector_time_step) diff --git a/tests/test_vector_env_wrapper/test_vector_env_wrapper.cpp b/tests/test_vector_env_wrapper/test_vector_env_wrapper.cpp index b29f9d53..c39c703d 100644 --- a/tests/test_vector_env_wrapper/test_vector_env_wrapper.cpp +++ b/tests/test_vector_env_wrapper/test_vector_env_wrapper.cpp @@ -1,5 +1,12 @@ + #include "gymfcpp/serial_vector_env_wrapper.h" +#include "gymfcpp/gymfcpp_config.h" + +#ifdef USE_PYTORCH +#include "gymfcpp/torch_state_adaptor.h" +#endif + #include "gymfcpp/cart_pole_env.h" #include "gymfcpp/time_step.h" #include "gymfcpp/time_step_type.h" @@ -43,7 +50,6 @@ TEST(SerialVectorEnvWrapper, Constructor) { PyErr_Print(); FAIL(); } - } @@ -105,54 +111,8 @@ TEST(TestStateAggregationCartPole, TestStep) } } -/*TEST(TestStateAggregationCartPole, TestStep) -{ - - try{ - Py_Initialize(); - boost::python::numpy::initialize(); - auto main_module = boost::python::import("__main__"); - auto main_namespace = main_module.attr("__dict__"); - - gymfcpp::StateAggregationCartPole env("v0", main_namespace, 10); - env.make(); - env.reset(); - - auto step_result = env.step(0); - ASSERT_TRUE(step_result.mid()); - } - catch(const boost::python::error_already_set&) - { - PyErr_Print(); - FAIL()<<"Error could not step in the environment"; - } -}*/ -/*TEST(TestStateAggregationCartPole, TestRender) -{ - - try{ - - Py_Initialize(); - boost::python::numpy::initialize(); - auto main_module = boost::python::import("__main__"); - auto main_namespace = main_module.attr("__dict__"); - - gymfcpp::StateAggregationCartPole env("v0", main_namespace, 10); - env.make(); - env.reset(); - - env.render(gymfcpp::RenderModeType::human); - - } - catch(const boost::python::error_already_set&) - { - PyErr_Print(); - FAIL()<<"Error could not step in the environment"; - } -}*/ - diff --git a/tests/test_vector_time_step/CMakeLists.txt b/tests/test_vector_time_step/CMakeLists.txt new file mode 100644 index 00000000..326fca9a --- /dev/null +++ b/tests/test_vector_time_step/CMakeLists.txt @@ -0,0 +1,20 @@ +CMAKE_MINIMUM_REQUIRED(VERSION 3.6) + +SET(EXECUTABLE test_vector_time_step) +SET(SOURCE ${EXECUTABLE}.cpp) + +ADD_EXECUTABLE(${EXECUTABLE} ${SOURCE}) + +TARGET_LINK_LIBRARIES(${EXECUTABLE} gymfcpplib) + +IF( USE_PYTORCH ) +TARGET_LINK_LIBRARIES(${EXECUTABLE} ${TORCH_LIBRARIES}) +ENDIF() + +TARGET_LINK_LIBRARIES(${EXECUTABLE} python3.8) +TARGET_LINK_LIBRARIES(${EXECUTABLE} boost_python38) +TARGET_LINK_LIBRARIES(${EXECUTABLE} boost_system) +TARGET_LINK_LIBRARIES(${EXECUTABLE} gtest) +TARGET_LINK_LIBRARIES(${EXECUTABLE} gtest_main) # so that tests dont need to have a main +TARGET_LINK_LIBRARIES(${EXECUTABLE} pthread) + diff --git a/tests/test_vector_time_step/test_vector_time_step.cpp b/tests/test_vector_time_step/test_vector_time_step.cpp new file mode 100644 index 00000000..861883b8 --- /dev/null +++ b/tests/test_vector_time_step/test_vector_time_step.cpp @@ -0,0 +1,112 @@ + +#include "gymfcpp/vector_time_step.h" +#include "gymfcpp/time_step_type.h" +#include "gymfcpp/torch_state_adaptor.h" +#include "gymfcpp/gymfcpp_config.h" + +#ifdef USE_PYTORCH +#include "gymfcpp/torch_state_adaptor.h" +#endif + +#include "gymfcpp/cart_pole_env.h" +#include "gymfcpp/time_step.h" +#include "gymfcpp/time_step_type.h" +#include "gymfcpp/gymfcpp_types.h" +#include "gymfcpp/render_mode_enum.h" + +#include +#include + +#include +#include + +namespace{ + +using gymfcpp::uint_t; +using gymfcpp::real_t; +using gymfcpp::CartPole; +using rlenvs::VectorTimeStep; +using rlenvs::torch_utils::TorchStateAdaptor; +using gymfcpp::TimeStep; +using gymfcpp::TimeStepTp; +} + + +TEST(TestVectorTimeStep, Constructor) { + + + VectorTimeStep> vec_step; + + ASSERT_TRUE(vec_step.empty()); + ASSERT_EQ(vec_step.size(), static_cast(0)); +} + + +TEST(TestVectorTimeStep, AddTimeStep) +{ + + VectorTimeStep> vec_step; + + ASSERT_TRUE(vec_step.empty()); + ASSERT_EQ(vec_step.size(), static_cast(0)); + + TimeStep> step(TimeStepTp::FIRST, 1.0, std::vector(2, 2.0)); + + vec_step.add_time_step(step); + ASSERT_EQ(vec_step.size(), static_cast(1)); + +} + + +TEST(TestVectorTimeStep, StackState) +{ + + VectorTimeStep> vec_step; + + ASSERT_TRUE(vec_step.empty()); + ASSERT_EQ(vec_step.size(), static_cast(0)); + + TimeStep> step1(TimeStepTp::FIRST, 1.0, std::vector(2, 2.0)); + + vec_step.add_time_step(step1); + + TimeStep> step2(TimeStepTp::FIRST, 2.0, std::vector(2, 3.0)); + vec_step.add_time_step(step2); + + ASSERT_EQ(vec_step.size(), static_cast(2)); + + auto tensor = vec_step.stack_states(); + + ASSERT_EQ(tensor.size(0), static_cast(2)); + ASSERT_EQ(tensor.size(1), static_cast(2)); +} + + +TEST(TestVectorTimeStep, StackReward) +{ + + VectorTimeStep> vec_step; + + ASSERT_TRUE(vec_step.empty()); + ASSERT_EQ(vec_step.size(), static_cast(0)); + + TimeStep> step1(TimeStepTp::FIRST, 1.0, std::vector(2, 2.0)); + + vec_step.add_time_step(step1); + + TimeStep> step2(TimeStepTp::FIRST, 2.0, std::vector(2, 3.0)); + vec_step.add_time_step(step2); + + ASSERT_EQ(vec_step.size(), static_cast(2)); + + auto tensor = vec_step.stack_rewards(); + + ASSERT_EQ(tensor.size(0), static_cast(2)); + //ASSERT_EQ(tensor.size(1), static_cast(2)); +} + + + + + +