In [1]:
import json
import matplotlib.pyplot as plt
import numpy as np

from env.chooseenv import make
from run_utils import get_players_and_action_space_list, run_game

In [2]:
def convert_act(joint_act_idx):
    joint_act = [[[0] * 4 for _ in range(1)] for _ in range(2)]
    for i in range(2):
        joint_act[i][0][joint_act_idx[i]] = 1
    return joint_act

In [4]:
game_type = 'snakes_1v1'
game = make(game_type)
game.max_step = 50
log_mode = True
policy_list = ['random'] * len(game.agent_nums)
multi_part_agent_ids, action_spaces = get_players_and_action_space_list(game)
run_game(game, game_type, multi_part_agent_ids, action_spaces, policy_list, log_mode, plot_game=False)

[[[0, 0, 0, 1]], [[1, 0, 0, 0]]]
[[[0, 0, 1, 0]], [[1, 0, 0, 0]]]
[[[1, 0, 0, 0]], [[0, 0, 1, 0]]]
[[[0, 1, 0, 0]], [[0, 1, 0, 0]]]
[[[1, 0, 0, 0]], [[0, 0, 1, 0]]]
[[[0, 0, 0, 1]], [[0, 0, 1, 0]]]
[[[0, 1, 0, 0]], [[1, 0, 0, 0]]]
[[[0, 0, 1, 0]], [[0, 0, 0, 1]]]
[[[1, 0, 0, 0]], [[0, 0, 1, 0]]]
Step10
[[[1, 0, 0, 0]], [[0, 0, 1, 0]]]
[[[0, 1, 0, 0]], [[0, 0, 1, 0]]]
[[[0, 0, 1, 0]], [[0, 1, 0, 0]]]
[[[0, 0, 0, 1]], [[0, 0, 1, 0]]]
[[[1, 0, 0, 0]], [[0, 0, 1, 0]]]
[[[0, 0, 1, 0]], [[1, 0, 0, 0]]]
[[[0, 1, 0, 0]], [[0, 1, 0, 0]]]
[[[0, 1, 0, 0]], [[0, 1, 0, 0]]]
[[[0, 1, 0, 0]], [[0, 0, 1, 0]]]
[[[0, 0, 0, 1]], [[0, 1, 0, 0]]]
Step20
[[[0, 1, 0, 0]], [[0, 1, 0, 0]]]
[[[0, 0, 0, 1]], [[0, 0, 1, 0]]]
[[[0, 1, 0, 0]], [[0, 0, 0, 1]]]
[[[1, 0, 0, 0]], [[0, 0, 1, 0]]]
[[[0, 1, 0, 0]], [[0, 0, 0, 1]]]
[[[0, 0, 1, 0]], [[1, 0, 0, 0]]]
[[[0, 0, 1, 0]], [[0, 0, 1, 0]]]
[[[0, 1, 0, 0]], [[0, 0, 1, 0]]]
[[[0, 1, 0, 0]], [[0, 0, 0, 1]]]
[[[0, 0, 1, 0]], [[0, 0, 0, 1]]]
Step30
[[[1, 0, 0, 0]], [[0, 

In [5]:
import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical

In [6]:
device = torch.device('cpu')
#if(torch.cuda.is_available()): 
#    device = torch.device('cuda:0') 
#    torch.cuda.empty_cache()
#    print("Device set to : " + str(torch.cuda.get_device_name(device)))
#else:
#    print("Device set to : cpu")

In [7]:
class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.state_values = []
        self.is_terminals = []
    
    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.state_values[:]
        del self.is_terminals[:]

In [8]:
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, has_continuous_action_space, action_std_init):
        super(ActorCritic, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space
        
        if has_continuous_action_space:
            self.action_dim = action_dim
            self.action_var = torch.full((action_dim,), action_std_init * action_std_init).to(device)
        # actor
        if has_continuous_action_space :
            self.actor = nn.Sequential(
                            nn.Linear(state_dim, 64),
                            nn.Tanh(),
                            nn.Linear(64, 64),
                            nn.Tanh(),
                            nn.Linear(64, action_dim),
                            nn.Tanh()
                        )
        else:
            self.actor = nn.Sequential(
                            nn.Linear(state_dim, 64),
                            nn.Tanh(),
                            nn.Linear(64, 64),
                            nn.Tanh(),
                            nn.Linear(64, action_dim),
                            nn.Softmax(dim=-1)
                        )
        # critic
        self.critic = nn.Sequential(
                        nn.Linear(state_dim, 64),
                        nn.Tanh(),
                        nn.Linear(64, 64),
                        nn.Tanh(),
                        nn.Linear(64, 1)
                    )
        
    def set_action_std(self, new_action_std):
        if self.has_continuous_action_space:
            self.action_var = torch.full((self.action_dim,), new_action_std * new_action_std).to(device)
        else:
            print("--------------------------------------------------------------------------------------------")
            print("WARNING : Calling ActorCritic::set_action_std() on discrete action space policy")
            print("--------------------------------------------------------------------------------------------")

    def forward(self):
        raise NotImplementedError
    
    def act(self, state):

        if self.has_continuous_action_space:
            action_mean = self.actor(state)
            cov_mat = torch.diag(self.action_var).unsqueeze(dim=0)
            dist = MultivariateNormal(action_mean, cov_mat)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)

        action = dist.sample()
        action_logprob = dist.log_prob(action)
        state_val = self.critic(state)

        return action.detach(), action_logprob.detach(), state_val.detach()
    
    def evaluate(self, state, action):

        if self.has_continuous_action_space:
            action_mean = self.actor(state)
            
            action_var = self.action_var.expand_as(action_mean)
            cov_mat = torch.diag_embed(action_var).to(device)
            dist = MultivariateNormal(action_mean, cov_mat)
            
            # For Single Action Environments.
            if self.action_dim == 1:
                action = action.reshape(-1, self.action_dim)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)
        
        return action_logprobs, state_values, dist_entropy

In [9]:
class PPO:
    def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space, action_std_init=0.6):

        self.has_continuous_action_space = has_continuous_action_space

        if has_continuous_action_space:
            self.action_std = action_std_init

        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        
        self.buffer = RolloutBuffer()

        self.policy = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)
        self.optimizer = torch.optim.Adam([
                        {'params': self.policy.actor.parameters(), 'lr': lr_actor},
                        {'params': self.policy.critic.parameters(), 'lr': lr_critic}
                    ])

        self.policy_old = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        
        self.MseLoss = nn.MSELoss()

    def set_action_std(self, new_action_std):
        if self.has_continuous_action_space:
            self.action_std = new_action_std
            self.policy.set_action_std(new_action_std)
            self.policy_old.set_action_std(new_action_std)
        else:
            print("--------------------------------------------------------------------------------------------")
            print("WARNING : Calling PPO::set_action_std() on discrete action space policy")
            print("--------------------------------------------------------------------------------------------")

    def decay_action_std(self, action_std_decay_rate, min_action_std):
        print("--------------------------------------------------------------------------------------------")
        if self.has_continuous_action_space:
            self.action_std = self.action_std - action_std_decay_rate
            self.action_std = round(self.action_std, 4)
            if (self.action_std <= min_action_std):
                self.action_std = min_action_std
                print("setting actor output action_std to min_action_std : ", self.action_std)
            else:
                print("setting actor output action_std to : ", self.action_std)
            self.set_action_std(self.action_std)

        else:
            print("WARNING : Calling PPO::decay_action_std() on discrete action space policy")
        print("--------------------------------------------------------------------------------------------")

    def select_action(self, state):

        if self.has_continuous_action_space:
            with torch.no_grad():
                state = torch.FloatTensor(state).to(device)
                action, action_logprob, state_val = self.policy_old.act(state)

            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.logprobs.append(action_logprob)
            self.buffer.state_values.append(state_val)

            return action.detach().cpu().numpy().flatten()
        else:
            with torch.no_grad():
                state = torch.FloatTensor(state).to(device)
                action, action_logprob, state_val = self.policy_old.act(state)
            
            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.logprobs.append(action_logprob)
            self.buffer.state_values.append(state_val)

            return action.item()

    def update(self):
        # Monte Carlo estimate of returns
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)
            
        # Normalizing the rewards
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

        # convert list to tensor
        old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device)
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device)
        old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device)
        old_state_values = torch.squeeze(torch.stack(self.buffer.state_values, dim=0)).detach().to(device)

        # calculate advantages
        advantages = rewards.detach() - old_state_values.detach()

        # Optimize policy for K epochs
        for _ in range(self.K_epochs):

            # Evaluating old actions and values
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)

            # match state_values tensor dimensions with rewards tensor
            state_values = torch.squeeze(state_values)
            
            # Finding the ratio (pi_theta / pi_theta__old)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # Finding Surrogate Loss  
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages

            # final loss of clipped objective PPO
            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy
            
            # take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()
            
        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # clear buffer
        self.buffer.clear()
    
    def save(self, checkpoint_path):
        torch.save(self.policy_old.state_dict(), checkpoint_path)
   
    def load(self, checkpoint_path):
        self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        
        

In [31]:
env = make(game_type)

# state space dimension
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

has_continuous_action_space = False
max_ep_len = 50
max_training_timesteps = 1000
print_freq = max_ep_len * 10 
log_freq = max_ep_len * 2
save_model_freq = int(1e5)

action_std = 0.6
action_std_decay_rate = 0.05 
min_action_std = 0.1 
action_std_decay_freq = int(2.5e5)

update_timestep = max_ep_len * 4
K_epochs = 80

eps_clip = 0.2
gamma = 0.99

lr_actor = 0.0003
lr_critic = 0.001

random_seed = 0 

In [32]:
ppo_agents = [
            PPO(
    state_dim, 
    action_dim, 
    lr_actor, 
    lr_critic, 
    gamma, 
    K_epochs, 
    eps_clip, 
    has_continuous_action_space, 
    action_std) for _ in range(2)
]

In [46]:
# printing and logging variables
print_running_rewards = [0,0]
print_running_episodes = 0

log_running_rewards = [0,0]
log_running_episodes = 0

time_step = 0
i_episode = 0

# training loop
while time_step <= max_training_timesteps:
    state = env.reset()
    current_ep_rewards = [0,0]

    for t in range(1, max_ep_len+1):

        # select action with policy
        actions = [ppo_agents[i].select_action(state) for i in range(2)]
        conv_acts = convert_act(actions)
        state, dict_obs, reward, done, _, _ = env.step(conv_acts)

        # saving reward and is_terminals
        for i in range(2):
            ppo_agents[i].buffer.rewards.append(reward[i])
            ppo_agents[i].buffer.is_terminals.append(done)
            current_ep_rewards[i] += reward[i]

        time_step +=1
        
        # update PPO agent
        if time_step % update_timestep == 0:
            for i in range(2):
                ppo_agents[i].update()

        # if continuous action space; then decay action std of ouput action distribution
        if has_continuous_action_space and time_step % action_std_decay_freq == 0:
            ppo_agent.decay_action_std(action_std_decay_rate, min_action_std)

        # log in logging file
        if time_step % log_freq == 0:
            log_avg_rewards = [0,0]

            for i in range(2):                
                # log average reward till last episode
                log_avg_rewards[i] = log_running_rewards[i] / log_running_episodes
                log_avg_rewards[i] = round(log_avg_rewards[i], 4)

                #log_f.write('{},{},{}\n'.format(i_episode, time_step, log_avg_reward))
                #log_f.flush()

            log_running_reward = 0
            log_running_episodes = 0

        # printing average reward
        if time_step % print_freq == 0:
            print_avg_rewards = [0,0]

            for i in range(2):
                # print average reward till last episode
                print_avg_rewards[i] = print_running_rewards[i] / print_running_episodes
                print_avg_rewards[i] = round(print_avg_rewards[i], 2)

            print("Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_rewards))

            print_running_reward = 0
            print_running_episodes = 0

    print_running_rewards = np.array(print_running_rewards) + np.array(current_ep_rewards)
    print_running_episodes += 1

    log_running_rewards =  np.array(log_running_rewards) + np.array(current_ep_rewards)
    log_running_episodes += 1

    i_episode += 1

Episode : 9 		 Timestep : 500 		 Average Reward : [1.22, 0.56]
Episode : 19 		 Timestep : 1000 		 Average Reward : [2.7, 1.2]
