In [49]:
import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from collections import OrderedDict


## 读入(这段读的是aigame作业的下发json)

In [50]:
all_json = []

aigame_botid = ["6048fc6b81fb3b738e911e3b",
               "6048fcf381fb3b738e912cb8",
               "6048fd3781fb3b738e9138ac",
               "6048fd7981fb3b738e9140fc",
               "6048fda981fb3b738e914488"]

def get_all_json(cwd):
    get_dir = os.listdir(cwd)  
    for i in get_dir:          
        sub_dir = os.path.join(cwd,i)  
        if os.path.isdir(sub_dir):     
            get_all_json(sub_dir)
        else:
            if i[-5:] == ".json":
                all_json.append(cwd + "/" + i)
                
get_all_json("data/")

## 读入(这段读的是botzone上下载的对局数据，nb_bots是天梯上排名前30的botid)

In [51]:
all_matches = []

def get_all_matches(cwd):
    get_dir = os.listdir(cwd)  
    for i in get_dir:          
        sub_dir = os.path.join(cwd,i)  
        if os.path.isdir(sub_dir):
            get_all_matches(sub_dir)
        else:
            if i[-8:] == ".matches":
                all_matches.append(cwd + "/" + i)

get_all_matches("data/")
nb_bots = open("nbbot.txt", "r+").read().split('\n')

## 预处理 Combo

In [91]:
def getCardId(card):
    # 求一张牌的 id
    if card < 52:
        return card // 4
    else:
        return card - 39

def getCombo(cards):
    # a combo is represented as a tuple(k, l, r, w)
    # 表示有 k * [l, r] 即 k 张 [l, r] 中的牌（作为主体）w \in [0, 1, 2] 表示带的是啥类型
    if len(cards) == 0:
        return (0, 0, 0, 0)
    tmp = np.zeros(15, dtype = int)
    for card in cards:
        tmp[getCardId(card)] += 1
    k = np.max(tmp)
    l = np.min(np.where(tmp == k))
    r = np.max(np.where(tmp == k))
    w = 0
    if k == 3:
        w = len(cards) // (r - l + 1) - 3
    if k == 4:
        w = (len(cards) // (r - l + 1) - 4) // 2
    return (k, l, r, w)

combo_dict = {}
combo_list = []
combo_cnt = 0

def initCombo():
    global combo_dict, combo_list, combo_cnt
    combo_dict = {}
    combo_list = []
    combo_cnt = 0
    def addCombo(combo):
        global combo_dict, combo_list, combo_cnt
        combo_list.append(combo)
        combo_dict[combo] = combo_cnt
        combo_cnt += 1

    minLength = [0, 5, 3, 2, 2]
    maxWings = [0, 1, 1, 3, 3]
    fold = [0, 0, 0, 1, 2]
    for k in range(1, 5):
        for x in range(13):
            for w in range(maxWings[k]):
                addCombo((k, x, x, w))
        for l in range(12):
            for r in range(l + minLength[k] - 1, 12):
                for w in range(maxWings[k]):
                    if (r - l + 1) * (k + w * fold[k]) <= 20:
                        addCombo((k, l, r, w))
    addCombo((1, 13, 13, 0))
    addCombo((1, 14, 14, 0))
    addCombo((1, 13, 14, 0))
    addCombo((0, 0, 0, 0))
    
initCombo()

def getPartition(cards):
    # 把一次出牌的编号集合划分成 mainbody 和 bywings
    # 其中 mainbody 是一个 list ，bywings 中每个 wing 是一个 list ，也就是一个 list 的 list
    combo = getCombo(cards)
    tmp = [[] for i in range(15)]
    for card in cards:
        tmp[getCardId(card)].append(card)
    mainbody, bywings = [], []
    for i in range(15):
        if len(tmp[i]) > 0:
            if combo[1] <= i and i <= combo[2]:
                mainbody.extend(tmp[i])
            else:
                bywings.append(tmp[i])
    return mainbody, bywings

def getComboMask(combo):
    # 给出一个 combo ，返回可以接在其后面牌型 mask 
    mask = np.zeros(combo_cnt)
    if combo == (0, 0, 0, 0):
        mask = np.ones(combo_cnt)
        mask[combo_dict[(0, 0, 0, 0)]] = 0
        return mask
    mask[combo_dict[(0, 0, 0, 0)]] = 1

    if combo == (1, 13, 14, 0):
        return mask
    mask[combo_dict[(1, 13, 14, 0)]] = 1

    if combo[0] == 4 and combo[1] == combo[2] and combo[3] == 0:
        for i in range(combo[1] + 1, 13):
            mask[combo_dict[(4, i, i, 0)]] = 1
        return mask
    for i in range(13):
        mask[combo_dict[(4, i, i, 0)]] = 1

    for cb in combo_list:
        if cb[0] == combo[0] and cb[2] - cb[1] == combo[2] - combo[1] and cb[3] == combo[3] and cb[1] > combo[1]:
            mask[combo_dict[cb]] = 1
            
    return mask


In [53]:
class Game(object):
    # 这里 0 始终是地主，1 始终是地主下家，2 始终是地主上家

    def __init__(self, init_data):
        self.hand = [np.zeros(15, dtype = int), np.zeros(15, dtype = int), np.zeros(15, dtype = int)]
        for player in range(3):
            for card in init_data[player]:
                self.hand[player][getCardId(card)] += 1
    
    def play(self, player, cards):
        # 模拟打牌 打出 cards 这个 list 中的所有牌
        for card in cards:
            self.hand[player][getCardId(card)] -= 1
            
    def possess(self, player, combo):
        # 判断 player 这个玩家是否拥有 combo 这个牌型的牌
        if combo == (0, 0, 0, 0):
            return True
        for i in range(combo[1], combo[2] + 1):
            if self.hand[player][i] < combo[0]:
                return False
            
        fold = [0, 0, 0, 1, 2]
        need_wings = (combo[2] - combo[1] + 1) * fold[combo[0]] if combo[3] > 0 else 0
        for i in range(15):
            if i < combo[1] or i > combo[2]:
                if self.hand[player][i] >= combo[3]:
                    need_wings -= 1
        if need_wings > 0:
            return False
        return True
    
    def getPossessMask(self, player):
        # 返回 player 拥有的牌型 mask
        mask = np.zeros(combo_cnt)
        for i in range(combo_cnt):
            if self.possess(player, combo_list[i]) == True:
                mask[i] = 1
        return mask
    
    def getMask1(self, player, combo):
        # getPossessMask 和 getComboMask 取交集
        return self.getPossessMask(player) * getComboMask(combo)
    
    def getMask2(self, player, combo, already_played):
        # 带翼的 mask，哪些翼是可以打的？
        # mask 的大小是 28, 表示 15 种单牌和 13 种对子
        # 指明 combo 后：(1)单牌/对子不能错 (2)少于1/2张的不能打 (3)和主体部分重复的不能打 (4)打过的不能打
        mask = np.ones(28)
        if combo[3] == 1:
            mask[range(15, 28)] = 0
            for i in range(13):
                if self.hand[player][i] < 1:
                    mask[i] = 0
            mask[range(combo[1], combo[2] + 1)] = 0
            for i in already_played:
                mask[i] = 0
        else:
            assert combo[3] == 2
            mask[range(0, 15)] = 0
            for i in range(13):
                if self.hand[player][i] < 2:
                    mask[i + 15] = 0
            mask[range(combo[1] + 15, combo[2] + 16)] = 0
            for i in already_played:
                mask[i + 15] = 0
        return mask
    
    def getInput(self, player):
        # 返回两个网络的输入
        # 这里包含五个部分：我自己的手牌数、对手的手牌数、我的顺子情况、三个人还剩多少张牌、我拥有牌型的 mask
        # size = 4 * 15 + 4 * 15 + 4 * 12 + 3 * 20 + 379
        p1 = (player + 1) % 3
        p2 = (player + 2) % 3

        myhand = np.zeros((4, 15))
        othershand = np.zeros((4, 15))
        for i in range(4):
            myhand[i, np.where(self.hand[player] >= i + 1)] = 1
            othershand[i, np.where(self.hand[p1] + self.hand[p2] >= i + 1)] = 1
        
        mystraight = np.zeros((4, 12))
        for i in range(4):
            k = 0
            for j in range(12):
                if self.hand[player][i] >= i + 1:
                    k += 1
                else:
                    k = 0
                mystraight[i, j] = k
                
        handcnt = np.zeros((3, 20))
        for player in range(3):
            handcnt[player, range(np.sum(self.hand[player]))] = 1

        return np.concatenate([myhand.flatten(), othershand.flatten(), mystraight.flatten(), handcnt.flatten(), self.getPossessMask(player)])
    
_input_size = 60 + 60 + 48 + 60 + combo_cnt

## 定义数据集类

In [54]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.X = []
        self.M = []
        self.Y = []
    def append(self, x, m, y):
        self.X.append(x)
        self.M.append(m)
        self.Y.append(y)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return (self.X[idx], self.M[idx], self.Y[idx])


In [55]:
DS1 = [MyDataset(), MyDataset(), MyDataset()]
DS2 = MyDataset()

In [56]:
print_cnt = 0

for file in all_json:
    print(file)
    with open(file, 'r+') as f:
        while True:
            data = f.readline()
            if(len(data) == 0):
                break
            data = json.loads(data)
            initdata = json.loads(data['initdata'])

            nb_bot = [0, 0, 0]
            exists_nb = 0
            for i in range(3):
                if "bot" in data['players'][i].keys() and data['players'][i]['bot'] == aigame_botid[0]:
                    nb_bot[i] = 1
                    exists_nb = 1
            if exists_nb == 0:
                break

            g = Game(initdata['allocation'])
            log = data['log']
            
            las_combo = (0, 0, 0, 0)
            las_play = -1
            
            for i in range(1, len(log), 2):
                player = -1
                for p in range(3):
                    if str(p) in log[i]:
                        player = p
                        break
                
                if log[i][str(p)]['verdict'] != 'OK':
                    print(log[i][str(p)]['verdict'])
                    break

                if las_play == player:
                    las_combo = (0, 0, 0, 0)
                    
                cards = log[i][str(player)]['response']
                cur_combo = getCombo(cards)
                mainbody, bywings = getPartition(cards)
                    
                if nb_bot[player]:
                    input_x = g.getInput(player)
                    input_m = g.getMask1(player, las_combo)
                    output_y = combo_dict[cur_combo]
                    assert input_m[output_y] == 1
                    if np.sum(input_m) > 1:
                        DS1[player].append(input_x, input_m, output_y)
                
                g.play(player, mainbody)
                
                if len(bywings) != 0:
                    already_played = []
                    for w in bywings:
                        assert len(w) == 1 or len(w) == 2
                        if nb_bot[player]:
                            input_x = g.getInput(player)
                            input_m = g.getMask2(player, cur_combo, already_played)
                            output_y = getCardId(w[0]) + (0 if len(w) == 1 else 15)
                            if input_m[output_y] != 1:
                                print(cards)
                                print(mainbody, bywings)
                                print(cur_combo)
                                print(g.hand[player])
                                print(already_played)
                                assert 0
                            assert input_m[output_y] == 1
                            DS.append(input_x, input_m, output_y)
                            
                        g.play(player, w)
                        already_played.append(getCardId(w[0]))
                
                if cur_combo != (0, 0, 0, 0):
                    las_combo = cur_combo
                    las_play = player
                    
                
    
    print_cnt += 1
    print("files = %d, lengths = (%d, %d, %d, %d, %d, %d)"
          % (print_cnt, len(DS1[0]), len(DS1[1]), len(DS1[2]), len(DS2)))


data/download_bot_matches/1_6048fc6b81fb3b738e911e3b/train/log_1_20.json
files = 1, lengths = (1169, 1007, 1146, 165, 113, 113)
data/download_bot_matches/1_6048fc6b81fb3b738e911e3b/train/log_1_16.json
files = 2, lengths = (2152, 1807, 2007, 303, 206, 196)
data/download_bot_matches/1_6048fc6b81fb3b738e911e3b/train/log_1_10.json
files = 3, lengths = (2840, 2408, 2573, 391, 277, 260)
data/download_bot_matches/1_6048fc6b81fb3b738e911e3b/train/log_1_3.json
files = 4, lengths = (3087, 2607, 2796, 427, 296, 277)
data/download_bot_matches/1_6048fc6b81fb3b738e911e3b/train/log_1_4.json
files = 5, lengths = (3697, 3108, 3290, 520, 342, 323)
data/download_bot_matches/1_6048fc6b81fb3b738e911e3b/train/log_1_5.json
files = 6, lengths = (4058, 3396, 3581, 580, 380, 348)
data/download_bot_matches/1_6048fc6b81fb3b738e911e3b/train/log_1_22.json
files = 7, lengths = (5064, 4317, 4569, 733, 479, 430)
data/download_bot_matches/1_6048fc6b81fb3b738e911e3b/train/log_1_21.json
files = 8, lengths = (6193, 5322, 

files = 84, lengths = (23785, 20027, 21129, 3529, 2143, 1982)
data/download_bot_matches/2_6048fcf381fb3b738e912cb8/train/log_2_13.json
files = 85, lengths = (23785, 20027, 21129, 3529, 2143, 1982)
data/download_bot_matches/2_6048fcf381fb3b738e912cb8/train/log_2_5.json
files = 86, lengths = (23785, 20027, 21129, 3529, 2143, 1982)
data/download_bot_matches/2_6048fcf381fb3b738e912cb8/train/log_2_18.json
files = 87, lengths = (23785, 20027, 21129, 3529, 2143, 1982)
data/download_bot_matches/2_6048fcf381fb3b738e912cb8/train/log_2_12.json
files = 88, lengths = (23785, 20027, 21129, 3529, 2143, 1982)
data/download_bot_matches/2_6048fcf381fb3b738e912cb8/train/log_2_31.json
files = 89, lengths = (23785, 20027, 21129, 3529, 2143, 1982)
data/download_bot_matches/2_6048fcf381fb3b738e912cb8/train/log_2_26.json
files = 90, lengths = (23785, 20027, 21129, 3529, 2143, 1982)
data/download_bot_matches/2_6048fcf381fb3b738e912cb8/train/log_2_27.json
files = 91, lengths = (23785, 20027, 21129, 3529, 2143, 

files = 155, lengths = (23785, 20027, 21129, 3529, 2143, 1982)
data/download_bot_matches/3_6048fd3781fb3b738e9138ac/train/log_3_17.json
files = 156, lengths = (23785, 20027, 21129, 3529, 2143, 1982)
data/download_bot_matches/3_6048fd3781fb3b738e9138ac/train/log_3_10.json
files = 157, lengths = (23785, 20027, 21129, 3529, 2143, 1982)
data/download_bot_matches/3_6048fd3781fb3b738e9138ac/train/log_3_34.json
files = 158, lengths = (23785, 20027, 21129, 3529, 2143, 1982)
data/download_bot_matches/3_6048fd3781fb3b738e9138ac/train/log_3_31.json
files = 159, lengths = (23785, 20027, 21129, 3529, 2143, 1982)
data/download_bot_matches/3_6048fd3781fb3b738e9138ac/train/log_3_22.json
files = 160, lengths = (23785, 20027, 21129, 3529, 2143, 1982)
data/download_bot_matches/3_6048fd3781fb3b738e9138ac/train/log_3_13.json
files = 161, lengths = (23785, 20027, 21129, 3529, 2143, 1982)
data/download_bot_matches/3_6048fd3781fb3b738e9138ac/train/log_3_30.json
files = 162, lengths = (23785, 20027, 21129, 352

In [57]:
print(len(DS1[0]), len(DS1[1]), len(DS1[2]), len(DS2))

23785 20027 21129
3529 2143 1982


## 网络框架

In [93]:
HIDDEN_SIZE = 512
torch.set_default_tensor_type(torch.DoubleTensor)

class MyModule(nn.Module):
    def __init__(self, INPUT_SIZE, OUTPUT_SIZE):
        super(MyModule, self).__init__()
        self.fc = nn.Sequential(OrderedDict([
                ('fc1', nn.Linear(INPUT_SIZE, HIDDEN_SIZE)),
                ('relu', nn.ReLU()),
                ('bn', nn.BatchNorm1d(HIDDEN_SIZE)),
                ('dropout', nn.Dropout(p = 0.1)),
                ('fc2', nn.Linear(HIDDEN_SIZE, OUTPUT_SIZE)),
            ]))
        
    def forward(self, x, m):
        return nn.Softmax(dim = -1)(self.fc(x) * m)
    

## 模型训练（主体）

In [94]:
best_acc = 0

def getPred(output):
    return output.detach().numpy().argmax(axis = 1)

def train(net, data_loader, data_size, criterion, optimizer):
    net.train()
    for i, (x, m, y) in enumerate(data_loader):
        optimizer.zero_grad()
        output = net(x, m)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        
    total_correct = 0
    avg_loss = 0.0
    for i, (x, m, y) in enumerate(data_loader):
        output = net(x, m)
        avg_loss += criterion(output, y).sum()
        pred = getPred(output)
        total_correct += (pred == y.detach().numpy()).sum()
    avg_loss /= data_size
    cur_acc = float(total_correct) / data_size
    print('Training Avg. Loss: %f, Accuracy: %f' % (avg_loss, cur_acc))

def validate(net, data_loader, data_size, criterion, model_name):
    global best_acc
    net.eval()
    total_correct = 0
    avg_loss = 0.0
    for i, (x, m, y) in enumerate(data_loader):
        output = net(x, m)
        avg_loss += criterion(output, y).sum()
        pred = getPred(output)
        total_correct += (pred == y.detach().numpy()).sum()
    
    avg_loss /= data_size
    cur_acc = float(total_correct) / data_size
    print('Validation Avg. Loss: %f, Accuracy: %f' % (avg_loss, cur_acc))
    
    if cur_acc > best_acc:
        best_acc = cur_acc
        torch.save(net.state_dict(), './model/best_model_for_' + model_name + '.pt')


In [95]:
for i in range(3):
        
    print("-----------------------------------------")
    print("Training model %d" % i)
    print("-----------------------------------------")

    net = MyModule(_input_size, combo_cnt)
    best_acc = 0
    
    train_size = int(0.9 * len(DS1[i]))
    valid_size = len(DS1[i]) - train_size

    train_data, valid_data = torch.utils.data.random_split(DS1[i], [train_size, valid_size])

    train_loader = torch.utils.data.DataLoader(train_data, shuffle = True, batch_size = 128)
    valid_loader = torch.utils.data.DataLoader(valid_data, batch_size = 128)

    for epoch in range(40):
        print("epoch = ", epoch)
        train(net, data_loader = train_loader,
              data_size = train_size,
              criterion = nn.CrossEntropyLoss(),
              optimizer = optim.Adam(net.parameters(), lr = 2e-3))

        validate(net, data_loader = valid_loader,
                 data_size = valid_size,
                 criterion = nn.CrossEntropyLoss(),
                 model_name = str(i) + "mainbody")

-----------------------------------------
Training model 0
-----------------------------------------
epoch =  0
Training Avg. Loss: 0.042377, Accuracy: 0.553910
Validation Avg. Loss: 0.043064, Accuracy: 0.563262
epoch =  1
Training Avg. Loss: 0.042219, Accuracy: 0.568859
Validation Avg. Loss: 0.042934, Accuracy: 0.576293
epoch =  2


KeyboardInterrupt: 

## 模型训练（带翼）

In [88]:
net = MyModule(_input_size, 28)
best_acc = 0

train_size = int(0.9 * len(DS2))
valid_size = len(DS2) - train_size

train_data, valid_data = torch.utils.data.random_split(DS2, [train_size, valid_size])

train_loader = torch.utils.data.DataLoader(train_data, shuffle = True, batch_size = 64)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size = 64)

for epoch in range(40):

    train(net, data_loader = train_loader,
           data_size = train_size,
           criterion = nn.CrossEntropyLoss(),
           optimizer = optim.Adam(net.parameters(), lr = 2e-3))

    validate(net, data_loader = valid_loader,
              data_size = valid_size,
              criterion = nn.CrossEntropyLoss(),
              model_name = str(i) + "bywings")

-----------------------------------------
Training model %d 0
-----------------------------------------
Training Avg. Loss: 0.020236, Accuracy: 0.703086
Validation Avg. Loss: 0.027811, Accuracy: 0.614731
Training Avg. Loss: 0.007777, Accuracy: 0.887280
Validation Avg. Loss: 0.014672, Accuracy: 0.804533
Training Avg. Loss: 0.005346, Accuracy: 0.904282
Validation Avg. Loss: 0.012303, Accuracy: 0.818697
Training Avg. Loss: 0.004064, Accuracy: 0.932935
Validation Avg. Loss: 0.009874, Accuracy: 0.847025
Training Avg. Loss: 0.003234, Accuracy: 0.947103
Validation Avg. Loss: 0.010340, Accuracy: 0.821530
Training Avg. Loss: 0.003344, Accuracy: 0.941436
Validation Avg. Loss: 0.011472, Accuracy: 0.818697
Training Avg. Loss: 0.002781, Accuracy: 0.948992
Validation Avg. Loss: 0.009851, Accuracy: 0.841360
Training Avg. Loss: 0.001970, Accuracy: 0.968829
Validation Avg. Loss: 0.009216, Accuracy: 0.875354
Training Avg. Loss: 0.002005, Accuracy: 0.967884
Validation Avg. Loss: 0.008729, Accuracy: 0.858

Training Avg. Loss: 0.021504, Accuracy: 0.707235
Validation Avg. Loss: 0.029234, Accuracy: 0.592965
Training Avg. Loss: 0.009758, Accuracy: 0.864274
Validation Avg. Loss: 0.017874, Accuracy: 0.778894
Training Avg. Loss: 0.006933, Accuracy: 0.905216
Validation Avg. Loss: 0.014701, Accuracy: 0.809045
Training Avg. Loss: 0.005343, Accuracy: 0.910824
Validation Avg. Loss: 0.012797, Accuracy: 0.778894
Training Avg. Loss: 0.004479, Accuracy: 0.930454
Validation Avg. Loss: 0.012245, Accuracy: 0.814070
Training Avg. Loss: 0.004500, Accuracy: 0.927089
Validation Avg. Loss: 0.011974, Accuracy: 0.829146
Training Avg. Loss: 0.003653, Accuracy: 0.943915
Validation Avg. Loss: 0.009680, Accuracy: 0.844221
Training Avg. Loss: 0.003154, Accuracy: 0.948962
Validation Avg. Loss: 0.011024, Accuracy: 0.824121
Training Avg. Loss: 0.002701, Accuracy: 0.960179
Validation Avg. Loss: 0.010312, Accuracy: 0.834171
Training Avg. Loss: 0.002591, Accuracy: 0.956254
Validation Avg. Loss: 0.010795, Accuracy: 0.824121


## main.py

In [105]:
my_hand = []
g = Game([[], [], []])
my_pos = -1
others = []
las_combo = (0, 0, 0, 0)

def BIDDING():
    bid_val = 0
    print(json.dumps({
        "response": bid_val
    }))
    assert(0)
    exit()

def PLAYING():
    def getFromHand(idx):
        global my_hand
        for c in my_hand:
            if getCardId(c) == idx:
                my_hand.remove(c)
                return c
    
    to_play = []
    
    model_path = "./model/"
    model_name = "best_model_for_" + str(my_pos) + "mainbody.pt"
    model = MyModule(_input_size, combo_cnt)
    model.load_state_dict(torch.load(model_path + model_name))
    
    mask = g.getMask1(my_pos, las_combo)
    combo_id = -1
    if np.sum(mask) == 1:
        combo_id = np.argmax(mask)
    else:
        combo_id = model(torch.from_numpy(g.getInput(my_pos)).unsqueeze(0),
                         torch.from_numpy(mask)).unsqueeze(0).detach().numpy().argmax()
    
    combo = combo_list[combo_id]
    for i in range(combo[1], combo[2] + 1):
        for j in range(combo[0]):
            to_play.append(getFromHand(i))
    g.play(my_pos, to_play)
    
    if combo[3] != 0:
        model_name = "best_model_for_" + str(my_pos) + "bywings.pt"
        model = MyModule(_input_size, 28)
        model.load_state_dict(torch.load(model_path + model_name))

        cnt = (combo[2] - combo[1] + 1) * (1 if combo[0] == 3 else 2)
        already_played = []
        for i in range(cnt):
            wing_id = model(torch.from_numpy(g.getInput(my_pos)),
                            torch.from_numpy(g.getMask2(my_pos, combo, already_played))).detach().numpy().argmax()
            tmp = []
            if wing_id < 15:
                tmp = [getFromHand(wing_id)]
            else:
                wing_id -= 15
                tmp = [getFromHand(wing_id), getFromHand(wing_id)]
            g.play(my_pos, tmp)
            to_play.extend(tmp)
            already_played.append(wing_id)
    
    print(json.dumps({
        "response": to_play
    }))
    assert 0
    exit()
    
if __name__ == "__main__":
    initCombo()
    data = json.loads(input())
    my_hand, others_hand = data["requests"][0]["own"], []
    for i in range(54):
        if i not in my_hand:
            others_hand.append(i)

    TODO = "bidding"
    if "bid" in data["requests"][0]:
        bid_list = data["requests"][0]["bid"]
    
    for i in range(len(data["requests"])):
        request = data["requests"][i]
        if "publiccard" in request:
            bot_pos = request["pos"]
            lord_pos = request["landlord"]
            my_pos = (bot_pos - lord_pos + 3) % 3
            others = [(my_pos + 1) % 3, (my_pos + 2) % 3]
            tmp = [[], [], []]
            if my_pos == 0:
                my_hand.extend(request["publiccard"])
                tmp[0] = my_hand
                tmp[1], tmp[2] = others_hand[:17], others_hand[17:] # 随便分
            else:
                tmp[my_pos] = my_hand
                tmp[0] = others_hand[:20]
                tmp[2 if my_pos == 1 else 1] = others_hand[20:]
            g = Game(tmp)
        if "history" in request:
            history = request["history"]
            TODO = "playing"
            for j in range(2):
                p = others[j]
                cards = history[j]
                g.play(p, cards)
                cur_combo = getCombo(cards)
                if cur_combo != (0, 0, 0, 0):
                    las_combo = cur_combo

            if i < len(data["requests"]) - 1:
                cards = data["responses"][i]
                g.play(my_pos, cards)
    
    if TODO == "bidding":
        BIDDING()
    else:
        PLAYING()

{"requests":[ {"bid":[0, 0],"own":[8, 						30, 						1, 						32, 						22, 						28, 						33, 						24, 						31, 						41, 						42, 						25, 						51, 						19, 						46, 						5, 						26]}, { "history": [ 						[ 							0, 							9, 							10, 							11 						], 						[] 					], 					"own": [ 						8, 						30, 						1, 						32, 						22, 						28, 						33, 						24, 						31, 						41, 						42, 						25, 						51, 						19, 						46, 						5, 						26 					], 					"publiccard": [ 						21, 						38, 						9 					], 					"landlord": 0, 					"pos": 2, 					"finalbid": 1 					} ],"responses":[1]}


ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 512])