In [None]:
!pip install vmas benchmarl pyvirtualdisplay moviepy
!apt-get install python3-opengl
import pyvirtualdisplay
display = pyvirtualdisplay.Display(visible=False, size=(1400, 900))
display.start()

In [None]:
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-cluster -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric

In [None]:
import torch

import vmas.simulator.core
import vmas.simulator.utils
from vmas.simulator.dynamics.common import Dynamics


class FixedWingKinematicBicycle(Dynamics):
    def __init__(
        self,
        world: vmas.simulator.core.World,
        width: float,
        l_f: float,
        l_r: float,
        max_steering_angle: float,
        min_v: float = 0.3,
        max_v: float = 1.0,
        integration: str = "rk4",
    ):
        super().__init__()
        assert integration in (
            "rk4",
            "euler",
        ), "Integration method must be 'euler' or 'rk4'."
        self.width = width
        self.l_f = l_f
        self.l_r = l_r
        self.max_steering_angle = max_steering_angle
        self.dt = world.dt
        self.integration = integration
        self.world = world
        self.min_v = min_v
        self.max_v = max_v

    def f(self, state, steering_command, v_command):
        theta = state[:, 2]  # Yaw angle
        beta = torch.atan2(
            torch.tan(steering_command) * self.l_r / (self.l_f + self.l_r),
            torch.tensor(1, device=self.world.device),
        )  # [-pi, pi] slip angle
        dx = v_command * torch.cos(theta + beta)
        dy = v_command * torch.sin(theta + beta)
        dtheta = (
            v_command
            / (self.l_f + self.l_r)
            * torch.cos(beta)
            * torch.tan(steering_command)
        )
        return torch.stack((dx, dy, dtheta), dim=1)  # [batch_size,3]

    def euler(self, state, steering_command, v_command):
        # Calculate the change in state using Euler's method
        # For Euler's method, see https://math.libretexts.org/Bookshelves/Calculus/Book%3A_Active_Calculus_(Boelkins_et_al.)/07%3A_Differential_Equations/7.03%3A_Euler's_Method (the full link may not be recognized properly, please copy and paste in your browser)
        return self.dt * self.f(state, steering_command, v_command)

    def runge_kutta(self, state, steering_command, v_command):
        # Calculate the change in state using fourth-order Runge-Kutta method
        # For Runge-Kutta method, see https://math.libretexts.org/Courses/Monroe_Community_College/MTH_225_Differential_Equations/3%3A_Numerical_Methods/3.3%3A_The_Runge-Kutta_Method
        k1 = self.f(state, steering_command, v_command)
        k2 = self.f(state + self.dt * k1 / 2, steering_command, v_command)
        k3 = self.f(state + self.dt * k2 / 2, steering_command, v_command)
        k4 = self.f(state + self.dt * k3, steering_command, v_command)
        return (self.dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4)

    @property
    def needed_action_size(self) -> int:
        return 2

    def process_action(self):
        # Extracts the velocity and steering angle from the agent's actions and convert them to physical force and torque
        v_command = self.agent.action.u[:, 0]
        # The only change we make:
        v_command = torch.clamp(
            v_command, self.min_v, self.max_v
        )

        steering_command = self.agent.action.u[:, 1]
        # Ensure steering angle is within bounds
        steering_command = torch.clamp(
            steering_command, -self.max_steering_angle, self.max_steering_angle
        )

        # Current state of the agent
        state = torch.cat((self.agent.state.pos, self.agent.state.rot), dim=1)

        v_cur_x = self.agent.state.vel[:, 0]  # Current velocity in x-direction
        v_cur_y = self.agent.state.vel[:, 1]  # Current velocity in y-direction
        v_cur_angular = self.agent.state.ang_vel[:, 0]  # Current angular velocity

        # Select the integration method to calculate the change in state
        if self.integration == "euler":
            delta_state = self.euler(state, steering_command, v_command)
        else:
            delta_state = self.runge_kutta(state, steering_command, v_command)

        # Calculate the accelerations required to achieve the change in state.
        acceleration_x = (delta_state[:, 0] - v_cur_x * self.dt) / self.dt**2
        acceleration_y = (delta_state[:, 1] - v_cur_y * self.dt) / self.dt**2
        acceleration_angular = (
            delta_state[:, 2] - v_cur_angular * self.dt
        ) / self.dt**2

        # Calculate the forces required for the linear accelerations
        force_x = self.agent.mass * acceleration_x
        force_y = self.agent.mass * acceleration_y

        # Calculate the torque required for the angular acceleration
        torque = self.agent.moment_of_inertia * acceleration_angular

        # Update the physical force and torque required for the user inputs
        self.agent.state.force[:, vmas.simulator.utils.X] = force_x
        self.agent.state.force[:, vmas.simulator.utils.Y] = force_y
        self.agent.state.torque = torque.unsqueeze(-1)

In [None]:
import typing
from typing import Callable, Dict, List

import torch
from torch import Tensor

from vmas import render_interactively
from vmas.simulator.core import Agent, Entity, Landmark, Box, Sphere, World
from vmas.simulator.scenario import BaseScenario
from vmas.simulator.sensors import Lidar
from vmas.simulator.utils import Color, ScenarioUtils, X, Y
from vmas.simulator.dynamics.kinematic_bicycle import KinematicBicycle

if typing.TYPE_CHECKING:
    from vmas.simulator.rendering import Geom


class FWDiscoveryScenario(BaseScenario):
    def make_world(self, batch_dim: int, device: torch.device, **kwargs):
        self.n_agents = kwargs.pop("n_agents", 5)
        self.n_targets = kwargs.pop("n_targets", 7)
        self.x_semidim = kwargs.pop("x_semidim", 1)
        self.y_semidim = kwargs.pop("y_semidim", 1)
        self._min_dist_between_entities = kwargs.pop("min_dist_between_entities", 0.2)
        self._lidar_range = kwargs.pop("lidar_range", 0.35)
        self._covering_range = kwargs.pop("covering_range", 0.25)

        self.use_agent_lidar = kwargs.pop("use_agent_lidar", False)
        self.n_lidar_rays_entities = kwargs.pop("n_lidar_rays_entities", 15)
        self.n_lidar_rays_agents = kwargs.pop("n_lidar_rays_agents", 12)

        self._agents_per_target = kwargs.pop("agents_per_target", 2)
        self.targets_respawn = kwargs.pop("targets_respawn", True)
        self.shared_reward = kwargs.pop("shared_reward", False)

        self.agent_collision_penalty = kwargs.pop("agent_collision_penalty", 0)
        self.covering_rew_coeff = kwargs.pop("covering_rew_coeff", 1.0)
        self.time_penalty = kwargs.pop("time_penalty", 0)
        self.render_action = kwargs.pop("render_action", False) # Modification
        ScenarioUtils.check_kwargs_consumed(kwargs)

        self._comms_range = self._lidar_range
        self.min_collision_distance = 0.005
        self.agent_radius = 0.05
        self.target_radius = self.agent_radius

        self.viewer_zoom = 1
        self.target_color = Color.GREEN

        # Make world
        world = World(
            batch_dim,
            device,
            x_semidim=self.x_semidim,
            y_semidim=self.y_semidim,
            collision_force=500,
            substeps=2,
            drag=0.25,
        )

        # Add agents
        entity_filter_agents: Callable[[Entity], bool] = lambda e: e.name.startswith(
            "agent"
        )
        entity_filter_targets: Callable[[Entity], bool] = lambda e: e.name.startswith(
            "target"
        )
        _max_steering_angle = torch.pi/4
        for i in range(self.n_agents):
            # Constraint: all agents have same action range and multiplier
            agent = Agent(
                name=f"agent_{i}",
                collide=True,
                color=Color.ORANGE, # Modification (not important)
                shape=Box(length=self.agent_radius * 2, width=self.agent_radius),
                sensors=(
                    [
                        Lidar(
                            world,
                            n_rays=self.n_lidar_rays_entities,
                            max_range=self._lidar_range,
                            entity_filter=entity_filter_targets,
                            render_color=Color.GREEN,
                        )
                    ]
                    + (
                        [
                            Lidar(
                                world,
                                angle_start=0.05,
                                angle_end=2 * torch.pi + 0.05,
                                n_rays=self.n_lidar_rays_agents,
                                max_range=self._lidar_range,
                                entity_filter=entity_filter_agents,
                                render_color=Color.BLUE,
                            )
                        ]
                        if self.use_agent_lidar
                        else []
                    )
                ),
                dynamics=FixedWingKinematicBicycle(
                    world,
                    width=self.agent_radius,
                    l_f=self.agent_radius,
                    l_r=self.agent_radius,
                    max_steering_angle=_max_steering_angle
                ), # Modification
                render_action=self.render_action # Modification
            )
            agent.collision_rew = torch.zeros(batch_dim, device=device)
            agent.covering_reward = agent.collision_rew.clone()
            world.add_agent(agent)

        self._targets = []
        for i in range(self.n_targets):
            target = Landmark(
                name=f"target_{i}",
                collide=True,
                movable=False,
                shape=Sphere(radius=self.target_radius),
                color=self.target_color,
            )
            world.add_landmark(target)
            self._targets.append(target)

        self.covered_targets = torch.zeros(batch_dim, self.n_targets, device=device)
        self.shared_covering_rew = torch.zeros(batch_dim, device=device)

        return world

    def reset_world_at(self, env_index: int = None):
        placable_entities = self._targets[: self.n_targets] + self.world.agents
        if env_index is None:
            self.all_time_covered_targets = torch.full(
                (self.world.batch_dim, self.n_targets),
                False,
                device=self.world.device,
            )
        else:
            self.all_time_covered_targets[env_index] = False
        ScenarioUtils.spawn_entities_randomly(
            entities=placable_entities,
            world=self.world,
            env_index=env_index,
            min_dist_between_entities=self._min_dist_between_entities,
            x_bounds=(-self.world.x_semidim, self.world.x_semidim),
            y_bounds=(-self.world.y_semidim, self.world.y_semidim),
        )
        for target in self._targets[self.n_targets :]:
            target.set_pos(self.get_outside_pos(env_index), batch_index=env_index)

    def reward(self, agent: Agent):
        is_first = agent == self.world.agents[0]
        is_last = agent == self.world.agents[-1]

        if is_first:
            self.time_rew = torch.full(
                (self.world.batch_dim,),
                self.time_penalty,
                device=self.world.device,
            )
            self.agents_pos = torch.stack(
                [a.state.pos for a in self.world.agents], dim=1
            )
            self.targets_pos = torch.stack([t.state.pos for t in self._targets], dim=1)
            self.agents_targets_dists = torch.cdist(self.agents_pos, self.targets_pos)
            self.agents_per_target = torch.sum(
                (self.agents_targets_dists < self._covering_range).type(torch.int),
                dim=1,
            )
            self.covered_targets = self.agents_per_target >= self._agents_per_target

            self.shared_covering_rew[:] = 0
            for a in self.world.agents:
                self.shared_covering_rew += self.agent_reward(a)
            self.shared_covering_rew[self.shared_covering_rew != 0] /= 2

        # Avoid collisions with each other
        agent.collision_rew[:] = 0
        for a in self.world.agents:
            if a != agent:
                agent.collision_rew[
                    self.world.get_distance(a, agent) < self.min_collision_distance
                ] += self.agent_collision_penalty

        if is_last:
            if self.targets_respawn:
                occupied_positions_agents = [self.agents_pos]
                for i, target in enumerate(self._targets):
                    occupied_positions_targets = [
                        o.state.pos.unsqueeze(1)
                        for o in self._targets
                        if o is not target
                    ]
                    occupied_positions = torch.cat(
                        occupied_positions_agents + occupied_positions_targets,
                        dim=1,
                    )
                    pos = ScenarioUtils.find_random_pos_for_entity(
                        occupied_positions,
                        env_index=None,
                        world=self.world,
                        min_dist_between_entities=self._min_dist_between_entities,
                        x_bounds=(-self.world.x_semidim, self.world.x_semidim),
                        y_bounds=(-self.world.y_semidim, self.world.y_semidim),
                    )

                    target.state.pos[self.covered_targets[:, i]] = pos[
                        self.covered_targets[:, i]
                    ].squeeze(1)
            else:
                self.all_time_covered_targets += self.covered_targets
                for i, target in enumerate(self._targets):
                    target.state.pos[self.covered_targets[:, i]] = self.get_outside_pos(
                        None
                    )[self.covered_targets[:, i]]
        covering_rew = (
            agent.covering_reward
            if not self.shared_reward
            else self.shared_covering_rew
        )

        return agent.collision_rew + covering_rew + self.time_rew

    def get_outside_pos(self, env_index):
        return torch.empty(
            (
                (1, self.world.dim_p)
                if env_index is not None
                else (self.world.batch_dim, self.world.dim_p)
            ),
            device=self.world.device,
        ).uniform_(-1000 * self.world.x_semidim, -10 * self.world.x_semidim)

    def agent_reward(self, agent):
        agent_index = self.world.agents.index(agent)

        agent.covering_reward[:] = 0
        targets_covered_by_agent = (
            self.agents_targets_dists[:, agent_index] < self._covering_range
        )
        num_covered_targets_covered_by_agent = (
            targets_covered_by_agent * self.covered_targets
        ).sum(dim=-1)
        agent.covering_reward += (
            num_covered_targets_covered_by_agent * self.covering_rew_coeff
        )
        return agent.covering_reward

    def observation(self, agent: Agent):
        lidar_1_measures = agent.sensors[0].measure()
        obs = {"obs" : torch.cat(
            [lidar_1_measures]
            + ([agent.sensors[1].measure()] if self.use_agent_lidar else []),
            dim=-1),
                "pos" : agent.state.pos,
                "vel" : agent.state.vel
        }
        if isinstance(agent.dynamics, KinematicBicycle) or isinstance(agent.dynamics, FixedWingKinematicBicycle):
            obs.update({
                "rot": agent.state.rot,
                "ang_vel": agent.state.ang_vel
            })
        return obs

    def info(self, agent: Agent) -> Dict[str, Tensor]:
        info = {
            "covering_reward": (
                agent.covering_reward
                if not self.shared_reward
                else self.shared_covering_rew
            ),
            "collision_rew": agent.collision_rew,
            "targets_covered": self.covered_targets.sum(-1),
        }
        return info

    def done(self):
        return self.all_time_covered_targets.all(dim=-1)

    def extra_render(self, env_index: int = 0) -> "List[Geom]":
        from vmas.simulator import rendering

        geoms: List[Geom] = []
        # Target ranges
        for target in self._targets:
            range_circle = rendering.make_circle(self._covering_range, filled=False)
            xform = rendering.Transform()
            xform.set_translation(*target.state.pos[env_index])
            range_circle.add_attr(xform)
            range_circle.set_color(*self.target_color.value)
            geoms.append(range_circle)
        # Communication lines
        for i, agent1 in enumerate(self.world.agents):
            for j, agent2 in enumerate(self.world.agents):
                if j <= i:
                    continue
                agent_dist = torch.linalg.vector_norm(
                    agent1.state.pos - agent2.state.pos, dim=-1
                )
                if agent_dist[env_index] <= self._comms_range:
                    color = Color.BLACK.value
                    line = rendering.Line(
                        (agent1.state.pos[env_index]),
                        (agent2.state.pos[env_index]),
                        width=1,
                    )
                    xform = rendering.Transform()
                    line.add_attr(xform)
                    line.set_color(*color)
                    geoms.append(line)

        return geoms

In [None]:
import copy
from typing import Callable, Optional
from benchmarl.environments import VmasTask
from benchmarl.utils import DEVICE_TYPING
from torchrl.envs import EnvBase, VmasEnv

def get_env_fun(
    self,
    num_envs: int,
    continuous_actions: bool,
    seed: Optional[int],
    device: DEVICE_TYPING):
  config = copy.deepcopy(self.config)
  if (hasattr(self, "name") and self.name == "NAVIGATION") or (
      self is VmasTask.NAVIGATION
  ):  
      scenario = FWDiscoveryScenario()  
  else:
      scenario = self.name.lower()
  return lambda: VmasEnv(
      scenario=scenario,
      num_envs=num_envs,
      continuous_actions=continuous_actions,
      seed=seed,
      device=device,
      categorical_actions=True,
      **config)

In [None]:
try:
    from benchmarl.environments import VmasClass
    VmasClass.get_env_fun = get_env_fun
except ImportError:
    print("Import Error")
    VmasTask.get_env_fun = get_env_fun

In [None]:
import wandb
import os
from kaggle_secrets import UserSecretsClient
secrets = UserSecretsClient()
os.environ["WANDB_API_KEY"] = secrets.get_secret("WANDB_API_KEY")
os.environ["WANDB_MODE"] = "online"
wandb.login()

In [None]:
from benchmarl.experiment import ExperimentConfig

experiment_config = ExperimentConfig.get_from_yaml() # We start by loading the defaults

experiment_config.sampling_device = "cuda" if torch.cuda.is_available() else "cpu"
experiment_config.train_device = "cuda" if torch.cuda.is_available() else "cpu"

experiment_config.max_n_frames = 20_000_000
experiment_config.gamma = 0.99
experiment_config.on_policy_collected_frames_per_batch = 100_000
experiment_config.on_policy_n_envs_per_worker = 1000
experiment_config.on_policy_n_minibatch_iters = 45
experiment_config.on_policy_minibatch_size = 4096
experiment_config.evaluation = True
experiment_config.render = True
experiment_config.share_policy_params = True
experiment_config.evaluation_interval = 200_000
experiment_config.evaluation_episodes = 20
experiment_config.loggers = ["wandb"]

task = VmasTask.NAVIGATION.get_from_yaml()

task.config = {
    "max_steps" : 100,
    "n_agents" : 4,
    "shared_reward" : False,
    "x_semidim" : 1,
    "y_semidim" : 1,
    "render_action" : True,
    "agents_per_target" : 1,
    "use_agent_lidar" : False,
    "agent_collision_penalty" : -1,
    "time_penalty" : -0.01
}

from benchmarl.algorithms import MappoConfig
mappo_algorithm_config = MappoConfig.get_from_yaml()
mappo_algorithm_config = MappoConfig(
        share_param_critic=True,
        clip_epsilon=0.2,
        entropy_coef=0.001,
        critic_coef=1,
        loss_critic_type="l2",
        lmbda=0.9,
        scale_mapping="biased_softplus_1.0",
        use_tanh_normal=True,
        minibatch_advantage=False,
    )

from benchmarl.models import MlpConfig
critic_model_config = MlpConfig(
        num_cells=[256, 256], 
        layer_class=torch.nn.Linear,
        activation_class=torch.nn.SiLU,
)

from benchmarl.models import GnnConfig, SequenceModelConfig
import torch_geometric

comms_radius = 2
gnn_config = GnnConfig(
    topology="from_pos", 
    edge_radius=comms_radius,
    self_loops=False,
    gnn_class=torch_geometric.nn.conv.GATv2Conv,
    gnn_kwargs={"add_self_loops": False, "residual": True}, 
    position_key="pos",
    pos_features=2,
    exclude_pos_from_node_features=True, 
mlp_config = MlpConfig.get_from_yaml()

model_config_gnn = SequenceModelConfig(model_configs=[gnn_config, mlp_config], intermediate_sizes=[256])

from benchmarl.experiment import Experiment
experiment = Experiment(
    task=task,
    algorithm_config=mappo_algorithm_config,
    model_config=model_config_gnn,
    critic_model_config=critic_model_config,
    seed=1337,
    config=experiment_config,
)
experiment.run()