In [1]:
#SAC implementation

In [2]:
import socket

import argparse
import os
import random
import subprocess
import time
from distutils.util import strtobool
from typing import List

import asr_pb2

import numpy as np
import pandas as pd
import torch.nn.functional as F
from collections import defaultdict
import torch
from torch import nn
from torch import optim
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from gym.spaces import MultiDiscrete
from stable_baselines3.common.vec_env import VecEnvWrapper, VecMonitor, VecVideoRecorder

from gevent import monkey

from gym_microrts import microrts_ai
from gym_microrts.envs.vec_env import (
    MicroRTSGridModeSharedMemVecEnv as MicroRTSGridModeVecEnv,
)

In [16]:
class MicroRTSStatsRecorder(VecEnvWrapper):
    def __init__(self, env, gamma=0.99) -> None:
        super().__init__(env)
        self.gamma = gamma

    def reset(self):
        obs = self.venv.reset()
        self.raw_rewards = [[] for _ in range(self.num_envs)]
        self.ts = np.zeros(self.num_envs, dtype=np.float32)
        self.raw_discount_rewards = [[] for _ in range(self.num_envs)]
        return obs

    def step_wait(self):
        obs, rews, dones, infos = self.venv.step_wait()
        newinfos = list(infos[:])
        for i, done in enumerate(dones):
            self.raw_rewards[i] += [infos[i]["raw_rewards"]]
            self.raw_discount_rewards[i] += [
                (self.gamma ** self.ts[i])
                * np.concatenate((infos[i]["raw_rewards"], infos[i]["raw_rewards"].sum()), axis=None)
            ]
            self.ts[i] += 1
            if done:
                info = infos[i].copy()
                raw_returns = np.array(self.raw_rewards[i]).sum(0)
                raw_names = [str(rf) for rf in self.rfs]
                raw_discount_returns = np.array(self.raw_discount_rewards[i]).sum(0)
                raw_discount_names = ["discounted_" + str(rf) for rf in self.rfs] + ["discounted"]
                info["microrts_stats"] = dict(zip(raw_names, raw_returns))
                info["microrts_stats"].update(dict(zip(raw_discount_names, raw_discount_returns)))
                self.raw_rewards[i] = []
                self.raw_discount_rewards[i] = []
                self.ts[i] = 0
                newinfos[i] = info
        return obs, rews, dones, newinfos

    
if __name__ == "__main__":


    #parameters
    seed = 9
    torch_deterministic = True
    num_selfplay_envs = 2
    num_bot_envs = 0
    partial_obs  = False
    train_maps = ["maps/16x16/basesWorkers16x16A.xml"]
    gamma = 0.99
    
    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = torch_deterministic

    envs = MicroRTSGridModeVecEnv(
            num_selfplay_envs=num_selfplay_envs,
            num_bot_envs=num_bot_envs,
            partial_obs=partial_obs,
            max_steps=2000,
            render_theme=2,
            ai2s=[microrts_ai.coacAI for _ in range(num_bot_envs - 6)]
            + [microrts_ai.randomBiasedAI for _ in range(min(num_bot_envs, 2))]
            + [microrts_ai.lightRushAI for _ in range(min(num_bot_envs, 2))]
            + [microrts_ai.workerRushAI for _ in range(min(num_bot_envs, 2))],
            map_paths=[train_maps[0]],
            reward_weight=np.array([10.0, 1.0, 1.0, 0.2, 1.0, 4.0]),
            cycle_maps=train_maps,
        )
    envs = MicroRTSStatsRecorder(envs, gamma)
    envs = VecMonitor(envs)
#     if args.capture_video:
#         envs = VecVideoRecorder(
#             envs, f"videos/{experiment_name}", record_video_trigger=lambda x: x % 100000 == 0, video_length=2000
#         )
    assert isinstance(envs.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"

<stable_baselines3.common.vec_env.vec_monitor.VecMonitor object at 0x7f995f8a9a00>
