In [None]:
import os
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, UpSampling2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
print("TensorFlow version:", tf.__version__)

# from tensorflow.keras.mixed_precision import Policy, set_global_policy

# Set the global policy to mixed precision
# policy = Policy('mixed_float16')
# set_global_policy(policy)

# print(f"Mixed precision policy set to: {policy}")

In [None]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import random
import matplotlib.cm as cm
import matplotlib.colors as colors


In [None]:
# Game class
class Card:
    def __init__(self, kind, show):
        self.kind = kind
        self.show = show
    def toString(self):
        return '[' + self.show +","+ self.kind +  ']'

class Game:
    #attrs
    # DC = ['A', 2, 3, 4, 5, 6, 7, 8, 9, 10] #dealer's cards
    # PS = [12, 13, 14, 15, 16, 17, 18, 19, 20, 21] #player's sum

    #--------------------

    def __init__(self):
        self.deck = self.createDeck()
        random.shuffle(self.deck)
        self.PH = []
        self.DH = []
        self.usefulAce = False

    def createDeck(self):
        kinds = ['Diamond', 'Club', 'Spade', 'Heart']
        shows = ['A', '2', '3', '4', '5', '6', '7', '8', '9','T', 'J', 'Q', 'K']
        return [Card(kind, show) for kind in kinds for show in shows]

    def newHand(self, target):
        card = self.deck.pop() #no repeat object
        if target == 'Player':
            self.PH.append(card)
        else:
            self.DH.append(card)

    def sum(self, target):
        if target == "Player":
            sum = 0
            for card in self.PH:
                if card.show in ('T', 'J', 'Q', 'K'):
                    sum += 10
                elif card.show == 'A':
                    self.usefulAce =True
                    sum += 11
                else:
                    sum += int(card.show)

            while sum > 21:
                if self.usefulAce == True:
                    sum -= 10
                    self.usefulAce = False
                else:
                    break
            return sum
        else:
            sum = 0
            for card in self.DH:
                if card.show in ('T','J', 'Q', 'K'):
                    sum += 10
                elif card.show == 'A':
                    sum += 11
                else:
                    sum += int(card.show)
            while sum > 21:
                if 'A' in [c.show for c in self.DH]:
                    sum -= 10
                else:
                    break
            return sum

    def displayHand(self, target, showAll=False):
        temp = None
        if target == 'Player':
            temp = self.PH
            out = ""
            for card in temp:
                out += card.toString() + "/"
            print(out)
        else: #for dealer, only display one single card at the beginning
            temp = self.DH
            skip = 0
            out=""
            if showAll == True:
                for card in temp:
                    out += card.toString() + "/"
                print(out)
            else:
                for card in temp:
                    if skip != 1:
                        out += card.toString() + "/"
                    else:
                        out += "???" + "/"
                    skip += 1
                print(out)

    #工具函数
    def usefulAce(self):
        return self.usefulAce

    def dealderFaceUp(self):
        return self.DH[0].show

class Session:
    def __init__(self):
        self.seq = []
        self.game = Game()
        self.stop = False
        self.winner = None
        self.gameTree = ['START']

    def runAgent(self, Agent):
        #init
        self.game.newHand("Player")
        self.game.newHand("Dealer")
        self.game.newHand("Player")
        self.game.newHand("Dealer")
        self.gameTree.append({'Player Cards':(self.game.PH[0], self.game.PH[1]),
                              'Player Sum': self.game.sum("Player"),
                              'Useful Ace': self.game.usefulAce,
                              'Dealer Face': self.game.dealderFaceUp(),
                              'Result': 'Continue'
                              })


        # print("---------------------starter hands------------------------\n")
        # print("For Player: \n")
        # self.game.displayHand("Player")
        # # print("For Dealer: \n")
        # self.game.displayHand("Dealer")
        self.someoneBusted = False
        #start to hit
        while(not self.stop):
            self.playerTurnAgent(Agent)
            if self.stop == True: break
            # print("---------------------------dealer turn-----------------------------------\n")
            self.gameTree.append('DEALDERROUND')
            self.dealerTurn()
            self.stop = True

        if not self.someoneBusted:
            winner = self.checkWin()
            self.winner = winner
            # print("------------------------------end of the game-----------------------------------\n")
            # print("The winnder is: " + winner + "\n")
            # print("For Player: \n")
            # self.game.displayHand("Player")
            # print("For Dealer: \n")
            # self.game.displayHand("Dealer", showAll=True)

        self.gameTree.append('END')
        self.gameTree.append(self.winner)

    def playerTurnAgent(self, Agent):
        location = (int(self.game.usefulAce), self.game.dealderFaceUp(), self.game.sum("Player"))
        while (True and self.stop == False):
            action = Agent.policy(location) #values: {'h': ,'s': }
            self.gameTree[-1]['Player Action'] = action
            # print("player's action is: " + action)
            if action == 'h':
                self.game.newHand("Player")
                # print("Player's card after action: \n")
                # self.game.displayHand("Player")
                if self.bust("Player"):
                    # print("Player you are busted!!!")
                    # print("------------------------------end of the game-----------------------------------\n")
                    # print("The winnder is: " + "Dealer" + "\n")
                    self.winner = 'Dealer'
                    # print("For Player: \n")
                    # self.game.displayHand("Player")
                    # print("For Dealer: \n")
                    # self.game.displayHand("Dealer", showAll=True)
                    self.needStop()
                    self.someoneBusted = True
                    self.gameTree.append({"Player Cards":self.game.PH,
                                          'Player Sum': self.game.sum("Player"),
                                          'Useful Ace': self.game.usefulAce,
                                          'Dealer Face': self.game.dealderFaceUp(),
                                          "Result":"Busted"})
                else: #not busted yet
                    # print(self.game.sum("Player"))
                    self.gameTree.append({"Player Cards":self.game.PH,
                                          'Player Sum': self.game.sum("Player"),
                                          'Useful Ace': self.game.usefulAce,
                                          'Dealer Face': self.game.dealderFaceUp(),
                                          "Result":"Continue"})

            elif action == 's':
                self.gameTree.append({"Player Cards":self.game.PH,
                                      'Player Sum': self.game.sum("Player"),
                                      'Useful Ace': self.game.usefulAce,
                                      'Dealer Face': self.game.dealderFaceUp(),
                                      "Result":"End Round"})
                break
            else:
                print("agent gives illegal action.")

    def dealerTurn(self):
        while (True and self.stop == False):
            if self.game.sum("Dealer") >= 17:
                action = 's'
                # print("Dealer doesnt want this.")
                self.gameTree.append({'Dealer Action':action, "Dealer Cards":self.game.DH, "Result":"End Round"})
                break
            else:
                action = 'h'
                self.game.newHand("Dealer")
                # print("Dealer's card after action: \n")
                # self.game.displayHand("Dealer")
                if self.bust("Dealer"):
                    # print("Dealer is busted!!!")
                    # print("------------------------------end of the game-----------------------------------\n")
                    # print("The winnder is: " + "Player" + "\n")
                    self.winner = 'Player'
                    # print("For Player: \n")
                    # self.game.displayHand("Player")
                    # print("For Dealer: \n")
                    # self.game.displayHand("Dealer", showAll=True)
                    self.someoneBusted = True
                    self.gameTree.append({'Dealer Action':action, "Dealer Cards":self.game.DH, "Result":"Busted"})
                else: #not busted yet
                    self.gameTree.append({'Dealer Action':action, "Dealer Cards":self.game.DH, "Result":"Continue"})
                    # print(self.game.sum("Dealer"))

    def bust(self, target):
        busted = False
        # print(self.game.sum(target))
        if self.game.sum(target) > 21:
            busted = True
        return busted

    def needStop(self):
        self.stop = True

    def checkWin(self): #boolean, under curr state, player or dealer wins.
        PV = self.game.sum("Player")
        DV = self.game.sum("Dealer")
        if PV > DV:
            winner = "Player"
        elif PV == DV:
            winner = "Tie"
        else:
            winner = "Dealer"
        return winner

class Grid:
    def __init__(self):
        self.L = 13
        self.W = 20
        self.grid = np.array([[0.00] * self.W for _ in range(self.L)])
        self.row_indexes = ['A', '2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K']
        self.col_indexes = [2,3,4,5,6,7,8,9,10,11,12,13, 14, 15, 16, 17, 18, 19, 20, 21] #999 for busted case

    def getValue(self,row, col):
        return self.grid[self.row_indexes.index(row), self.col_indexes.index(col)]

    def updateValue(self,row, col, val):
        #row: 'A' '2' ...
        #col: int 2 - 21
        # print(self.grid[self.row_indexes.index(row), self.col_indexes.index(col)], val)
        self.grid[self.row_indexes.index(row), self.col_indexes.index(col)] = val
        # print(self.grid[self.row_indexes.index(row), self.col_indexes.index(col)])
        return True

    def plotGrid(self):
        # for row in self.grid:
        #     print(" ".join(map(str, row)))

        # row_mapping = {label: idx for idx, label in enumerate(self.row_indexes)}
        # col_mapping = {label: idx for idx, label in enumerate(self.col_indexes)}
        # numeric_row_indexes = [row_mapping[row] for row in self.row_indexes]
        # numeric_col_indexes = [col_mapping[col] for col in self.col_indexes]

        # ax.imshow(self.grid, cmap='Greys', extent=(0, self.W, 0, self.L))
        self.grid = np.round(self.grid, decimals=2)

        fig, ax = plt.subplots(figsize=(12,12))

        ax.set_xticks(np.arange(len(self.col_indexes)), labels=self.col_indexes)
        ax.set_yticks(np.arange(len(self.row_indexes)), labels=self.row_indexes)

        # norm = colors.Normalize(vmin=self.grid.min(), vmax=self.grid.max())
        norm = colors.Normalize(vmin=-1, vmax=1)
        colorMap = cm.get_cmap('hot')
        ax.imshow(self.grid, cmap='coolwarm')

        for i in range(len(self.row_indexes)):
            for j in range(len(self.col_indexes)):
                value = self.grid[i, j]
                # Normalize the value
                normalized_value = norm(value)

                # Map the normalized value to a color
                raw_color = colorMap(normalized_value)

                # Round each component of the color to two decimal places
                color = tuple(round(component, 4) for component in raw_color)
                text = ax.text(j, i, value, ha="center", va="center", color=color, fontsize=8)

        # ax.grid(visible=True, color='red', linestyle='-', linewidth=0.5)
        # plt.tight_layout()

        plt.show()
        return 0

    def plotQTable(Qs): #Qs a dict{(Hit, A): table, (Hit, no A): table, () : table, (): table}


        return 0

In [None]:
class DNN:
    def __init__(self, num_input, num_output):
      self.num_input = num_input
      self.num_output = num_output
      self.model = Sequential()
      self.model.add(Input(shape=(self.num_input,)))
      self.model.add(Dropout(0.5))
      self.model.add(Dense(256, use_bias=False))
      self.model.add(Dropout(0.5))
      self.model.add(Activation('relu'))
      self.model.add(Dense(256, use_bias=False))
      self.model.add(Dropout(0.2))
      self.model.add(Activation('relu'))
      # self.model.add(tf.keras.layers.Dropout(0.5))
      self.model.add(Dense(num_output, use_bias=False))
      self.model.add(Activation('softmax'))
      self.optimizer = Adam(learning_rate=0.001)


    def predict(self, X):
        return self.model.predict(X, verbose=0)

    def compile(self, optimizer, loss):
        self.model.compile(optimizer=optimizer, loss=loss)

    # def forward(self, X):
    #     return self.model(X)

    def backward(self, lr, gt, b, at, st, yeta_t):
      # print("st:", st)
      with tf.GradientTape() as tape:
        # tape.watch(self.model.weights)
        y = self.model(st)
        # print([var.name for var in tape.watched_variables()])
        epsilon = 1e-5
        y =  tf.clip_by_value(y, 0.05, 0.95)
        print(type(y), type(at))
        print("-gt: ", -gt)
        print("-gt*yeta_t", yeta_t)
        print("at * y:", at * y)
        print("tf.reduce sum at * y: ", tf.reduce_sum(at * y))
        print("tf.log(xxx): ", tf.math.log(tf.reduce_sum(at * y)))
        # loss = -(gt)*yeta_t*(tf.math.log(tf.reduce_sum(at * y)))
        loss = -(tf.reduce_sum(gt * at * y))
        print(f"loss{loss}")
      grad = tape.gradient(loss, self.model.trainable_variables)
      # print(f"grad: {grad}")

      # self.optimizer.apply_gradients(zip(grad, self.model.trainable_variables))
      self.optimizer.apply_gradients(zip(grad, self.model.weights))


#------------------------------------------------------------------------------------------

class nn_agent:
    def __init__(self, action_space):
        self.action_space = action_space
        self.net = DNN(3, 2)

    def policy(self, state):
        pdf = self.pdf(state)
        if np.random.uniform(0,1) < 0.1:
            # print("random choice")
            action = np.random.choice(self.action_space)
        else:
            # action = self.action_space[np.argmax(pdf)] #不应该argmax，而是要依据distribution来make choice: np.random.choice(action_space, pdf[0])，这是一个巨大的误区！！！！
            action = np.random.choice(self.action_space, p=pdf[0])
        print(f"choosing an action {action} based on {pdf}")
        return action

    def pdf(self, state):
        s = self.to_numeric_state(state)
        # print(state, s)
        val = self.net.predict(s)
        # print(val)
        return val

    def backward_at_t(self, lr, gt, b, at, st, yeta_t): #preds, gt, at, st, yeta_t, lr
        self.net.backward(lr, gt, b, at, self.to_numeric_state(st), yeta_t)

    #helper method
    def to_numeric_state(self, state):
        #state mus be numeric............
        #original state in the game is called location:
        #location = (int(self.game.usefulAce), self.game.dealderFaceUp(), self.game.sum("Player"))
        DF_refer = {'A':1, '2':2, '3':3, '4':4, '5':5, '6':6, '7':7, '8':8, '9':9, 'T':10, 'J':11, 'Q':12, 'K':13}
        numeric_state = np.array([int(state[0]), DF_refer[state[1]], state[2]], dtype=np.float32).reshape(1,-1)
        return numeric_state


class nn_traject:
    def __init__(self, agent, QTable):
        self.agent = agent
        self.session = Session()
        self.DF = None #'A'
        self.winner = None #'Player' 'Dealer' 'Tie'
        self.SATraj = [] #记录的是每一个game tree node在(S,A)空间里的坐标位置 -> rewardt
                        #[l1, r1, l2, r2, l3, r3, ..., ln, rn]

    def currentState(self, node): #a node is a dict from session.gameTree
        PS = node['Player Sum']
        usefulAce = node['Useful Ace']
        self.DF = node['Dealer Face']
        return (usefulAce, self.DF, PS)

    def currentReward(self):
        if self.winner != None:
            if self.winner == "Player":
                return 1 #playe r wins
            elif self.winner == 'Tie':
                return 0
            else:
                return -1 #dealer wins
        else:
            return 0 #game still on-going

    def generateTraject(self): #s1, r1, s2, r2, ...
        #print('------------------------start of the traject----------------------\n')
        self.session.runAgent(self.agent)
        traj = self.session.gameTree
        trajIter = iter(traj)
        next(trajIter)

        while True:
            currNode = next(trajIter)
            # print(currNode)
            currState = self.currentState(currNode)
            UA = currState[0]
            DF = currState[1]
            PS = currState[2]

            #1.current state is not the ending state (is a legal state), which has the action for next round
            if 'Player Action' in currNode.keys():
                action = currNode['Player Action']
                r = self.currentReward()

                self.SATraj.append((UA, DF, PS))
                self.SATraj.append(action)
                self.SATraj.append(r)
            #2.current state is the busted or continued state given by a stay, there will be no action
            #the next node is either 'END' or 'DEALERROUND'
            elif  'Player Action' not in currNode.keys() and 'Result' in currNode.keys():
                result = currNode['Result']
                if result == 'End Round':
                    #不会在这里添加reward
                    #但是要删去上一步的r
                    # assert self.SATraj.pop() == 0
                    self.SATraj.pop()
                    break
                elif result == 'Busted':
                    #会在这里添加reward
                    self.winner = self.session.winner
                    r = self.currentReward()
                    # assert self.SATraj.pop() == 0
                    self.SATraj.pop()
                    self.SATraj.append(r)
                    # print("return is here")
                    return self.SATraj
                else:

                    raise Exception
            #3.DEALERROUND
            else:
                break

        while True:
            currNode = next(trajIter)
            #整个游戏结束
            if currNode == 'END':
                self.winner = self.session.winner
                r = self.currentReward()
                self.SATraj.append(r)
                break
            #dealer 还在继续
            #没有任何需要update的地方
        # print("return is here")
        return self.SATraj


class nn_mc:
    def __init__(self, M):
        self.M = M
        self.agent = None

    def start(self):
        QTs = {'h': [Grid(), Grid()], 's':[Grid(), Grid()]} # first grid is UA=0 ,sec grid is UA=1
        action_space = ['h', 's']
        agent = nn_agent(action_space)
        lr = 1/100
        for k in range(M):
            gts = [0]
            if k%10000 == 0: print(k)
            # print(f"seesion-----------------{k}--------------------------------------")
            # print(agent.action_space)
            m = nn_traject(agent, QTs).generateTraject()
            # print(m)
            paired_m = [(m[i], m[i + 1], m[i + 2]) for i in range(0, len(m), 3)] # ( , ,28)
            # print(paired_m)
            for i in range(len(paired_m)):
                # print(f"-step {i}-")
                gt = np.sum([paired_m[j][2]*(0.95)**(j-i+1) for j in range(i,len(paired_m))])
                gts.append(gt)
                at = np.array([float(paired_m[i][1]==a) for a in action_space]) #one-hot for action at t
                st = paired_m[i][0]
                yeta_t = 0.95**(i)
                b = np.mean(gts)
                print(b)
                agent.backward_at_t(lr, gt, b, at, st, yeta_t)

        self.agent = agent

    def pdf_to_action_table(self): #p(a=hit | s) for every pair of a and s
        g0 = Grid()
        g1 = Grid()
        action_table = [g0, g1] #UA=0 UA=1
        UA = False
        for i in g0.row_indexes:
            for j in g0.col_indexes:
                s = [UA,i,j]
                # print(s)
                # print(self.agent.pdf(s))
                prob_hit = self.agent.pdf(s)[0][0]
                # print(prob_hit)
                g0.updateValue(i,j, prob_hit)

        return action_table

In [None]:
# QTs = {'h': [Grid(), Grid()], 's':[Grid(), Grid()]} # first grid is UA=0 ,sec grid is UA=1
# action_space = ['h', 's']
# agent = nn_agent(action_space)
# m = nn_traject(agent, QTs).generateTraject()
# lr = 1/100

# paired_m = [(m[i], m[i + 1], m[i + 2]) for i in range(0, len(m), 3)] # ( , ,28)
# print(paired_m)
# for i in range(0, len(paired_m)):
#     print(f"--------------session {i}-----------------")
#     print()
#     gt = np.sum([paired_m[j][-1]*(0.95)**(j-i+1) for j in range(i,len(paired_m))])
#     print(gt)
#     at = np.array([float(paired_m[i][1]==a) for a in action_space]) #one-hot for action at t
#     st = paired_m[i][0]
#     print("st:", st)
#     yeta_t = 0.95**(i)
#     print(f"gt:{gt}, at: {at}")
#     agent.backward_at_t(lr, gt, at, st, yeta_t)

# print(tf.test.gpu_device_name())

with tf.device('/GPU:0'):
  M = 100
  sim = nn_mc(M)
  sim.start()

  AT_hit = sim.pdf_to_action_table()
  print(AT_hit[0].grid)
  AT_hit[0].plotGrid()


In [None]:
AT_hit[0].grid