In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from ipypb import irange

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class DVN(nn.Module):
    def __init__(self, n=5):
        super(DVN, self).__init__()
        in_sz = (n+1)*34
        self.ds1 = nn.Linear(in_sz, in_sz);
        self.bn1 = nn.BatchNorm1d(in_sz, eps=1)
        self.ds2 = nn.Linear(in_sz, in_sz*2);
        self.bn2 = nn.BatchNorm1d(in_sz*2, eps=1)
        self.ds3 = nn.Linear(in_sz*2, in_sz);
        self.bn3 = nn.BatchNorm1d(in_sz, eps=1)
        self.ds4 = nn.Linear(in_sz, 1, bias=False)
        
        nn.init.xavier_normal_(self.ds1.weight)
        nn.init.xavier_normal_(self.ds2.weight)
        nn.init.xavier_normal_(self.ds3.weight)
        nn.init.xavier_normal_(self.ds4.weight)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.ds1(x)))
        x = F.relu(self.bn2(self.ds2(x)))
        x = F.relu(self.bn3(self.ds3(x)))
        return self.ds4(x)

In [4]:
# state: players (N) + merged (1) x token (1) + card (33)  

N = 5

def mutate(state, pot, action):
    state, pot = state.copy(), pot.copy()
    if state[0,0] == 0:
        action = False
    if action:
        state[[0,-1],0] -= 1
        pot[0] += 1
    else:
        state[0] += pot
        state[-1] += pot
        pot *= 0
    return state, pot
        
def next_turn(state):
    return np.r_[state[1:-1], state[0:1], state[-1:]]

def to_tensor(state):
    return torch.from_numpy(state.reshape((-1, (N+1)*34))).float().to(device)

def evaluate(net, state):
    net.eval()
    return net(to_tensor(state)).data.cpu().numpy()[0,0]

def score(state):
    out = state[0,0]
    if state[0,1] == 1: 
        out -= 3
    for i in range(2, 34):
        if state[0,i] == 1 and state[0,i-1] == 0:
            out -= i + 2
    return out
    
def spawn_player(value_net):
    def player(state, pot):
        if state[0,0] == 0: 
            return False
        sT = mutate(state, pot, True)[0]
        sF = mutate(state, pot, False)[0]
        vT = evaluate(value_net, sT)
        vF = evaluate(value_net, sF)
        return vT > vF
    return player

In [5]:
vns = [DVN(N).to(device) for _ in range(N)]
players = [spawn_player(vn) for vn in vns]

In [6]:
def gym(players):
    card = np.arange(33) + 3
    
    deck = card.copy()
    np.random.shuffle(deck)
    deck = list(deck[:24])
    def draw():
        return np.r_[0, (card.copy() == deck.pop()).astype(int)]

    n = len(players)
    state = np.r_[[np.r_[11, np.zeros((33,), int)] for _ in range(n)]]
    state = np.r_[state, state.sum(axis=0).reshape((1, -1))]
    
    turn = np.random.randint(n)
    replay = []
    # non-terminal: is_terminal, id, state_before, state_after, reward
    # terminal: is_terminal, id, state, score, place
    
    while len(deck) > 0:
        pot = draw()
        while True:
            nothanks = players[turn](state, pot)
            state_, pot_ = mutate(state, pot, nothanks)
            reward = score(state_) - score(state)
            replay.append((False, turn, state, state_, reward))
            if nothanks:
                turn = (turn + 1)%n
                state = next_turn(state_)
                pot = pot_
            else:
                state = state_
                break
                
    scores = []
    for _ in range(n):
        scores.append(-score(state))
        state = next_turn(state)
    scores.sort()
    for _ in range(n):
        scr = score(state)
        replay.append((True, turn, state, scr, scores.index(-scr)))
        state = next_turn(state)
        turn = (turn + 1)%n
      
    return replay

replay = gym(players)

In [7]:
def train(vn, replay, gamma=0.95, batch_size=32, place_value=[120, 80, 60, 40, 0], epoch=10):
    vn.train()
    criterion = nn.SmoothL1Loss()
    optimizer = optim.Adam(vn.parameters())
    
    n_replay = len(replay)
    for _ in irange(epoch):
        argseq = np.arange(n_replay)
        np.random.shuffle(argseq)
        for i in range(0, n_replay, batch_size):
            x = []
            y = []
            if argseq[i:].size < batch_size:
                break
            for j in range(batch_size):
                is_terminal, _, state, a, b = replay[argseq[i+j]]
                x.append(state.flatten())
                if is_terminal:
                    y.append(a + place_value[b])
                else:
                    state_after = a
                    reward = b
                    y.append(reward + evaluate(vn, state_after)*gamma)

            x = to_tensor(np.r_[x])
            y = torch.from_numpy(np.r_[y].reshape((-1,1))).float().to(device)
            prediction = vn(x)
            loss = criterion(prediction, target=y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

In [13]:
for _ in irange(20):
    replay = []
    for _ in irange(64):
        replay += gym(players)
    for i in range(5):
        train(vns[i], replay)
        torch.save(vns[i].state_dict(), f'm{i}.mdl')

KeyboardInterrupt: 

In [18]:
for _, idx, _, scr, place in gym(players)[-N:]:
    print(f'player {idx}: #{place+1} ({scr})')

player 0: #5 (-158)
player 1: #1 (4)
player 2: #3 (-47)
player 3: #2 (-28)
player 4: #4 (-48)
