# Deep Reinforcement Learning based Chess Bot
## The bot will learn chess from scratch solely by playing against itself


## <center>Let's Begin The Journey !!!</center>
<center>🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉🕉</center>

#Importing Dependencies

In [None]:
!pip install python-chess

In [None]:
import numpy as np
import chess
import tensorflow as tf
import pandas as pd
from keras import Model, optimizers
from keras.models import load_model
from keras.layers import *
from keras.utils import plot_model
import sys
import os
import gc
import json
from google.colab import drive
from collections import deque
import multiprocessing
import threading
import requests
import time
import tempfile
import traceback
import warnings
drive.mount('/content/gdrive')

#Creating Model Class

In [None]:
class Model_handler :
  def __init__(self, id = "None") :
    self.server_url = self.get_server_url()
    self.session = requests.Session()
    self.id =  self.send_request("/ready-model/", id, data = None)
    self.id = self.id.text
    self.is_updated_version = False

  def load_latest(self) :
    response = self.send_request("/download-model/", self.id, data = None)
    with tempfile.NamedTemporaryFile(delete=False, suffix=".keras") as temp_model_file:
      temp_model_file.write(response.content)
    self.model = tf.keras.models.load_model(temp_model_file.name, compile = False)
    self.is_updated_version = True

  def check_update_status(self, acknowledge_id) :
    response = self.send_request("/check-update-status/", acknowledge_id, data = None)
    if response.text == "Done" :
      return True
    return False

  def predict(self, state, color) : #as per the network
    prediction = self.model(self.convert_state_to_input(state, color))
    return self.convert_output_to_probs(state, color, prediction[0].numpy()[0][0], prediction[1][0])

  def bulk_predict(self, arr) :
    states = [x[0].state for x in arr]
    colors = [x[0].color for x in arr]
    batch_size = 64
    inputs, masks = self.convert_state_to_input(states, colors)
    for i in range(0, len(states), batch_size):
      value_predictions, _ = self.model(inputs = (inputs[i:i+batch_size], masks[i:i+batch_size]))
      for j in range(0, len(value_predictions)) :
        arr[i+j][0].value = value_predictions.numpy()[j][0]

  def update_model(self, states, colors, values, actions,config) :
    println("Model Update")
    inputs, masks = self.convert_state_to_input(states, colors)
    result = self.send_request("/train/", self.id, data = {
        "inputs": inputs.tolist(),
        "masks": masks.tolist(),
        "values": values,
        "actions": actions,
        "config": config
      })
    self.is_updated_version = False
    return result.text

  def get_server_url(self) :
    with open(os.path.join('/content/gdrive', 'My Drive', 'CHESS-AI', 'url.txt'), 'r') as drive_file :
      content = drive_file.read()
    return content

  def send_request(self, path, id, data = None) :
    first_try = True
    while True :
      try :
        if data is None :
          response = self.session.get(url = self.server_url + path + id)
        else :
          response = self.session.post(url = self.server_url + path + id, json = data)
        if response.status_code == 200:
          return response
        while True :
          self.server_url = self.get_server_url()
          println("Retrying request with URL "+self.server_url)
          if path == "/ready-model/" :
            temp = self.session.get(url = self.server_url+"/ready-model/"+id)
            if temp.status_code == 200 :
              return temp
          else :
            temp = self.session.get(url = self.server_url+"/ready-model/"+self.id)
            if temp.status_code == 200 :
              return self.send_request(path, id, data)
          time.sleep(10)
      except Exception as e:
        if first_try :
          first_try = False
        print(e)
        self.session = requests.Session()
        self.server_url = self.get_server_url()

  def convert_state_to_input(self, state, color) :
    if type(state) == str :
      temp = state.split("_")
      temp_arr = np.zeros((8,8,12), dtype = np.float16)
      arr2 = np.zeros((8,8,10), dtype = np.float16)
      temp_arr, arr2 = self.convert_board_to_input(temp[-1], color)
      for i in range(1,5) :
        temp_arr = np.concatenate([self.convert_board_to_input(temp[-1-i], color, False), temp_arr], axis = 2)
      return (np.expand_dims(temp_arr, axis = 0), np.expand_dims(arr2, axis = 0))
    inputs = np.empty(shape = (len(state),8,8,60))
    masks = np.empty(shape = (len(state),8,8,10))
    for s in range(len(state)) :
      temp = state[s].split("_")
      temp_arr = np.zeros((8,8,12), dtype = np.float16)
      arr2 = np.zeros((8,8,10), dtype = np.float16)
      temp_arr, arr2 = self.convert_board_to_input(temp[-1], color[s])
      for i in range(1,5) :
        temp_arr = np.concatenate([self.convert_board_to_input(temp[-1-i], color[s], False), temp_arr], axis = 2)
      inputs[s] = temp_arr
      masks[s] = arr2
    return (inputs, masks)

  def convert_board_to_input(self, state, color, current = True) :
    if current :
      board = chess.Board(state)
      board.turn = color
      arr = np.zeros((8,8,12), dtype = np.float16)
      arr2 = np.zeros((8,8,10), dtype = np.float16)
      piece_to_value = self.get_piece_to_value(color)
      piece_to_value2 = self.get_piece_to_value(color, False)
      for i in range(64) :
        if(board.piece_at(i) is not None) :
          arr[i//8,i%8,piece_to_value[board.piece_at(i).symbol()]] = 1
      for move in board.legal_moves :
        square = move.to_square
        arr[square//8, square%8, piece_to_value[board.piece_at(move.from_square).symbol()]] = 0.5
        symbol = board.piece_at(move.from_square).symbol()
        if move.promotion is not None :
          arr2[move.promotion-2, move.from_square%8, 9] = 1
        else :
          arr2[square//8, square%8, piece_to_value2[symbol]] = 1
          if(piece_to_value2[symbol] == 1 or piece_to_value2[symbol] == 3 or piece_to_value2[symbol] == 5) :
            piece_to_value2[symbol] += 1
      return (arr, arr2)
    else :
      arr = np.zeros((8,8,12), dtype = np.float16)
      if len(state) == 0 :
        return arr
      board = chess.Board(state)
      board.turn = color
      piece_to_value = self.get_piece_to_value(color)
      for i in range(64) :
        if(board.piece_at(i) is not None) :
          arr[i//8,i%8,piece_to_value[board.piece_at(i).symbol()]] = 1
      return arr

  def get_piece_to_value(self, color, inp = True) :
    if inp :
      if(color == 1) :
        return {
        'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
        'p': 6, 'n':7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
        }
      return {
        'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
        'P': 6, 'N':7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11
      }
    else :
      if(color == 1) :
        return {
        'P': 0, 'N': 1, 'B': 3, 'R': 5, 'Q': 7, 'K': 8
        }
      return {
        'p': 0, 'n': 1, 'b': 3, 'r': 5, 'q': 7, 'k': 8
      }

  def convert_action_to_softmax(self, states, colors, moves) :
    array = []
    for i in range(len(states)) :
      board = chess.Board(states[i].split("_")[-1])
      board.turn = colors[i]
      move = chess.Move.from_uci(moves[i])
      to_square = move.to_square
      from_square = move.from_square
      piece_to_value = self.get_piece_to_value(colors[i],False)
      piece_type = piece_to_value[board.piece_at(from_square).symbol()]
      arr = np.zeros((8,8,10), dtype = np.float64)
      if(piece_type == 1 or piece_type == 3 or piece_type == 5) :
        present = False
        for i in range(from_square) :
          if board.piece_at(i) is not None :
            if board.piece_at(i).symbol() == board.piece_at(from_square).symbol() :
              arr[to_square//8, to_square%8, piece_type+1] = 1
              present = True
              break
        if not present :
          arr[to_square//8, to_square%8, piece_type] = 1
      else :
        if move.promotion is not None :
          arr[move.promotion-2, from_square%8, 9] = 1
        else :
          arr[to_square//8, to_square%8, piece_type] = 1
      array.append(arr.flatten().tolist())
    return array

  def convert_output_to_probs(self, state, color, value_output, policy_output) :
    policy = tf.reshape(policy_output, [8,8,10])
    board = chess.Board(state.split("_")[-1])
    board.turn = color
    piece_to_value = self.get_piece_to_value(color,False)
    move_dict = {}
    for move in list(board.legal_moves) :
      to_square = move.to_square
      from_square = move.from_square
      piece_type = piece_to_value[board.piece_at(from_square).symbol()]
      if move.promotion is not None :
        move_dict[move.uci()] = policy[move.promotion-2, from_square%8, 9]
      else :
        move_dict[move.uci()] = policy[to_square//8, to_square%8, piece_type]
        if(piece_type == 1 or piece_type == 3 or piece_type == 5) :
          piece_to_value[board.piece_at(from_square).symbol()] += 1
    return (value_output.item(), [item[0] for item in sorted(move_dict.items(), key = lambda x: x[1], reverse = True)])

# Creating Environment Class

In [None]:
class Env :
  def __init__(self,board = chess.Board(), turn = True) :
    if not (type(board) == str) :
      self.board = board
      self.game_on = True
      self.history = deque(["", "", "", ""], maxlen = 4)
    else :
      self.game_on = True
      self.history = deque(["", "", "", ""], maxlen = 4)
      states = board.split("_")
      for i in range(4):
        if(len(states[i])>0) :
          self.history.append(states[i])
      if(len(states[-1])>0) :
        self.board = chess.Board(states[-1])
      else :
        self.board = chess.Board()
    self.turn = turn
    self.board.turn = turn
    self.repeat_penalty_w = 0
    self.repeat_penalty_b = 0

  def reset(self,agent_w,agent_b) :
    self.board = chess.Board()
    self.game_on = True
    self.history.clear()
    agent_w.episodes = []
    agent_b.episodes = []

  def print(self) :
    display(self.board)

  def get_legal_moves(self, is_uci = False) :
    if(is_uci) :
      arr = []
      for move in self.board.legal_moves :
        arr.append(move.uci())
      return arr
    return list(self.board.legal_moves)

  def get_curr_state(self) :
    #return current state
    c = 4
    prev_state = ""
    for state in self.history :
      prev_state += state + "_"
    return prev_state + self.board.board_fen()
  def do_action(self,action,agent_color) :
    #perform action and return immediate reward
    self.history.append(self.board.board_fen())
    map = self.board.piece_map()
    self.board.push_uci(action)
    self.turn = (agent_color == False)
    return (self.get_immediate_reward(agent_color, map), self.get_curr_state())

  def get_immediate_reward(self,color, map) :
    outcome = self.board.outcome(claim_draw = True)
    if(outcome is None) :
      return self.evaluation_fx(color, map)
    println(outcome.termination)
    self.game_on = False
    winner = outcome.winner
    if(winner is None) :
      return 0
    if winner == color :
      return 1
    return -1

  def evaluation_fx(self, color, map) :
    board = chess.Board(self.board.board_fen())
    board.turn = color
    new_map = self.board.piece_map()  ## reward for capturing
    if len(map)==len(new_map) :
      capture = 0
    else :
      piece_value = [1, 3, 3, 5, 9, 0]
      for piece in map.values() :
        if piece not in new_map :
          capture = piece_value[piece.piece_type-1]/10
          break
    moves = [0,0,0,0,0]  ## reward for piece positions
    for move in board.legal_moves :
      piece = board.piece_type_at(move.from_square)-1
      if(piece<5) :
        moves[piece] += 1
    prev_moves = self.board.move_stack
    if len(prev_moves)>=3 :
      if prev_moves[-3].from_square==prev_moves[-1].to_square and prev_moves[-1].from_square==prev_moves[-3].to_square :
        if color :
          self.repeat_penalty_w += 0.35
        else :
          self.repeat_penalty_b += 0.35
      elif prev_moves[-3].to_square==prev_moves[-1].from_square :
        if color :
          self.repeat_penalty_w += 0.32
        else :
          self.repeat_penalty_b += 0.32
      else :
        if color :
          self.repeat_penalty_w = 0
        else :
          self.repeat_penalty_b = 0

    if color :
      return (moves[0]/16 + moves[1]/15 + moves[2]/24 + moves[3]/26 + moves[4]/24)/4+capture-self.repeat_penalty_w
    else :
      return (moves[0]/16 + moves[1]/15 + moves[2]/24 + moves[3]/26 + moves[4]/24)/4+capture-self.repeat_penalty_b

#Creating Agent Class

In [None]:
class Agent :
  def __init__(self,config, model, color) :
    self.config = config
    self.model = model
    self.color = color
    self.episodes = []
 #state,self.env.turn, i, step, start_iter, experience, saved_state, self.rand_gen, self.seed
  def get_action(self, state, color, ep_no, step_no, start_iter, experience, saved_tree, seed, logs) :
    try :
      tree = TD_Tree_Search(self.model, self, state, color, ep_no, step_no, start_iter, experience, seed)
      experience = tree.traverse(saved_tree)
    except Exception :
      print(traceback.print_exc())
      time.wait(1000) # to avoid accidently sending wrong data for training
    states = []
    colors = []
    values = []
    moves = []
    for key in experience.keys() :
      data = key.split("-")
      states.append(data[0])
      colors.append(eval(data[1]))
      moves.append(data[2])
      values.append(experience[key])
      logs = pd.concat([logs, pd.DataFrame({'Step_No': [step_no],
                                            'States': [states[-1]],
                                            'Colors': [colors[-1]],
                                            'Moves': [moves[-1]],
                                            'Values' : [values[-1]],
                              })], ignore_index= True)
      logs.reset_index(drop=True, inplace=True)
    acknowlegement_id = self.update_model(states, colors, values, self.model.convert_action_to_softmax(states, colors, moves))
    println(("Update Request sent with acknowledgement No. -", acknowlegement_id))
    logs.to_csv(os.path.join('/content/gdrive', 'My Drive', 'CHESS-AI', self.model.id, "Exploration_Backup", "logs_"+str(seed)+".csv"))
    while True :
      if self.model.check_update_status(acknowlegement_id) :
        break
      time.sleep(10)
    self.model.load_latest()
    println("Model Update completed")
    return tree.choose_actions(self.model.predict(state, color)[1],self.config["exploit_epsilon"])

  def update_model(self, states, colors, values, actions) :
    return self.model.update_model(states, colors, values, actions, self.config)

#Creating Trainer Class

In [None]:
class Trainer :
  def __init__(self, agent_w, agent_b, n_episodes, seed, init_board = chess.Board(), init_turn = True) :
    self.init_board = init_board
    self.agent_w = agent_w
    self.agent_b = agent_b
    self.n_eps = n_episodes
    self.init_turn = init_turn
    self.seed = seed
    self.logs = pd.DataFrame({'Step_No': [],
                                'States': [],
                                'Colors': [],
                                'Moves': [],
                                'Values' : [],
                              })
    self.start_training()

  def start_training(self) :
    if self.agent_w.model.id == self.agent_w.model.id :
      self.model = self.agent_w.model
      self.id = self.agent_w.model.id
    else :
      println("Error: Different Model IDs used for both agent use single model for both")
      return
    println("Training Initiated")
    start_eps, start_step, start_iter, experience, tree, init_state, init_turn = self.load_backup_experiences()
    for i in range(start_eps,self.n_eps) :
      self.env = Env(init_state, init_turn)
      step = start_step
      while(self.env.game_on) :
        state = self.env.get_curr_state()
        if self.env.turn :
          action = self.agent_w.get_action(state,self.env.turn, i, step, start_iter, experience, tree, self.seed, self.logs)
        else :
          action = self.agent_b.get_action(state,self.env.turn, i, step, start_iter, experience, tree, self.seed, self.logs)
        reward, next_state = self.env.do_action(action, self.env.turn)
        step +=1
        println(("Ep_No.- ", i+1, "Move_No.- ", step, state, self.env.turn, reward))
        start_iter = 0
        experience = {}
        saved_state = {}
        tree = None
      step = 0
      self.env.reset(self.agent_w,self.agent_b)

  def load_backup_experiences(self) :
    while True :
      if not os.path.exists(os.path.join('/content/gdrive', 'My Drive', 'CHESS-AI', self.model.id, "Exploration_Backup")):
        println(f'Directory created by server is missing in Google Drive this may be a synchronization issue in GDrive, retrying after 5 secs')
        time.sleep(5)
      else :
        break
    file_path = os.path.join('/content/gdrive', 'My Drive', 'CHESS-AI', self.model.id, "Exploration_Backup", "exploration_"+str(self.seed)+".json")
    if not os.path.exists(file_path):
      self.logs.to_csv(os.path.join('/content/gdrive', 'My Drive', 'CHESS-AI', self.id, "Exploration_Backup", "logs_"+str(self.seed)+".csv"))
      return (0, 0, 0, {}, None, self.init_board, self.init_turn)
    self.logs = pd.read_csv(os.path.join('/content/gdrive', 'My Drive', 'CHESS-AI', self.id, "Exploration_Backup", "logs_"+str(self.seed)+".csv"))
    with open(file_path, 'r') as drive_file :
      content = drive_file.read()
    explr_dict = json.loads(content)
    println(f'Backup Loaded, training resumed from Ep - {explr_dict["start_eps"]}, Step - {explr_dict["start_step"]} and Iteration - {explr_dict["start_iter"]}')
    return (explr_dict["start_eps"], explr_dict["start_step"], explr_dict["start_iter"], explr_dict["experience"], explr_dict["tree"], explr_dict["init_state"], explr_dict["init_turn"])

# Creating TD_Search_Tree Class

In [None]:
class TD_Tree_Search :
  def __init__(self, model, agent, state, color, ep_no, step_no, start_iter, experience, seed) :
    self.model = model
    self.agent = agent
    self.config = agent.config
    self.state = state
    self.color = color
    self.ep_no = ep_no
    self.step_no = step_no
    self.start_iter = start_iter
    self.seed = seed
    self.experience = experience

  def traverse(self, saved_tree) :
    self.root = self.load_from_dict(saved_tree)
    if not self.model.is_updated_version :
      self.model.load_latest()
    if self.start_iter == 0 :
      self.save_backup_experiences(-1, self.root)
    for self.i in range(self.start_iter, self.config["no_simulations"]) :
      startt = time.time()
      self.rand_gen = np.random.default_rng(int(str(self.seed)+str(self.ep_no)+str(self.step_no)+str(self.i)))
      node = self.root
      env = Env(self.state, self.color)
      path = []
      while(env.game_on) :
        if not hasattr(node, "value") :
          predictions = self.model.predict(node.state, node.color)
          node.value = predictions[0]
          node.policy = predictions[1]
        action = self.choose_actions(node.policy,self.config["epsilon"])
        reward, next_state = env.do_action(action,node.color)
        path.append((node, action, reward))
        if(action not in node.actions) :
          break
        node = node.children[next_state]
      node = node.add_child(next_state, not node.color, action,reward)
      endt = time.time()
      temp_path = self.rollout(env, node, path)
      endr = time.time()
      self.backpropogate(path, temp_path)
      node.del_children(env)
      endb = time.time()
      try :
        println(f"Iter - {self.i+1} Total {len(path)+len(temp_path)} steps in {endb-startt} secs, <Selection - Steps = {len(path)}, Time = {endt-startt}, Avg = {(endt-startt)/len(path)}>, <Rollout - Steps = {len(temp_path)}, Time = {endr-endt}, Avg = {(endr-endt)/len(temp_path)}>")
      except :
        pass
      if (self.i+1) % self.config["batch_size"] == 0 :
        self.save_backup_experiences(self.i, self.root)
    self.rand_gen = np.random.default_rng(int(str(self.seed)+str(self.ep_no)+str(self.step_no)))
    return self.experience

  def rollout(self, env, init_node, path) :
    node = init_node
    temp_path = []
    while(env.game_on) :
      action = self.choose_actions(env.get_legal_moves(True),1)
      reward, next_state = env.do_action(action,node.color)
      temp_path.append((node, action, reward))
      if len(temp_path) == 256 :
        break
      node = node.add_child(next_state, not node.color, action,reward)
    self.model.bulk_predict(temp_path)
    return temp_path

  def backpropogate(self, path, temp_path) :
    eps = path+temp_path
    eligibility_traces_white = {state[0].state : 0.0 for state in eps if state[0].color == True}
    eligibility_traces_black = {state[0].state : 0.0 for state in eps if state[0].color == False}
    for i in range(len(eps)) :
      state = eps[i]
      if state[0].color :
        if i < len(eps) - 2 :
          next_state = eps[i + 2]
          reward = eps[i][2]-eps[i+1][2]
          delta = reward + self.config["discount"] * next_state[0].value - state[0].value
        elif i < len(eps) - 1 :
          next_state = None
          if eps[i+1][2] == 1 or eps[i+1][2] == -1:
            reward = -1*eps[i+1][2]
          else :
            reward = eps[i][2]-eps[i+1][2]
          delta = reward - state[0].value
        else :
          next_state = None
          reward = eps[i][2]
          delta = reward - state[0].value
        eligibility_traces_white[state[0].state] += 1
        for s in range(len(path)):
          if path[s][0].color :
            path[s][0].value += self.config["td_learning_rate"] * delta * eligibility_traces_white[path[s][0].state]
            eligibility_traces_white[path[s][0].state] *= self.config["discount"] * self.config["lmbda"]
      else :
        if i < len(eps) - 2 :
          next_state = eps[i + 2]
          reward = eps[i][2]-eps[i+1][2]
          delta = reward + self.config["discount"] * next_state[0].value - state[0].value
        elif i < len(eps) - 1 :
          next_state = None
          reward = eps[i][2]-eps[i+1][2]
          delta = reward - state[0].value
        else:
          next_state = None
          reward = eps[i][2]
          delta = reward - state[0].value
        eligibility_traces_black[state[0].state] += 1
        for s in range(len(path)):
          if path[s][0].color == False :
            path[s][0].value += self.config["td_learning_rate"] * delta * eligibility_traces_black[path[s][0].state]
            eligibility_traces_black[path[s][0].state] *= self.config["discount"] * self.config["lmbda"]
    for i in range(len(path)) :
      id = eps[i][0].state +"-"+ str(eps[i][0].color) +"-"+ eps[i][1]
      self.experience[id] = eps[i][0].value

  def choose_actions(self, probs, epsilon) :
    if(self.rand_gen.random() > epsilon) :
        return probs[0]
    else :
      return self.rand_gen.choice(probs)

  def save_backup_experiences(self, index, root) :
    explr_dict = {
        "start_eps" : self.ep_no,
        "start_step" : self.step_no,
        "start_iter" : index+1,
        "tree" : root.convert_to_dict(),
        "experience" : self.experience,
        "init_state" : self.state,
        "init_turn" : self.color
    }
    with open(os.path.join('/content/gdrive', 'My Drive', 'CHESS-AI', self.model.id, "Exploration_Backup", "exploration_"+str(self.seed)+".json"), 'w') as drive_file:
      drive_file.write(json.dumps(explr_dict, indent=2))
    println(f"Saved backup till Ep - {self.ep_no}, Step - {self.step_no} and Iteration - {index}")

  def load_from_dict(self, tree, parent = None) :
    if tree is None :
      return TD_Search_Node(self.state, self.color)
    root = TD_Search_Node(tree["state"], tree["color"], parent)
    root.actions = tree["actions"]
    for state in tree["children"] :
      root.children[state] = self.load_from_dict(tree["children"][state], root)
    return root

class TD_Search_Node :
  def __init__(self,state,color,parent_node = None) :
    self.state = state
    self.color = color
    self.parent = parent_node
    self.actions = {}
    self.children = {}
  def add_child(self,state,color,action,reward) :
    node = TD_Search_Node(state,color,self)
    self.actions[action] = [state,reward]
    self.children[state] = node
    return node

  def convert_to_dict(self) :
    tree = {
        "state" : self.state,
        "color" : self.color,
        "actions" : self.actions,
        "children" : {}
    }
    for state in self.children :
      tree["children"][state] = self.children[state].convert_to_dict()
    return tree

  def del_children(self, env) :
    del env
    #del self.actions
    #del self.children
    try :
      del self.value
    except :
      pass
    #gc.collect()
    self.actions = {}
    self.children = {}

#Storing Training Configurations

In [None]:
#storing  training configurations
agent_w_config = {
    "epsilon" : 0.78,
    "exploit_epsilon" : 0,
    "discount" : 0.83,
    "base_learning_rate" : 0.05,
    "actor_learning_rate" : 0.01,
    "critic_learning_rate" : 0.05,
    "td_learning_rate" : 0.1,
    "lmbda" : 0.9,
    "batch_size" : 32,
    "no_simulations" : 32*500,   # in multiple of batch_size
    "actor_coefficient" : 1,
    "critic_coefficient" : 3,
    "entropy_coefficient" : 0.1
}
agent_b_config = {
    "epsilon" : 0.78,
    "exploit_epsilon" : 0,
    "discount" : 0.83,
    "base_learning_rate" : 0.05,
    "actor_learning_rate" : 0.01,
    "critic_learning_rate" : 0.05,
    "td_learning_rate" : 0.1,
    "lmbda" : 0.9,
    "batch_size" : 32,
    "no_simulations" : 32*500,   # in multiple of batch_size
    "actor_coefficient" : 1,
    "critic_coefficient" : 3,
    "entropy_coefficient" : 0.1
}

In [None]:
warnings.simplefilter(action='ignore', category=FutureWarning)
def println(output) :
  print(multiprocessing.current_process().name, output, "\n")
  sys.stdout.flush()

# Training Pipeline

In [None]:
# @title Training Initiate Code {display-mode: "form"}
model = Model_handler("25_02_2024-10_52_47")
agent_w = Agent(agent_w_config,model,True)
agent_b = Agent(agent_b_config,model,False)
n_episodes = 2
n_process = multiprocessing.cpu_count()
#seeds = [13, 3]
seeds = [1001, 1111]
try :
  if True :
    with multiprocessing.Pool(processes = multiprocessing.cpu_count()) as pool:
      result = pool.starmap(Trainer, zip([agent_w]*n_process, [agent_b]*n_process, [n_episodes]*n_process, seeds))
  else :
    trainer = Trainer(agent_w, agent_b, n_episodes, 2002)
except Exception :
  print(traceback.print_exc())
  pool.terminate()
  gc.collect()