In [1]:
# Run this cell to mount your Google Drive.

from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [2]:
pth = '/content/drive/MyDrive/Colab Notebooks/Thesis'

In [3]:
all_classes = ['SP0','SP1','Draw','Pickup','DH','GIN',
               'AS', '2S', '3S', '4S', '5S', '6S', '7S', '8S', '9S', 'TS', 'JS', 'QS', 'KS',
               'AH', '2H', '3H', '4H', '5H', '6H', '7H', '8H', '9H', 'TH', 'JH', 'QH', 'KH',
               'AD', '2D', '3D', '4D', '5D', '6D', '7D', '8D', '9D', 'TD', 'JD', 'QD', 'KD',
               'AC', '2C', '3C', '4C', '5C', '6C', '7C', '8C', '9C', 'TC', 'JC', 'QC', 'KC',
               'AS', '2S', '3S', '4S', '5S', '6S', '7S', '8S', '9S', 'TS', 'JS', 'QS', 'KS',
               'AH', '2H', '3H', '4H', '5H', '6H', '7H', '8H', '9H', 'TH', 'JH', 'QH', 'KH',
               'AD', '2D', '3D', '4D', '5D', '6D', '7D', '8D', '9D', 'TD', 'JD', 'QD', 'KD',
               'AC', '2C', '3C', '4C', '5C', '6C', '7C', '8C', '9C', 'TC', 'JC', 'QC', 'KC']


# Gin Rummy

## Imports

In [4]:
#-------------------------------------------------------------------------------
# The following code was originally written by Todd Neller in Java.
# It was translated into Python by Anthony Hein.
#-------------------------------------------------------------------------------

#-------------------------------------------------------------------------------
# A class for modeling a game of Gin Rummy
# @author Todd W. Neller
# @version 1.0
#-------------------------------------------------------------------------------

#-------------------------------------------------------------------------------
# Copyright (C) 2020 Todd Neller
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# Information about the GNU General Public License is available online at:
#   http://www.gnu.org/licenses/
# To receive a copy of the GNU General Public License, write to the Free
# Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
# 02111-1307, USA.
#-------------------------------------------------------------------------------

import random
import time
import numpy as np
import os
import torch

%cd /content/drive/My Drive/Colab Notebooks/Thesis/GinRummy

from Deck import Deck
from GinRummyUtil import GinRummyUtil
from SimpleGinRummyPlayer import SimpleGinRummyPlayer

%cd /content/drive/My Drive/Colab Notebooks/Thesis/SupervisedLearning

from models import *

%cd /content/drive/My Drive/Colab Notebooks/Thesis
#-------------------------------------------------------------------------------

# TRACKING
# Plane (5x52)      Feature
# 0	 currHand       the cards in current player's hand
# 1	 topCard        the top card of the discard pile
# 2	 deadCard       the dead cards: cards in discard pile (excluding the top card)
# 3	 oppCard        opponent known cards: cards picked up from discard pile, but not discarded
# 4	 unknownCard    the unknown cards: cards in stockpile or in opponent hand (but not known)

# Action ID         Action
# 0	                score_player_0_action
# 1	                score_player_1_action
# 2	                draw_card_action
# 3	                pick_up_discard_action
# 4	                declare_dead_hand_action
# 5	                gin_action
# 6 - 57	        discard_action
# 58 - 109	        knock_action

# Knock_bin
# Action ID         Action
# 0	                No Knock
# 1	                Knock

def one_hot(cards):
    ret = np.zeros(52)
    for card in cards:
        ret[card.getId()] = 1
    return ret

def un_one_hot(arr):
    rankNames = ["A", "2", "3", "4", "5", "6", "7", "8", "9", "T", "J", "Q", "K"]
    suitNames = ['S', 'H', 'D', 'C']
    ret = []
    for i in range(len(arr)):
        if arr[i] != 0:
            ret.append(rankNames[i%13] + suitNames[i//13])
    return ret

#-------------------------------------------------------------------------------

/content/drive/My Drive/Colab Notebooks/Thesis/GinRummy
/content/drive/My Drive/Colab Notebooks/Thesis/SupervisedLearning
/content/drive/My Drive/Colab Notebooks/Thesis


In [5]:
class EstimatorNetwork(nn.Module):
    ''' The function approximation network for Estimator
        It is just a series of sigmoid layers. All in/out are torch.tensor
        (OLD) It is just a series of tanh layers. All in/out are torch.tensor
    '''

    def __init__(self, mlp_layers=None, batch_norm=False):
        ''' Initialize the Q network
        Args:
            action_num (int): number of legal actions
            state_shape (list): shape of state tensor
            mlp_layers (list): output size of each fc layer
        '''
        super(EstimatorNetwork, self).__init__()

        self.action_num = 110
        self.state_shape = 260
        self.mlp_layers = mlp_layers
        self.batch_norm = batch_norm

        # build the Q network
        layer_dims = [np.prod(self.state_shape)] + self.mlp_layers
        fc = [nn.Flatten()]
        if batch_norm:
            fc.append(nn.BatchNorm1d(layer_dims[0]))
        for i in range(len(layer_dims)-1):
            fc.append(nn.Linear(layer_dims[i], layer_dims[i+1], bias=True))
            fc.append(nn.Sigmoid())
        fc.append(nn.Linear(layer_dims[-1], self.action_num, bias=True))
        fc.append(nn.Softmax(dim=1))
        self.fc_layers = nn.Sequential(*fc)

    def forward(self, s):
        ''' Predict action values
        Args:
            s  (Tensor): (batch, state_shape)
        '''
        return self.fc_layers(s)

## MLPGinRummyPlayer

In [6]:
# -------------------------------------------------------------------------------
#  MLPGinRummyPlayer
#
#  This estimation will be calculated using a Multilayer Percepton trained on the
#  SimpleGinRummyPlayer written
#  by Calvin Tan.
#
#  @author Calvin Tan
#  @version 1.0
# -------------------------------------------------------------------------------

# -------------------------------------------------------------------------------
# The following code was originally written by Todd Neller in Java.
# It was translated into Python by May Jiang.
# -------------------------------------------------------------------------------

# -------------------------------------------------------------------------------
# Copyright (C) 2020 Todd Neller
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# Information about the GNU General Public License is available online at:
#   http://www.gnu.org/licenses/
# To receive a copy of the GNU General Public License, write to the Free
# Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
# 02111-1307, USA.
# -------------------------------------------------------------------------------

from typing import List, TypeVar
from random import randint
from GinRummyUtil import GinRummyUtil
from GinRummyPlayer import GinRummyPlayer

# Import MLP Models
# from SupervisedLearning.models import *

Card = TypeVar('Card')

class MLPGinRummyPlayer(GinRummyPlayer):

    def loadModel(self, model_pt):
        print('Load Model')
        self.model = model_pt

    def setVerbose(self, verbose):
        self.playVerbose = verbose

    def updateStates(self, states):
        if self.playVerbose:
            print('Update States')
        self.state = states

    def knockAction(self) -> bool:
        return self.knock

    # Inform player of 0-based player number (0/1), starting player number (0/1), and dealt cards
    def startGame(self, playerNum: int, startingPlayerNum: int, cards: List[Card]) -> None:
        self.playerNum = playerNum
        self.startingPlayerNum = startingPlayerNum
        self.cards = list(cards)
        self.opponentKnocked = False
        self.drawDiscardBitstrings = [] # long[], or List[int]
        self.faceUpCard = None
        self.faceUpCardBool = False
        self.drawnCard = None
        self.state = None
        self.knock = False
        self.playVerbose = False



    # def willDrawFaceUpCard(self, card: Card) -> bool:
    #     # Return true if card would be a part of a meld, false otherwise.
    #     self.faceUpCard = card
    #     newCards = list(self.cards)
    #     newCards.append(card)
    #     for meld in GinRummyUtil.cardsToAllMelds(newCards):
    #         if card in meld:
    #             return True
    #     return False

    # Return whether or not player will draw the given face-up card on the draw pile.
    def willDrawFaceUpCard(self, card: Card) -> bool:
        self.faceUpCard = card
        # BPBD, either draw(2)->False or pickup(3)->True
        state = np.expand_dims(self.state, axis=0)
        state = torch.from_numpy(state).type(torch.FloatTensor).to(device)
        action = self.model(state)
        action = action.detach().numpy().reshape(-1)
        if self.playVerbose:
            print('Draw new card:', action[2])
            print('Pickup from discard:', action[3])
        if action[3] > action[2]:
            # print('Pickup Discard Action')
            self.faceUpCardBool = True
            return True
        # print('Draw from Deck Action')
        self.faceUpCardBool = False
        return False




    # Report that the given player has drawn a given card and, if known, what the card is.
    # If the card is unknown because it is drawn from the face-down draw pile, the drawnCard is null.
    # Note that a player that returns false for willDrawFaceUpCard will learn of their face-down draw from this method.
    def reportDraw(self, playerNum: int, drawnCard: Card) -> None:
        # Ignore other player draws.  Add to cards if playerNum is this player.
        if playerNum == self.playerNum:
            self.cards.append(drawnCard)
            self.drawnCard = drawnCard






    # def getDiscard(self) -> Card:
    #     # Discard a random card (not just drawn face up) leaving minimal deadwood points.
    #     minDeadwood = float('inf')
    #     candidateCards = []
    #     for card in self.cards:
    #         # Cannot draw and discard face up card.
    #         if card == self.drawnCard and self.drawnCard == self.faceUpCard:
    #         # if card == self.drawnCard and self.faceUpCard:
    #             continue
    #         # Disallow repeat of draw and discard.
    #         drawDiscard = [self.drawnCard, card]
    #         if GinRummyUtil.cardsToBitstring(drawDiscard) in self.drawDiscardBitstrings:
    #             continue

    #         remainingCards = list(self.cards)
    #         remainingCards.remove(card)
    #         bestMeldSets = GinRummyUtil.cardsToBestMeldSets(remainingCards)
    #         deadwood = GinRummyUtil.getDeadwoodPoints3(remainingCards) if len(bestMeldSets) == 0 \
    #             else GinRummyUtil.getDeadwoodPoints1(bestMeldSets[0], remainingCards)
    #         if deadwood <= minDeadwood:
    #             if deadwood < minDeadwood:
    #                 minDeadwood = deadwood
    #                 candidateCards.clear()
    #             candidateCards.append(card)
    #     # Prevent future repeat of draw, discard pair.
    #     discard = candidateCards[randint(0, len(candidateCards)-1)]
    #     drawDiscard = [self.drawnCard, discard]
    #     self.drawDiscardBitstrings.append(GinRummyUtil.cardsToBitstring(drawDiscard))
    #     return discard

    # Get the player's discarded card.  If you took the top card from the discard pile,
    # you must discard a different card.
    # If this is not a card in the player's possession, the player forfeits the game.
    # @return the player's chosen card for discarding
    def getDiscard(self) -> Card:
        # APBD, either either discard or knock...
        # determine the allowable actions (which cards can be discarded/knocked on)
        currHand = np.array(self.state[0:52])
        # if self.playVerbose:
        #     print('Current Hand:', un_one_hot(currHand))
        # disallow discarding PickUp FaceUp/Discarded Card
        if self.faceUpCardBool:
        # if self.drawnCard == self.faceUpCard:
            currHand[self.drawnCard.getId()] = 0
        
        state = np.expand_dims(self.state, axis=0)
        state = torch.from_numpy(state).type(torch.FloatTensor).to(device)
        action = self.model(state)
        action = action.detach().numpy().reshape(-1)

        discardMax = max(currHand * action[6:58])
        knockMax = max(currHand * action[58:110])

        if self.playVerbose:
            unmeldedCards = self.cards.copy()
            bestMelds = GinRummyUtil.cardsToBestMeldSets(unmeldedCards)
            if len(bestMelds) > 0:
                melds = bestMelds[0]
                for meld in melds:
                    for card in meld:
                        unmeldedCards.remove(card)
                melds.extend(unmeldedCards)
            else:
                melds = unmeldedCards
            print('Current Hand:', melds)
            if np.argmax(action) > 58:
                print('Knock', all_classes[np.argmax(action)], '| D:', Deck.getCard(np.argmax(currHand * action[6:58])), '| K:', Deck.getCard(np.argmax(currHand * action[58:])), '|', np.argmax(action))
            else:
                print('Discard', all_classes[np.argmax(action)], '| D:', Deck.getCard(np.argmax(currHand * action[6:58])), '| K:', Deck.getCard(np.argmax(currHand * action[58:])), '|', np.argmax(action))
            print('MAX:{:.4f}, {:.4f}'.format(discardMax, knockMax))

        if discardMax > knockMax:
            if self.playVerbose:
                print('Discard Action')
            self.knock = False
            return Deck.getCard(np.argmax(currHand * action[6:58]))
        else:
            if self.playVerbose:
                print('Knock Action')
            self.knock = True
            return Deck.getCard(np.argmax(currHand * action[58:]))




















    # Report that the given player has discarded a given card.
    def reportDiscard(self, playerNum: int, discardedCard: Card) -> None:
        # Ignore other player discards.  Remove from cards if playerNum is this player.
        if playerNum == self.playerNum:
            self.cards.remove(discardedCard)

    # At the end of each turn, this method is called and the player that cannot (or will not) end the round will return a null value.
    # However, the first player to "knock" (that is, end the round), and then their opponent, will return an ArrayList of ArrayLists of melded cards.
    # All other cards are counted as "deadwood", unless they can be laid off (added to) the knocking player's melds.
    # When final melds have been reported for the other player, a player should return their final melds for the round.
    # @return null if continuing play and opponent hasn't melded, or an ArrayList of ArrayLists of melded cards.
    def getFinalMelds(self) -> List[List[Card]]:
        # Check if deadwood of maximal meld is low enough to go out.
        bestMeldSets = GinRummyUtil.cardsToBestMeldSets(self.cards) # List[List[List[Card]]]
        if not self.opponentKnocked and (len(bestMeldSets) == 0 or \
            GinRummyUtil.getDeadwoodPoints1(bestMeldSets[0], self.cards) > \
            GinRummyUtil.MAX_DEADWOOD):
            return None
        if len(bestMeldSets) == 0:
            return []
        return bestMeldSets[randint(0, len(bestMeldSets)-1)]

    # When an player has ended play and formed melds, the melds (and deadwood) are reported to both players.
    def reportFinalMelds(self, playerNum: int, melds: List[List[Card]]) -> None:
        # Melds ignored by simple player, but could affect which melds to make for complex player.
        if playerNum != self.playerNum:
            self.opponentKnocked = True

    # Report current player scores, indexed by 0-based player number.
    def reportScores(self, scores: List[int]) -> None:
        # Ignored by simple player, but could affect strategy of more complex player.
        return

    # Report layoff actions.
    def reportLayoff(self, playerNum: int, layoffCard: Card, opponentMeld: List[Card]) -> None:
        # Ignored by simple player, but could affect strategy of more complex player.
        return

    # Report the final hands of players.
    def reportFinalHand(self, playerNum: int, hand: List[Card]) -> None:
        # Ignored by simple player, but could affect strategy of more complex player.
        return

## RandGinRummyPlayer

In [7]:
# -------------------------------------------------------------------------------
#  RandGinRummyPlayer
#
#  This estimation will be calculated using a Multilayer Percepton trained on the
#  SimpleGinRummyPlayer written
#  by Calvin Tan.
#
#  @author Calvin Tan
#  @version 1.0
# -------------------------------------------------------------------------------

# -------------------------------------------------------------------------------
# The following code was originally written by Todd Neller in Java.
# It was translated into Python by May Jiang.
# -------------------------------------------------------------------------------

# -------------------------------------------------------------------------------
# Copyright (C) 2020 Todd Neller
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# Information about the GNU General Public License is available online at:
#   http://www.gnu.org/licenses/
# To receive a copy of the GNU General Public License, write to the Free
# Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
# 02111-1307, USA.
# -------------------------------------------------------------------------------

from typing import List, TypeVar
from random import randint
from GinRummyUtil import GinRummyUtil
from GinRummyPlayer import GinRummyPlayer
import random

# Import MLP Models
# from SupervisedLearning.models import *

Card = TypeVar('Card')

class RandGinRummyPlayer(GinRummyPlayer):

    # Inform player of 0-based player number (0/1), starting player number (0/1), and dealt cards
    def startGame(self, playerNum: int, startingPlayerNum: int, cards: List[Card]) -> None:
        self.playerNum = playerNum
        self.startingPlayerNum = startingPlayerNum
        self.cards = list(cards)
        self.opponentKnocked = False
        self.drawDiscardBitstrings = [] # long[], or List[int]
        self.faceUpCard = None
        self.drawnCard = None
        self.state = None

    def willDrawFaceUpCard(self, card: Card) -> bool:
        # Return random choice
        self.faceUpCard = card
        newCards = list(self.cards)
        newCards.append(card)
        choice = random.randint(0, 1)
        if choice == 0:
            return True
        return False


    # Report that the given player has drawn a given card and, if known, what the card is.
    # If the card is unknown because it is drawn from the face-down draw pile, the drawnCard is null.
    # Note that a player that returns false for willDrawFaceUpCard will learn of their face-down draw from this method.
    def reportDraw(self, playerNum: int, drawnCard: Card) -> None:
        # Ignore other player draws.  Add to cards if playerNum is this player.
        if playerNum == self.playerNum:
            self.cards.append(drawnCard)
            self.drawnCard = drawnCard

    # Get the player's discarded card.  If you took the top card from the discard pile,
    # you must discard a different card.
    # If this is not a card in the player's possession, the player forfeits the game.
    # @return the player's chosen card for discarding
    def getDiscard(self) -> Card:

        choice = random.randint(0, len(self.cards)-1)
        discCard = self.cards[choice]
        while discCard == self.faceUpCard:
            choice = random.randint(0, len(self.cards)-1)
            discCard = self.cards[choice]
        return discCard


    # Report that the given player has discarded a given card.
    def reportDiscard(self, playerNum: int, discardedCard: Card) -> None:
        # Ignore other player discards.  Remove from cards if playerNum is this player.
        if playerNum == self.playerNum:
            self.cards.remove(discardedCard)

    # At the end of each turn, this method is called and the player that cannot (or will not) end the round will return a null value.
    # However, the first player to "knock" (that is, end the round), and then their opponent, will return an ArrayList of ArrayLists of melded cards.
    # All other cards are counted as "deadwood", unless they can be laid off (added to) the knocking player's melds.
    # When final melds have been reported for the other player, a player should return their final melds for the round.
    # @return null if continuing play and opponent hasn't melded, or an ArrayList of ArrayLists of melded cards.
    def getFinalMelds(self) -> List[List[Card]]:
        # Check if deadwood of maximal meld is low enough to go out.
        bestMeldSets = GinRummyUtil.cardsToBestMeldSets(self.cards) # List[List[List[Card]]]
        if not self.opponentKnocked and (len(bestMeldSets) == 0 or \
            GinRummyUtil.getDeadwoodPoints1(bestMeldSets[0], self.cards) > \
            GinRummyUtil.MAX_DEADWOOD):
            return None
        if len(bestMeldSets) == 0:
            return []
        return bestMeldSets[randint(0, len(bestMeldSets)-1)]

    # When an player has ended play and formed melds, the melds (and deadwood) are reported to both players.
    def reportFinalMelds(self, playerNum: int, melds: List[List[Card]]) -> None:
        # Melds ignored by simple player, but could affect which melds to make for complex player.
        if playerNum != self.playerNum:
            self.opponentKnocked = True

    # Report current player scores, indexed by 0-based player number.
    def reportScores(self, scores: List[int]) -> None:
        # Ignored by simple player, but could affect strategy of more complex player.
        return

    # Report layoff actions.
    def reportLayoff(self, playerNum: int, layoffCard: Card, opponentMeld: List[Card]) -> None:
        # Ignored by simple player, but could affect strategy of more complex player.
        return

    # Report the final hands of players.
    def reportFinalHand(self, playerNum: int, hand: List[Card]) -> None:
        # Ignored by simple player, but could affect strategy of more complex player.
        return

## Game Definition

In [8]:
class GinRummyGame:

    # Hand size (before and after turn). After draw and before discard there is one extra card.
    HAND_SIZE = 10;

    # Whether or not to print information during game play
    playVerbose = False;

    # Two Gin Rummy players numbered according to their array index.
    players = [];

    # Set whether or not there is to be printed output during gameplay.
    def setPlayVerbose(self, playVerbose):
        self.playVerbose = playVerbose
    
    #-------------------------------- updateState --------------------------------#
    # 2020-12-20: Define a method to append states
    # 2021-01-16: modified append state to work for either player (0 or 1)
    def updateState(self, currentPlayer, discards, oppCard):
        currHand = one_hot(self.players[currentPlayer].cards)
        topCard = np.zeros(52)
        if len(discards) > 0:
            topCard[discards[-1].getId()] = 1
        deadCard = np.zeros(52)
        for d in range(len(discards) - 1):
            deadCard[discards[d].getId()] = 1
        unknownCard = np.ones(52) - currHand - topCard - deadCard - oppCard
        self.states = np.array([currHand, topCard, deadCard, oppCard, unknownCard]).flatten()
    #------------------------------------------------------------------------------#

    # Create a self with two given players
    def __init__(self, player0, player1):
        self.players = []
        self.players.extend([player0, player1])

    # Play a game of Gin Rummy and return the winning player number 0 or 1.
    # @return the winning player number 0 or 1

    def play(self):
        scores = [0, 0]
        hands = []
        hands.extend([[], []])

        startingPlayer = random.randrange(2);

        # while game not over
        while scores[0] < GinRummyUtil.GOAL_SCORE and scores[1] < GinRummyUtil.GOAL_SCORE:

            currentPlayer = startingPlayer
            opponent = (1 if currentPlayer == 0 else 0)

            # get shuffled deck and deal cards
            deck = Deck.getShuffle(random.randrange(10 ** 8))
            hands[0] = []
            hands[1] = []
            for i in range(2 * self.HAND_SIZE):
                hands[i % 2] += [deck.pop()]
            for i in range(2):
                self.players[i].startGame(i, startingPlayer, hands[i]);
                if self.playVerbose:
                    print("Player %d is dealt %s.\n" % (i, hands[i]))
            if self.playVerbose:
                print("Player %d starts.\n" % (startingPlayer))
            discards = []
            discards.append(deck.pop())
            if self.playVerbose:
                print("The initial face up card is %s.\n" % (discards[len(discards) - 1]))
            firstFaceUpCard = discards[len(discards) - 1]
            turnsTaken = 0
            knockMelds = None

            # 11/25 - Initial state, prior to any cards
            # 1/16 - Initialize oppCard to be two dimensional to track both players as opponents
            oppCard = []
            oppCard.extend([np.zeros(52), np.zeros(52)])

            for i in range(2):
                if isinstance(self.players[i], MLPGinRummyPlayer):
                    self.players[i].setVerbose(self.playVerbose)

            # while the deck has more than two cards remaining, play round
            while len(deck) > 2:
#-------------------------------------------------------------- BPBD --------------------------------------------------------------#
                drawFaceUp = False
                faceUpCard = discards[len(discards) - 1]

                # offer draw face-up iff not 3rd turn with first face up card (decline automatically in that case)
                if not (turnsTaken == 2 and faceUpCard == firstFaceUpCard):

                    #------------------------------------ DRAW ------------------------------------#
                    # 2020-12-01  -  Track states BEFORE the player PICKUP BEFORE player DISCARDS (track_bpbd)
                    # 2021-01-16  -  Track for both players instead of just player 0
                    # Action      -  PickUp from Discard(FaceUp) or Deck (Unknown)
                    # State       -  BPBD -> APBD

                    self.updateState(currentPlayer,discards,oppCard[currentPlayer])

                    #------------------------------------------------------------------------------#

                    # 2021-01-16  -  Update player with current states
                    if isinstance(self.players[currentPlayer], MLPGinRummyPlayer):
                        self.players[currentPlayer].updateStates(self.states)

                    # both players declined and 1st player must draw face down
                    drawFaceUp = self.players[currentPlayer].willDrawFaceUpCard(faceUpCard)
                    
                    if self.playVerbose and not drawFaceUp and faceUpCard == firstFaceUpCard and turnsTaken < 2:
                        print("Player %d declines %s.\n" % (currentPlayer, firstFaceUpCard))

                if not (not drawFaceUp and turnsTaken < 2 and faceUpCard == firstFaceUpCard):

                    # continue with turn if not initial declined option
                    if self.playVerbose:
                        if drawFaceUp:
                            print('drawFaceUp (Pickup discarded card)')
                        else:
                            print('Draw from deck')
                    drawCard = discards.pop() if drawFaceUp else deck.pop()
                    for i in range(2):
                        to_report = drawCard if i == currentPlayer or drawFaceUp else None
                        self.players[i].reportDraw(currentPlayer, to_report)

                    if self.playVerbose:
                        print("Player %d draws %s.\n" % (currentPlayer, drawCard))
                    hands[currentPlayer].append(drawCard)
#-------------------------------------------------------------- APBD --------------------------------------------------------------#
                    
                    self.updateState(currentPlayer,discards,oppCard[currentPlayer])
                    
                    # 2021-01-16  -  Update player with current states
                    if isinstance(self.players[currentPlayer], MLPGinRummyPlayer):
                    # if type(self.players[currentPlayer]) == type(MLPGinRummyPlayer()):
                        self.players[currentPlayer].updateStates(self.states)

                    discardCard = self.players[currentPlayer].getDiscard()

                    # 2021-01-16  -  Track for both players instead of just player 0
                    # Track opponent pickup and discard after each discard 

                    # Set discarded card to 0 (in case discarded card was seen)
                    oppCard[1 - currentPlayer][discardCard.getId()] = 0
                    if drawFaceUp: # if opponent draws TopCard from discard
                        oppCard[1 - currentPlayer][drawCard.getId()] = 1

                    if not discardCard in hands[currentPlayer] or discardCard == faceUpCard:
                        print("Player %d discards %s illegally and forfeits.\n" % (currentPlayer, discardCard))
                        return opponent;
                    hands[currentPlayer].remove(discardCard)
                    for i in range(2):
                        self.players[i].reportDiscard(currentPlayer, discardCard)                    
                    if self.playVerbose:
                        print("Player %d discards %s.\n" % (currentPlayer, discardCard))
                    discards.append(discardCard)

                    if self.playVerbose:
                        unmeldedCards = hands[currentPlayer].copy()
                        bestMelds = GinRummyUtil.cardsToBestMeldSets(unmeldedCards)
                        if len(bestMelds) == 0:
                            print("Player %d has %s with %d deadwood.\n" % (currentPlayer, unmeldedCards, GinRummyUtil.getDeadwoodPoints3(unmeldedCards)))
                        else:
                            melds = bestMelds[0]
                            for meld in melds:
                                for card in meld:
                                    unmeldedCards.remove(card)
                            melds.extend(unmeldedCards)
                            print("Player %d has %s with %d deadwood.\n" % (currentPlayer, melds, GinRummyUtil.getDeadwoodPoints3(unmeldedCards)))

#-------------------------------------------------------------- KNOCK --------------------------------------------------------------#
                    # CHECK FOR KNOCK
                    knockMelds = self.players[currentPlayer].getFinalMelds()
                    if knockMelds != None:
                        # print('Current Player:', currentPlayer)
                        # print(knockMelds)
                        # break
                        # 2021-01-16  -  Check if MLPGinRummyPlayer knocks
                        if isinstance(self.players[currentPlayer], MLPGinRummyPlayer):
                            knock = self.players[currentPlayer].knockAction()
                            if self.playVerbose:
                                print(knock)
                            if knock:
                                break
                        else:
                            break
                    
                turnsTaken += 1
                # currentPlayer = 1 if currentPlayer == 0 else 0
                # opponent = 1 if currentPlayer == 0 else 0
                if len(deck) > 2:
                    currentPlayer = 1 if currentPlayer == 0 else 0
                    opponent = 1 if currentPlayer == 0 else 0

            # if knockMelds != None and len(deck) > 2:
            if knockMelds != None:
                # round didn't end due to non-knocking and 2 cards remaining in draw pile
                # check legality of knocking meld
                handBitstring = GinRummyUtil.cardsToBitstring(hands[currentPlayer])
                unmelded = handBitstring
                for meld in knockMelds:
                    meldBitstring = GinRummyUtil.cardsToBitstring(meld)
                    if (not meldBitstring in GinRummyUtil.getAllMeldBitstrings()) or ((meldBitstring & unmelded) != meldBitstring):
                        # non-meld or meld not in hand
                        # print(len(deck))
                        # print(meld)
                        # print(knockMelds)
                        # print(currentPlayer, hands[currentPlayer])
                        # print(1- currentPlayer, hands[1-currentPlayer])
                        # print(GinRummyUtil.getDeadwoodPoints1(knockMelds, hands[1-currentPlayer]))
                        print("Player %d melds %s illegally and forfeits.\n" % (currentPlayer, knockMelds))
                        return opponent
                    unmelded &= ~meldBitstring # remove successfully melded cards from

                # compute knocking deadwood
                knockingDeadwood = GinRummyUtil.getDeadwoodPoints1(knockMelds, hands[currentPlayer])
                if knockingDeadwood > GinRummyUtil.MAX_DEADWOOD:
                    print("Player %d melds %s with greater than %d deadwood and forfeits.\n" % (currentPlayer, knockMelds, knockingDeadwood))
                    return opponent

                meldsCopy = []
                for meld in knockMelds:
                    meldsCopy.append(meld.copy())
                for i in range(2):
                    self.players[i].reportFinalMelds(currentPlayer, meldsCopy)
                if self.playVerbose:
                    if knockingDeadwood > 0:
                        print("Player %d melds %s with %d deadwood from %s.\n" % (currentPlayer, knockMelds, knockingDeadwood, GinRummyUtil.bitstringToCards(unmelded)))
                    else:
                        print("Player %d goes gin with melds %s.\n" % (currentPlayer, knockMelds))

                # get opponent meld
                opponentMelds = self.players[opponent].getFinalMelds();
                meldsCopy = []
                for meld in opponentMelds:
                    meldsCopy.append(meld.copy())
                for i in range(2):
                    self.players[i].reportFinalMelds(opponent, meldsCopy)

                # check legality of opponent meld
                opponentHandBitstring = GinRummyUtil.cardsToBitstring(hands[opponent])
                opponentUnmelded = opponentHandBitstring
                for meld in opponentMelds:
                    meldBitstring = GinRummyUtil.cardsToBitstring(meld)
                    if (meldBitstring not in GinRummyUtil.getAllMeldBitstrings()) or ((meldBitstring & opponentUnmelded) != meldBitstring):
                        # non-meld or meld not in hand
                        print("Player %d melds %s illegally and forfeits.\n" % (opponent, opponentMelds))
                        return currentPlayer
                    opponentUnmelded &= ~meldBitstring # remove successfully melded cards from

                if self.playVerbose:
                    print("Player %d melds %s.\n" % (opponent, opponentMelds))

                # lay off on knocking meld (if not gin)
                unmeldedCards = GinRummyUtil.bitstringToCards(opponentUnmelded)
                if knockingDeadwood > 0:
                    # knocking player didn't go gin
                    cardWasLaidOff = False
                    while True:
                        # attempt to lay each card off
                        cardWasLaidOff = False
                        layOffCard = None
                        layOffMeld = None
                        for card in unmeldedCards:
                            for meld in knockMelds:
                                newMeld = meld.copy()
                                newMeld.append(card)
                                newMeldBitstring = GinRummyUtil.cardsToBitstring(newMeld)
                                if newMeldBitstring in GinRummyUtil.getAllMeldBitstrings():
                                    layOffCard = card
                                    layOffMeld = meld
                                    break
                            if layOffCard != None:
                                if self.playVerbose:
                                    print("Player %d lays off %s on %s.\n" % (opponent, layOffCard, layOffMeld))
                                for i in range(2):
                                    self.players[i].reportLayoff(opponent, layOffCard, layOffMeld.copy())
                                unmeldedCards.remove(layOffCard)
                                layOffMeld.append(layOffCard)
                                cardWasLaidOff = True
                                break
                        if not cardWasLaidOff:
                            break

                opponentDeadwood = 0
                for card in unmeldedCards:
                    opponentDeadwood += GinRummyUtil.getDeadwoodPoints2(card)
                if self.playVerbose:
                    print("Player %d has %d deadwood with %s\n" % (opponent, opponentDeadwood, unmeldedCards))
                # compare deadwood and compute new scores
                if knockingDeadwood == 0:
                    # gin round win
                    scores[currentPlayer] += GinRummyUtil.GIN_BONUS + opponentDeadwood
                    if self.playVerbose:
                        print("Player %d scores the gin bonus of %d plus opponent deadwood %d for %d total points.\n" % \
                        (currentPlayer, GinRummyUtil.GIN_BONUS, opponentDeadwood, GinRummyUtil.GIN_BONUS + opponentDeadwood))

                elif knockingDeadwood < opponentDeadwood:
                    # non-gin round win:
                    scores[currentPlayer] += opponentDeadwood - knockingDeadwood;
                    if self.playVerbose:
                        print("Player %d scores the deadwood difference of %d.\n" % (currentPlayer, opponentDeadwood - knockingDeadwood))

                else:
                    # undercut win for opponent
                    scores[opponent] += GinRummyUtil.UNDERCUT_BONUS + knockingDeadwood - opponentDeadwood;
                    if self.playVerbose:
                        print("Player %d undercuts and scores the undercut bonus of %d plus deadwood difference of %d for %d total points.\n" % \
                        (opponent, GinRummyUtil.UNDERCUT_BONUS, knockingDeadwood - opponentDeadwood, GinRummyUtil.UNDERCUT_BONUS + knockingDeadwood - opponentDeadwood))

                startingPlayer = 1 if startingPlayer == 0 else 0 # starting player alternates

            # If the round ends due to a two card draw pile with no knocking, the round is cancelled.
            else:
                if self.playVerbose:
                    print("The draw pile was reduced to two cards without knocking, so the hand is cancelled.")

            # report final hands
            for i in range(2):
                for j in range(2):
                    self.players[i].reportFinalHand(j, hands[j].copy())

            # score reporting
            if self.playVerbose:
                print("Player\tScore\n0\t%d\n1\t%d\n" % (scores[0], scores[1]))
            for i in range(2):
                self.players[i].reportScores(scores.copy())

        if self.playVerbose:
            print("Player %s wins.\n" % (0 if scores[0] > scores[1] else 1))
        return 0 if scores[0] >= GinRummyUtil.GOAL_SCORE else 1

# Test Agents

## Shared

In [9]:
def testAgents(agent0,agent1,numGames,verbose):
    numP1Wins = 0
    game = GinRummyGame(agent0, agent1)
    # Multiple non-verbose games
    game.setPlayVerbose(verbose)
    for i in range(2):
        if isinstance(game.players[i], MLPGinRummyPlayer):
            print(game.players[i].model)
    for i in range(numGames):
        if i % 100 == 0:
            print("Game ... ", i)
        numP1Wins += game.play()
    print("Games Won: P0:%d, P1:%d.\n" % (numGames - numP1Wins, numP1Wins))

In [10]:
state = 'all'
action = 'all'

## test games

In [None]:
numGames = 2000
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_MLP_base_extra_knock_data_40K'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device))
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Game ...  1000
Game ...  1100
Game ...  1200
Game ...  1300
Game ...  1400
Game ...  1500
Game ...  1600
Game ...  1700
Game ...  1800
Game ...  1900
Games Won: P0:346, P1:1654.



In [None]:
numGames = 1000
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_MLP_base_extra_knock_data_40K'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device))
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Games Won: P0:267, P1:733.



In [None]:
numGames = 2000
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_MLP_base_extra_knock_data_40K'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device))
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Game ...  1000
Game ...  1100
Game ...  1200
Game ...  1300
Game ...  1400
Game ...  1500
Game ...  1600
Game ...  1700
Game ...  1800
Game ...  1900
Games Won: P0:513, P1:1487.



In [None]:
numGames = 2000
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_MLP_base_extra_knock_data_40K'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device))
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Game ...  1000
Game ...  1100
Game ...  1200
Game ...  1300
Game ...  1400
Game ...  1500
Game ...  1600
Game ...  1700
Game ...  1800
Game ...  1900
Games Won: P0:526, P1:1474.



In [None]:
numGames = 2000
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_MLP_base_extra_knock_data_40K'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device))
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Game ...  1000
Game ...  1100
Game ...  1200
Game ...  1300
Game ...  1400
Game ...  1500
Game ...  1600
Game ...  1700
Game ...  1800
Game ...  1900
Games Won: P0:358, P1:1642.



In [None]:
numGames = 2000
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_2hl_extra_knock_data_40K'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device))
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Game ...  1000
Game ...  1100
Game ...  1200
Game ...  1300
Game ...  1400
Game ...  1500
Game ...  1600
Game ...  1700
Game ...  1800
Game ...  1900
Games Won: P0:542, P1:1458.



In [None]:
numGames = 2000
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_2hl_extra_knock_data_40K'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device))
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Game ...  1000
Game ...  1100
Game ...  1200
Game ...  1300
Game ...  1400
Game ...  1500
Game ...  1600
Game ...  1700
Game ...  1800
Game ...  1900
Games Won: P0:535, P1:1465.



In [None]:
numGames = 1
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_2hl_extra_knock_data_40K'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device))
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
Game ...  0
Games Won: P0:1, P1:0.



## MLP Models vs. RandomAgent

In [None]:
numGames = 2000
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(model)
print(model_name)
agent1 = RandGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
all_states_all_actions
MLP_base(
  (l1): Linear(in_features=260, out_features=520, bias=True)
  (l2): Linear(in_features=520, out_features=110, bias=True)
  (act_fnc): Sigmoid()
  (sfx): Softmax(dim=1)
)
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Game ...  1000
Game ...  1100
Game ...  1200
Game ...  1300
Game ...  1400
Game ...  1500
Game ...  1600
Game ...  1700
Game ...  1800
Game ...  1900
Games Won: P0:2000, P1:0.



In [None]:
numGames = 2000
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_MLP_base_extra_knock_data_40K'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(model)
print(model_name)
agent1 = RandGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
all_states_all_actions_MLP_base_extra_knock_data_40K
MLP_base(
  (l1): Linear(in_features=260, out_features=520, bias=True)
  (l2): Linear(in_features=520, out_features=110, bias=True)
  (act_fnc): Sigmoid()
  (sfx): Softmax(dim=1)
)
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Game ...  1000
Game ...  1100
Game ...  1200
Game ...  1300
Game ...  1400
Game ...  1500
Game ...  1600
Game ...  1700
Game ...  1800
Game ...  1900
Games Won: P0:2000, P1:0.



In [None]:
numGames = 2000
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_2hl_extra_knock_data_40K'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(model)
print(model_name)
agent1 = RandGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
all_states_all_actions_2hl_extra_knock_data_40K
MLP_2HL(
  (l1): Linear(in_features=260, out_features=520, bias=True)
  (l2): Linear(in_features=520, out_features=520, bias=True)
  (l3): Linear(in_features=520, out_features=110, bias=True)
  (act_fnc): Sigmoid()
  (sfx): Softmax(dim=1)
)
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Game ...  1000
Game ...  1100
Game ...  1200
Game ...  1300
Game ...  1400
Game ...  1500
Game ...  1600
Game ...  1700
Game ...  1800
Game ...  1900
Games Won: P0:2000, P1:0.



In [None]:
numGames = 2000
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_2hl_extra_knock_data_80K'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(model)
print(model_name)
agent1 = RandGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
all_states_all_actions_2hl_extra_knock_data_80K
MLP_2HL(
  (l1): Linear(in_features=260, out_features=520, bias=True)
  (l2): Linear(in_features=520, out_features=520, bias=True)
  (l3): Linear(in_features=520, out_features=110, bias=True)
  (act_fnc): Sigmoid()
  (sfx): Softmax(dim=1)
)
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Game ...  1000
Game ...  1100
Game ...  1200
Game ...  1300
Game ...  1400
Game ...  1500
Game ...  1600
Game ...  1700
Game ...  1800
Game ...  1900
Games Won: P0:2000, P1:0.



In [None]:
numGames = 2000
agent0 = MLPGinRummyPlayer()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model_name = 'all_states_all_actions_2hl_extra_knock_data_40K'
checkpoint = torch.load('models/dqn/TEST5/model_posttrain.pth', map_location=device)
mlp_layers=[520, 110]
# mlp_layers=[520, 520, 110]
batch_norm = False
qnet = EstimatorNetwork(mlp_layers, batch_norm)
qnet = qnet.to(device)
qnet.load_state_dict(checkpoint['dqn_q_estimator'])

agent0.loadModel(qnet)
# print(model_name)
agent1 = RandGinRummyPlayer()
states, actions = [], []
# testAgents(agent0,agent1,numGames,verbose=True)
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
EstimatorNetwork(
  (fc_layers): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=260, out_features=520, bias=True)
    (2): Sigmoid()
    (3): Linear(in_features=520, out_features=110, bias=True)
    (4): Sigmoid()
    (5): Linear(in_features=110, out_features=110, bias=True)
    (6): Softmax(dim=1)
  )
)
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Game ...  1000
Game ...  1100
Game ...  1200
Game ...  1300
Game ...  1400
Game ...  1500
Game ...  1600
Game ...  1700
Game ...  1800
Game ...  1900
Games Won: P0:1996, P1:4.



In [11]:
numGames = 1000
agent0 = MLPGinRummyPlayer()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model_name = 'all_states_all_actions_2hl_extra_knock_data_40K'
checkpoint = torch.load('models/dqn/selfplay/TEST1/model_posttrain.pth', map_location=device)
# mlp_layers=[520, 110]
mlp_layers=[520, 520, 110]
batch_norm = False
qnet = EstimatorNetwork(mlp_layers, batch_norm)
qnet = qnet.to(device)
qnet.load_state_dict(checkpoint['dqn_q_estimator'])

agent0.loadModel(qnet)
# print(model_name)
agent1 = RandGinRummyPlayer()
states, actions = [], []
# testAgents(agent0,agent1,numGames,verbose=True)
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
EstimatorNetwork(
  (fc_layers): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=260, out_features=520, bias=True)
    (2): Sigmoid()
    (3): Linear(in_features=520, out_features=520, bias=True)
    (4): Sigmoid()
    (5): Linear(in_features=520, out_features=110, bias=True)
    (6): Sigmoid()
    (7): Linear(in_features=110, out_features=110, bias=True)
    (8): Softmax(dim=1)
  )
)
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Games Won: P0:1000, P1:0.



## MLP Models vs. SimpleGinRummyAgent

In [None]:
numGames = 1000
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(model)
print(model_name)
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
all_states_all_actions
MLP_base(
  (l1): Linear(in_features=260, out_features=520, bias=True)
  (l2): Linear(in_features=520, out_features=110, bias=True)
  (act_fnc): Sigmoid()
  (sfx): Softmax(dim=1)
)
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Games Won: P0:130, P1:870.



In [None]:
numGames = 1000
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_MLP_base_extra_knock_data_40K'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(model)
print(model_name)
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
all_states_all_actions_MLP_base_extra_knock_data_40K
MLP_base(
  (l1): Linear(in_features=260, out_features=520, bias=True)
  (l2): Linear(in_features=520, out_features=110, bias=True)
  (act_fnc): Sigmoid()
  (sfx): Softmax(dim=1)
)
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Games Won: P0:197, P1:803.



In [None]:
numGames = 1000
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_2hl_extra_knock_data_40K'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(model)
print(model_name)
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
all_states_all_actions_2hl_extra_knock_data_40K
MLP_2HL(
  (l1): Linear(in_features=260, out_features=520, bias=True)
  (l2): Linear(in_features=520, out_features=520, bias=True)
  (l3): Linear(in_features=520, out_features=110, bias=True)
  (act_fnc): Sigmoid()
  (sfx): Softmax(dim=1)
)
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Games Won: P0:246, P1:754.



In [None]:
numGames = 1000
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_2hl_extra_knock_data_80K'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(model)
print(model_name)
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
all_states_all_actions_2hl_extra_knock_data_80K
MLP_2HL(
  (l1): Linear(in_features=260, out_features=520, bias=True)
  (l2): Linear(in_features=520, out_features=520, bias=True)
  (l3): Linear(in_features=520, out_features=110, bias=True)
  (act_fnc): Sigmoid()
  (sfx): Softmax(dim=1)
)
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Games Won: P0:273, P1:727.



## MLP Model vs. MLP Model

In [None]:
numGames = 2000
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(model)
# print(model_name)
agent1 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_2hl_extra_knock_data_40K'
model = torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent1.loadModel(model)
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
Load Model
MLP_base(
  (l1): Linear(in_features=260, out_features=520, bias=True)
  (l2): Linear(in_features=520, out_features=110, bias=True)
  (act_fnc): Sigmoid()
  (sfx): Softmax(dim=1)
)
MLP_2HL(
  (l1): Linear(in_features=260, out_features=520, bias=True)
  (l2): Linear(in_features=520, out_features=520, bias=True)
  (l3): Linear(in_features=520, out_features=110, bias=True)
  (act_fnc): Sigmoid()
  (sfx): Softmax(dim=1)
)
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Game ...  1000
Game ...  1100
Game ...  1200
Game ...  1300
Game ...  1400
Game ...  1500
Game ...  1600
Game ...  1700
Game ...  1800
Game ...  1900
Games Won: P0:630, P1:1370.



In [None]:
numGames = 200
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_MLP_base_extra_knock_data_40K'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(model)
# print(model_name)
agent1 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_2hl_extra_knock_data_40K'
model = torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent1.loadModel(model)
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
Load Model
MLP_base(
  (l1): Linear(in_features=260, out_features=520, bias=True)
  (l2): Linear(in_features=520, out_features=110, bias=True)
  (act_fnc): Sigmoid()
  (sfx): Softmax(dim=1)
)
MLP_2HL(
  (l1): Linear(in_features=260, out_features=520, bias=True)
  (l2): Linear(in_features=520, out_features=520, bias=True)
  (l3): Linear(in_features=520, out_features=110, bias=True)
  (act_fnc): Sigmoid()
  (sfx): Softmax(dim=1)
)
Game ...  0
Game ...  100
Games Won: P0:80, P1:120.



## Test

### Other

In [12]:
numGames = 1
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(model)
print(model_name)
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=True)

Load Model
all_states_all_actions
MLP_base(
  (l1): Linear(in_features=260, out_features=520, bias=True)
  (l2): Linear(in_features=520, out_features=110, bias=True)
  (act_fnc): Sigmoid()
  (sfx): Softmax(dim=1)
)
Game ...  0
Player 0 is dealt [6D, 3D, 4C, AC, KS, 7H, 6H, TD, QC, JD].

Player 1 is dealt [KD, 5S, QS, AH, 6C, 5C, 8S, 8H, 7S, 5H].

Player 1 starts.

The initial face up card is KH.

Player 1 declines KH.

Update States
Draw new card: 1.0
Pickup from discard: 6.0894706e-10
Player 0 declines KH.

Draw from deck
Player 1 draws 3H.

Player 1 discards KD.

Player 1 has [[5S, 5H, 5C], QS, AH, 6C, 8S, 8H, 7S, 3H] with 43 deadwood.

Update States
Draw new card: 0.9999256
Pickup from discard: 7.4397634e-05
Draw from deck
Player 0 draws 8D.

Update States
Current Hand: [6D, 3D, 4C, AC, KS, 7H, 6H, TD, QC, JD, 8D]
Discard KS | D: KS | K: 8D | 18
MAX:0.9420, 0.0001
Discard Action
Player 0 discards KS.

Player 0 has [6D, 3D, 4C, AC, 7H, 6H, TD, QC, JD, 8D] with 65 deadwood.

Draw from

In [13]:
numGames = 1
agent0 = MLPGinRummyPlayer()
model_name = 'all_states_all_actions_2hl_extra_knock_data_40K'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)
agent0.loadModel(model)
print(model_name)
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=True)

Load Model
all_states_all_actions_2hl_extra_knock_data_40K
MLP_2HL(
  (l1): Linear(in_features=260, out_features=520, bias=True)
  (l2): Linear(in_features=520, out_features=520, bias=True)
  (l3): Linear(in_features=520, out_features=110, bias=True)
  (act_fnc): Sigmoid()
  (sfx): Softmax(dim=1)
)
Game ...  0
Player 0 is dealt [8S, 6D, 9D, TH, JS, 5C, 2H, 3D, 3H, KC].

Player 1 is dealt [QS, QD, KD, 8D, 2S, KH, 3C, 4S, 7S, 5D].

Player 1 starts.

The initial face up card is 7D.

Player 1 declines 7D.

Update States
Draw new card: 0.9999914
Pickup from discard: 8.584575e-06
Player 0 declines 7D.

Draw from deck
Player 1 draws 9S.

Player 1 discards KD.

Player 1 has [QS, QD, 8D, 2S, KH, 3C, 4S, 7S, 5D, 9S] with 68 deadwood.

Update States
Draw new card: 1.0
Pickup from discard: 1.3967293e-11
Draw from deck
Player 0 draws AD.

Update States
Current Hand: [8S, 6D, 9D, TH, JS, 5C, 2H, 3D, 3H, KC, AD]
Discard JS | D: JS | K: JS | 16
MAX:0.8300, 0.0000
Discard Action
Player 0 discards JS.



### Debugging

In [None]:
a = GinRummyGame(agent0, agent1)

In [None]:
a.players

In [None]:
numGames = 1000
agent0 = RandGinRummyPlayer()
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
testAgents(agent0,agent1,numGames,verbose=False)

### QNet

In [None]:
numGames = 1
numGames = 2000
agent0 = MLPGinRummyPlayer()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model_name = 'all_states_all_actions_2hl_extra_knock_data_40K'
checkpoint = torch.load('models/dqn/TEST5/model_posttrain.pth', map_location=device)
mlp_layers=[520, 110]
# mlp_layers=[520, 520, 110]
batch_norm = False
qnet = EstimatorNetwork(mlp_layers, batch_norm)
qnet = qnet.to(device)
qnet.load_state_dict(checkpoint['dqn_q_estimator'])

agent0.loadModel(qnet)
# print(model_name)
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
# testAgents(agent0,agent1,numGames,verbose=True)
testAgents(agent0,agent1,numGames,verbose=False)

In [None]:
numGames = 1000
agent0 = MLPGinRummyPlayer()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model_name = 'all_states_all_actions_2hl_extra_knock_data_40K'
checkpoint = torch.load('models/dqn/TEST9/model_posttrain.pth', map_location=device)
# mlp_layers=[520, 110]
mlp_layers=[520, 520, 110]
batch_norm = False
qnet = EstimatorNetwork(mlp_layers, batch_norm)
qnet = qnet.to(device)
qnet.load_state_dict(checkpoint['dqn_q_estimator'])

agent0.loadModel(qnet)
# print(model_name)
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
# testAgents(agent0,agent1,numGames,verbose=True)
testAgents(agent0,agent1,numGames,verbose=False)

In [None]:
numGames = 1000
agent0 = MLPGinRummyPlayer()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model_name = 'all_states_all_actions_2hl_extra_knock_data_40K'
checkpoint = torch.load('models/dqn/TEST10/model_posttrain.pth', map_location=device)
# mlp_layers=[520, 110]
mlp_layers=[520, 520, 110]
batch_norm = False
qnet = EstimatorNetwork(mlp_layers, batch_norm)
qnet = qnet.to(device)
qnet.load_state_dict(checkpoint['dqn_q_estimator'])

agent0.loadModel(qnet)
# print(model_name)
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
# testAgents(agent0,agent1,numGames,verbose=True)
testAgents(agent0,agent1,numGames,verbose=False)

In [None]:
numGames = 100
agent0 = MLPGinRummyPlayer()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model_name = 'all_states_all_actions_2hl_extra_knock_data_40K'
checkpoint = torch.load('models/dqn/TEST9/model_posttrain.pth', map_location=device)
# mlp_layers=[520, 110]
mlp_layers=[520, 520, 110]
batch_norm = False
qnet = EstimatorNetwork(mlp_layers, batch_norm)
qnet = qnet.to(device)
qnet.load_state_dict(checkpoint['dqn_q_estimator'])
agent0.loadModel(qnet)

agent1 = MLPGinRummyPlayer()
# model_name = 'all_states_all_actions_2hl_extra_knock_data_40K'
checkpoint = torch.load('models/dqn/TEST10/model_posttrain.pth', map_location=device)
# mlp_layers=[520, 110]
mlp_layers=[520, 520, 110]
batch_norm = False
qnet = EstimatorNetwork(mlp_layers, batch_norm)
qnet = qnet.to(device)
qnet.load_state_dict(checkpoint['dqn_q_estimator'])
agent1.loadModel(qnet)

states, actions = [], []
# testAgents(agent0,agent1,numGames,verbose=True)
testAgents(agent0,agent1,numGames,verbose=False)

In [15]:
numGames = 1000
agent0 = MLPGinRummyPlayer()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model_name = 'all_states_all_actions_2hl_extra_knock_data_40K'
checkpoint = torch.load('models/dqn/selfplay/TEST2/model_posttrain.pth', map_location=device)
# mlp_layers=[520, 110]
mlp_layers=[520, 520, 110]
batch_norm = False
qnet = EstimatorNetwork(mlp_layers, batch_norm)
qnet = qnet.to(device)
qnet.load_state_dict(checkpoint['dqn_q_estimator'])

agent0.loadModel(qnet)
# print(model_name)
agent1 = SimpleGinRummyPlayer()
states, actions = [], []
# testAgents(agent0,agent1,numGames,verbose=True)
testAgents(agent0,agent1,numGames,verbose=False)

Load Model
EstimatorNetwork(
  (fc_layers): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=260, out_features=520, bias=True)
    (2): Sigmoid()
    (3): Linear(in_features=520, out_features=520, bias=True)
    (4): Sigmoid()
    (5): Linear(in_features=520, out_features=110, bias=True)
    (6): Sigmoid()
    (7): Linear(in_features=110, out_features=110, bias=True)
    (8): Softmax(dim=1)
  )
)
Game ...  0
Game ...  100
Game ...  200
Game ...  300
Game ...  400
Game ...  500
Game ...  600
Game ...  700
Game ...  800
Game ...  900
Games Won: P0:2, P1:998.



# Loading Agent Model Testing

In [None]:
state = 'all'
action = 'all'
model_name = 'all_states_all_actions'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state,action,model_name), map_location=device)

MLP_base(
  (l1): Linear(in_features=260, out_features=520, bias=True)
  (l2): Linear(in_features=520, out_features=110, bias=True)
  (act_fnc): Sigmoid()
  (sfx): Softmax(dim=1)
)

In [None]:
state_s = 'apbd'
action_a = 'knock'
data_pth = '{}/data/{}/{}'.format(pth,state_s,action_a)
states = np.load('{}/s_2k.npy'.format(data_pth))
actions = np.load('{}/a_2k.npy'.format(data_pth))

## Agent

In [None]:
agent = MLPGinRummyPlayer()

In [None]:
state_s = 'all'
action_a = 'all'
agent.loadModel(torch.load('{}/models/{}/{}/{}/model.pt'.format(pth,state_s,action_a,model_name), map_location=device))

Load Model


In [None]:
# input = np.expand_dims(states[0],axis=0)
# prob = agent.model(torch.from_numpy(input).type(torch.FloatTensor))
# action = prob.detach().numpy().reshape(-1)
# action[6:58]*np.zeros(52)
# # agent.model(np.expand_dims(,axis=0))

In [None]:
deck = Deck.getShuffle(random.randrange(10 ** 8))
hands = []
hands.extend([[], []])
hands[0] = []
hands[1] = []
for i in range(2 * GinRummyGame.HAND_SIZE):
    hands[i % 2] += [deck.pop()]
agent.startGame(0, 0, hands[0]);

In [None]:
agent.updateStates(states[-1])

Update States


In [None]:
c = Deck.strCardMap['AC']
agent.willDrawFaceUpCard(c)

Pickup from discard: 5.1510585e-10
Draw new card: 8.215045e-09
Draw Action


True

In [None]:
agent.playerNum

0

In [None]:
i = 11
agent.updateStates(states[i])
agent.reportDraw(0, c)
agent.getDiscard()
all_classes[np.argmax(actions[i])]

Update States
6D 89
MAX:0.19825728237628937, 0.7989177107810974
Knock Action


'6D'

In [None]:
agent.model(torch.from_numpy(np.expand_dims(agent.state, axis=0)).type(torch.FloatTensor).to(device))
# state = np.expand_dims(self.state, axis=0)
# action = self.model(torch.from_numpy(state).type(torch.FloatTensor))

tensor([[1.5613e-08, 1.5413e-08, 5.1511e-10, 8.2150e-09, 2.9999e-08, 1.2821e-09,
         5.5631e-09, 2.5155e-08, 2.7546e-06, 4.5472e-07, 1.0434e-03, 8.9934e-17,
         3.4545e-15, 7.6638e-20, 1.2097e-23, 9.0395e-21, 3.3336e-16, 1.2249e-17,
         1.4306e-13, 1.5554e-08, 5.1269e-08, 1.0969e-08, 7.8594e-07, 7.9670e-12,
         4.4021e-18, 5.8035e-19, 3.2541e-18, 1.2281e-18, 1.5415e-15, 2.9808e-19,
         8.6931e-15, 8.5212e-14, 1.4021e-08, 1.0776e-08, 1.0269e-08, 2.9233e-09,
         6.7081e-11, 1.9826e-01, 8.6911e-15, 1.6958e-20, 6.8392e-18, 3.8210e-16,
         1.5054e-19, 4.5664e-16, 3.2462e-10, 1.5062e-07, 2.1884e-08, 1.1281e-08,
         1.1448e-04, 5.9201e-12, 1.1257e-13, 1.4764e-16, 3.0289e-24, 8.3574e-19,
         7.6238e-17, 3.5383e-14, 8.8625e-20, 1.4193e-20, 2.3093e-07, 1.5752e-07,
         4.3524e-05, 1.8425e-06, 1.4934e-03, 3.1863e-11, 4.4363e-08, 1.9563e-07,
         4.9013e-06, 1.5853e-08, 1.0693e-06, 1.3684e-08, 8.0732e-07, 1.6183e-08,
         2.2061e-05, 9.4124e

## QNet

In [None]:
class EstimatorNetwork(nn.Module):
    ''' The function approximation network for Estimator
        It is just a series of sigmoid layers. All in/out are torch.tensor
        (OLD) It is just a series of tanh layers. All in/out are torch.tensor
    '''

    def __init__(self, mlp_layers=None, batch_norm=False):
        ''' Initialize the Q network
        Args:
            action_num (int): number of legal actions
            state_shape (list): shape of state tensor
            mlp_layers (list): output size of each fc layer
        '''
        super(EstimatorNetwork, self).__init__()

        self.action_num = 110
        self.state_shape = 260
        self.mlp_layers = mlp_layers
        self.batch_norm = batch_norm

        # build the Q network
        layer_dims = [np.prod(self.state_shape)] + self.mlp_layers
        fc = [nn.Flatten()]
        if batch_norm:
            fc.append(nn.BatchNorm1d(layer_dims[0]))
        for i in range(len(layer_dims)-1):
            fc.append(nn.Linear(layer_dims[i], layer_dims[i+1], bias=True))
            fc.append(nn.Sigmoid())
        fc.append(nn.Linear(layer_dims[-1], self.action_num, bias=True))
        fc.append(nn.Softmax(dim=1))
        self.fc_layers = nn.Sequential(*fc)

    def forward(self, s):
        ''' Predict action values
        Args:
            s  (Tensor): (batch, state_shape)
        '''
        return self.fc_layers(s)

In [None]:
state_shape = 260
action_num = 110
mlp_layers=[520, 520, 110]
batch_norm = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

qnet = EstimatorNetwork(mlp_layers, batch_norm)
qnet = qnet.to(device)

In [None]:
checkpoint = torch.load('models/dqn/TEST4/model_posttrain.pth', map_location=device)

In [None]:
qnet.load_state_dict(checkpoint['dqn_q_estimator'])

<All keys matched successfully>

In [None]:
qnet(torch.from_numpy(np.expand_dims(states[-1], axis=0)).type(torch.FloatTensor).to(device))

tensor([[7.2249e-09, 2.2420e-10, 9.7314e-16, 4.7728e-15, 1.2340e-20, 6.8497e-10,
         5.7679e-02, 3.3210e-02, 1.1054e-01, 2.0988e-02, 1.7701e-02, 6.7780e-03,
         1.2067e-02, 4.0160e-03, 2.8021e-04, 3.1774e-03, 5.9243e-03, 1.0486e-02,
         3.3039e-03, 2.7109e-02, 3.3363e-02, 4.7771e-02, 1.8946e-02, 2.7230e-02,
         9.2067e-03, 3.6598e-03, 2.2502e-03, 9.3831e-03, 7.2180e-03, 2.3893e-03,
         2.6061e-03, 6.0732e-03, 2.8625e-02, 4.8642e-02, 3.8667e-02, 3.7032e-02,
         1.3078e-02, 9.2480e-02, 9.3268e-04, 1.0151e-03, 4.8095e-03, 1.0041e-03,
         5.6105e-03, 2.7269e-03, 3.4772e-03, 3.9008e-02, 4.8259e-02, 3.4264e-02,
         4.3172e-02, 2.0446e-02, 1.8986e-02, 2.4711e-03, 7.2034e-03, 1.0481e-03,
         1.0171e-02, 4.2704e-03, 6.3416e-04, 8.6092e-03, 4.0611e-10, 3.4673e-10,
         1.6277e-10, 1.9494e-10, 3.1402e-10, 1.9318e-10, 2.3830e-10, 1.1181e-10,
         1.1989e-10, 1.7327e-10, 2.8693e-10, 8.1085e-11, 1.4320e-10, 9.0052e-11,
         2.9499e-10, 5.1579e