In [2]:
import numpy as np
from tic_env import TictactoeEnv, OptimalPlayer
from utils import play_game, Metric
import matplotlib.pyplot as plt
import pandas as pd
from tqdm.notebook import tqdm_notebook
import itertools
import torch
import torch.nn as nn

# Learning from Expert

In [4]:
##################################### Helpers ###################################
class DQN(nn.Module):
    def __init__(self):
        super(DQN,self).__init__()

        self.fc=nn.Sequential(nn.Linear(3*3*2,128),nn.ReLU(),
                              nn.Linear(128,128),nn.ReLU(),
                              nn.Linear(128,128),nn.ReLU(),
                              nn.Linear(128,9))
    
    def forward(self,x):
        #x=x.view(-1,18)#x.sum(2).view(-1,9)
        x=self.fc(x)
        return x

class ReplayMemory(object):

    def __init__(self, capacity,batch_size):
        self.memory = deque([],maxlen=capacity)
        self.Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))
        self.device='cpu'
        self.batch_size=batch_size
        self.step = 0
        self.accumulated_loss = 0
        self.optimizer = None
        
    def push(self, *args):
        """Save a transition"""
        self.memory.append(self.Transition(*args))

    def sample(self):
        return random.sample(self.memory, self.batch_size)

    def __len__(self):
        return len(self.memory)

class deep_Q_player:
    def __init__(self):
        self.player=''
        
    def act(self,state,model:nn.Module,epsilon:float):
        
        with torch.no_grad():
            action_scores = model(state)
        
        action = action_scores.max(1)[1].item()
        
        if np.random.random() < epsilon:
            action = np.random.choice(range(9)).item()
        
        return action

def grid2tensor(grid:np.array,player:str):
    
    state=np.zeros((3,3,2))
    a = 2*(player=='X')-1 #  1 if player='X' and -1 otherwise
    
    grid1 = np.where(grid==a,1,0)
    grid2 = np.where(grid==-a,1,0)
    
    state[:,:,0]=grid1
    state[:,:,1]=grid2
    state = torch.tensor(state)
    
    flatten_arrays = (state[:,:,0].flatten(),state[:,:,1].flatten())
    
    return torch.cat(flatten_arrays).view(-1,18).float()

def test_policy(eps_optimalplayer,q_table=None,verbose=False,DQN_policy_net=None):
    
    env = TictactoeEnv() # environment
    
    if DQN_policy_net is None:
        assert q_table is not None, "Provide q_table"
        agent_player = agent(q_table) # agent    
        
    else:
        agent_player=deep_Q_player()
    
    #-- Holders 
    wins_count = dict()
    wins_count['optimal_player']=0
    wins_count['agent']=0
    wins_count['draw']=0 
    players = dict()
    players[None] = 'draw'
    turns = np.array(['X','O'])
    agent_symbol = None # 'X' or 'O'
    optimal_symbol = None
    num_illegal_steps = 0
    
    for episode in range(500):
        
        env.reset()
        np.random.seed(episode) 
        
        if episode < 250 :
            agent_symbol = turns[0]
            optimal_symbol = turns[1]
            
        else:
            agent_symbol = turns[1]
            optimal_symbol = turns[0]
        
        player_opt = OptimalPlayer(epsilon=eps_optimalplayer,player=optimal_symbol)
        players[optimal_symbol]=(player_opt,'optimal_player')
        players[agent_symbol]=(agent_player,'agent')
        
        for j in range(9):    
            
            #-- Get turn
            turn = env.current_player
            
            #-- observe grid
            grid,end,_ = env.observe() 
            
            #-- Play
            current_player, _ = players[turn]
            
            #-- Playing with DQN-agent
            if DQN_policy_net is not None and turn==agent_symbol:
                state = grid2tensor(grid,agent_symbol)
                move = current_player.act(state,DQN_policy_net,0)
                
                try:
                    env.step(move,print_grid=False)
                    
                except ValueError:
                    env.end = True
                    env.winner = optimal_symbol
                    num_illegal_steps += 1
   
            else:
                move = current_player.act(grid)  
                env.step(move,print_grid=False)
        
            #-- Chek that the game has finished
            if env.end :
                if env.winner is not None :
                    _,winner = players[env.winner]
                    wins_count[winner] = wins_count[winner] + 1
                else :
                    wins_count['draw'] = wins_count['draw'] + 1
                
                break
    
    M = (wins_count['agent']-wins_count['optimal_player'])/500
    
    if verbose :
        string ="M_rand"
        if eps_optimalplayer < 1:
            string = "M_opt"    
        print(string+" : ",M)
        print(wins_count,'\n')
        print('Number of illegal steps',num_illegal_steps)

    
    return M,num_illegal_steps

def update_policy(policy_net:nn.Module,
                  target_net:nn.Module,
                  memory:ReplayMemory,
                  criterion=nn.SmoothL1Loss(),# F.huber_loss,
                  gamma=0.99,
                  online_update=False,online_state_action_reward=(None,None,None,None)):
    
    
    if online_update :
        #assert None not in online_state_action_reward,'provide these values.'
        
        #-- Compute Q values
        state,next_state,action,reward = online_state_action_reward
        state=state.to(memory.device)
        state_action_values = policy_net(state)[:,action] # take Q(state,action)
        
        next_state_values=torch.tensor([0.0])
        if next_state is not None:
            next_state = next_state.to(memory.device)
            next_state_values = target_net(next_state).max(1)[0].detach() # take max Q(state',action')
        
        #-- Compute target
        target = reward + gamma*next_state_values
        
        #-- Update gradients
        memory.optimizer.zero_grad()
        loss = criterion(state_action_values,target)
        loss.backward()
        memory.optimizer.step
        memory.step += 1

        #-- Log
        #wandb.log({'loss':loss.item(),'reward':reward,'Step':memory.step})

    else:
        if len(memory) < memory.batch_size:
            return False
        
        #-- Sample Transitions
        transitions = memory.sample()
        
        #-- GetTransition of batch-arrays
        batch = memory.Transition(*zip(*transitions))
        
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=memory.device, dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
        
        state_batch = torch.cat(batch.state).to(memory.device)
        action_batch = torch.cat(batch.action).to(memory.device)
        reward_batch = torch.cat(batch.reward).to(memory.device)

        #-- Compute Q values
        memory.optimizer.zero_grad()
        state_action_values = policy_net(state_batch).gather(1, action_batch.unsqueeze(1))
        
        #-- Compute target
        next_state_values = torch.zeros(memory.batch_size,device=memory.device)
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
        target = reward_batch + gamma*next_state_values
        
        #-- Update gradients
        loss = criterion(state_action_values,target.unsqueeze(1))#,reduction='mean')
        loss.backward()
        #for p in policy_net.parameters():
        #    p.grad.data.clamp_(-1,1)
        memory.optimizer.step()
        memory.step += 1

        #-- Log
        #wandb.log({'loss':loss,'mean_batch_reward':reward_batch.float().mean(),'Step':memory.step})
    
    memory.accumulated_loss += loss.item()
    return True # loss.item()

In [None]:
def deep_q_learning(epsilon,num_episodes:int,
                    env:TictactoeEnv,
                    path_save:str,
                    eps_opt=0.5,gamma=0.99,
                    render=False,test=False,online_update=False,wandb_tag="DQN"):
    #-- agent
    agent = deep_Q_player()
    
    #-- Initialize Q networks
    policy_net = DQN()
    target_net = DQN()
    target_net.load_state_dict(policy_net.state_dict())

    #-- Initialize hyperparameters
    batch_size=64
    gamma=0.99
    lr=5e-4
    memory = ReplayMemory(10000,batch_size)
    memory.optimizer = optim.Adam(policy_net.parameters(),lr=lr) # optimizer
    args={'gamma':gamma,'batch_size':batch_size,'replay_buffer':int(1e4),'lr':lr,'eps_opt':eps_opt,'online_update':online_update}
    policy_net.to(memory.device)
    target_net.to(memory.device)

    #-- wandb init
    #wandb.init(tags=[wandb_tag],project='ANN',entity='fadelmamar', 
    #           name='DQN_learnFromXpert', config=args)
    
    #-- Holder 
    wins_count = dict()
    wins_count['optimal_player']=0
    wins_count['agent']=0
    wins_count['draw']=0
    players = dict()
    M_opts = list()
    M_rands = list()
    accumulate_reward = 0
    agent_mean_rewards = [0]*int(num_episodes//250)
    num_illegal_actions = 0
    turns = np.array(['X','O'])
    
    for episode in range(1,num_episodes+1):
        
        #wandb.log({'episode':episode,'epsilon_greedy':epsilon(episode)})
        
        if episode % 250 == 0 :
            agent_mean_rewards[int(episode//250)-1] = accumulate_reward/250
            #wandb.log({'mean_reward':accumulate_reward/250,'mean_loss':memory.accumulated_loss/250})
            accumulate_reward = 0 # reset
            memory.accumulated_loss = 0 # reset
            
            if test:
                M_opt,num_illegal_opt = test_policy(0,q_table=None,DQN_policy_net=policy_net,verbose=False)
                M_rand,num_illegal_rand = test_policy(1,q_table=None,DQN_policy_net=policy_net,verbose=False)
                M_opts.append(M_opt)
                M_rands.append(M_rand)
                #wandb.log({'M_opt':M_opt,'M_rand':M_rand,
                #           'num_illegal_opt':num_illegal_opt,
                #           'num_illegal_rand':num_illegal_rand})
            
        env.reset()
        #-- Permuting player every 2 games
        if episode % 2 == 0 :
            turns[0] = 'X'
            turns[1] = 'O'
        else:
            turns[0] = 'O'
            turns[1] = 'X'
        
        player_opt = OptimalPlayer(epsilon=eps_opt,player=turns[0])
        agent_learner = turns[1]
        players[turns[0]]='optimal_player'
        players[turns[1]]='agent'
        
        #--
        current_state = None
        A = None # action
        
        for j in range(9):
  
            #-- Agent plays 
            if env.current_player == turns[1] :
                
                current_state = grid2tensor(env.observe()[0],agent_learner)                          
                A = agent.act(current_state, policy_net,epsilon(episode))  #----- Choose action A with epsilon greedy
                #wandb.log({'action':A})  
                
                #----- Take action A
                try :
                    _,_,_ = env.step(A,print_grid=False)
                                        
                #----- End game when agent moves illegaly
                except ValueError :                                
                    num_illegal_actions += 1
                    #wandb.log({'num_illegal_actions':num_illegal_actions})
                    env.end = True #-- Terminating game
                    env.winner = turns[0] # optimal player
                    
            #-- Optimal player plays 
            if not env.end :
                
                grid,end,_ = env.observe() #-- observe grid
                move = player_opt.act(grid) #-- get move
                env.step(move,print_grid=False) # optimal player takes a move
  
                #-- Update agent and Replay buffer
                if current_state is not None :   
                    next_state = grid2tensor(env.observe()[0],agent_learner)    
                    agent_reward = env.reward(agent_learner)
                    
                    if not env.end : 
                        memory.push(current_state, torch.tensor([A]), next_state, torch.tensor([agent_reward]))
                        
                    if online_update:
                        update_policy(policy_net,target_net,memory,gamma=gamma,
                                      online_update=True,online_state_action_reward=(current_state,next_state,A,agent_reward))

            #-- Update policy offline if applicable
            if online_update == False :
                success = update_policy(policy_net,target_net, memory,gamma=gamma, online_update=False)

            #-- Chek that the game has finished
            if env.end :
                agent_reward = env.reward(agent_learner)
                memory.push(current_state, torch.tensor([A]), None, torch.tensor([agent_reward])) #-- Store in Replay buffer
                
                if online_update:
                    update_policy(policy_net,target_net,memory,gamma=gamma,
                                  online_update=True,online_state_action_reward=(current_state,None,A,agent_reward))
                #-- Logging
                if env.winner is not None :
                    winner = players[env.winner]
                    wins_count[winner] = wins_count[winner] + 1            
                else :
                    wins_count['draw'] = wins_count['draw'] + 1  
                accumulate_reward += agent_reward
                #wandb.log({'accumulated_reward':accumulate_reward,
                #           'reward':agent_reward})
                #wandb.log(wins_count)
                #-- Render
                if render : 
                    print(f"Episode {episode} ; Winner is {winner}.")
                    env.render()
                    
                break # stop for-loop
        
        #-- Log results            
        if episode % 5000 == 0 :
            print(f"\nEpisode : {episode}")
            print(wins_count)
            
        #-- Upddate target network
        if episode % 500 == 0:
            target_net.load_state_dict(deepcopy(policy_net.state_dict()))
            target_net.eval()
    
    wandb.finish()

    return wins_count,agent_mean_rewards,M_opts,M_rands


In [None]:
#-- Q.11 & Q.12
eps_1=lambda x : 0.3
if False:
    test=False
    for do in [False,True]:
        env = TictactoeEnv()
        wins_count,agent_mean_rewards,M_opts,M_rands = deep_q_learning(epsilon=eps_1,num_episodes=int(20e3),
                                                                          eps_opt=0.5,env=env,path_save=None,
                                                                          gamma=0.99,render=False,test=test,
                                                                          wandb_tag="V3",online_update=do)

In [None]:
#-- Q.13
eps_min=0.1
eps_max=0.8
if False :
    env = TictactoeEnv()
    test=True
    for N_star in [1,10e3,20e3,30e3,40e3]:
        print('-'*20,' N_star : ',N_star,'-'*20)
        eps_2=lambda x : max([eps_min,eps_max*(1-x/N_star)])
        wins_count,agent_mean_rewards,M_opts,M_rands = deep_q_learning(epsilon=eps_2,num_episodes=int(20e3),
                                                                          eps_opt=0.5,env=env,path_save=None,
                                                                          gamma=0.99,render=False,test=test,
                                                                          wandb_tag=f"V3--{int(N_star)}",online_update=False)

In [None]:
#- Q.14
eps_min=0.1
eps_max=0.8
if True :
    env = TictactoeEnv()
    test=True
    N_star=1 # best N_star from Q.13
    eps_2=lambda x : max([eps_min,eps_max*(1-x/N_star)])
    for eps_opt in [0,0.25,0.5,0.75,1]:
        print('-'*20,' eps_opt : ',eps_opt,'-'*20)
        wins_count,agent_mean_rewards,M_opts,M_rands = deep_q_learning(epsilon=eps_2,num_episodes=int(20e3),
                                                                          eps_opt=eps_opt,env=env,path_save=None,
                                                                          gamma=0.99,render=False,test=test,
                                                                          wandb_tag=f"V3--eps_opt:{eps_opt}",online_update=False)         

# Self-play

In [72]:
def Metric_Q(policy,Q_table,optimal=False):
    N_wins=0
    N_losses=0
    N=0
    Turns = np.array([['X','O']]*250+[['O','X']]*250)
    for i in range(500):
        np.random.seed()

        
        if optimal: 
            player_test = OptimalPlayer(epsilon=0., player=Turns[i,1])
        if not optimal:
            player_test = OptimalPlayer(epsilon=1., player=Turns[i,1])

        player_new = policy(player=Turns[i,0],epsilon=0)
        env=TictactoeEnv()
        while not env.end:
            if env.current_player == player_new.player:
                state=get_state(env.grid,player_new)
                move = player_new.act(state,env.grid,Q_table)       
            else:
                move = player_test.act(env.grid)

            if not isinstance(move,tuple): 
                    move=(int(move/3),move%3)
            env.step(move, print_grid=False)
                
        if env.winner==player_new.player:
            N_wins+=1
        if env.winner==player_test.player:
            N_losses+=1
        N+=1
        env.reset()               
    return (N_wins - N_losses)/N

class deep_Q_player:
    def __init__(self,player='X', epsilon=0):
        self.player=player
        self.epsilon=epsilon
    
        
    def act(self,state,model):
        
        action_scores = model(state)
        
        action = np.random.choice(np.where(action_scores == np.max(action_scores))[0])
        if np.random.random() < self.epsilon:
            action = np.random.choice(range(9))
        
        return action

def play_game_self_deepQ(env,model,p1,p2):
    R_t=[]
    while not env.end:
        if env.current_player == p1.player:
            state1=grid2tensor(env.grid,p1.player)
            action1=p1.act(state1,model)
            if env.check_valid():
                env.step((int(action1/3), action1 % 3), print_grid=False)
                reward1=env.reward(p1.player)
                new_state1=grid2tensor(env.grid,p1.player)
                R_t.append(state1,action1,reward1,new_state1)
            else:
                reward1=-1
                R_t.append(state1,action1,reward1,None)
                break

        if env.current_player == p2.player:
            state2=grid2tensor(env.grid,p2.player)
            action2=p2.act(state2,model)
            if env.check_valid():
                env.step((int(action2/3), action2 % 3), print_grid=False)
                reward2=env.reward(p2.player)
                new_state2=grid2tensor(env.grid,p2.player)
                R_t.append(state2,action2,reward2,new_state2)
            else:
                reward2=-1
                R_t.append(state2,action2,reward2,None)
                break
        
    return env, R_t

def Q_loss(r,max_new_Q_val,Q_val,gamma=.99):
        return .5(r + gamma*max_new_Q_val - Q_val).pow(2)

In [7]:
gamma=.99
buff_size=10000
batch_size=64
lr=5e-4

class DQN(nn.Module):
    def __init__(self):
        super(DQN,self).__init__()

        self.fc=nn.Sequential(
            nn.Linear(3*3*2,128),
            nn.ReLU(),
            nn.Linear(128,128),
            nn.ReLU(),
            nn.Linear(128,128),
            nn.ReLU(),
            nn.Linear(128,9)
        )
    
    def forward(self,x):
        x=x.view(18)
        x=self.fc(x)
        return x

        

In [63]:
def grid2tensor(grid,player):
    tens=torch.zeros(3,3,2)
    grid=torch.from_numpy(grid)
    if player=='X':
        player_index=torch.nonzero(grid==1,as_tuple=False)
        op_index=torch.nonzero(grid==-1,as_tuple=False)
    if player=='O':
        player_index=torch.nonzero(grid==-1,as_tuple=False)
        op_index=torch.nonzero(grid==1,as_tuple=False)
    tens[player_index[:,0],player_index[:,1],0]=1
    tens[op_index[:,0],op_index[:,1],1]=1
    return tens

In [65]:
model=DQN()
R=[]
target=model
p1=deep_Q_player('X')
p2=deep_Q_player('O')
#should we first play all games, then train? or do it at the same time?
for game in range(nb_games):
    env=TictactoeEnv()
    env, R_t=play_game_self_deepQ(env,model,p1,p2)
    R.extend(R_t)
    sample=sample_from(R)
    loss=Q_loss(rews,max_new_Q_vals,Q_vals)
    if game%target_update_step==0:
        target=model
    



tensor([[[0., 0.],
         [0., 1.],
         [0., 0.]],

        [[0., 0.],
         [1., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.],
         [0., 0.]]])

In [69]:
def m(a):
    a=0
a=1
print(a)
m(a)
print(a)

1
1
