In [None]:
import gym
import torch
import torch.nn as nn

from torch_bwim.dataset.TorchDataUtils import TorchDataUtils
from torch_bwim.loss_functions.reinforcement_learning.ReinforcementLoss import ReinforcementLoss
from torch_bwim.lr_schedulers.modules.NullScheduler import NullScheduler
from torch_bwim.nets.NnModuleUtils import NnModuleUtils
from torch_bwim.nets.modules.reinforce.DiscrActorCriticNet import DiscrActorCriticNet
from torch_bwim.optimizers.modules.AdamFactory import AdamFactory
from torch_bwim.trainers.modules.ActorCriticNetTrainer import ActorCriticNetTrainer

%load_ext autoreload
%autoreload 2

In [None]:
class GymEnvWrapper(gym.Env):
    def __init__(self, env: gym.Env):
        super().__init__()
        self.env = env

    def step(self, action):
        action = action.item()
        state, reward, done, extra_info = self.env.step(action=action)
        return NnModuleUtils.from_array(state), reward, done, extra_info

    def reset(self, seed=None, return_info=False,options=None):
        state = self.env.reset(seed=seed, return_info=return_info, options=options)
        return NnModuleUtils.from_array(state)

    def close(self):
        self.env.close()


In [None]:
env = gym.make('CartPole-v1')
env = GymEnvWrapper(env=env)

In [None]:
state = env.reset()
print(state)

In [None]:
state, reward, done, _ = env.step(action=torch.tensor([1]))
print(state, reward, done)

In [None]:
class Policy(DiscrActorCriticNet):

  class Config(DiscrActorCriticNet.Config):
      def __init__(self, in_size, hidden_size, action_out, value_out):
          super().__init__()
          self.in_size = in_size
          self.hidden_size = hidden_size
          self.action_out = action_out
          self.value_out = value_out

  def __init__(self, config: Config):
    super().__init__(config=config)
    self.config = config
    cfg = config
    self.l1 = nn.Linear(cfg.in_size, cfg.hidden_size)
    self.action_head = nn.Linear(cfg.hidden_size, cfg.action_out)
    self.value_head = nn.Linear(cfg.hidden_size, cfg.value_out)

  def forward(self, x):
    hidden_state = torch.relu(self.l1(x))
    action_scores = self.action_head(hidden_state)
    state_values = self.value_head(hidden_state)
    return torch.softmax(action_scores, dim=-1), state_values


In [None]:
net = Policy(config=Policy.Config(in_size=4, hidden_size=128, action_out=2, value_out=1))

In [None]:
action_scores, baseline = net.forward(torch.randn((1, 4)))
print(action_scores.shape)
print(baseline.shape)

In [None]:
action = net(state=state)
print(action)
print(action.shape)

In [None]:
env.step(action)

In [None]:
net = Policy(config=Policy.Config(in_size=4, hidden_size=128, action_out=2, value_out=1))

trainer = ActorCriticNetTrainer(
    train_config=ActorCriticNetTrainer.Config(),
    logger=print
)

trainer.initialize(
    net=net, env=env,
    loss_function=ReinforcementLoss(config=ReinforcementLoss.Config(gamma=0.99)),
    scheduler_config=NullScheduler.Config(), optimizer_config=AdamFactory.Config(learning_rate=1e-3, weight_decay=1e-4),
    cuda=False
)

trainer.train(episode_num=1000, max_iter_in_episode=200)