### this notebook is for training our Agent in the lux environment and see how well we do :D

In [1]:
# module imports
import torch
import numpy as np
import torch.optim as optim
from luxai_s3.wrappers import LuxAIS3GymEnv
from agent import Agent
from network import AgentNetwork, compute_network_difference, has_converged
from rewards import calculate_rewards
from ac2methods import compute_advantages, compute_weight_loss, compute_action_loss

In [2]:
# reset our gym environment
env = LuxAIS3GymEnv(numpy_output=True)
obs, info = env.reset()

env_cfg = info["params"]



In [3]:
# set torch device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
# set players
players = {
    "player_0": Agent("player_0", env_cfg, AgentNetwork((env_cfg["map_width"], env_cfg["map_height"]), env_cfg["max_units"], 6).to(device), device),
    "player_1": Agent("player_1", env_cfg, AgentNetwork((env_cfg["map_width"], env_cfg["map_height"]), env_cfg["max_units"], 6).to(device), device)
}

In [5]:
# set optimizer for network
optimizer = torch.optim.Adam(players["player_0"].net.parameters(), lr=1e-4)

In [6]:
# set some hyperparams
episode_num = 0
reward_history = []
network_difs = []
wins = 0
gamma = 0.99
lambda_ = 0.95
value_coeff=0.5
entropy_coeff=0.01
win_rates = []

while True:
    obs, info = env.reset()
    game_done = False
    step = 0
    last_obs = {}
    last_actions = {}
    print(f"episode num: {episode_num}")

    # initialize rewards array and trajectories
    rewards = {
        "player_0": [],
        "player_1": []
    }

    # save last env reward 
    last_env_reward = {
        "player_0": np.zeros(1, dtype=np.int32),
        "player_1": np.zeros(1, dtype=np.int32)
    }
    
    while not game_done:
        actions = {}
        # store current observations for learning
        last_obs = {
            "player_0": obs["player_0"].copy(),
            "player_1": obs["player_1"].copy()
        }

        # get network output, including actions
        network_outs = {}
        for id_, agent in players.items():
            
            network_outs[id_] = agent.act_train(step=step, obs=obs[id_])

            actions[id_] = agent.sample_actions(network_outs[id_][1].detach().cpu(), network_outs[id_][2].detach().cpu())

            # save actions
            last_actions[id_] = actions.copy()

         
        # step in environment for both agents
        obs, reward, terminated, truncated, info = env.step(actions)
        match_result = None
        if (last_env_reward != reward):
            if reward["player_0"] > last_env_reward["player_0"]:
                match_result = "win"
            elif reward["player_1"] > last_env_reward["player_1"]:
                match_result = "loss"
            else:
                match_result = "draw"

        last_env_reward = reward.copy()

        # calc rewards for both agents
        for id_, agent in players.items():
            map_memory, enemy_memory, ally_memory, relic_points, _, _ = agent.process_obs(obs[id_])
            rewards[id_].append(calculate_rewards(network_outs[id_][0].squeeze(0).detach().cpu().numpy(), map_memory, enemy_memory, ally_memory, relic_points, match_result))
            

        # calc whether game is finished
        dones = {k: terminated[k] | truncated[k] for k in terminated}

        # Compute returns and advantages for player 0
        returns, advantages = compute_advantages(
            rewards=[rewards["player_0"][-1]],
            values=[network_outs["player_0"][3].squeeze(0).squeeze(-1).detach().cpu().numpy()],
            gamma=gamma,
            lambda_=lambda_
        )

        # compute losses
        weight_loss = compute_weight_loss(
            log_probs=torch.cat((network_outs["player_0"][1].log(), network_outs["player_0"][2].log()), dim=-1).to(device),
            advantages=torch.tensor(advantages, dtype=torch.float32).to(device),
            values=network_outs["player_0"][3].squeeze(-1),
            returns=torch.tensor(returns, dtype=torch.float32).to(device),
            entropy_coeff=entropy_coeff,
            value_coeff=value_coeff
        )
        action_loss = compute_action_loss(
            log_probs=network_outs["player_0"][1].log(),
            advantages=torch.tensor(advantages, dtype=torch.float32).to(device),
            values=network_outs["player_0"][3].squeeze(-1),
            returns=torch.tensor(returns, dtype=torch.float32).to(device),
            entropy_coeff=entropy_coeff,
            value_coeff=value_coeff
        )
    
        # backpropogation and optimization
        optimizer.zero_grad()
        total_loss = weight_loss + action_loss
        total_loss.backward()
        # clip gradients to prevent overflow
        torch.nn.utils.clip_grad_norm_(players["player_0"].net.parameters(), max_norm=1.0)

        optimizer.step()

        print(f"Step {step} of episode {episode_num} completed. Loss: {total_loss.item():.4f}")

        if dones["player_0"] or dones["player_1"]:
            game_done = True
            # save model weights
            torch.save(players["player_0"].net.state_dict(), f"models/agent_network_episode_{episode_num}")
            wins += int(reward["player_0"] > reward["player_1"])
            print(wins)
            win_rates.append(wins / (episode_num + 1))
        step += 1

    # store rewards
    reward_history.append(rewards["player_0"])
   

    # calc l2 norm
    network_dif = compute_network_difference(players["player_0"].net, players["player_1"].net)
    network_difs.append(network_dif)

    # update adversary to current state dict every 5 episodes
    if episode_num % 5 == 0:
        players["player_1"].net.load_state_dict(players["player_0"].net.state_dict())

    # if network has converged according to our criterion break out of training loop
    if has_converged(win_rates, network_difs):
        print(f"agent converged after {episode_num + 1} episodes!")
        break
        
    episode_num += 1

episode num: 0


ValueError: Expected 2D array, got 1D array instead:
array=[100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100.
 100. 100.].
Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.