In [1]:
!pip install vmas



In [6]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
import vmas
import gym
import itertools

class GraphSageLayer(nn.Module):
    def __init__(self, dim_in: int,
                 dim_out: int,
                 agg_type: str):
        super().__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.agg_type = agg_type
        self.act = nn.ReLU()

        if agg_type == 'gcn':
            self.weight = nn.Linear(dim_in, dim_out, bias=False)
            self.bias   = nn.Linear(dim_in, dim_out, bias=False)
        elif agg_type == 'mean':
            self.weight = nn.Linear(2 * dim_in, dim_out, bias=False)
        elif agg_type == 'maxpool':
            self.linear_pool = nn.Linear(dim_in, dim_in, bias=True)
            self.weight      = nn.Linear(2 * dim_in, dim_out, bias=False)
        else:
            raise RuntimeError(f"Unknown aggregation type: {agg_type}")

    def forward(self, feat: torch.Tensor,
                edge: torch.Tensor,
                degree: torch.Tensor) -> torch.Tensor:
        # feat: [N, dim_in], edge: [E,2], degree: [N]

        agg_vector = torch.zeros_like(feat)

        if self.agg_type == 'gcn':
            agg_vector.index_add_(0, edge[:,1], feat[edge[:,0]])
            inv = (1.0 / degree.clamp(min=1)).unsqueeze(-1)
            out = self.act(self.weight(agg_vector * inv) + self.bias(feat))

        elif self.agg_type == 'mean':
            agg_vector.index_add_(0, edge[:,1], feat[edge[:,0]])
            inv = (1.0 / degree.clamp(min=1)).unsqueeze(-1)
            cat = torch.cat([agg_vector * inv, feat], dim=-1)
            out = self.act(self.weight(cat))

        else:  # 'maxpool'
            src = self.act(self.linear_pool(feat))[edge[:,0]]
            idx = edge[:,1].unsqueeze(-1).expand_as(src)
            agg_vector.scatter_reduce_(0, idx, src, reduce='amax', include_self=False)
            cat = torch.cat([agg_vector, feat], dim=-1)
            out = self.act(self.weight(cat))

        return F.normalize(out, p=2, dim=-1)

def build_edge_lists(coords, n_agents, threshold: float):
    edges = [(i, i) for i in range(n_agents)]
    for i, j in itertools.combinations(range(n_agents), 2):
        dist = torch.dist(coords[i], coords[j], p=2)
        if dist <= threshold:
            edges.append((i, j))
    return edges

def build_knn_edge_lists(coords, n_agents, k: int):
    edges = []
    dist = torch.cdist(coords, coords, p=2)
    dist.fill_diagonal_(float('inf'))
    knn_idx = torch.topk(dist, k, largest=False).indices
    edges = [(i, i) for i in range(n_agents)]
    for i in range(n_agents):
        for j in knn_idx[i].tolist():
            edges.append((i, j))
    return edges

In [None]:
import argparse
import os
import random
import time
from distutils.util import strtobool

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from collections import deque
import vmas
from tqdm import tqdm
import sys
import os

from distutils.util import strtobool

def parse_args():
    class Args:
        # Experiment setup
        exp_name = "notebook_run"
        scenario = "navigation"
        learning_rate = 1e-4
        seed = 1
        total_timesteps = 50_000_000
        torch_deterministic = True
        cuda = True
        track = False
        wandb_project_name = "graph-ml-projects"
        wandb_entity = None

        # GNN specific arguments
        agg_type = "gcn"
        hidden_dim = 64
        dist = 0.1

        # Algorithm hyperparameters
        num_agents = 4
        num_envs = 600
        num_steps = 100
        anneal_lr = True
        gae = True
        gamma = 0.99
        gae_lambda = 0.95
        num_minibatches = 45
        update_epochs = 4
        norm_adv = True
        clip_coef = 0.2
        clip_vloss = True
        ent_coef = 0.01
        vf_coef = 0.5
        max_grad_norm = 0.5
        target_kl = None

        # Derived values
        batch_size = num_envs * num_steps  # 600 * 100 = 60000
        minibatch_size = batch_size // num_minibatches  # 60000 // 45 ≈ 1333

    return Args()

def make_env(scenario, num_envs, continuous_actions, seed, device, n_agents):
    return vmas.make_env(
        scenario=scenario,
        num_envs=num_envs,
        device=device,
        continuous_actions=continuous_actions,
        seed=seed,
        n_agents=n_agents,
        max_steps=100,
    )

def add_agent_id(obs, num_envs, num_agents, device):
    agent_ids = torch.eye(num_agents, device=device)  # [num_agents, num_agents]
    agent_ids = agent_ids.unsqueeze(0).repeat(num_envs, 1, 1)  # [num_envs, num_agents, num_agents]
    agent_ids = agent_ids.view(-1, num_agents)  # [num_envs * num_agents, num_agents]
    return torch.cat([obs, agent_ids], dim=-1)  # [num_envs * num_agents, obs_dim + num_agents]

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class GraphSageAgent(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_dim, dist=1.0, agg_type='gcn'):
        super().__init__()
        self.dist = dist
        # GraphSAGE policy
        self.gsage1 = GraphSageLayer(obs_dim, hidden_dim, agg_type=agg_type)
        self.gsage2 = GraphSageLayer(hidden_dim, hidden_dim, agg_type=agg_type)
        self.policy_head = layer_init(nn.Linear(hidden_dim, act_dim), std=0.01)
        # Critic: MLP
        self.critic = nn.Sequential(
            layer_init(nn.Linear(hidden_dim, 256)), nn.Tanh(),
            layer_init(nn.Linear(256, 256)), nn.Tanh(),
            layer_init(nn.Linear(256, 1), std=1.0)
        )

    def get_value(self, communicated_obs: torch.Tensor) -> torch.Tensor:
        # communicated_obs: [-1, obs_dim]
        return self.critic(communicated_obs)

    def get_communicated_obs_slow(self, x: torch.Tensor, positions):
        # x: [-1, num_agents, obs_dim]
        # positions: [-1, num_agents, 2]
        device = x.device
        num_agents = x.shape[1]
        # apply GCN per env
        h_list = []
        for e in range(x.shape[0]):
            feat_e = x[e] # [num_agents, obs_dim]
            coord_e = positions[e] # [num_agents, 2]
            edges_e = build_edge_lists(coord_e, num_agents, self.dist)
            edges_e = torch.tensor(edges_e).to(device) # [num_edges, 2]
            degree_e = torch.bincount(edges_e[:,1], minlength=num_agents).to(device) # [num_agents, ]
            h_e = self.gsage1(feat_e, edges_e, degree_e)
            h_e = self.gsage2(h_e, edges_e, degree_e) # [num_agents, obs_dim]
            h_list.append(h_e)
        # flatten back
        h = torch.cat(h_list, dim=0)  # [-1, 64]
        return h

    def get_communicated_obs(self, x: torch.Tensor, positions: torch.Tensor):
        """
        x: [num_envs, num_agents, obs_dim]
        positions: [num_envs, num_agents, pos_dim]  (e.g. pos_dim=2)
        """
        device = x.device
        num_envs, n_agents, feat_dim = x.shape

        # 1) Flatten node features: [N, feat_dim], N = num_envs * n_agents
        x_flat = x.view(-1, feat_dim)  # N x feat_dim
        dist_mat = torch.cdist(positions, positions, p=2)
        mask = dist_mat <= self.dist  # [E, N, N]
        env_idx, src_idx, dst_idx = mask.nonzero(as_tuple=True)
        src_flat = env_idx * n_agents + src_idx
        dst_flat = env_idx * n_agents + dst_idx
        edges = torch.stack([src_flat, dst_flat], dim=1).to(device)  # E_total x 2
        N = num_envs * n_agents
        degree = torch.bincount(dst_flat, minlength=N).to(device)

        h = self.gsage1(x_flat, edges, degree)
        h = self.gsage2(h, edges, degree)  # still N x hidden_dim

        return h  # [num_envs * num_agents, hidden_dim]

    def get_action_and_value(self, communicated_obs, action=None):
        # communicated_obs: [-1, obs_dim]
        logits = self.policy_head(communicated_obs)
        dist = Categorical(logits=logits)
        if action is None:
            action = dist.sample()
        return action, dist.log_prob(action), dist.entropy(), self.critic(communicated_obs)

    def get_action(self, x, deterministic=True):
        pass

if __name__ == "__main__":
    args = parse_args()
    run_name = f"{args.scenario}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb
        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # env setup
    envs = make_env(
        scenario=args.scenario,
        num_envs=args.num_envs,
        continuous_actions=False, # only consider discrete action space
        seed=args.seed,
        device=device,
        n_agents=args.num_agents,
    )

    # check dim of env
    obs_list = envs.reset()
    obs_dim = obs_list[0].shape[-1]
    act_dim = envs.action_space[0].n

    agent = GraphSageAgent(obs_dim=obs_dim + args.num_agents, act_dim=act_dim, hidden_dim=args.hidden_dim, dist=args.dist, agg_type=args.agg_type).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

    # Storage setup
    obs = torch.zeros((args.num_steps, args.num_envs * args.num_agents, args.hidden_dim)).to(device)
    actions = torch.zeros((args.num_steps, args.num_envs * args.num_agents)).to(device)
    logprobs = torch.zeros((args.num_steps, args.num_envs * args.num_agents)).to(device)
    rewards = torch.zeros((args.num_steps, args.num_envs * args.num_agents)).to(device)
    dones = torch.zeros((args.num_steps, args.num_envs * args.num_agents)).to(device)
    values = torch.zeros((args.num_steps, args.num_envs * args.num_agents)).to(device)

    # start the game
    global_step = 0
    SAVE_INTERVAL = 10_000_000
    episode_returns = np.zeros(args.num_envs, dtype=np.float32)
    episode_lengths = np.zeros(args.num_envs, dtype=np.int32)
    return_queue = deque(maxlen=100)
    length_queue = deque(maxlen=100)
    start_time = time.time()
    next_obs_list = envs.reset(seed=args.seed) # (num_envs, obs_dim) per agents
    next_obs = torch.stack(next_obs_list, dim=1).to(device) # [num_envs, num_agents, obs_dim]
    next_obs = next_obs.view(-1, obs_dim) # [num_envs * num_agents, obs_dim]
    next_obs = add_agent_id(next_obs, args.num_envs, args.num_agents, device) # [num_envs * num_agents, obs_dim + num_agents]
    next_gnn_obs = next_obs.view(-1, args.num_agents, obs_dim + args.num_agents) # [-1, num_agents, obs_dim]

    next_done = torch.zeros(args.num_envs * args.num_agents).to(device)
    num_updates = args.total_timesteps // args.batch_size

    positions = []
    for e in range(args.num_envs):
        coords = torch.stack([agent.state.pos[e] for agent in envs.agents], dim=0).to(device)  # [n_agents, pos_dim]
        positions.append(coords)
    positions = torch.stack(positions, dim=0).to(device)  # [-1, num_agents, 2]

    for update in range(1, num_updates + 1):
        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            frac = 1.0 - (update - 1.0) / num_updates
            lrnow = frac * args.learning_rate
            optimizer.param_groups[0]["lr"] = lrnow

        for step in range(0, args.num_steps):
            global_step += 1 * args.num_envs
            dones[step] = next_done

            # ALGO LOGIC: action logic
            with torch.no_grad():
                communicated_obs = agent.get_communicated_obs(next_gnn_obs, positions)
                action, logprob, entropy, value = agent.get_action_and_value(communicated_obs)

            obs[step] = communicated_obs
            values[step] = value.flatten()
            actions[step] = action
            logprobs[step] = logprob

            # execute the game and log data
            action_array = action.view(args.num_envs, args.num_agents).cpu().numpy()
            action_list = [action_array[:, i] for i in range(args.num_agents)]
            next_obs_list, reward_list, done, info = envs.step(action_list)

            reward = torch.stack(reward_list, dim=1).to(device)
            env_rewards = torch.sum(reward, dim=1).cpu().numpy()
            episode_returns += env_rewards
            episode_lengths += 1

            rewards[step] = reward.flatten()

            next_obs = torch.stack(next_obs_list, dim=1).to(device) # [num_envs, num_agents, obs_dim]
            next_obs = next_obs.view(-1, obs_dim) # [num_envs * num_agents, obs_dim]
            next_obs = add_agent_id(next_obs, args.num_envs, args.num_agents, device) # [num_envs * num_agents, obs_dim + num_agents]
            next_gnn_obs = next_obs.view(-1, args.num_agents, obs_dim + args.num_agents)

            next_done = done.to(device).unsqueeze(1).repeat(1, args.num_agents) # [num_envs, ] => [num_envs, num_agents]
            next_done = next_done.flatten() # [num_envs * num_agents, ]

            positions = []
            for e in range(args.num_envs):
                coords = torch.stack([agent.state.pos[e] for agent in envs.agents], dim=0).to(device)  # [n_agents, pos_dim]
                positions.append(coords)
            positions = torch.stack(positions, dim=0).to(device)  # [-1, num_agents, 2]

            episode_ret = []
            episode_len = []

            for i in range(len(done)):
                if done[i]:
                    episode_ret.append(episode_returns[i])
                    episode_len.append(episode_lengths[i])
                    episode_returns[i] = 0
                    episode_lengths[i] = 0

            # logging episode return and length
            if episode_ret and global_step > 100:
                print(f"global_step={global_step}, episodic_return={np.mean(episode_ret)}")
                writer.add_scalar("charts/episodic_return", np.mean(episode_ret), global_step)
                writer.add_scalar("charts/episodic_length", np.mean(episode_len), global_step)

        # bootstrap value if not done
        with torch.no_grad():
            next_value = agent.get_value(communicated_obs).reshape(1, -1)
            if args.gae:
                advantages = torch.zeros_like(rewards).to(device)
                lastgaelam = 0
                for t in reversed(range(args.num_steps)):
                    if t == args.num_steps - 1:
                        nextnonterminal = 1.0 - next_done.float()
                        nextvalues = next_value
                    else:
                        nextnonterminal = 1.0 - dones[t + 1].float()
                        nextvalues = values[t + 1]
                    delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
                    advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
                returns = advantages + values
            else:
                returns = torch.zeros_like(rewards).to(device)
                for t in reversed(range(args.num_steps)):
                    if t == args.num_steps - 1:
                        nextnonterminal = 1.0 - next_done.float()
                        next_return = next_value
                    else:
                        nextnonterminal = 1.0 - dones[t + 1].float()
                        next_return = returns[t + 1]
                    returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
                advantages = returns - values

        # flatten the batch
        b_obs = obs.reshape((-1, args.hidden_dim))
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape((-1,1))
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)

        # Optimizing the policy and value network
        b_inds = np.arange(args.batch_size)
        clipfracs = []
        for epoch in range(args.update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, args.batch_size, args.minibatch_size):
                end = start + args.minibatch_size
                mb_inds = b_inds[start:end]

                _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])

                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]

                mb_advantages = b_advantages[mb_inds]
                if args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if args.clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -args.clip_coef,
                        args.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                entropy_loss = entropy.mean()
                loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                optimizer.step()

            if args.target_kl is not None:
                if approx_kl > args.target_kl:
                    break

        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        # record rewards for plotting purposes
        writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
        writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
        writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
        writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
        writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
        writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
        writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
        writer.add_scalar("losses/explained_variance", explained_var, global_step)
        print("SPS:", int(global_step / (time.time() - start_time)))
        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

        if global_step % SAVE_INTERVAL < 1000:
            torch.save(agent.actor.state_dict(), f"{run_name}_policy_{global_step}.pth")
            torch.save(agent.critic.state_dict(), f"{run_name}_critic_{global_step}.pth")
            print("SAVE MODEL in global_step = ", global_step)

    torch.save(agent.actor.state_dict(), f"{run_name}_policy.pth")
    torch.save(agent.critic.state_dict(), f"{run_name}_critic.pth")

    writer.close()

global_step=60000, episodic_return=-4.67769193649292
SPS: 1574
global_step=60600, episodic_return=-0.04767770320177078
global_step=61200, episodic_return=-0.043007101863622665
global_step=61800, episodic_return=-0.053333912044763565
global_step=62400, episodic_return=-0.07168986648321152
global_step=63000, episodic_return=-0.060140419751405716
global_step=63600, episodic_return=-0.03862128406763077
global_step=64200, episodic_return=-0.04448026046156883
global_step=64800, episodic_return=-0.048524558544158936
global_step=65400, episodic_return=-0.06828081607818604
global_step=66000, episodic_return=-0.08017471432685852
global_step=66600, episodic_return=-0.0742112323641777
global_step=67200, episodic_return=-0.06346938014030457
global_step=67800, episodic_return=-0.05459117889404297
global_step=68400, episodic_return=-0.05288754031062126
global_step=69000, episodic_return=-0.041751086711883545
global_step=69600, episodic_return=-0.036984387785196304
global_step=70200, episodic_return=-