In [1]:
import time
import warnings
import gymnasium as gym
import pandas as pd
import matplotlib.pyplot as plt

from src.environment.env import ShogiEnv
from src.agent.shogi_agent import ShogiAgent

warnings.filterwarnings("ignore")

gym.register(id="Shogi-v0", entry_point="env:ShogiEnv")
env: ShogiEnv = gym.make("Shogi-v0")
agent = ShogiAgent()

ModuleNotFoundError: No module named 'env'

In [None]:
def play_game(environment: ShogiEnv, player: ShogiAgent) -> (float, bool, bool):
    losses = []
    rewards = []
    terminated = False
    truncated = False
    environment.reset()
    agent.reset()

    while not terminated and not truncated:
        current_state = env.get_observation()

        # Take action
        action, mask_index = player.select_action(environment)
        state, reward, terminated, truncated, _ = environment.step(action)

        # Update the player
        player.adaptive_e_greedy()
        new_state = env.get_observation()
        loss = player.train_model(
            mask_index,
            reward,
            (terminated or truncated),
            current_state,
            new_state,
        )

        rewards.append(reward)
        losses.append(loss)

    return rewards, terminated, truncated, losses

## Train the Agent

In [None]:
progress = []

In [None]:
for i in range(20):
    start = time.time()
    reward_list, _terminated, _truncated, loss_list = play_game(env, agent)
    end = time.time()
    print(f"{i}: {end - start}")

    progress.append(
        {
            "reward": sum(reward_list),
            "loss": sum(loss_list),
            "terminated": _terminated,
            "truncated": _truncated,
            "duration": end - start,
        }
    )

In [None]:
df = pd.DataFrame.from_records(progress)

fig, (ax1, ax2) = plt.subplots(1, 2)

ax1.plot(df.index, df["reward"], linewidth=0.1)
ax1.set_title("Total rewards by game")

ax2.plot(df.index, df["loss"], linewidth=0.1)
ax2.set_title("Total loss by game")

# Show the plot
plt.show()

In [None]:
df.head(30)

In [None]:
agent.save_model("models/test_1.pth")