In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from fireplace.game import Game
from fireplace.player import Player
from fireplace.utils import random_draft
from fireplace import cards
from fireplace import actions
from fireplace.exceptions import GameOver, InvalidAction
from hearthstone.enums import CardClass, CardType
import random
import numpy as np
import sys
cards.db.initialize()

[fireplace.__init__]: Initializing card database
[fireplace.__init__]: Merged 4139 cards


In [38]:
from fastai.conv_learner import *

In [2]:
def setup_game():
    """
    initializes a game between two players
    Returns:
        game: A game entity representing the start of the game after the mulligan phase
    """

    #choose classes (priest, rogue, shaman, warlock)
    p1 = random.randint(6, 9)
    p2 = random.randint(6, 9)
    #initialize players and randomly draft decks
    #pdb.set_trace()
    deck1 = random_draft(CardClass(p1))
    deck2 = random_draft(CardClass(p2))
    player1 = Player("Player1", deck1, CardClass(p1).default_hero)
    player2 = Player("Player2", deck2, CardClass(p2).default_hero)
    #begin the game
    game = Game(players=(player1, player2))
    game.start()

    #Skip mulligan for now
    for player in game.players:
        cards_to_mulligan = random.sample(player.choice.cards, 0)
        player.choice.choose(*cards_to_mulligan)

    return game


In [3]:
game = setup_game()

[fireplace.entity]: Setting up game Game(players=(Player(name='Player1', hero=None), Player(name='Player2', hero=None)))
[fireplace.entity]: Tossing the coin... Player2 wins!
[fireplace.actions]: Player(name='Player1', hero=<Hero ('Valeera Sanguinar')>) triggering <TargetedAction: Summon(<Summon.CARD>=<HeroPower ('Dagger Mastery')>)> targeting [Player(name='Player1', hero=<Hero ('Valeera Sanguinar')>)]
[fireplace.actions]: Player1 summons [<HeroPower ('Dagger Mastery')>]
[fireplace.entity]: Empty stack, refreshing auras and processing deaths
[fireplace.actions]: Player(name='Player1', hero=<Hero ('Valeera Sanguinar')>) triggering <TargetedAction: Summon(<Summon.CARD>=<Hero ('Valeera Sanguinar')>)> targeting [Player(name='Player1', hero=<Hero ('Valeera Sanguinar')>)]
[fireplace.actions]: Player1 summons [<Hero ('Valeera Sanguinar')>]
[fireplace.entity]: Empty stack, refreshing auras and processing deaths
[fireplace.entity]: Player(name='Player1', hero=<Hero ('Valeera Sanguinar')>) shuff

In [4]:
def get_state(game, player):
    """
    Args:
        game, the current game object
        player, the player from whose perspective to analyze the state
    return:
        a numpy array features extracted from the
        supplied game.
    """

    p1 = player
    p2 = player.opponent
    s = np.zeros(263, dtype=np.int32)

    #0-9 player1 class, we subtract 1 here because the classes are from 1 to 10
    s[p1.hero.card_class-1] = 1
    #10-19 player2 class
    s[10 + p2.hero.card_class-1] = 1
    i = 20
    # 20-21: current health of current player, then opponent
    s[i] = p1.hero.health
    s[i + 1] = p2.hero.health

    # 22: hero power usable y/n
    s[i + 2] = p1.hero.power.is_usable()*1
    # 23-24: # of mana crystals for you opponent
    s[i + 3] = p1.max_mana
    s[i + 4] = p2.max_mana
    # 25: # of crystals still avalible
    s[i + 5] = p1.mana
    #26-31: weapon equipped y/n, pow., dur. for you, then opponent
    s[i + 6] = 0 if p1.weapon is None else 1
    s[i + 7] = 0 if p1.weapon is None else p1.weapon.damage
    s[i + 8] = 0 if p1.weapon is None else p1.weapon.durability

    s[i + 9] = 0 if p2.weapon is None else 1
    s[i + 10] = 0 if p2.weapon is None else p2.weapon.damage
    s[i + 11] = 0 if p2.weapon is None else p2.weapon.durability

    # 32: number of cards in opponents hand
    s[i + 12] = len(p2.hand)
    #in play minions

    i = 33
    #33-102, your monsters on the field
    p1_minions = len(p1.field)
    for j in range(0, 7):
        if j < p1_minions:
            # filled y/n, pow, tough, current health, can attack
            s[i] = 1
            s[i + 1] = p1.field[j].atk
            s[i + 2] = p1.field[j].max_health
            s[i + 3] = p1.field[j].health
            s[i + 4] = p1.field[j].can_attack()*1
            # deathrattle, div shield, taunt, stealth y/n
            s[i + 5] = p1.field[j].has_deathrattle*1
            s[i + 6] = p1.field[j].divine_shield*1
            s[i + 7] = p1.field[j].taunt*1
            s[i + 8] = p1.field[j].stealthed*1
            s[i + 9] = p1.field[j].silenced*1
        i += 10

    #103-172, enemy monsters on the field
    p2_minions = len(p2.field)
    for j in range(0, 7):
        if j < p2_minions:
            # filled y/n, pow, tough, current health, can attack
            s[i] = 1
            s[i + 1] = p2.field[j].atk
            s[i + 2] = p2.field[j].max_health
            s[i + 3] = p2.field[j].health
            s[i + 4] = p2.field[j].can_attack()*1
            # deathrattle, div shield, taunt, stealth y/n
            s[i + 5] = p2.field[j].has_deathrattle*1
            s[i + 6] = p2.field[j].divine_shield*1
            s[i + 7] = p2.field[j].taunt*1
            s[i + 8] = p2.field[j].stealthed*1
            s[i + 9] = p2.field[j].silenced*1
        i += 10

    #in hand

    #173-262, your cards in hand
    p1_hand = len(p1.hand)
    for j in range(0, 10):
        if j < p1_hand:
            #card y/n
            s[i] = 1
            # minion y/n, attk, hp, battlecry, div shield, deathrattle, taunt
            s[i + 1] = 1 if p1.hand[j].type == 4 else 0
            s[i + 2] = p1.hand[j].atk if s[i + 1] == 1 else 0
            s[i + 2] = p1.hand[j].health if s[i + 1] == 1 else 0
            s[i + 3] = p1.hand[j].divine_shield*1 if s[i + 1] == 1 else 0
            s[i + 4] = p1.hand[j].has_deathrattle*1 if s[i + 1] == 1 else 0
            s[i + 5] = p1.hand[j].taunt*1 if s[i + 1] == 1 else 0
            # weapon y/n, spell y/n, cost
            s[i + 6] = 1 if p1.hand[j].type == 7 else 0
            s[i + 7] = 1 if p1.hand[j].type == 5 else 0
            s[i + 8] = p1.hand[j].cost
        i += 9

    return s

In [5]:
s = get_state(game, game.player1)

In [23]:
class ResBlock(nn.Module):
    def __init__(self, ni, nf, stride=2, kernel_size=3):
        super(ResBlock, self).__init__()

        self.conv1 = nn.Conv2d(ni, nf, stride=stride, 
                kernel_size = kernel_size, padding=1)
        self.conv2 = nn.Conv2d(nf, nf, stride=stride, 
                kernel_size = kernel_size, padding=1)

        self.bn1 = nn.BatchNorm2d(nf)
        self.bn2 = nn.BatchNorm2d(nf)

    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(residual + self.bn2(self.conv2(x)))
        return x

In [31]:
class DQN(nn.Module): #layers = [16, 32, 64, 128, 256]
    def __init__(self, layers, state):
        super(DQN, self).__init__()
        self.state_size = state.size

        #first conv layer
        self.conv1 = nn.Conv2d(4, 16, #append last two turns onto input as 3rd dim
        	kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)

        #residual layers (40?)
        self.layers1 = nn.ModuleList([ResBlock(layers[i], layers[i+1])
            for i in range(len(layers) - 1)])
        self.layers2 = nn.ModuleList([ResBlock(layers[i], layers[i+1])
            for i in range(len(layers) - 1)])
        self.layers3 = nn.ModuleList([ResBlock(layers[i], layers[i+1])
            for i in range(len(layers) - 1)])

        #policy head
        self.act_conv1 = nn.Conv2d(256, 4, kernel_size=1)
        self.act_bn1 = nn.BatchNorm2d(4)
        self.act_fc1 = nn.Linear(4*self.state_size, self.state_size)
        
        #value head
        self.val_conv1 = nn.Conv2d(256, 2, kernel_size=1)
        self.val_bn1 = nn.BatchNorm2d(2)
        self.val_fc1 = nn.Linear(2*state.size, 64)
        self.val_fc2 = nn.Linear(64, 1)

    def forward(self, state_input):
                #train conv layer --> feed into res layers, 
                #pass into action and value seperately and return

        x = F.relu(self.bn1(self.conv1(state_input))) #convlayer of state into bn into relu

                #residual layers
        for l,l2,l3 in zip(self.layers1, self.layers2, self.layers3):
            x = l3(l2(l(x))) #conv into bn into relu(orig + conv into bn) into each layer batch
        #action policy head (action probabilities)
        x_act = F.relu(self.act_bn1(self.act_conv1(x))) #feed resnet into policy head
        x_act = x_act.view(-1, 4*self.state_size)
        x_act = F.log_softmax((self.act_fc1(x_act)))
        #value head (score of board state)
        x_val = F.relu(self.val_bn1(self.val_conv1(x))) #feed resnet into value head
        x_val = x_val.view(-1, 2*self.state_size)
        x_val = F.relu(self.val_fc1(x_val))
        x_val = F.tanh(self.val_fc2(x_val))

        return x_act, x_val

In [32]:
class PolicyEvaluator():
        def __init__(self, layers, state, model_file=None):
                self.state_size = state.size
                self.weight_decay = 1e-4
                self.policy_value_net = DQN(layers, board).cuda()
                self.optimizer = optim.Adam(self.policy_value_net.parameters(),
                                    weight_decay=self.weight_decay)

                if model_file:
                        net_params = torch.load(model_file)
                        self.policy_value_net.load_state_dict(net_params)

        def policy_value(self, state_batch):
                state_batch = Variable(torch.FloatTensor(state_batch).cuda())
                log_act_probs, value = self.policy_value_net(state_batch)
                act_probs = np.exp(log_act_probs.data.cpu().numpy())
                return act_probs, value.data.cpu().numpy()

In [33]:
state = s

In [34]:
net = DQN([16, 32, 64, 128, 256], state)