In [61]:
from Environment import FoolGame, GameState, GameTreeNode

import numpy as np
import random
import matplotlib.pyplot as plt

import logging

In [62]:
infoLogger = logging.getLogger(__name__)
infoLogger.setLevel(logging.INFO)
py_handler = logging.FileHandler(f"{__name__}.log", mode='w')
py_formatter = logging.Formatter("%(name)s %(asctime)s %(levelname)s %(message)s")

py_handler.setFormatter(py_formatter)
infoLogger.addHandler(py_handler)

In [63]:
names = {0: '6', 1: '7', 2: '8', 3: '9', 4: '10', 5: "jack", 6: "queen",
             7: "king", 8: "ace"}
game = FoolGame(36, 4, 6, 3, names=names)
game.redo()

In [64]:
print(game.state_history.restore(-1).get_state())

{'players_banks': [[(np.int32(23), {'power': 'jack', 'type': 2}, {'power': 5, 'type': 2}, array([1, 0, 0, 1, 1, 0])), (np.int32(0), {'power': '6', 'type': 0}, {'power': 0, 'type': 0}, array([0, 0, 0, 0, 0, 1])), (np.int32(11), {'power': '8', 'type': 1}, {'power': 2, 'type': 1}, array([0, 1, 0, 0, 1, 1])), (np.int32(30), {'power': '9', 'type': 3}, {'power': 3, 'type': 3}, array([1, 1, 0, 1, 0, 0])), (np.int32(4), {'power': '10', 'type': 0}, {'power': 4, 'type': 0}, array([0, 0, 0, 1, 0, 1])), (np.int32(26), {'power': 'ace', 'type': 2}, {'power': 8, 'type': 2}, array([1, 0, 1, 0, 0, 1]))], [(np.int32(14), {'power': 'jack', 'type': 1}, {'power': 5, 'type': 1}, array([0, 1, 0, 1, 1, 0])), (np.int32(34), {'power': 'king', 'type': 3}, {'power': 7, 'type': 3}, array([1, 1, 1, 0, 0, 0])), (np.int32(29), {'power': '8', 'type': 3}, {'power': 2, 'type': 3}, array([1, 1, 0, 0, 1, 1])), (np.int32(2), {'power': '8', 'type': 0}, {'power': 2, 'type': 0}, array([0, 0, 0, 0, 1, 1])), (np.int32(33), {'po

In [65]:
def create_observations(state: dict) -> tuple[np.ndarray, np.ndarray]:
    """return:
    res[0, :, :] - карты игрока,
    res[1:num_players, :, :] - информация о картах других,
    res[num_players, :, :] - table,
    res[num_players + 1, :, :] - bita, 
    res[num_players + 2, :, :] - target,
    res[num_players + 3, :, :] - scalar cards power,
    res[num_players + 4, :, :] - limitation on attacker cards,
    res[num_players + 5, :, :] - limitation on defender cards"""
    player_num = state["round"] % len(state["players_banks"])
    num_players = len(state['players_banks'])
    observations = np.zeros((36, num_players + 4)) # + функция мощности
    for card in state['players_banks'][player_num]:
        observations[card[0]][0] = 1
    for i, bank in enumerate(state['players_info']):
        if i != player_num:
            if i > player_num:
                k = i - player_num
            else:
                k = num_players + i - player_num
            for card in bank:
                observations[card[0]][k] = 1

    result = np.zeros((num_players + 6, 5, 10))
    result[:num_players + 3, 4, 9] = np.ones_like(result[:num_players + 3, 4, 9])
    result[num_players + 5, 4, 9] = 1
    result[num_players + 4, 4, 9] = 1

    for card in state['table']:
        observations[card[0]][num_players] = 1
        result[num_players + 4, :4, card[2]['power']] = 1
    for card in state['bita']:
        observations[card[0]][num_players + 1] = 1
    for card in state['target']:
        observations[card[0]][num_players + 2] = 1
        result[num_players + 5, card[2]['type'], card[2]['power']+1:-1] = 1
        result[num_players + 4, :4, card[2]['power']] = 1
    if len(state['target']) == 0 and len(state['table']) == 0 and state['role'] == 0:
        result[num_players + 4, :4, :9] = np.ones_like(result[num_players + 4, :4, :9])
        result[num_players + 4, 4, 9] = 0

    x_vals = np.array([0.5, 1.5, 2.5, 3.5])
    y_vals = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
    grid_points = np.array([[x, y] for x in x_vals for y in y_vals])
    x_grid = grid_points[:, 0]
    y_grid = grid_points[:, 1]
    func1 = lambda x, y: -0.5*y+10
    func2 = lambda x, y: 0.5*np.cos(x*np.pi + np.pi/2)**2
    func3 = lambda x, y: np.where((game.cosir[2]['type'] < x) & (game.cosir[2]['type'] + 1 > x), 10 + (func1(x, -y)) * func2(x, y), 10 - (func1(x, y)) * func2(x, y))
    u_vals = func3(x_grid, y_grid)#.reshape(-1, 1)

    observations[:, num_players + 3] = u_vals

    observations = observations.T

    for y_idx in range(4):
        for x_idx in range(9):
            linear_idx = y_idx * 9 + x_idx
            result[:num_players+4, y_idx, x_idx] = observations[:, linear_idx]

    if state['role'] == 0:
        result[0, 4, 9] = 1
    else:
        result[num_players + 2, 4, 9] = 1

    x_vals = np.array([1, 2, 3, 4])
    y_vals = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
    grid_points = np.array([[x, y, state["role"]] for x in x_vals for y in y_vals] + [[0, 0, state["role"]]])
    return result, grid_points

print(game.state_history.restore(-1).get_state())
obs, grid = create_observations(game.state_history.restore(-1).get_state())
print(obs)

{'players_banks': [[(np.int32(23), {'power': 'jack', 'type': 2}, {'power': 5, 'type': 2}, array([1, 0, 0, 1, 1, 0])), (np.int32(0), {'power': '6', 'type': 0}, {'power': 0, 'type': 0}, array([0, 0, 0, 0, 0, 1])), (np.int32(11), {'power': '8', 'type': 1}, {'power': 2, 'type': 1}, array([0, 1, 0, 0, 1, 1])), (np.int32(30), {'power': '9', 'type': 3}, {'power': 3, 'type': 3}, array([1, 1, 0, 1, 0, 0])), (np.int32(4), {'power': '10', 'type': 0}, {'power': 4, 'type': 0}, array([0, 0, 0, 1, 0, 1])), (np.int32(26), {'power': 'ace', 'type': 2}, {'power': 8, 'type': 2}, array([1, 0, 1, 0, 0, 1]))], [(np.int32(14), {'power': 'jack', 'type': 1}, {'power': 5, 'type': 1}, array([0, 1, 0, 1, 1, 0])), (np.int32(34), {'power': 'king', 'type': 3}, {'power': 7, 'type': 3}, array([1, 1, 1, 0, 0, 0])), (np.int32(29), {'power': '8', 'type': 3}, {'power': 2, 'type': 3}, array([1, 1, 0, 0, 1, 1])), (np.int32(2), {'power': '8', 'type': 0}, {'power': 2, 'type': 0}, array([0, 0, 0, 0, 1, 1])), (np.int32(33), {'po

In [66]:
print(obs.shape)
print(grid)

(9, 5, 10)
[[1 1 0]
 [1 2 0]
 [1 3 0]
 [1 4 0]
 [1 5 0]
 [1 6 0]
 [1 7 0]
 [1 8 0]
 [1 9 0]
 [2 1 0]
 [2 2 0]
 [2 3 0]
 [2 4 0]
 [2 5 0]
 [2 6 0]
 [2 7 0]
 [2 8 0]
 [2 9 0]
 [3 1 0]
 [3 2 0]
 [3 3 0]
 [3 4 0]
 [3 5 0]
 [3 6 0]
 [3 7 0]
 [3 8 0]
 [3 9 0]
 [4 1 0]
 [4 2 0]
 [4 3 0]
 [4 4 0]
 [4 5 0]
 [4 6 0]
 [4 7 0]
 [4 8 0]
 [4 9 0]
 [0 0 0]]


In [67]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


class MIONet3d(nn.Module):
    def __init__(self, branch_dim, trunk_input_dim=3, latent_dim=64):
        super().__init__()
        self.latent_dim = latent_dim

        # Общая Branch сеть (для точек x, y ∈ координатной сетке)
        self.branch_general = nn.Sequential(
            nn.Conv2d(branch_dim, 32, kernel_size=3, padding=1), nn.SELU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(32, latent_dim)
        )

        # Trunk сеть
        self.trunk_net = nn.Sequential(
            nn.Linear(trunk_input_dim, 32), nn.ReLU(),
            nn.Linear(32, latent_dim)
        )

    def target_positions(self, lin_input):
        dots = []
        pos = torch.prod(lin_input, dim=1) # (B, H, W)
        for i in range(pos.shape[1] - 1):
            for j in range(pos.shape[2] - 1):
                dots.append(pos[:, i, j])
        dots.append(pos[:, -1, -1])
        print(torch.stack(dots, dim=1).shape)
        return torch.stack(dots, dim=1).to(torch.bool)

    def forward(self, nonlinear_func, coords):
        """
        nonlinear_func: (B, C1, H, W)
        coords: (B, N, 3) — (x, y, role) координаты
        Returns:
            Q-values: (B, N)
        """
        B, N, dimension = coords.shape

        # --- Nonlinear branch (f1,...,fn) ---
        nonlinear_out = self.branch_general(nonlinear_func)  # (B, latent_dim)
        nonlinear_out = nonlinear_out.unsqueeze(1).expand(-1, N, -1)  # (B, N, latent_dim)

        # --- Masking ---
        if coords[0, 0, 2] == 0:
            mask = self.target_positions(torch.cat([nonlinear_func[:, 0:1, :, :], nonlinear_func[:, 7:8, :, :]], dim=1))
        else:
            mask = self.target_positions(torch.cat([nonlinear_func[:, 0:1, :, :], nonlinear_func[:, 8:9, :, :]], dim=1))

        coords_masked = coords[mask]  # (B_total_used, 3)
        branch_masked = nonlinear_out[mask]  # (B_total_used, latent_dim)

        # --- Trunk features ---
        trunk_feat = self.trunk_net(coords_masked)  # (B_total_used, latent_dim)

        # --- Pointwise product ---
        q_masked = torch.sum(branch_masked * trunk_feat, dim=-1)  # (B_total_used)

        # --- Восстановим полный тензор с нулями в ненужных местах ---
        q_vals = torch.zeros(B, N, device=coords.device)
        q_vals[mask] = q_masked

        return q_vals, mask

In [68]:
print(obs)

[[[ 1.    0.    0.    0.    1.    0.    0.    0.    0.    0.  ]
  [ 0.    0.    1.    0.    0.    0.    0.    0.    0.    0.  ]
  [ 0.    0.    0.    0.    0.    1.    0.    0.    1.    0.  ]
  [ 0.    0.    0.    1.    0.    0.    0.    0.    0.    0.  ]
  [ 0.    0.    0.    0.    0.    0.    0.    0.    0.    1.  ]]

 [[ 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.  ]
  [ 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.  ]
  [ 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.  ]
  [ 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.  ]
  [ 0.    0.    0.    0.    0.    0.    0.    0.    0.    1.  ]]

 [[ 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.  ]
  [ 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.  ]
  [ 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.  ]
  [ 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.  ]
  [ 0.    0.    0.    0.    0.    0.    0.    0.    0.    1.  ]]

 [[ 0.    0.    0.    0.    0.    

In [69]:
input_func = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
coord = torch.tensor(grid, dtype=torch.float32).unsqueeze(0)
model = MIONet3d(9)
print(input_func.shape)
print(coord.shape)
print(input_func)

torch.Size([1, 9, 5, 10])
torch.Size([1, 37, 3])
tensor([[[[ 1.0000,  0.0000,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.0000,  0.0000,
            0.0000,  1.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  1.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000

In [70]:
q, mask = model(input_func, coord)
print(q)
masked_indices = mask.nonzero(as_tuple=False)  # (K, 2), где K — число True
relative_idx = torch.argmax(q[mask])    # индекс среди отфильтрованных
action_idx = masked_indices[relative_idx][1].item()
print(action_idx)

torch.Size([1, 37])
tensor([[ 4.1053e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0229e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -6.2383e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -1.6104e+00,  0.0000e+00,
          0.0000e+00, -2.2765e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -1.0407e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], grad_fn=<IndexPutBackward0>)
0


In [71]:
def find_q_approx(game: FoolGame, player_num, pred_model: torch.nn.Module, verbose=False) -> tuple[float, bool]:
    infoLogger.info(f"find_q_approx {player_num} player, role: {game.role}") if verbose else None

    step_count = 0
    while game.round % game.num_players != player_num:
        step_count += 1
        #infoLogger.debug(f"[Step {step_count}]")

        state = game.state_history.restore(-1).get_state()
        obs, grid = create_observations(state)

        nonlinear_input = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)#.to(torch.device('cuda:0'))
        coords_input = torch.tensor(grid, dtype=torch.float32).unsqueeze(0)#.to(torch.device('cuda:0'))

        Q_values, mask = pred_model(nonlinear_input, coords_input)
        masked_indices = mask.nonzero(as_tuple=False)  # (K, 2), где K — число True
        relative_idx = torch.argmax(q[mask])    # индекс среди отфильтрованных
        action = masked_indices[relative_idx][1].item()

        #infoLogger.debug(f"[Step {step_count}] Opponent chosen {action} with Q={Q_values[0, action].item():.4f}")
        _, not_terminal, _ = game.step(action)
        if not not_terminal:
            infoLogger.info(f"[Step {step_count}] Game ended.") if verbose else None
            return -100, False

    # Наступил ход нужного игрока
    #infoLogger.info(f"{player_num} turn. Approximating Q.")

    state = game.state_history.restore(-1).get_state()
    obs, grid = create_observations(state)

    nonlinear_input = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)#.to(torch.device('cuda:0'))
    coords_input = torch.tensor(grid, dtype=torch.float32).unsqueeze(0)#.to(torch.device('cuda:0'))

    Q_values, mask = pred_model(nonlinear_input, coords_input)

    q_max = torch.max(Q_values[mask]).item()
    #infoLogger.info(f"Maximal Q-value for player {player_num}: {q_max:.4f}")

    return q_max, True

In [72]:
find_q_approx(game, 0, model)

torch.Size([1, 37])


(0.00041053444147109985, True)

In [None]:
B, N = 256, 37
modelA = MIONet2d(7, 2).to(torch.device('cuda:0'))
weights = modelA.state_dict()
modelA_eval = MIONet2d(7, 2).to(torch.device('cuda:0'))
modelA_eval.load_state_dict(weights)
modelD = MIONet2d(7, 2).to(torch.device('cuda:0'))
weights = modelD.state_dict()
modelD_eval = MIONet2d(7, 2).to(torch.device('cuda:0'))
modelD_eval.load_state_dict(weights)

optimizerA = optim.Adam(modelA.parameters(), lr=0.0001)
optimizerD = optim.Adam(modelD.parameters(), lr=0.0001)
criterion = nn.MSELoss()
epochs = 500
discounting = 0.98
lossesA = []
lossesD = []
games = [FoolGame(36, 4, 6, 3, names=names) for i in range(B)]
[game.redo() for game in games]

modelA.train()
modelD.train()
modelA_eval.eval()
modelD_eval.eval()

for epoch in range(epochs):
    infoLogger.info(f"epoch={epoch}")
    # Updating evaluation model weights
    if (epoch % 5 == 0 and epoch != 0) or epoch == 1:
        weights = modelA.state_dict()
        modelA_eval.load_state_dict(weights)
        weights = modelD.state_dict()
        modelD_eval.load_state_dict(weights)
        infoLogger.info(f"updating eval model weights")

    state_nonlinearA = []
    state_linearA = []

    state_nonlinearD = []
    state_linearD = []

    coords_input = None

    D_idxs = []
    A_idxs = []

    # Capturing observables from games
    infoLogger.info(f"capturing observables")
    for i, game in enumerate(games):
        state = game.state_history.restore(-1).get_state()
        obs, grid = create_observations(state)

        nonlinear_input = create_nonlin_tensor(obs, state['role']).unsqueeze(0)
        linear_input = create_lin_tensor(obs, state['role']).unsqueeze(0)
        if coords_input is None:
            coords_input = torch.tensor(grid, dtype=torch.float32).unsqueeze(0)

        if state['role'] == 0:
            state_nonlinearA.append(nonlinear_input)
            state_linearA.append(linear_input)
            A_idxs.append(i)
        else:
            state_nonlinearD.append(nonlinear_input)
            state_linearD.append(linear_input)
            D_idxs.append(i)
        if i % 50 == 0:
            infoLogger.info(f"game={i}")

    # Approximating Q-values
    if len(state_nonlinearA) > 0:
        t_state_nonlinearA = torch.cat(state_nonlinearA).to(torch.device('cuda:0'))#.detach()
        t_state_linearA = torch.cat(state_linearA).to(torch.device('cuda:0'))
        t_coords_inputA = coords_input.repeat(t_state_linearA.shape[0], 1, 1).to(torch.device('cuda:0'))#.detach()

        Q_valuesA = modelA(t_state_nonlinearA, t_state_linearA, t_coords_inputA)
        infoLogger.info(f"Approximating attacker Q values")

    if len(state_linearD) > 0:
        t_state_nonlinearD = torch.cat(state_nonlinearD).to(torch.device('cuda:0'))#.detach()
        t_state_linearD = torch.cat(state_linearD).to(torch.device('cuda:0'))
        t_coords_inputD = coords_input.repeat(t_state_linearD.shape[0], 1, 1).to(torch.device('cuda:0'))#.detach()

        Q_valuesD = modelD(t_state_nonlinearD, t_state_linearD, t_coords_inputD)
        infoLogger.info(f"Approximating defender Q values")

    # Applying Bellman formula
    h, m = 0, 0
    ngA = []
    ngD = []
    rewD = torch.zeros(len(state_linearD), 37).to(torch.device('cuda:0'))
    rewA = torch.zeros(len(state_linearA), 37).to(torch.device('cuda:0'))
    contA = []
    contD = []
    infoLogger.info(f"Applying Bellman formula")
    for i, game in enumerate(games):
        infoLogger.info(f"Bellman game={i}") if i % 50 == 0 else None
        verbose = True if i % 50 == 0 else False
        if i in D_idxs:
            Q_v = Q_valuesD[h, :] # [37]
            for j, q_val in enumerate(Q_v):
                if q_val > 1e-7 or q_val < 1e-7:
                    new_game: FoolGame = game.copy()
                    rew, cont, player_not_ended = new_game.step(j)
                    if player_not_ended and cont:
                        Q_next, cont = find_q_approx(new_game, game.round % game.num_players, modelA_eval, modelD_eval, verbose)
                        if cont:
                            rew += Q_next * discounting
                    rewD[h, j] = rew
                else:
                    rewD[h, j] = 0
            h += 1
        if i in A_idxs:
            Q_v = Q_valuesA[m, :] # [37]
            for j, q_val in enumerate(Q_v):
                if q_val > 1e-7 or q_val < 1e-7:
                    new_game: FoolGame = game.copy()
                    rew, cont, player_not_ended = new_game.step(j)
                    if player_not_ended and cont:
                        Q_next, cont = find_q_approx(new_game, game.round % game.num_players, modelA_eval, modelD_eval, verbose)
                        if cont:
                            rew += Q_next * discounting
                    rewA[m, j] = rew
                else:
                    rewA[m, j] = 0
            m += 1

    # Optimizing process
    infoLogger.info(f"Optimizing Q functions")
    if epoch >= 5:
        if len(state_linearA) > 0:
            optimizerA.zero_grad()
            loss = criterion(Q_valuesA, rewA)
            loss.backward()
            optimizerA.step()
            lossesA.append(loss.item())

        if len(state_linearD) > 0:
            optimizerD.zero_grad()
            loss = criterion(Q_valuesD, rewD)
            loss.backward()
            optimizerD.step()
            lossesD.append(loss.item())

    # Making eps -greedy moves
    infoLogger.info(f"Making \eps -greedy moves")
    h, m = 0, 0
    epsilon = 0.15
    new_games = []
    for i, game in enumerate(games):
        infoLogger.info(f"game={i}")
        if i in D_idxs:
            possible_actions = torch.nonzero(Q_valuesD[m, :]).tolist()
            #print("D:", torch.argmax(Q_valuesD[m, :], dim=0).item())
            if random.random() < epsilon:
                action = random.choice(possible_actions)[0]
            else:
                action = torch.argmax(Q_valuesD[m, :]).item()
            _, cont, _ = game.step(action)
            if cont:
                new_games.append(game)
            m += 1
        else:
            possible_actions = torch.nonzero(Q_valuesA[h, :]).tolist()
            #print("A:", torch.argmax(Q_valuesA[h, :], dim=0).item())
            if random.random() < epsilon:
                action = random.choice(possible_actions)[0]
            else:
                action = torch.argmax(Q_valuesA[h, :]).item()
            _, cont, _ = game.step(action)
            if cont:
                new_games.append(game)
            h += 1
    games = new_games
    added_games = [FoolGame(36, 4, 6, 3, names=names) for i in range(B - len(games))]
    [game.redo() for game in added_games]
    games += added_games

    print("epoch:", epoch)
    if len(lossesA) > 0:
        print("lossA:", lossesA[-1])
    if len(lossesD) > 0:
        print("lossD:", lossesD[-1])

plt.plot(np.arange(1, len(lossesA) + 1), lossesA)
plt.show()
plt.plot(np.arange(1, len(lossesD) + 1), lossesD)
plt.show()