In [1]:
#SAC implementation

In [37]:
import socket

import argparse
import os
import random
import subprocess
import time
from distutils.util import strtobool
from typing import List

import asr_pb2

import numpy as np
import pandas as pd
import torch.nn.functional as F
from collections import defaultdict
import torch
from torch import nn
from torch import optim
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from gym.spaces import MultiDiscrete
from stable_baselines3.common.vec_env import VecEnvWrapper, VecMonitor, VecVideoRecorder

from gevent import monkey

from gym_microrts import microrts_ai
from gym_microrts.envs.vec_env import (
    MicroRTSGridModeSharedMemVecEnv as MicroRTSGridModeVecEnv,
)
from stable_baselines3.common.vec_env import VecEnvWrapper, VecMonitor, VecVideoRecorder
from stable_baselines3.common.buffers import ReplayBuffer

In [50]:
# ALGO LOGIC: initialize agent here:
class CategoricalMasked(Categorical):
    def __init__(self, probs=None, logits=None, validate_args=None, masks=None, mask_value=None):
        if masks is None:
            masks = []
        logits = torch.where(masks.bool(), logits, mask_value)
        super().__init__(probs, logits, validate_args)

class Transpose(nn.Module):
    def __init__(self, permutation):
        super().__init__()
        self.permutation = permutation

    def forward(self, x):
        return x.permute(self.permutation)


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class MicroRTSStatsRecorder(VecEnvWrapper):
    def __init__(self, env, gamma=0.99) -> None:
        super().__init__(env)
        self.gamma = gamma

    def reset(self):
        obs = self.venv.reset()
        self.raw_rewards = [[] for _ in range(self.num_envs)]
        self.ts = np.zeros(self.num_envs, dtype=np.float32)
        self.raw_discount_rewards = [[] for _ in range(self.num_envs)]
        return obs

    def step_wait(self):
        obs, rews, dones, infos = self.venv.step_wait()
        newinfos = list(infos[:])
        for i, done in enumerate(dones):
            self.raw_rewards[i] += [infos[i]["raw_rewards"]]
            self.raw_discount_rewards[i] += [
                (self.gamma ** self.ts[i])
                * np.concatenate((infos[i]["raw_rewards"], infos[i]["raw_rewards"].sum()), axis=None)
            ]
            self.ts[i] += 1
            if done:
                info = infos[i].copy()
                raw_returns = np.array(self.raw_rewards[i]).sum(0)
                raw_names = [str(rf) for rf in self.rfs]
                raw_discount_returns = np.array(self.raw_discount_rewards[i]).sum(0)
                raw_discount_names = ["discounted_" + str(rf) for rf in self.rfs] + ["discounted"]
                info["microrts_stats"] = dict(zip(raw_names, raw_returns))
                info["microrts_stats"].update(dict(zip(raw_discount_names, raw_discount_returns)))
                self.raw_rewards[i] = []
                self.raw_discount_rewards[i] = []
                self.ts[i] = 0
                newinfos[i] = info
        return obs, rews, dones, newinfos

    
if __name__ == "__main__":


    #parameters
    seed = 9
    torch_deterministic = True
    num_selfplay_envs = 2
    num_bot_envs = 0
    partial_obs  = False
    train_maps = ["maps/16x16/basesWorkers16x16A.xml"]
    gamma = 0.99
    q_lr = 3e-4
    policy_lr = 2.5e-4
    eps=1e-5
    buffer_size = int(1e6)
    total_timesteps = 5000000
    learning_starts = int(2e4)
    update_frequency  = 4
    
    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = torch_deterministic

    envs = MicroRTSGridModeVecEnv(
            num_selfplay_envs=num_selfplay_envs,
            num_bot_envs=num_bot_envs,
            partial_obs=partial_obs,
            max_steps=2000,
            render_theme=2,
            ai2s=[microrts_ai.coacAI for _ in range(num_bot_envs - 6)]
            + [microrts_ai.randomBiasedAI for _ in range(min(num_bot_envs, 2))]
            + [microrts_ai.lightRushAI for _ in range(min(num_bot_envs, 2))]
            + [microrts_ai.workerRushAI for _ in range(min(num_bot_envs, 2))],
            map_paths=[train_maps[0]],
            reward_weight=np.array([10.0, 1.0, 1.0, 0.2, 1.0, 4.0]),
            cycle_maps=train_maps,
        )
    envs = MicroRTSStatsRecorder(envs, gamma)
    envs = VecMonitor(envs)

In [None]:
#https://github.com/timoklein/cleanrl/blob/sac-discrete/cleanrl/sac_atari.py
#https://github.com/pranz24/pytorch-soft-actor-critic
#https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py

In [None]:

class Actor(nn.Module):
    def __init__(self, envs, mapsize=16 * 16):
        super().__init__()
        self.mapsize = mapsize
        h, w, c = envs.observation_space.shape
        self.encoder = nn.Sequential(
            Transpose((0, 3, 1, 2)),
            layer_init(nn.Conv2d(c, 32, kernel_size=3, padding=1)),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, kernel_size=3, padding=1)),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.ReLU(),
        )
        self.actor = nn.Sequential(
            layer_init(nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)),
            nn.ReLU(),
            layer_init(nn.ConvTranspose2d(32, 78, 3, stride=2, padding=1, output_padding=1)),
            Transpose((0, 2, 3, 1)),
        )
        self.register_buffer("mask_value", torch.tensor(-1e8))

    def get_action(self, x, action=None, invalid_action_masks= None, envs=None, device=None):
        hidden = self.encoder(x)
        logits = self.actor(hidden)
        grid_logits = logits.reshape(-1, envs.action_plane_space.nvec.sum())
        split_logits = torch.split(grid_logits, envs.action_plane_space.nvec.tolist(), dim=1)

        if action is None:
            invalid_action_masks = invalid_action_masks.view(-1, invalid_action_masks.shape[-1])
            split_invalid_action_masks = torch.split(invalid_action_masks, envs.action_plane_space.nvec.tolist(), dim=1)
            multi_categoricals = [
                CategoricalMasked(logits=logits, masks=iam, mask_value=self.mask_value)
                for (logits, iam) in zip(split_logits, split_invalid_action_masks)
            ]
            action = torch.stack([categorical.sample() for categorical in multi_categoricals])
        else:
            invalid_action_masks = invalid_action_masks.view(-1, invalid_action_masks.shape[-1])
            action = action.view(-1, action.shape[-1]).T
            split_invalid_action_masks = torch.split(invalid_action_masks, envs.action_plane_space.nvec.tolist(), dim=1)
            multi_categoricals = [
                CategoricalMasked(logits=logits, masks=iam, mask_value=self.mask_value)
                for (logits, iam) in zip(split_logits, split_invalid_action_masks)
            ]
        logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
        entropy = torch.stack([categorical.entropy() for categorical in multi_categoricals])
        num_predicted_parameters = len(envs.action_plane_space.nvec)
        logprob = logprob.T.view(-1, self.mapsize, num_predicted_parameters)
        entropy = entropy.T.view(-1, self.mapsize, num_predicted_parameters)
        action = action.T.view(-1, self.mapsize, num_predicted_parameters)
        return action, logprob.sum(1).sum(1), entropy.sum(1).sum(1), invalid_action_masks


    
class SoftQNetwork(nn.Module):
    def __init__(self, envs, mapsize=16 * 16):
        super().__init__()
        self.mapsize = mapsize
        h, w, c = envs.observation_space.shape
        self.encoder = nn.Sequential(
            Transpose((0, 3, 1, 2)),
            layer_init(nn.Conv2d(c, 32, kernel_size=3, padding=1)),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, kernel_size=3, padding=1)),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.ReLU(),
        )

        self.critic = nn.Sequential(
            nn.Flatten(),
            layer_init(nn.Linear(64 * 4 * 4, 128)),
            nn.ReLU(),
            layer_init(nn.Linear(128, 1), std=1),
        )
        self.register_buffer("mask_value", torch.tensor(-1e8))

    def get_value(self, x):
        return self.critic(self.encoder(x))
    

def build_models():
    actor = Actor(envs).to(device)
    qf1 = SoftQNetwork(envs).to(device)
    qf2 = SoftQNetwork(envs).to(device)
    qf1_target = SoftQNetwork(envs).to(device)
    qf2_target = SoftQNetwork(envs).to(device)
    qf1_target.load_state_dict(qf1.state_dict())
    qf2_target.load_state_dict(qf2.state_dict())
    q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=q_lr , eps=eps)
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=policy_lr)
    
    
    #self.obs_shape, dtype=observation_space.dtype)
    
    #add autotune
    alpha = 0.2
    
    envs.observation_space.dtype = np.float32
    rb = ReplayBuffer(
        buffer_size,
        envs.observation_space,
        envs.action_space,
        device,
        handle_timeout_termination=True,
    )
    start_time = time.time()
    # TRY NOT TO MODIFY: start the game
    obs = envs.reset()
    for global_step in range(total_timesteps):
        if global_step < learning_starts:
            actions = np.array([envs.action_space.sample() for _ in range(envs.num_envs)])
            actions = actions.reshape(envs.num_envs, 256, 7)
            actions = torch.Tensor(actions).to(torch.int64).to(device)
        else:
            invalid_mask = torch.tensor(envs.get_action_mask()).to(device)
            actions, _, _, _ = actor.get_action(torch.Tensor(obs).to(device), envs=envs, invalid_action_masks=invalid_mask, device=device)
            #action shape torch.Size([8, 256, 7])
        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards, dones, infos = envs.step(actions.cpu().numpy().reshape(envs.num_envs, -1))
        
        for info in infos:
            if "episode" in info.keys():
                print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
                break
        real_next_obs = next_obs.copy()
        
#         for idx, d in enumerate(dones):
#             if d:
#                 print(infos[idx])
#                 real_next_obs[idx] = infos[idx]["terminal_observation"]

        obs = next_obs
        if global_step > learning_starts and global_step % update_frequency == 0:
            raise("OK")

build_models()



In [None]:
#remarks
- Currently actor and critic do not share a common encoder (in original PPO they share a common encoder)
But according to authors of the SAC paper, sharing a CNN encoder between Actor and Critics is not recommended for SAC

- Autotune for alpha not implemented

- Does not handle 'terminal_observation' (not supported by env), check how is it handle by PPO

- Memory does not save yet the invalid masks!