In [1]:
import torch
from agents.bigtwo312 import Agent312
from agents.bigtwo312fs import Agent312fs
from agents.random import RandomAgent
from env.game import Bigtwo

torch.manual_seed(0)
torch.cuda.manual_seed(0)
game = Bigtwo(42)

agents = [
    Agent312fs("cuda:0"),
    Agent312fs("cuda:0"),
    # RandomAgent(),
    Agent312fs("cuda:0"),
    Agent312fs("cuda:0"),
]

for i in range(4):
    ckp = torch.load(f"checkpoints/bigtwo312fs-{i}_20240925_194858.pth")
    agents[i].model.load_state_dict(ckp["model_state_dict"])
    agents[i].optimizer.load_state_dict(ckp["optimizer_state_dict"])

  ckp = torch.load(f"checkpoints/bigtwo312fs-{i}_20240925_194858.pth")


In [3]:
import numpy as np
from agents.random import RandomAgent
from agents.bigtwo312 import Agent312

agent312 = Agent312("cuda:0")
checkpoint = torch.load("checkpoints/bigtwo312-0_20240924_091506.pth")
agent312.model.load_state_dict(checkpoint["model_state_dict"])


def evalueate(agent):
    def punish(p):
        if p < 10:
            return p
        elif p < 13:
            return 2 * p
        else:
            return 3 * p
    game = Bigtwo(56)
    players = [
        agent,
        agent312,
        agent312,
        agent312,
    ]
    points = [0, 0, 0, 0]
    for i in range(500):
        game.reset()
        while game.winner == None:
            players[game.player_to_act].act(game, training=False)

        holdings = [np.sum(player.holding, axis=0) for player in game.players]
        holdings = [punish(p) for p in holdings]
        for index, agent in enumerate(players):
            if holdings[index] == 0:
                points[index] += np.sum(holdings)
            else:
                points[index] -= holdings[index]
    return points

  checkpoint = torch.load("checkpoints/bigtwo312-0_20240924_091506.pth")


In [4]:
from utils.logger import append_to_log


for j in range(100000):
    points = [0, 0, 0, 0]
    losses = [0, 0, 0, 0]
    for i in range(500):
        game.reset()
        while game.winner == None:
            agents[game.player_to_act].act(game)
        for index, agent in enumerate(agents):
            agent.get_reward(game, index)
        points[game.winner] += 1

    for index, agent in enumerate(agents):
        loss = agent.learn()
        losses[index] = loss

    if j % 100 == 0 and j > 0:
        for index, agent in enumerate(agents):
            agent.save(index)
        append_to_log("log.txt", f"points: {evalueate(agents[0])}")

    print(points)
    print(losses)

[133, 119, 126, 122]
[0.3588086664676666, 0.38901883363723755, 0.37987762689590454, 0.40074795484542847]
[126, 110, 143, 121]
[0.49298134446144104, 0.3316503167152405, 0.441376656293869, 0.39979955554008484]
[118, 127, 129, 126]
[0.4227329194545746, 0.39192119240760803, 0.40811407566070557, 0.44108328223228455]
[146, 118, 113, 123]
[0.463591605424881, 0.44679316878318787, 0.3895183205604553, 0.39919760823249817]
[119, 105, 140, 136]
[0.47609031200408936, 0.3083038330078125, 0.40941357612609863, 0.42140424251556396]
[137, 110, 126, 127]
[0.4278121888637543, 0.35599541664123535, 0.464704304933548, 0.40004876255989075]
[134, 125, 124, 117]
[0.49238598346710205, 0.43353649973869324, 0.4092267155647278, 0.44789209961891174]
[139, 129, 122, 110]
[0.4390909969806671, 0.34409740567207336, 0.4394603371620178, 0.3505018949508667]
[134, 125, 117, 124]
[0.3729669451713562, 0.39919424057006836, 0.3458734452724457, 0.44538331031799316]
[134, 135, 107, 124]
[0.4560740888118744, 0.42416447401046753, 0

KeyboardInterrupt: 

In [1]:
from agents.bigtwo312fs import Agent312fs
from utils.round_weights import round_weights
from agents.bigtwo156 import Agent156
from agents.bigtwo312 import Agent312
from agents.bigtwo312fs import Agent312fs
import torch

agent312 = Agent312("cuda:0")
state_dict = torch.load("checkpoints/bigtwo312-0_20240924_041837.pth")["model_state_dict"]
agent312.model.load_state_dict(state_dict)
agent312fs = Agent312fs("cuda:0")
state_dict = torch.load("checkpoints/bigtwo312fs-0_20240925_191930.pth")["model_state_dict"]
agent312fs.model.load_state_dict(round_weights(state_dict,3))


  state_dict = torch.load("checkpoints/bigtwo312-0_20240924_041837.pth")["model_state_dict"]
  state_dict = torch.load("checkpoints/bigtwo312fs-0_20240925_191930.pth")["model_state_dict"]


<All keys matched successfully>

In [None]:
import importlib
from agents.bigtwo156 import Agent156


agent156 = Agent156("cuda:0")
state_dict = torch.load("checkpoints/bigtwo156-3_20240923_105233.pth")["model_state_dict"]
agent156.model.load_state_dict(state_dict)

In [7]:
import numpy as np
from env.game import Bigtwo


def punish(p):
    if p < 10:
        return p
    elif p < 13:
        return 2 * p
    else:
        return 3 * p


game = Bigtwo(162)
players = [
    agents[3],
    agent312,
    agents[1],
    agent312,
]

In [8]:
ranks = [[0,0,0,0] for _ in range(4)]
points = [0, 0, 0, 0]
for i in range(1000):
    game.reset()
    while game.winner == None:
        if game.player_to_act == 0:
            players[game.player_to_act].act(game, training=False)
        else:
            players[game.player_to_act].act(game)

    holdings = [np.sum(player.holding, axis=0) for player in game.players]
    holdings = [punish(p) for p in holdings]
    for index, agent in enumerate(players):
        if holdings[index] == 0:
            points[index] += np.sum(holdings)
        else:
            points[index] -= holdings[index]

    # for i in range(4):
    #     rank = np.where(np.argsort(points) == i)[0][0]
    #     ranks[i][rank] += 1

print(ranks)
print(points)

[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
[-17, 535, 22, -540]


In [67]:
print(points)

[1796, 2517, 3872, -8185]


In [25]:
import json

from utils.round_weights import round_list


state_dict0 = round_weights(state_dict, 2)

# Convert to JSON-serializable format (convert tensors to lists)
state_dict_serializable = {
    key: round_list(value.tolist(), 2) for key, value in state_dict0.items()
}

# Save as JSON
with open("state_dict.json", "w") as f:
    json.dump(state_dict_serializable, f, separators=(",", ":"))

In [25]:
import torch

x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

x_f = x.view(-1, 2)
x_f

tensor([[ 1,  2],
        [ 3,  4],
        [ 5,  6],
        [ 7,  8],
        [ 9, 10]])