In [6]:
import numpy as np
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy
from typing import *

In [8]:
class LagEnv(gym.Env):
    def __init__(self, state_dim: int, action_dim: int, n_lag: int):
        super().__init__()
        self.action_space = gym.spaces.Discrete(action_dim)
        self.observation_space = gym.spaces.Box(
            low=np.full(state_dim, -1).astype(np.float32),
            high=np.full(state_dim, 1).astype(np.float32),
        )
        dict_space = {}
        for i in range(n_lag):
            dict_space[f"state_{i}"] = gym.spaces.Box(    
                low=np.full(state_dim, -1).astype(np.float32),
                high=np.full(state_dim, 1).astype(np.float32),
            )
        self.observation_space = gym.spaces.Dict(dict_space)
    
    def reset(self) -> np.ndarray:
        pass
    
    def step(self, action_index: int) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]:
        pass
    
    def render(self):
        raise NotImplementedError()

In [9]:
lagenv = LagEnv(state_dim=10, action_dim=5, n_lag=5)

- [x] gym.EnvのLagDict環境を作成
- [ ] FeatureExtractorを作成
  - [x] FeatureAxis: MLPを時系列ごとにかける
  - [ ] TimeAxis: MLPを特徴量ごとにかける
  - [ ] CNN: いわゆる時系列CNN

### policy_kwargsの項目
- features_extractor_class
- features_extractor_kwargs（observation_space以外の引数）

In [25]:
import torch
from torch import nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


class FeatureAxisCombinedExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict, emb_dim: int):
        super(FeatureAxisCombinedExtractor, self).__init__(observation_space, features_dim=1)
        self._check_state_dim(observation_space=observation_space)
        
        extractors = {
            key: nn.Sequential(nn.Linear(self.state_dim, emb_dim), nn.ReLU())
            for key in observation_space.spaces.keys()
        }
        self.extractors = nn.ModuleDict(extractors)
        self._features_dim = emb_dim * len(observation_space)

    def forward(self, observations) -> torch.Tensor:
        embeds = []
        for key, extractor in self.extractors.items():
            embeds.append(extractor(observations[key]))
        return torch.cat(embeds, dim=1)
    
    def _check_state_dim(self, observation_space: gym.spaces.Dict) -> None:
        state_dims = []
        for key, subspace in observation_space.spaces.items():
            if not key.startswith("state_"):
                raise Exception("ラグ特徴量の状態を示さないkey")
            state_dims.append(subspace.shape[0])
        
        if max(state_dims) != min(state_dims):
            raise Exception("特徴量数が時系列で異なる")
        
        self.state_dim = state_dims[0]

In [26]:
fe = FeatureAxisCombinedExtractor(observation_space=lagenv.observation_space, emb_dim=16)

In [22]:
lagenv.observation_space.sample()

OrderedDict([('state_0',
              array([-0.7439461 , -0.5942791 ,  0.5662168 , -0.23193447,  0.2599095 ,
                     -0.14923005, -0.1183605 , -0.2974712 ,  0.6337957 ,  0.9915399 ],
                    dtype=float32)),
             ('state_1',
              array([-0.7202005 ,  0.07508245, -0.9085906 , -0.9690788 , -0.5691145 ,
                      0.25228876,  0.06154174, -0.6212881 , -0.9810684 ,  0.669271  ],
                    dtype=float32)),
             ('state_2',
              array([-0.09554597,  0.6957237 ,  0.13710429,  0.5714343 ,  0.11541936,
                     -0.8194528 , -0.9398576 ,  0.39805675,  0.7849702 , -0.06726027],
                    dtype=float32)),
             ('state_3',
              array([-0.50279194,  0.15855195,  0.57658184,  0.23680189, -0.12748915,
                      0.8073213 , -0.1417726 , -0.4662511 , -0.16362762,  0.4707762 ],
                    dtype=float32)),
             ('state_4',
              array([-0.90996045, -

In [None]:
model = PPO(policy=ActorCriticPolicy, env=env, verbose=1, **ppo_params)