# Basic implementation of MCTS to Classifier Chains

---
Importing modules to generate data

In [4]:
import numpy as np
np.set_printoptions(precision=3)
from sklearn.datasets import make_multilabel_classification
from sklearn.model_selection import train_test_split

### Generate some data

In [13]:
n_samples = 1000
n_features=6
n_labels=3
random_state=0

X, Y = make_multilabel_classification(
    n_samples=n_samples, 
    n_features=n_features, 
    n_labels=n_labels, 
    random_state=random_state)

test_size = 0.2
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=test_size, random_state=random_state)

print(f"{X_train.shape = }\n{X_train}")
print(f"{Y_train.shape = }\n{Y_train}")

X_train.shape = (800, 6)
[[14. 16.  4.  8.  7. 10.]
 [ 9.  9.  7.  6.  9.  5.]
 [ 7.  6. 19.  3.  5.  7.]
 ...
 [12.  9.  7. 14.  6.  8.]
 [ 6.  4.  3. 12.  7.  6.]
 [13.  9.  1. 15.  9.  4.]]
Y_train.shape = (800, 5)
[[1 1 1 1 0]
 [0 0 0 1 0]
 [0 0 1 1 1]
 ...
 [1 0 0 1 0]
 [0 1 1 0 0]
 [1 0 0 0 0]]


In [7]:
X_train

array([[14., 16.,  4.,  8.,  7., 10.],
       [ 9.,  9.,  7.,  6.,  9.,  5.],
       [ 7.,  6., 19.,  3.,  5.,  7.],
       ...,
       [12.,  9.,  7., 14.,  6.,  8.],
       [ 6.,  4.,  3., 12.,  7.,  6.],
       [13.,  9.,  1., 15.,  9.,  4.]])

In [None]:
from sklearn.multioutput import ClassifierChain
from sklearn.linear_model import LogisticRegression

solver = "liblinear"
base = LogisticRegression(solver=solver)
chain = ClassifierChain(base)

# chain = chain.fit(X_train,Y_train)

# Monte Carlo Tree Search Pseudocode:

> function MonteCarloTreeSearch(root) 
>>     while time_budget_not_exceeded:
>>         node_to_expand = selectNodeToExpand(root)
>>         simulation_result = simulateRandomPlayout(node_to_expand)
>>         backpropagate(simulation_result, node_to_expand)
>>     return bestChild(root)

## Below is a class that implements the constraints imposed.

In [6]:
import time
class Constraint:
    def __init__(self,time=False,max_iter=False,d_time:float=1.,n_iter:int=100,verbose=False)->None:
        assert (time or max_iter), f"At least {time=} or {max_iter=} should be True"
        assert(not max_iter or (isinstance(n_iter,int) and n_iter > 0)), f"{n_iter=} should be positive if {max_iter=}"
        assert( (not time) or (d_time > 0)), f"{d_time=} should be positive if {time=}"     
            
        self.time:bool       = time
        self.d_time:float    = d_time
        self.end_time:float = None 
            
        self.max_iter:bool   = max_iter
        self.n_iter:int      = n_iter
        self.curr_iter:int   = None
            
        self.reset()
        
        self.v:bool          = verbose
        
    def reset(self)->None:
        self.end_time = time.time() + self.d_time
        self.curr_iter = -1
        
    def _bool_time(self)->bool:
        return (not self.time or self.end_time >= time.time())
    
    def _bool_iter(self)->bool:
        self.curr_iter += 1
        return (not self.max_iter or self.curr_iter < self.n_iter)
    
    def __bool__(self)->bool:
        if self.v: # verbose
            bt = self._bool_time()
            bi = self._bool_iter()
            if not bt:
                print(f"Time Constraint Attained. Current iteration: {self.curr_iter:_}/{self.n_iter:_}")
                return False
            if not bi:
                print(f"Iteration Constraint Attained. Time left: {self.end_time - time.time():.3f}/{self.d_time}s")
                return False
            return True
        return self._bool_time() and self._bool_iter()

## Below is a class that represents the nodes in the MCTS algorithm

In [7]:
class MCTSNode:
    def __init__(self,rank:int=2,n_children:int=2):
        self.val = 0
        self.visit_count = 0
        
        self.parent = None
        self.children = None
        self.n_children = n_children
        self.rank:int = rank
        
    def __get__(self,key:int):
        assert (key >= 0 and key < n_children), f"{key} is not a valid key."
        return self.children[key]

    def is_terminal(self):
        return self.rank == 0
    
    def is_expanded(self):
        return self.children==None

# In MCTS, many alogirhtms are needed. We implement some below

In [8]:
def randmax(A):
    maxValue=max(A,key=lambda x:x.proba)
    index = [i for i in range(len(A)) if A[i].proba==maxValue]
    return np.random.choice(index)

def eps_greedy(node:MCTSNode,eps:float=0.1)->int:
    if np.random.rand < eps: # explore
        return np.random.choice(node.children)
    return randmax(node.children)

def select(node,selector)->MCTSNode:
    while(not node.is_expanded() and not node.is_terminal()):
        ind = selector(node)
        node = node[ind]
    return node

def back_prog(node,reward):
    assert (node.is_terminal()), f"The node should be terminal to back-propagate. Node rank={node.rank}"

In [11]:
def MCTS(root=None,n_label:int = 4):
    if root is None:
        root = MCTSNode(rank=n_label)
    
    cons = Constraint(time=True,d_time=1,max_iter=True,n_iter=10000000,verbose=True)
    while(cons):
#         print("hi")
        continue
#         node = select(root)
#         reward = simulate(node)
#         back_prog(node,reward)
        
#     return bestChild(root)

MCTS()

Time Constraint Attained. Current iteration: 2_233_809/10_000_000
