In [1]:
from src.luxai_s3.wrappers import LuxAIS3GymEnv
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import math
from collections import deque
from torch.distributions import Categorical
from tqdm import tqdm

In [2]:
env = LuxAIS3GymEnv(numpy_output=True)
env.reset()

action = dict(
                player_0=np.random.randint(0,5,size=(env.env_params.max_units, 3)),
                player_1=np.random.randint(0,5,size=(env.env_params.max_units, 3))
            )
for i in range(10) :
    obs, reward, terminated, truncated, info = env.step(action)
print(reward)

{'player_0': array(0, dtype=int32), 'player_1': array(0, dtype=int32)}


In [3]:
def obs_to_state(obs:dict) -> torch.Tensor:
    list_state = []

    #Units
    list_state.append(torch.from_numpy(obs['units']['position'].astype(float)).flatten()) #position
    list_state.append(torch.from_numpy(obs['units']['energy'].astype(float)).flatten()) #energy
    list_state.append(torch.from_numpy(obs['units_mask'].astype(float)).flatten()) #unit_mask
    
    #Map
    list_state.append(torch.from_numpy(obs['sensor_mask'].astype(float)).flatten()) #sensor_mask
    list_state.append(torch.from_numpy(obs['map_features']['energy'].astype(float)).flatten()) #map_energy
    list_state.append(torch.from_numpy(obs['map_features']['tile_type'].astype(float)).flatten()) #map_tile_type

    list_state.append(torch.from_numpy(obs['relic_nodes'].astype(float)).flatten()) #relic_nodes
    list_state.append(torch.from_numpy(obs['relic_nodes_mask'].astype(float)).flatten()) #relic_nodes_mask

    #Game
    list_state.append(torch.from_numpy(obs['team_points'].astype(float)).flatten()) #team_points
    list_state.append(torch.from_numpy(obs['team_wins'].astype(float)).flatten()) #team_wins

    list_state.append(torch.from_numpy(obs['steps'].astype(float)).flatten()) #steps
    list_state.append(torch.from_numpy(obs['match_steps'].astype(float)).flatten()) #match_steps
    
    return torch.cat(list_state)

state = obs_to_state(obs['player_0'])
state = obs_to_state(obs['player_1'])

In [7]:
from kits.python.agent import Agent
from kits.python.lux.kit import from_json


env = LuxAIS3GymEnv(numpy_output=True)
agent_dict = dict()
agent_dict["player_0"] = Agent("player_0", dict(max_units=16))
agent_dict["player_1"] = Agent("player_1", dict(max_units=16))

states_ = []
for i in range(1) :

    obs,_ = env.reset()
    for step in range(505) :
        action = dict(
                    player_0=agent_dict["player_0"].act(step=step,obs=obs['player_0']),
                    player_1=agent_dict["player_1"].act(step=step,obs=obs['player_1']),
                )
        obs, reward, terminated, truncated, info = env.step(action)
        state_0 = obs_to_state(obs['player_0'])
        state_1 = obs_to_state(obs['player_1'])
        states_.append(state_0)
        states_.append(state_1)

states = torch.stack(states_,dim=0)
print(torch.mean(states,dim=0))
print(torch.std(states,dim=0))
torch.save(torch.mean(states,dim=0),"mean_state.pt")
torch.save(torch.std(states,dim=0),"std_state.pt")

KeyError: 'max_units'

In [121]:
class Policy(nn.Module) :
    def __init__(self,n_input,n_action,n_units,sap_range) :

        super(Policy,self).__init__()

        self.n_units = n_units
        self.n_action = n_action
        self.sap_range = sap_range

        self.inputs = nn.Linear(n_input,512,dtype=torch.double)

        self.hidden1 = nn.Linear(512,128,dtype=torch.double)
        self.hidden2 = nn.Linear(128,32,dtype=torch.double)

        self.actor_action = []
        self.actor_dx = []
        self.actor_dy = []

        self.critic_action = []
        self.critic_dx = []
        self.critic_dy = []

        for unit in range(self.n_units) :
            self.actor_action.append(nn.Linear(32,self.n_action,dtype=torch.double))
            self.actor_dx.append(nn.Linear(32,self.sap_range*2 + 1,dtype=torch.double))
            self.actor_dy.append(nn.Linear(32,self.sap_range*2 + 1,dtype=torch.double))

        self.critic = nn.Linear(32,1,dtype=torch.double)

    def forward(self,x) :

        x = F.relu(self.inputs(x))
        x = F.relu(self.hidden1(x))
        x = F.relu(self.hidden2(x))

        actor_action = torch.zeros(self.n_units,self.n_action)
        actor_dx = torch.zeros(self.n_units,self.sap_range*2 + 1)
        actor_dy = torch.zeros(self.n_units,self.sap_range*2 + 1)

        for unit in range(self.n_units) :
            actor_action[unit] = F.log_softmax(self.actor_action[unit](x),dim=-1)
            actor_dx[unit] = F.log_softmax(self.actor_dx[unit](x),dim=-1)
            actor_dy[unit] = F.log_softmax(self.actor_dy[unit](x),dim=-1)

        value = self.critic(x)

        return actor_action,actor_dx,actor_dy,value

In [122]:
lr = 1e-5
n_input = 1880
n_units = env.env_params.max_units
sap_range = env.env_params.unit_sap_range
n_action = 6

model = Policy(n_input,n_action,n_units,sap_range=sap_range)
optimizer = torch.optim.Adam(model.parameters(),lr=lr)
num_episodes = 2
n = 1000
victory_bonus = 0

match_step = env.env_params.max_steps_in_match + 1
len_episode = match_step * env.env_params.match_count_per_episode


In [123]:
for i in range(n) :

    values = torch.zeros(2,num_episodes,len_episode)
    rewards = torch.zeros(2,num_episodes,len_episode)
    log_probs = torch.zeros(2,num_episodes,len_episode)

    for episode in range(1):
        # Reset the environment and get the initial state
        obs, _ = env.reset()
        state_0 = obs_to_state(obs['player_0'])
        state_1 = obs_to_state(obs['player_1'])

        base_reward_0 = 0
        base_reward_1 = 0
        # Keep track of the states, actions, and rewards for each step in the episode

        # Run the episode
        for step in range(50):

            # Get the action probabilities from the policy network
            actor_action_0,actor_dx_0,actor_dy_0, value_0 = model(state_0)
            actor_action_1,actor_dx_1,actor_dy_1, value_1 = model(state_1)

            action_0 = torch.zeros(n_units,3,dtype=torch.int)
            action_0[:,0] = Categorical(logits=actor_action_0).sample()
            action_0[:,1] = Categorical(logits=actor_dx_0).sample() - sap_range
            action_0[:,2] = Categorical(logits=actor_dy_0).sample() - sap_range

            action_1 = torch.zeros(n_units,3,dtype=torch.int)
            action_1[:,0] = Categorical(logits=actor_action_1).sample()
            action_1[:,1] = Categorical(logits=actor_dx_1).sample() - sap_range
            action_1[:,2] = Categorical(logits=actor_dy_1).sample() - sap_range

            log_prob_0 = torch.sum(actor_action_0[torch.arange(n_units),action_0[:,0]])
            log_prob_0 += torch.sum(actor_dx_0[torch.arange(n_units),action_0[:,1]])
            log_prob_0 += torch.sum(actor_dy_0[torch.arange(n_units),action_0[:,2]])

            log_prob_1 = torch.sum(actor_action_1[torch.arange(n_units),action_1[:,0]])
            log_prob_1 += torch.sum(actor_dx_1[torch.arange(n_units),action_1[:,1]])
            log_prob_1 += torch.sum(actor_dy_1[torch.arange(n_units),action_1[:,2]])

            action = dict(player_0 = np.array(action_0,dtype=int), player_1 = np.array(action_1,dtype=int))

            # Take the chosen action and observe the next state and reward
            obs, reward, truncated, done, info = env.step(action)
            next_state_0 = obs_to_state(obs['player_0'])
            next_state_1 = obs_to_state(obs['player_1'])

            if step == 0 :
                reward_memory = reward

            if reward['player_0'] > reward_memory['player_0'] or reward['player_1'] > reward_memory['player_1'] :
                base_reward_0 = reward_0
                base_reward_1 = reward_1

                if reward['player_0'] > reward_memory['player_0'] :
                    base_reward_0 += victory_bonus
                else : 
                    base_reward_1 += victory_bonus
                reward_memory = reward


            reward_0 = obs['player_0']['team_points'][0] + base_reward_0
            reward_1 = obs['player_1']['team_points'][1] + base_reward_1
                
            # Store the current state, action, and reward

            values[0,episode,step] = value_0
            values[1,episode,step] = value_1

            rewards[0,episode,step] = reward_0
            rewards[1,episode,step] = reward_1

            log_probs[0,episode,step] = log_prob_0
            log_probs[1,episode,step] = log_prob_1

            state_0 = next_state_0
            state_1 = next_state_1

    y = rewards[:,:,-1].view(2,num_episodes,1) - rewards

    y = y.flatten()
    values = values.flatten()
    log_probs = log_probs.flatten()

    policy_loss = torch.mean(log_probs * (y - values))
    value_loss = F.mse_loss(y,values)

    loss = policy_loss + value_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(base_reward_0,base_reward_1,loss)
        


torch.Size([2, 16, 2])
torch.Size([2, 16])
torch.Size([2, 16])
torch.Size([2, 16, 2])
torch.Size([2, 16])
torch.Size([2, 16])
torch.Size([2, 16, 2])
torch.Size([2, 16])
torch.Size([2, 16])
torch.Size([2, 16, 2])
torch.Size([2, 16])
torch.Size([2, 16])
torch.Size([2, 16, 2])
torch.Size([2, 16])
torch.Size([2, 16])
torch.Size([2, 16, 2])
torch.Size([2, 16])
torch.Size([2, 16])
torch.Size([2, 16, 2])
torch.Size([2, 16])
torch.Size([2, 16])
torch.Size([2, 16, 2])
torch.Size([2, 16])
torch.Size([2, 16])
torch.Size([2, 16, 2])
torch.Size([2, 16])
torch.Size([2, 16])
torch.Size([2, 16, 2])
torch.Size([2, 16])
torch.Size([2, 16])
torch.Size([2, 16, 2])
torch.Size([2, 16])
torch.Size([2, 16])
torch.Size([2, 16, 2])
torch.Size([2, 16])
torch.Size([2, 16])
torch.Size([2, 16, 2])
torch.Size([2, 16])
torch.Size([2, 16])
torch.Size([2, 16, 2])
torch.Size([2, 16])
torch.Size([2, 16])
torch.Size([2, 16, 2])
torch.Size([2, 16])
torch.Size([2, 16])
torch.Size([2, 16, 2])
torch.Size([2, 16])
torch.Size([

KeyboardInterrupt: 