In [1]:
import pandas as pd
import json
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

In [2]:
# generate state action pair from its children who has the highest value
def load_data(filename, sep='\t', visit_threshold=5):
    df = pd.read_csv(filename, sep=sep)
    features = json.loads(df.iloc[0]['Game_Features']).keys()
    data = {"Action": []}

    for feature in features:
        data[feature] = []

    for _, row in df.iterrows():
        children = df[df['Parent_Name'] == row['Name']]
        children = children[children['Visits'] >= visit_threshold]
        if children.shape[0]:
            child = children.nlargest(1, 'Value')
            child_features = json.loads(child.iloc[0]['Game_Features'])
            child_action = child.iloc[0]['Action_Name']
            for feature in child_features:
                data[feature].append(child_features[feature])
            data['Action'].append(child_action)
            
    return pd.DataFrame(data)

In [3]:
dfs = []

for i in range(0, 18):
    if i == 3:
        continue
    dfs.append(load_data(f"MCTS_test_{i}.csv"))

df = pd.concat(dfs)
df

Unnamed: 0,Action,SCORE,SCORE_ADV,ORDINAL,OUR_TURN,HAS_WON,FINAL_ORD,ROUND,PROTECTED,HIDDEN,...,COUNTESS_DISCARD,PRINCESS_DISCARD,GUARD_OTHER,PRIEST_OTHER,BARON_OTHER,HANDMAID_OTHER,PRINCE_OTHER,KING_OTHER,COUNTESS_OTHER,PRINCESS_OTHER
0,Baron - compare the cards with player 1,0.2,0.4,0.5,1.0,0.0,0.0,0.1,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,Draw a card and remove protection status.,0.0,0.0,0.5,1.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,Draw a card and remove protection status.,0.0,0.0,0.5,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,Draw a card and remove protection status.,0.2,0.4,0.5,1.0,0.0,0.0,0.1,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,Baron - compare the cards with player 1,0.2,0.4,0.5,1.0,0.0,0.0,0.1,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
13,Priest - see the cards of player 0,1.2,1.6,0.5,1.0,0.0,0.0,0.8,0.0,2.0,...,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
14,Baron - compare the cards with player 0,1.4,1.6,0.5,1.0,1.0,0.5,0.9,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
15,Draw a card and remove protection status.,1.2,1.6,0.5,1.0,0.0,0.0,0.8,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
16,Draw a card and remove protection status.,1.2,1.6,0.5,1.0,0.0,0.0,0.8,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [4]:
X = df.drop(columns=['Action'])
y = df['Action']

# Max leaf is the number of action
max_leaf = len(df['Action'].unique())
features = X.columns

In [5]:
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

In [6]:
clf = DecisionTreeClassifier(max_leaf_nodes=max_leaf, random_state=42)
clf.fit(X_train, y_train)

DecisionTreeClassifier(max_leaf_nodes=24, random_state=42)

In [7]:
text_representation = tree.export_text(clf, feature_names=list(features))
print(text_representation)

|--- DRAW_DECK <= 0.59
|   |--- PROTECTED <= 0.50
|   |   |--- FINAL_ORD <= 0.25
|   |   |   |--- CARDS <= 0.11
|   |   |   |   |--- COUNTESS_DISCARD <= 0.50
|   |   |   |   |   |--- HANDMAID_DISCARD <= 0.25
|   |   |   |   |   |   |--- ROUND <= 0.90
|   |   |   |   |   |   |   |--- class: Baron - compare the cards with player 1
|   |   |   |   |   |   |--- ROUND >  0.90
|   |   |   |   |   |   |   |--- class: King - trade hands with player 1
|   |   |   |   |   |--- HANDMAID_DISCARD >  0.25
|   |   |   |   |   |   |--- class: Guard - guess if player 1 holds card Priest
|   |   |   |   |--- COUNTESS_DISCARD >  0.50
|   |   |   |   |   |--- class: Countess - needs to be discarded if the player also holds King or Prince
|   |   |   |--- CARDS >  0.11
|   |   |   |   |--- CARDS <= 0.89
|   |   |   |   |   |--- GUARD <= 0.50
|   |   |   |   |   |   |--- CARDS <= 0.23
|   |   |   |   |   |   |   |--- class: Priest - see the cards of player 1
|   |   |   |   |   |   |--- CARDS >  0.23
|   | 