In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from ipypb import irange

from IPython.display import clear_output

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class DVN(nn.Module):
    def __init__(self, n=5):
        super(DVN, self).__init__()
        
        sz1 = (n+1)*34
        sz2 = (n+1)*30
        sz3 = (n+1)*26
        sz4 = (n+1)*22

        self.ds1 = nn.Linear(sz1, sz2);
        self.ds2 = nn.Linear(sz2, sz3);
        self.ds3 = nn.Linear(sz3, sz4);
        self.ds4 = nn.Linear(sz4, 1, bias=False)
        
        nn.init.xavier_normal_(self.ds1.weight)
        nn.init.xavier_normal_(self.ds2.weight)
        nn.init.xavier_normal_(self.ds3.weight)
        nn.init.xavier_normal_(self.ds4.weight)
        
    def forward(self, x):
        x = F.relu(self.ds1(x))
        x = F.relu(self.ds2(x))
        x = F.relu(self.ds3(x))
        return self.ds4(x)

In [4]:
# state: players (N) + merged (1) x token (1) + card (33)  

N = 5

def mutate(state, pot, action):
    state, pot = state.copy(), pot.copy()
    if state[0,0] == 0:
        action = False
    if action:
        state[[0,-1],0] -= 1
        pot[0] += 1
    else:
        state[0] += pot
        state[-1] += pot
        pot *= 0
    return state, pot
        
def next_turn(state):
    return np.r_[state[1:-1], state[0:1], state[-1:]]

def to_tensor(state):
    return torch.from_numpy(state.reshape((-1, (N+1)*34))).float().to(device)

def evaluate(net, state):
    net.eval()
    return net(to_tensor(state)).data.cpu().numpy()[0,0]

def get_score(state):
    out = state[0,0]
    if state[0,1] == 1: 
        out -= 3
    for i in range(2, 34):
        if state[0,i] == 1 and state[0,i-1] == 0:
            out -= i + 2
    return out
    
def spawn_player(value_net):
    def player(state, pot):
        if state[0,0] == 0: 
            return False
        sT = mutate(state, pot, True)[0]
        sF = mutate(state, pot, False)[0]
        vT = evaluate(value_net, sT)
        vF = evaluate(value_net, sF)
        
        # print(vT, vF)
        return vT > vF
    return player

In [5]:
vn = DVN(N).to(device)
# vn.load_state_dict(torch.load('m.mdl'))
players = [spawn_player(vn) for _ in range(N)]

# vns = [DVN(N).to(device) for _ in range(N)]
# players = [spawn_player(vn) for vn in vns]

In [6]:
def gym(players):
    card = np.arange(33) + 3
    
    deck = card.copy()
    np.random.shuffle(deck)
    deck = list(deck[:24])
    def draw():
        return np.r_[0, (card.copy() == deck.pop()).astype(int)]

    n = len(players)
    state = np.r_[[np.r_[11, np.zeros((33,), int)] for _ in range(n)]]
    state = np.r_[state, state.sum(axis=0).reshape((1, -1))]
    
    turn = np.random.randint(n)
    replay = []
    # non-terminal: is_terminal, id, state_before, state_after, reward
    # terminal: is_terminal, id, state, score, place
    
    while len(deck) > 0:
        pot = draw()
        while True:
            nothanks = players[turn](state, pot)

            state_, pot_ = mutate(state, pot, nothanks)
            reward = get_score(state_) - get_score(state)
            replay.append((False, turn, state, state_, reward))
            if state[0,0] > 0:
                state_f, pot_f = mutate(state, pot, not nothanks)
                reward_f = get_score(state_f) - get_score(state)
                replay.append((False, turn, state, state_f, reward_f))            
            
            if nothanks:
                turn = (turn + 1)%n
                state = next_turn(state_)
                pot = pot_
            else:
                state = state_
                break
                
    scores = []
    for _ in range(n):
        scores.append(-get_score(state))
        state = next_turn(state)
    scores.sort()
    for _ in range(n):
        score = get_score(state)
        replay.append((True, turn, state, score, scores.index(-score)))
        state = next_turn(state)
        turn = (turn + 1)%n
      
    return replay

In [7]:
def train(vn, replay, gamma=0.95, batch_size=64, place_value=[0, 0, 0, 0, 0], epoch=10):
    vn.train()
    criterion = nn.SmoothL1Loss()
    optimizer = optim.Adam(vn.parameters())
    
    n_replay = len(replay)
    for _ in range(epoch):
        argseq = np.arange(n_replay)
        np.random.shuffle(argseq)
        for i in irange(0, n_replay, batch_size):
            x = []
            y = []
            if argseq[i:].size < batch_size:
                break
            for j in range(batch_size):
                is_terminal, _, state, a, b = replay[argseq[i+j]]
                x.append(state.flatten())
                if is_terminal:
                    y.append(a + place_value[b])
                else:
                    state_after = a
                    reward = b
                    y.append(reward + evaluate(vn, state_after)*gamma)

            x = to_tensor(np.r_[x])
            y = torch.from_numpy(np.r_[y].reshape((-1,1))).float().to(device)
            prediction = vn(x)
            loss = criterion(prediction, target=y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
    return loss

In [8]:
replay = []
for lineage in range(1024):
    for _ in irange(8):
        replay += gym(players)
    replay = replay[-32*256:]
    loss = train(vn, replay, batch_size=64, epoch=1)
    torch.save(vn.state_dict(), 'm.mdl')
    
    clear_output(wait=True)
    total = 0
    print(f'lineage: {lineage}')
    for _, idx, state, score, place in gym(players)[-N:]:
        total += score
        print(f'  player {idx}: #{place+1} ({score}): {state[0,0]} - [{", ".join([f"{i+2}" for i in range(1,34) if state[0,i] == 1])}]')
    print(f'total: {total}')
    print(f'loss:', loss.data.cpu().numpy() + 0)

lineage: 1023
  player 4: #3 (-35): 1 - [15, 21]
  player 0: #1 (0): 0 - []
  player 1: #5 (-87): 54 - [3, 4, 5, 9, 11, 12, 13, 16, 20, 23, 24, 25, 26, 28, 29, 31, 32, 33]
  player 2: #2 (-34): 0 - [7, 10, 17]
  player 3: #3 (-35): 0 - [35]
total: -191
loss: 6.3656134605407715


In [9]:
for lineage in range(1024):
    for _ in irange(8):
        replay += gym(players)
    replay = replay[-32*256:]
    loss = train(vn, replay, batch_size=64, epoch=1)
    torch.save(vn.state_dict(), 'm.mdl')
    
    clear_output(wait=True)
    total = 0
    print(f'lineage: {lineage}')
    for _, idx, state, score, place in gym(players)[-N:]:
        total += score
        print(f'  player {idx}: #{place+1} ({score}): {state[0,0]} - [{", ".join([f"{i+2}" for i in range(1,34) if state[0,i] == 1])}]')
    print(f'total: {total}')
    print(f'loss:', loss.data.cpu().numpy() + 0)

lineage: 1023
  player 1: #1 (-14): 1 - [15]
  player 2: #5 (-121): 52 - [4, 5, 6, 9, 11, 12, 16, 20, 24, 26, 27, 29, 34]
  player 3: #4 (-114): 1 - [23, 25, 32, 35]
  player 4: #2 (-29): 1 - [3, 10, 17]
  player 0: #3 (-48): 0 - [8, 18, 22]
total: -326
loss: 4.929439067840576


In [10]:
for lineage in range(1024):
    for _ in irange(8):
        replay += gym(players)
    replay = replay[-32*256:]
    loss = train(vn, replay, batch_size=64, epoch=1)
    torch.save(vn.state_dict(), 'm.mdl')
    
    clear_output(wait=True)
    total = 0
    print(f'lineage: {lineage}')
    for _, idx, state, score, place in gym(players)[-N:]:
        total += score
        print(f'  player {idx}: #{place+1} ({score}): {state[
0,0]} - [{", ".join([f"{i+2}" for i in range(1,34) if state[0,i] == 1])}]')
    print(f'total: {total}')
    print(f'loss:', loss.data.cpu().numpy() + 0)

lineage: 1023
  player 4: #2 (4): 14 - [10]
  player 0: #3 (-10): 12 - [22]
  player 1: #1 (10): 10 - []
  player 2: #5 (-198): 9 - [3, 4, 7, 12, 13, 15, 20, 26, 28, 30, 32, 34, 35]
  player 3: #4 (-155): 10 - [6, 8, 11, 14, 18, 21, 23, 31, 33]
total: -349
loss: 6.5375471115112305


In [11]:
for lineage in range(1024*64):
    for _ in irange(8):
        replay += gym(players)
    replay = replay[-32*256:]
    loss = train(vn, replay, batch_size=64, epoch=1)
    torch.save(vn.state_dict(), 'm.mdl')
    
    clear_output(wait=True)
    total = 0
    print(f'lineage: {lineage}')
    for _, idx, state, score, place in gym(players)[-N:]:
        total += score
        print(f'  player {idx}: #{place+1} ({score}): {state[0,0]} - [{", ".join([f"{i+2}" for i in range(1,34) if state[0,i] == 1])}]')
    print(f'total: {total}')
    print(f'loss:', loss.data.cpu().numpy() + 0)

lineage: 10394
  player 1: #3 (-42): 3 - [21, 24, 25, 26]
  player 2: #1 (-12): 1 - [5, 8]
  player 3: #5 (-91): 3 - [12, 19, 28, 35]
  player 4: #2 (-39): 47 - [4, 6, 7, 17, 27, 32, 33]
  player 0: #4 (-56): 1 - [3, 9, 10, 11, 16, 29, 30]
total: -240
loss: 6.625244140625


KeyboardInterrupt: 