In [10]:
import os

import matplotlib.pyplot as plt
from pprint import pprint
from time import sleep
import torch

import gymnasium as gym
import numpy as np
from pettingzoo.classic import connect_four_v3
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.data import Batch

env = connect_four_v3.env(render_mode='human')
# env = connect_four_v3.env()
env = PettingZooEnv(env)

In [11]:
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy
from tianshou.utils.net.common import Net

In [12]:
# train_path = '../log/rps/dqn/policy.pth'
# train_path = '../log/rps/dqn/policy_v1.pth'
# train_path = '../log/rps/dqn/policy_v2.pth'
# train_path = '../log/rps/dqn/policy_v3.pth'  # DQN v.s DQN
train_path_p1 = '../log/rps/dqn/policy_p1_v3.pth'  # DQN v.s DQN
train_path_p2 = '../log/rps/dqn/policy_p2_v3.pth'  # DQN v.s DQN
wsl = True

In [13]:
if wsl:
    os.environ["SDL_VIDEODRIVER"]="x11"  # if on WSL. https://learn.microsoft.com/en-us/windows/wsl/tutorials/gui-apps

In [14]:
state_shape = env.observation_space["observation"].shape  # (6, 7, 2)
action_shape = env.action_space.n  # 7

# net = Net(state_shape, hidden_sizes=[16, 16], device="cpu")

net = Net(state_shape=state_shape,
          action_shape=action_shape,
          hidden_sizes=[128, 128, 128, 128],
             device="cuda" if torch.cuda.is_available() else "cpu",
            ).to("cuda" if torch.cuda.is_available() else "cpu")
optim = torch.optim.Adam(net.parameters(), lr=1e-4)
policy_p1 = DQNPolicy(
            model=net,
            optim=optim,
            discount_factor=0.9,
            estimation_step=3,
            target_update_freq=320,
        )
policy_p2 = policy_p1

In [15]:
#### https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#save-load-policy

# Against Human
if not torch.cuda.is_available():
    policy_p1.load_state_dict(torch.load(train_path_p1, map_location=torch.device('cpu')))
    policy_p2.load_state_dict(torch.load(train_path_p2, map_location=torch.device('cpu')))
else:
    policy_p1.load_state_dict(torch.load(train_path_p1))
    policy_p2.load_state_dict(torch.load(train_path_p2))

In [16]:
from time import sleep

In [20]:
# Agent v.s. Agent

agents = env.agents
new_game = True
obs = env.reset()
observation = obs
done = False
info = {}

env.reset()
policy_p1.eval()
policy_p2.eval()
# First action.
action = env.action_space.sample()
observation, reward, done, truncated, info = env.step(action)
sleep(0.5)
while not done:
    # P2
    observation['obs'] = observation['obs'].reshape(-1, int(np.prod(state_shape)))  # Reshape observation    
    action = policy_p2(Batch(obs=observation, info=info)).act[0]
    observation, reward, done, truncated, info = env.step(action)
    sleep(0.5)
    
    if done:
        break

    # P1
    observation['obs'] = observation['obs'].reshape(-1, int(np.prod(state_shape)))  # Reshape observation    
    action = policy_p1(Batch(obs=observation, info=info)).act[0]
    observation, reward, done, truncated, info = env.step(action)
    sleep(0.5)

winner = agents[np.argmax(reward)]
print(f'Game ended. The winner is {winner}')    

Game ended. The winner is player_0


In [13]:
agents = env.agents
agent = agents[0]
new_game = True
obs = env.reset()
observation = obs
done = False
info = {}

env.reset()
policy.eval()
while not done:
    # determine agent's action.
    if agent == 'player_0' and new_game:
        action = env.action_space.sample()
    elif agent == 'player_1' and not new_game:
        observation['obs'] = observation['obs'].reshape(-1, int(np.prod(state_shape)))  # Reshape observation    
        action = policy(Batch(obs=observation)).act[0]
        print(f'action chosen by policy: column {action + 1}')

    if not new_game or agent == 'player_0':
        observation, reward, done, truncated, info = env.step(action)
        print('mask after agent moves:', observation['mask'])
    
    if not done:
        try:
            player_action = int(input('User input starts with 1 to 7: ')) - 1
        except:
            while not isinstance(player_action, int) and player_action < 7:
                print('invalid entry')
                player_action = int(input('User input starts with 1 to 7: ')) - 1
        observation, reward, done, truncated, info = env.step(player_action)
        observation['info'] = info

        print('mask after player moves:', observation['mask'])

    new_game = False

winner = agents[np.argmax(reward)]
print(f'Game ended. The winner is {winner}')    

mask after agent moves: [True, True, True, True, True, True, True]
User input starts with 1 to 7: 6
mask after player moves: [True, True, True, True, True, True, True]
mask after agent moves: [True, True, True, True, True, True, True]
User input starts with 1 to 7: 6
mask after player moves: [True, True, True, True, True, True, True]
mask after agent moves: [True, True, True, True, True, True, True]
User input starts with 1 to 7: 6
mask after player moves: [True, True, True, True, True, False, True]
obs['action_mask'] contains a mask of all legal moves that can be chosen.
mask after agent moves: [True, True, True, True, True, False, True]
Game ended. The winner is player_1


In [21]:
env.close()