In [8]:
import torch
import torch.nn as nn

import matplotlib.pyplot as plt

from torch.optim import Adam
from torch.nn import MSELoss

from pyrl.agents import Agent, EpsilonGreedyAgent
from pyrl.environments import TicTacToe
from pyrl.algorithms.value import QLearning

from pyrl.experiment import Experiment

In [9]:
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()

        self.linear = nn.Sequential(
            nn.Linear(9, 32),
            nn.Tanh(),
            nn.Linear(32, 9),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.linear(x)

In [10]:
dqn = DQN()

# create algorithm
q_learning = QLearning(network=dqn,
                       gamma=1,
                       optimizer=Adam(params=dqn.parameters(), lr=1e-2),
                       loss_func=MSELoss())

agent = Agent(q_learning)

env = TicTacToe(player=agent, first=True)

exp = Experiment(agent=agent, environment=env)

In [None]:
params = {
    "n_epochs": 1_000_000,
    "n_iter": 1_000,
    "batch_size": 10,
    "eps_start": 1.0,
    "eps_end": 0.001,
    "n_target_train": 100
}

agent, rewards = exp.explore(params=params, progress=True, visualize=False)

plt.plot(rewards)
plt.show()

  3%|▎         | 26778/1000000 [02:24<1:22:10, 197.40it/s]

In [None]:
params["eps_start"] = 0
params["n_epochs"] = 10

_,rewards = exp.explore(params=params, progress=False, visualize=True)