In [None]:
import random
from tqdm import tqdm
import numpy as np
import torch
import pandas as pd
import torch.nn as nn
import copy
from milestone2 import *

In [None]:
STARTING = -1
DESTINATION = 56
SAFE_SQUARES = [0, 8, 13, 21, 25, 26, 34, 39, 47, 51, 52, 53, 54, 55, 56]
DEVICE=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

cuda


## Environment

In [3]:
class Ludo:
    def __init__(self, render_mode=""):

        self.all_gotis = [Gotis("red"), Gotis("yellow")]
        self.dice = Dice()
        self.terminated = False
        self.player_turn = 0
        self.roll = self.dice.roll()
        self.render_mode = render_mode

    def __repr__(self):
        gotis_repr = "\n\n".join([repr(g) for g in self.all_gotis])
        return (
            f"{gotis_repr}\n\n"
            f"Dice Roll: {self.roll}\n"
            f"Terminated: {self.terminated}\n"
            f"Player Turn: {self.player_turn}"
        )

    def step(self, action=None):
        if self.terminated:
            return self._get_state()

        if self._no_action_possible(action):
            return self._change_turn()

        if not self._is_valid_input(action):
            return self._handle_invalid_action()

        return self._perform_move(action)

    def reset(self):
        self.__init__()
        return self._get_state()

    def get_action_space(self):

        gotis = (
            self.all_gotis[0].gotis
            if self.player_turn == 0
            else self.all_gotis[1].gotis
        )

        action_space = [
            (dice_index, goti_index)
            for dice_index, dice in enumerate(self.roll)
            for goti_index, goti in enumerate(gotis)
            if self._is_valid_move(goti, dice)
        ]

        return action_space

    def check_win(self, gotis):
        gotis = gotis.gotis
        return all(goti.position == 56 for goti in gotis)

    def _perform_move(self, action):
        dice_index, goti_number = action
        dice = self.roll[dice_index]
        self.roll.pop(dice_index)

        self.all_gotis[self.player_turn].move_goti(goti_number, dice)

        current_goti = self.all_gotis[self.player_turn].gotis[goti_number]
        current_goti_position_opponent_view = (
            current_goti.convert_into_opponent_position()
        )

        if self.all_gotis[not self.player_turn].kill_goti(
            current_goti_position_opponent_view
        ):
            return self._get_extra_turn()

        if self.check_win(self.all_gotis[self.player_turn]):
            return self._handle_win()

        if current_goti.position == DESTINATION:
            return self._get_extra_turn()

        if len(self.roll) >= 1:
            return self._get_state()

        return self._change_turn()

    def _no_action_possible(self, action):
        return action is None and not self.get_action_space()


    def _is_valid_move(self, goti, dice):
        if goti.position == STARTING:
            return dice == 6
        return goti.position + dice <= DESTINATION

    def _is_valid_input(self, action):
        return action in self.get_action_space()


    def _change_turn(self):
        self.player_turn = not self.player_turn
        self.roll = self.dice.roll()
        return self._get_state()

    def _handle_invalid_action(self):
        self.terminated = True
        self.player_turn = not self.player_turn
        return self._get_state()

    def _handle_win(self):
        self.terminated = True
        return self._get_state()

    def _get_extra_turn(self):
        new_roll = self.dice.roll()
        self.roll += new_roll
        return self._get_state()

    def _get_state(self):
        return (
            self.all_gotis[0],
            self.all_gotis[1],
            self.roll,
            self.terminated,
            self.player_turn,
        )


class Gotis:
    def __init__(self, color: str):
        self.color = color.capitalize()
        self.gotis = [Goti() for _ in range(4)]

    def __repr__(self):
        goti_positions = "\n".join(
            [f" Goti {i + 1}: {goti.position}" for i, goti in enumerate(self.gotis)]
        )
        return f"{self.color} Gotis' Distance from starting point:\n{goti_positions}"

    def move_goti(self, goti_number, dice):
        self.gotis[goti_number].move(dice)

    def kill_goti(self, position):
        for i in range(4):
            if self.gotis[i].position == position:
                if self.gotis[i].kill_goti():
                    return True
        return False


class Goti:
    def __init__(self, position=STARTING):
        self.position = position
        assert STARTING <= self.position <= DESTINATION

    def __repr__(self):
        return f"Goti's Distance from starting point: {self.position}"

    def move(self, dice):
        if self.position == STARTING:
            if dice == 6:
                self.position = 0
            return

        if self.position + dice <= DESTINATION:
            self.position += dice

    def convert_into_opponent_position(self):
        if STARTING >= self.position or self.position > 50 or self.position == 25:
            return -2  # Position cannot be converted

        if self.position <= 24:
            return self.position + 26

        return self.position - 26

    def kill_goti(self):
        if self.position not in SAFE_SQUARES:
            self.position = -1
            return True
        return False


class Dice:
    def roll(self):

        rolls = []

        for _ in range(3):
            roll = self.simulate_one_dice_roll()
            rolls.append(roll)

            if roll != 6:
                break
            elif len(rolls) == 3 and roll == 6:
                return []

        return rolls

    def simulate_one_dice_roll(self):
        return random.randint(1, 6)


## Actor-Critic

In [4]:
class ActorCritic(torch.nn.Module):
    def __init__(self, actor_input_dim,critic_input_dim,actor_output_dim,critic_output_dim):
        super().__init__()

        self.actor= torch.nn.Sequential(
            torch.nn.Linear(actor_input_dim,256),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(256,256),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(256,actor_output_dim),

        )

        self.critic= torch.nn.Sequential(
            torch.nn.Linear(critic_input_dim,256),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(256,256),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(256,critic_output_dim),

        )

        self.optimizer=torch.optim.Adam(self.parameters(),lr=0.0001,betas=(0.9,0.999),eps=1e-8,weight_decay=0.01)
        self.scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,30000,1e-6)

    def forward_actor(self,state,device=None):

        if not torch.is_tensor(state):
            state = torch.tensor(state, dtype=torch.float32)

        if device:
            state=state.to(device)
        else:
            state=state.to(DEVICE)

        actions_prob = self.actor(state)
        return actions_prob

    def forward_critic(self,state,device=None):

        if not torch.is_tensor(state):
            state = torch.tensor(state, dtype=torch.float32)
        if device:
            state=state.to(device)
        else:
            state=state.to(DEVICE)
        value = self.critic(state)
        return value




## Heuristic agent

In [None]:
class HeuristicAgent:
    
    def __init__(self, rng_seed=None):
        if rng_seed is not None:
            random.seed(rng_seed)

    def score_action(self, env, action):
        sim = copy.deepcopy(env)
        next_state = sim.step(action)
        reward = 0.0
        if next_state[3] and next_state[4] == env.player_turn:
            reward += 100.0
            return reward
        opp_index = 1 - env.player_turn
        before_kills = sum(1 for g in env.all_gotis[opp_index].gotis if g.position == -1)
        after_kills = sum(1 for g in next_state[opp_index].gotis if g.position == -1)
        if after_kills > before_kills:
            reward += 50.0

        own_index = env.player_turn
        for i in range(4):
            before = env.all_gotis[own_index].gotis[i].position
            after = next_state[own_index].gotis[i].position
            if before == -1 and after == -1:
                continue
            if before ==  STARTING and after == 0:
                reward += 10.0

        progress = 0
        for i in range(4):
            pos_after = next_state[own_index].gotis[i].position
            if pos_after == STARTING:
                continue
            if pos_after == -1:
                continue
            if pos_after == DESTINATION:
                reward += 30.0
            if pos_after in SAFE_SQUARES:
                reward += 8.0
            before_pos = env.all_gotis[own_index].gotis[i].position
            delta = 0
            if before_pos != STARTING and pos_after != STARTING and pos_after >= before_pos:
                delta = pos_after - before_pos
            reward += 0.5 * delta
            if pos_after > progress:
                progress = pos_after

        reward += random.random() * 0.01
        return reward

    def get_action(self, env):
        action_space = env.get_action_space()
        if not action_space:
            return None
        # score each action and choose max
        scored = [(self.score_action(env, a), a) for a in action_space]
        scored.sort(key=lambda x: x[0], reverse=True)
        return scored[0][1]




## Policies

In [None]:

def get_win_percentages(n, policy1, policy2):

    env = Ludo()
    wins = [0, 0]
    policies = [policy1, policy2]

    for i in range(2):
        for _ in tqdm(range(n // 2)):

            state = env.reset()
            terminated = False
            player_turn = 0

            while not terminated:

                action_space = env.get_action_space()
                action = policies[player_turn].get_action(state, action_space)

                state = env.step(action)
                terminated, player_turn = state[3], state[4]

            wins[player_turn - i] += 1

        policies[0], policies[1] = policies[1], policies[0]

    win_percentages = [(win / n) * 100 for win in wins]

    return win_percentages


class Policy_Random:
    def get_action(self, state, action_space):
        if action_space:
            return random.choice(action_space)
        return None


class testpolicy():
    def __init__(self):

        self.q_net=Qnet(8,1).to(DEVICE)

        best_weight=torch.load(r"best.pth")
        self.q_net.load_state_dict(best_weight['model_state_dict'])

    def get_action(self,state,action_space):
        if len(action_space)==0:
            return None

        state_action_feats,_=self.get_state_values(state,action_space)

        prob=self.actor_critic.forward_actor(state)

        prob=torch.softmax(prob,dim=0)

        prob=prob.squeeze(-1)

        action_idx=torch.argmax(prob).item()
        action = action_space[action_idx]


        return action

    def get_state_action_features(self,state,action):

        dice_index,goti_index=action

        red_gotis,yellow_gotis,dice_roll,terminated,player_id=state

        if int(player_id)==0:
            my_goti=red_gotis.gotis[goti_index]
            dushman_goti=yellow_gotis.gotis
        else:
            my_goti=yellow_gotis.gotis[goti_index]
            dushman_goti=red_gotis.gotis

        dice_value=dice_roll[dice_index]

        my_goti_pos=my_goti.position
        old_goti_pos=my_goti_pos

        if my_goti_pos==-1 and dice_value==6:
            my_goti_pos=0
        else:
            my_goti_pos=my_goti_pos+dice_value if my_goti_pos+dice_value<=DESTINATION else my_goti_pos


        # is in danger
        is_in_danger=0
        for dushman in dushman_goti:
            if dushman.position != -1:
                if 1<=my_goti_pos-dushman.position<=6:
                    is_in_danger=1
                    break


        # can kill
        can_kill=0
        for dushman in dushman_goti:
            if dushman.position!=-1 and my_goti_pos+dice_value == dushman.position:
                can_kill=1
                break

        # is safe
        is_safe=0
        if my_goti_pos in SAFE_SQUARES:
            is_safe=1


        # distance from home
        dist_home=1
        if my_goti_pos!=-1:
            dist_home=(DESTINATION-my_goti_pos)/DESTINATION



        # is on home path
        is_home_path=0
        if (DESTINATION-my_goti_pos) <=5:
            is_home_path=1


        # can enter board
        can_enter=0
        if old_goti_pos == -1 and dice_value==6:
            can_enter=1


        # progress
        progress=0
        if my_goti_pos !=-1:
            progress=(DESTINATION-my_goti_pos)/DESTINATION


        return [is_in_danger,can_kill,is_safe,dist_home,is_home_path,can_enter,progress,dice_value/6]

    def get_state_features(self,gotis):

        g_home,g_goal,g_safe,min_dist,max_dist=0,0,0,float('inf'),float('-inf')
        for g in gotis.gotis:
            if g.position == -1:
                max_dist=0
                min_dist=0
            if g.position == DESTINATION:
                min_dist=0
                max_dist=0

            if g.position == -1:
                g_home+=1

            if g.position in SAFE_SQUARES:
                g_safe+=1

            if g.position == DESTINATION:
                g_goal+=1

            if g.position != -1 and g.position not in SAFE_SQUARES:
                dist=(DESTINATION-g.position)/56.0

                if dist < min_dist:
                    min_dist=dist

                elif dist > max_dist:
                    max_dist=dist


        return g_home,g_goal,g_safe,min_dist,max_dist

    def get_state_values(self,state,action_space):

        gotis_red,gotis_yellow,dice_roll,terminated,player_turn=state

        r_home,r_goal,r_safe,r_min_dist,r_max_dist=self.get_state_features(gotis_red)
        y_home,y_goal,y_safe,y_min_dist,y_max_dist=self.get_state_features(gotis_yellow)

        r_active=4-(r_home+r_goal)
        y_active=4-(y_home+y_goal)

        # 1. whether any 6 exists in the roll
        has_six = 1 if 6 in dice_roll else 0

        # 2. number of sixes in the roll
        num_sixes = dice_roll.count(6)

      

        state_key=[r_home,r_goal,r_active,r_safe,y_home,y_goal,y_active,y_safe,has_six,num_sixes,sum(dice_roll)]

        state_key=np.array(state_key)

        state_key=(state_key-np.mean(state_key)/np.std(state_key))
        state_key=(state_key-np.min(state_key))/(np.max(state_key)-np.min(state_key))


        features=[]
        for a in action_space:
            features.append(self.get_state_action_features(state,a))

        features=np.array(features)

        features=(features-np.mean(features)/np.std(features))
        features=(features-np.min(features))/(np.max(features)-np.min(features))


        return features,state_key


class policy():
    def __init__(self,epsilon,alpha,gamma):
        self.env=Ludo()
        self.epsilon=epsilon
        self.alpha=alpha
        self.gamma=gamma
        self.q_table=dict()
        self.Model=dict()
        self.current_state=0
        self.actor_critic=ActorCritic(8,11,1,1)
        

    def update_q_table(self,state,action_space):
        if state not in self.q_table:
            self.q_table[state]={}
            for action in action_space:
                if action not in self.q_table[state]:
                    self.q_table[state][action]=0.0

    def save_q_table(self,name):
        np.save(name,self.q_table)

    def get_action_eval(self,state,action_space):

        if len(action_space)==0:
            return None

        state_action_feats,_=self.get_state_values(state,action_space)
        
        self.actor_critic.eval()
        with torch.no_grad():
               prob=self.actor_critic.forward_actor(state_action_feats)

               prob=prob.squeeze(-1)

        prob=torch.softmax(prob,dim=0)

        action_idx = torch.argmax(prob).item()
        action = action_space[action_idx]

        return action

    def get_state_action_features(self,state,action):

        dice_index,goti_index=action

        red_gotis,yellow_gotis,dice_roll,terminated,player_id=state
        
        my_goti=red_gotis.gotis[goti_index]
        dushman_goti=yellow_gotis.gotis
        
        dice_value=dice_roll[dice_index]

        my_goti_pos=my_goti.position
        old_goti_pos=my_goti_pos

        if my_goti_pos==-1 and dice_value==6:
            my_goti_pos=0
        else:
            my_goti_pos=my_goti_pos+dice_value if my_goti_pos+dice_value<=DESTINATION else my_goti_pos


        # is in danger
        is_in_danger=0
        for dushman in dushman_goti:
            if dushman.position != -1:
                if 1<=my_goti_pos-dushman.position<=6:
                    is_in_danger=1
                    break


        # can kill
        can_kill=0
        for dushman in dushman_goti:
            if dushman.position!=-1 and my_goti_pos+dice_value == dushman.position:
                can_kill=1
                break

        # is safe
        is_safe=0
        if my_goti_pos in SAFE_SQUARES:
            is_safe=1


        # distance from home
        dist_home=1
        if my_goti_pos!=-1:
            dist_home=(DESTINATION-my_goti_pos)/DESTINATION



        # is on home path
        is_home_path=0
        if (DESTINATION-my_goti_pos) <=5:
            is_home_path=1


        # can enter board
        can_enter=0
        if old_goti_pos == -1 and dice_value==6:
            can_enter=1


        # progress
        progress=0
        if my_goti_pos !=-1:
            progress=(DESTINATION-my_goti_pos)/DESTINATION


        return [is_in_danger,can_kill,is_safe,dist_home,is_home_path,can_enter,progress,dice_value/6]

    def get_state_features(self,gotis):

        g_home,g_goal,g_safe,min_dist,max_dist=0,0,0,float('inf'),float('-inf')
        for g in gotis.gotis:
            if g.position == -1:
                max_dist=0
                min_dist=0
            if g.position == DESTINATION:
                min_dist=0
                max_dist=0

            if g.position == -1:
                g_home+=1

            if g.position in SAFE_SQUARES:
                g_safe+=1

            if g.position == DESTINATION:
                g_goal+=1

            if g.position != -1 and g.position not in SAFE_SQUARES:
                dist=(DESTINATION-g.position)/56.0

                if dist < min_dist:
                    min_dist=dist

                elif dist > max_dist:
                    max_dist=dist

        return g_home,g_goal,g_safe,min_dist,max_dist

    def get_state_values(self,state,action_space):

        gotis_red,gotis_yellow,dice_roll,terminated,player_turn=state

        r_home,r_goal,r_safe,r_min_dist,r_max_dist=self.get_state_features(gotis_red)
        y_home,y_goal,y_safe,y_min_dist,y_max_dist=self.get_state_features(gotis_yellow)

        r_active=4-(r_home+r_goal)
        y_active=4-(y_home+y_goal)

        # 1. whether any 6 exists in the roll
        has_six = 1 if 6 in dice_roll else 0

        # 2. number of sixes in the roll
        num_sixes = dice_roll.count(6)


        state_key=[r_home,r_goal,r_active,r_safe,y_home,y_goal,y_active,y_safe,has_six,num_sixes,sum(dice_roll)]

        state_key=np.array(state_key)

        state_key=(state_key-np.mean(state_key)/np.std(state_key))
        state_key=(state_key-np.min(state_key))/(np.max(state_key)-np.min(state_key))


        features=[]
        for a in action_space:
            features.append(self.get_state_action_features(state,a))

        features=np.array(features)

        features=(features-np.mean(features)/np.std(features))
        features=(features-np.min(features))/(np.max(features)-np.min(features))


        return features,state_key

    def get_action(self,state,action_space):

        if len(action_space)==0:
            return None

        r=np.random.rand()
        if r<self.epsilon:
            action=random.choice(action_space)
        else:

            prob=self.actor_critic.forward_actor(state)

            prob=torch.softmax(prob,dim=0)

            prob=prob.squeeze(-1)

            action_idx=torch.argmax(prob).item()
            action = action_space[action_idx]

        return action

    def get_reward(self, current_state, next_state,player_id):

        prev_red, prev_yellow = current_state[0], current_state[1]
        curr_red, curr_yellow = next_state[0], next_state[1]


        my_goti=curr_red
        my_goti_prev=prev_red.gotis
        my_goti_curr=curr_red.gotis
        dushman_goti_prev=prev_yellow.gotis
        dushman_goti_curr=curr_yellow.gotis

        
        reward = 0

        my_prev_goti_p = [g.position for g in my_goti_prev]
        my_curr_goti_p = [g.position for g in my_goti_curr]
        dushman_prev_p = [g.position for g in dushman_goti_prev]
        dushman_curr_p = [g.position for g in dushman_goti_curr]

        # move forward
        for i in range(4):
            if my_prev_goti_p[i] != -1 and my_curr_goti_p[i] != -1 and my_curr_goti_p[i]!=DESTINATION and my_prev_goti_p[i]!=my_curr_goti_p[i]:
                reward += 0.01 

        # get out of base
        for i in range(4):
            if my_prev_goti_p[i] == -1 and my_curr_goti_p[i] >= 0:
                reward += 0.1

        # land on safe square
        for i in range(4):
              if my_curr_goti_p[i] in SAFE_SQUARES and my_prev_goti_p[i] not in SAFE_SQUARES:
                  reward += 0.1

        # kill opponent
        for r in range(4):
              for y in range(4):
                  if my_curr_goti_p[r] == dushman_prev_p[y] and dushman_prev_p[y] not in SAFE_SQUARES:
                      reward += 0.3

        # if own guiti captured
        for r in range(4):
            for y in range(4):
                if dushman_curr_p[y] == my_prev_goti_p[r] and my_prev_goti_p[r] not in SAFE_SQUARES:
                    reward -= 0.4

        # get to destination
        for i in range(4):
            if my_curr_goti_p[i] == DESTINATION and my_prev_goti_p[i] != DESTINATION:
                reward += 0.5

        if next_state[3]:  # terminated
            if self.env.check_win(my_goti):
                reward += 1.0
            else:
                reward -= 1.0


        return reward

    
    def actor_critic(self,episodes,_lambda):
        best_res=0
        self.actor_critic.to(DEVICE)
        opp=HeuristicAgent()
        for episode in range(episodes):
            self.current_state=self.env.reset()
            terminated=self.current_state[3]
            running_loss=0.0
            player_turn=self.current_state[4]

            states_buf = []
            action_feats_buf = []
            actions_buf = []
            rewards_buf = []
            values_buf = []
            masks_buf = []

            total_rewards=0.0

            while not terminated:
                action_space=self.env.get_action_space()


                if len(action_space)==0:
                    next_state=self.env.step(None)
                    terminated=next_state[3]
                    self.current_state=next_state
                    player_turn=self.current_state[4]
                    continue


                if int(player_turn)==1:
                    action=opp.get_action(self.env)
                    next_state = self.env.step(action)
                    terminated = next_state[3]
                    player_turn=next_state[4]
                    self.current_state = next_state



                else:


                    state_action_feats,state_feats=self.get_state_values(self.current_state,action_space)


                    action=self.get_action(state_action_feats,action_space)


                    # critic
                    value=self.actor_critic.forward_critic(state_feats)
                    value=value.squeeze()


                    next_state=self.env.step(action)

                    reward=self.get_reward(self.current_state,next_state,player_turn)

                    terminated=next_state[3]
                    player_turn=next_state[4]
                    self.current_state=next_state

                    states_buf.append(state_feats)
                    action_feats_buf.append(state_action_feats)
                    actions_buf.append(action_space.index(action))
                    rewards_buf.append(reward)
                    values_buf.append(value)
                    masks_buf.append(0 if terminated else 1)

                    total_rewards+=reward

                    if len(rewards_buf) >= _lambda or terminated:
                        next_action_space=self.env.get_action_space()

                        if terminated or len(next_action_space)==0:
                            next_value=0.0

                        else:
                            next_state_act_feats,next_state_features=self.get_state_values(next_state,next_action_space)

                            next_value=self.q_net.forward()

                            next_value=next_value.squeeze(-1)

                        returns=next_value
                        for r, m in zip(reversed(rewards_buf), reversed(masks_buf)):

                            returns = r + (self.gamma) * returns * m

                        value = values_buf[0]
                        advantage = returns - value

                        # actor loss
                        probs = self.actor_critic.forward_actor(action_feats_buf[0])
                        probs=torch.softmax(probs,dim=0)

                        act_idx = actions_buf[0]

                        actor_loss = -torch.log(probs[act_idx] + 1e-8) * advantage

                        # critic loss
                        critic_loss = advantage**2

                        loss = actor_loss + critic_loss


                        self.actor_critic.optimizer.zero_grad()
                        loss.backward()
                        self.actor_critic.optimizer.step()

                        states_buf.clear()
                        action_feats_buf.clear()
                        actions_buf.clear()
                        rewards_buf.clear()
                        values_buf.clear()
                        masks_buf.clear()

                        running_loss+=loss.item()
                
            self.actor_critic.scheduler.step()

            self.epsilon = max(self.epsilon * 0.99995, 0.05)

            with open("episode_rewards.log", "a") as f:
                f.write(f"{episode},{total_rewards}\n")

            if episode % 500==0:
                print(f"episodes: [{episode}/{episodes}] , epsilon: [{self.epsilon}], lr: [{self.actor_critic.scheduler.get_last_lr()}] , loss(actor_critic): [{running_loss/episodes}]")

            if episode==990:
                torch.save({"model_state_dict":self.actor_critic.state_dict()},'best1.pth')

            if episode % 1000==0 and episode>0:

                win_percentage=get_win_percentages(1000,testpolicy(),Policy_Random())
                print(win_percentage)

                if win_percentage[0]>=best_res:
                    best_res=win_percentage[0]
                    torch.save({"model_state_dict":self.actor_critic.state_dict()},'best1.pth')
                    print("best model saved")

In [None]:
env=policy(0.1,0.0001,0.95)
env.Q_net_app(30000,5)