In [None]:
import gymnasium as gym
import tensorflow as tf

import stock

In [None]:
env = gym.make("LunarLander-v2", render_mode="human")

# agentの用意
action_network = stock.rl.network.FCModel(env.action_space.n, 32, 2)
value_network = stock.rl.network.FCModel(env.action_space.n, 32, 2)
optimizer = tf.keras.optimizers.Adam(1e-3)
training_policy = stock.rl.agent.EpsilonGreedyPolicy(epsilon=0.2)
agent = stock.rl.agent.DDQNAgent(
    env.action_space, env.observation_space, 
    action_network=action_network, 
    value_network=value_network,
    optimizer=optimizer, training_policy=training_policy,
)

# trainerの用意
replay_buffer = stock.rl.replay_buffer.ReplayBuffer(capacity=10000)
callbacks = [
    stock.rl.callback.TrainLogger(),
]
params = stock.rl.Trainer.Params(max_steps=1000, batch_size=1000)
trainer = stock.rl.Trainer(env=env, agent=agent, replay_buffer=replay_buffer, callbacks=callbacks, params=params)

In [None]:
trainer.train()

In [None]:
class BestActionPolicy(stock.rl.agent.training_policy.BasePolicy):
    def __call__(self, score):
        return tf.argmax(score, axis=-1).numpy()

In [None]:

env = gym.make("LunarLander-v2", render_mode="human")
total_reward = 0
observation, info = env.reset()
while True:
    action = trainer.agent.policy(observation)[0]
    next_observation, reward, terminated, truncated, info = env.step(action)
    total_reward += reward

    observation = next_observation
    if terminated or truncated:
        break
     
trainer.env