In [1]:
from env.env import BigtwoEnv
import numpy as np
import torch
import random

from torch import nn
from env.game import Bigtwo
torch.manual_seed(0)
torch.cuda.manual_seed(0)

class BigtwoModel(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.device = torch.device(device)
        self.dense1 = nn.Linear(52 + 52 + 52, 256)
        self.dense2 = nn.Linear(256, 128)
        self.dense3 = nn.Linear(128, 1)
        self.to(self.device)

    def forward(self, x):
        x = self.dense1(x)
        x = torch.relu(x)
        x = self.dense2(x)
        x = torch.relu(x)
        x = self.dense3(x)
        return x


def observe(game: Bigtwo):
    players = game.players
    player_to_act = game.player_to_act
    player = players[player_to_act]
    legal_actions = player.legal_actions
    legal_actions = torch.tensor(np.array([a.code for a in legal_actions])).to("cuda:0")
    holding = torch.tensor(player.holding).to("cuda:0")
    other_indices = [
        (i + player_to_act) % 4
        for i in range(4)
        if (i + player_to_act) % 4 != player_to_act
    ]
    others_holding = [players[i].holding for i in other_indices]
    others_holding = np.bitwise_or.reduce(others_holding, axis=0)
    others_holding = torch.tensor(others_holding).to("cuda:0")
    x = torch.cat([holding, others_holding], dim=0)
    x_batch = x.repeat(len(legal_actions), 1)
    x_batch = torch.cat([x_batch, legal_actions], dim=1).float()
    return dict(
        x_batch=x_batch,
    )


env = BigtwoEnv()
model0 = BigtwoModel("cuda:0")
optimizer0 = torch.optim.Adam(model0.parameters(), lr=0.001)
model1 = BigtwoModel("cuda:0")
optimizer1 = torch.optim.Adam(model1.parameters(), lr=0.001)
game = Bigtwo()

In [2]:
for _ in range(1000):
    game_num = 100
    record0 = []
    result0 = []
    record1 = []
    result1 = []

    for i in range(game_num):
        row0 = []
        row1 = []
        game.reset()
        while game.winner == None:
            if game.player_to_act == 0:
                obs = observe(game)
                with torch.no_grad():
                    output = model0(obs["x_batch"])
                action_index = torch.argmax(output, dim=0)[0]
                action = game.players[0].legal_actions[action_index]
                row0.append(obs["x_batch"][action_index])
            if game.player_to_act == 1:
                obs = observe(game)
                with torch.no_grad():
                    output = model1(obs["x_batch"])
                action_index = torch.argmax(output, dim=0)[0]
                action = game.players[1].legal_actions[action_index]
                row1.append(obs["x_batch"][action_index])
            else:
                index = game.np_random.choice(
                    len(game.players[game.player_to_act].legal_actions)
                )
                action = game.players[game.player_to_act].legal_actions[index]
            game.step(action)
        result0.append(game.winner == 0)
        record0.append(row0)
        result1.append(game.winner == 1)
        record1.append(row1)
    print(np.mean(result0), np.mean(result1))
    for index, row0 in enumerate(record0):
        optimizer0.zero_grad()
        x_batch = torch.stack(row0)
        output = model0(x_batch)
        y_batch = torch.ones(x_batch.shape[0], 1).to("cuda:0") * result0[index]
        loss = torch.nn.functional.mse_loss(output, y_batch)
        loss.backward()
        optimizer0.step()
    for index, row1 in enumerate(record1):
        optimizer1.zero_grad()
        x_batch = torch.stack(row1)
        output = model1(x_batch)
        y_batch = torch.ones(x_batch.shape[0], 1).to("cuda:0") * result1[index]
        loss = torch.nn.functional.mse_loss(output, y_batch)
        loss.backward()
        optimizer1.step()

0.32 0.09


KeyboardInterrupt: 

In [5]:
from env.game import Bigtwo
game = Bigtwo()

game.player_to_act

2

In [3]:
game.winner

0

In [3]:
from utils.checkpoint import checkpoint


checkpoint(model1, optimizer1, "bigtwo_model3")

Model and optimizer saved as checkpoints/bigtwo_model3_20240923_004149.pth
