diff --git a/vmas/interactive_rendering.py b/vmas/interactive_rendering.py index 384baf0d..4cde95f7 100644 --- a/vmas/interactive_rendering.py +++ b/vmas/interactive_rendering.py @@ -37,12 +37,12 @@ class InteractiveEnv: """ def __init__( - self, - env: GymWrapper, - control_two_agents: bool = False, - display_info: bool = True, - save_render: bool = False, - render_name: str = "interactive", + self, + env: GymWrapper, + control_two_agents: bool = False, + display_info: bool = True, + save_render: bool = False, + render_name: str = "interactive", ): self.env = env self.control_two_agents = control_two_agents @@ -68,7 +68,7 @@ def __init__( if self.control_two_agents: assert ( - self.n_agents >= 2 + self.n_agents >= 2 ), "Control_two_agents is true but not enough agents in scenario" self.env.render() @@ -103,17 +103,17 @@ def _cycle(self): for agent in self.agents ] action_list[self.current_agent_index] = self.u[ - : self.env.unwrapped().get_agent_action_size( - self.agents[self.current_agent_index] - ) - ] + : self.env.unwrapped().get_agent_action_size( + self.agents[self.current_agent_index] + ) + ] if self.n_agents > 1 and self.control_two_agents: action_list[self.current_agent_index2] = self.u2[ - : self.env.unwrapped().get_agent_action_size( - self.agents[self.current_agent_index2] - ) - ] + : self.env.unwrapped().get_agent_action_size( + self.agents[self.current_agent_index2] + ) + ] obs, rew, done, info = self.env.step(action_list) if self.display_info: @@ -124,7 +124,7 @@ def _cycle(self): message = f"Obs: {obs_str[:len(obs_str) // 2]}" self._write_values(self.text_idx + 1, message) - message = f"Rew: {round(rew[self.current_agent_index],3)}" + message = f"Rew: {round(rew[self.current_agent_index], 3)}" self._write_values(self.text_idx + 2, message) total_rew = list(map(add, total_rew, rew)) @@ -149,23 +149,26 @@ def _cycle(self): def _init_text(self): from vmas.simulator import rendering - - try: - self.text_idx = len(self.env.unwrapped().viewer.text_lines) - except AttributeError: - self.text_idx = 0 + state = rendering.RenderStateSingleton() + self.text_idx = len(state.text_lines) for i in range(N_TEXT_LINES_INTERACTIVE): text_line = rendering.TextLine( - self.env.unwrapped().viewer.window, self.text_idx + i + # self.env.unwrapped().viewer.window, + self.text_idx + i ) - self.env.unwrapped().viewer.text_lines.append(text_line) + state.text_lines.append(text_line) def _write_values(self, index: int, message: str, font_size: int = 15): - self.env.unwrapped().viewer.text_lines[index].set_text( + from vmas.simulator import rendering + rendering.RenderStateSingleton().text_lines[index].set_text( message, font_size=font_size ) + # self.env.unwrapped().viewer.text_lines[index].set_text( + # message, font_size=font_size + # ) + # keyboard event callbacks def _key_press(self, k, mod): from pyglet.window import key @@ -300,11 +303,11 @@ def format_obs(obs): def render_interactively( - scenario: Union[str, BaseScenario], - control_two_agents: bool = False, - display_info: bool = True, - save_render: bool = False, - **kwargs, + scenario: Union[str, BaseScenario], + control_two_agents: bool = False, + display_info: bool = True, + save_render: bool = False, + **kwargs, ): """ Use this script to interactively play with scenarios diff --git a/vmas/scenarios/debug/diff_drive.py b/vmas/scenarios/debug/diff_drive.py index c5af2eba..10e7f0a5 100644 --- a/vmas/scenarios/debug/diff_drive.py +++ b/vmas/scenarios/debug/diff_drive.py @@ -1,6 +1,7 @@ # Copyright (c) 2022-2023. # ProrokLab (https://www.proroklab.org/) # All rights reserved. +import random import typing from typing import List @@ -52,8 +53,23 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): world.add_agent(agent) + self._init_text() + return world + def _init_text(self): + from vmas.simulator import rendering + + state = rendering.RenderStateSingleton() + # here the index an be customized to change the position of the text + offset = len(state.text_lines) + + # I used a list here but you could also add variables: + self.custom_obs_text_index = 0 + offset + state.text_lines.append(rendering.TextLine(self.custom_obs_text_index)) + self.custom_rew_text_index = 1 + offset + state.text_lines.append(rendering.TextLine(self.custom_rew_text_index)) + def reset_world_at(self, env_index: int = None): ScenarioUtils.spawn_entities_randomly( self.world.agents, @@ -75,8 +91,8 @@ def reward(self, agent: Agent): def observation(self, agent: Agent): observations = [ - agent.state.pos, - agent.state.vel, + agent.state.pos, agent.state.rot, + agent.state.vel, agent.state.ang_vel, ] return torch.cat( observations, @@ -85,8 +101,7 @@ def observation(self, agent: Agent): def extra_render(self, env_index: int = 0) -> "List[Geom]": from vmas.simulator import rendering - - geoms: List[Geom] = [] + state = rendering.RenderStateSingleton() # Agent rotation for agent in self.world.agents: @@ -101,9 +116,12 @@ def extra_render(self, env_index: int = 0) -> "List[Geom]": xform.set_translation(*agent.state.pos[env_index]) line.add_attr(xform) line.set_color(*color) - geoms.append(line) + state.onetime_geoms.append(line) - return geoms + # DO NOT COMMIT THIS INTO MASTER, IT IS ONLY TO SHOW TEXT RENDER EXAMPLE + state.text_lines[self.custom_obs_text_index].set_text(f"custom obs text {random.randint(0, 100)}") + state.text_lines[self.custom_rew_text_index].set_text(f"custom rew text {random.randint(0, 100)}") + return [] if __name__ == "__main__": diff --git a/vmas/simulator/environment/environment.py b/vmas/simulator/environment/environment.py index 4d6f41e9..7e2547e3 100644 --- a/vmas/simulator/environment/environment.py +++ b/vmas/simulator/environment/environment.py @@ -8,6 +8,7 @@ import numpy as np import torch +import typing from gym import spaces from torch import Tensor from vmas.simulator.core import Agent, TorchVectorizedObject @@ -23,6 +24,9 @@ override, ) +if typing.TYPE_CHECKING: + from vmas.simulator.rendering import Viewer + # environment for all agents in the multiagent world # currently code assumes that no agents will be created/destroyed at runtime! @@ -63,7 +67,7 @@ def __init__( # rendering self.render_geoms_xform = None self.render_geoms = None - self.viewer = None + self.viewer: Viewer = None self.headless = None self.visible_display = None @@ -550,11 +554,11 @@ def render( self.headless = False pyglet.options["headless"] = self.headless - self._init_rendering() + self._init_rendering(env_index) # Render comm messages + text_idx = 0 if self.world.dim_c > 0: - idx = 0 for agent in self.world.agents: if agent.silent: continue @@ -571,8 +575,8 @@ def render( word = ALPHABET[torch.argmax(agent.state.c[env_index]).item()] message = agent.name + " sends " + word + " " - self.viewer.text_lines[idx].set_text(message) - idx += 1 + self.viewer.text_lines[text_idx].set_text(message) + text_idx += 1 zoom = max(VIEWER_MIN_ZOOM, self.scenario.viewer_zoom) @@ -662,19 +666,19 @@ def plot_function(self, f, precision, plot_range, cmap_range, cmap_alpha): ) self.viewer.add_onetime_list(geoms) - def _init_rendering(self): + def _init_rendering(self, env_index): from vmas.simulator import rendering self.viewer = rendering.Viewer( *self.scenario.viewer_size, visible=self.visible_display ) + self.viewer.text_lines = [] idx = 0 if self.world.dim_c > 0: - self.viewer.text_lines = [] for agent in self.world.agents: if not agent.silent: - text_line = rendering.TextLine(self.viewer.window, idx) + text_line = rendering.TextLine(idx) self.viewer.text_lines.append(text_line) idx += 1 diff --git a/vmas/simulator/rendering.py b/vmas/simulator/rendering.py index b688ac22..2264d7c7 100644 --- a/vmas/simulator/rendering.py +++ b/vmas/simulator/rendering.py @@ -89,6 +89,17 @@ def get_display(spec): ) +class RenderStateSingleton(object): + + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(RenderStateSingleton, cls).__new__(cls) + cls.instance.geoms = [] + cls.instance.onetime_geoms = [] + cls.instance.text_lines = [] + return cls.instance + + class Viewer(object): def __init__(self, width, height, display=None, visible=True): display = get_display(display) @@ -104,6 +115,7 @@ def __init__(self, width, height, display=None, visible=True): self.geoms = [] self.text_lines = [] self.onetime_geoms = [] + self.render_state = RenderStateSingleton() self.transform = Transform() self.bounds = None @@ -132,12 +144,15 @@ def set_bounds(self, left, right, bottom, top): def add_geom(self, geom): self.geoms.append(geom) + self.render_state.geoms.append(geom) def add_onetime(self, geom): self.onetime_geoms.append(geom) + self.render_state.onetime_geoms.append(geom) def add_onetime_list(self, geoms): self.onetime_geoms.extend(geoms) + self.render_state.onetime_geoms.extend(geoms) def render(self, return_rgb_array=False): glClearColor(1, 1, 1, 1) @@ -146,17 +161,26 @@ def render(self, return_rgb_array=False): self.window.switch_to() self.window.dispatch_events() + # self.transform.enable() + # for geom in self.geoms: + # geom.render() + # for geom in self.onetime_geoms: + # geom.render() + # self.transform.disable() self.transform.enable() - for geom in self.geoms: + for geom in self.render_state.geoms: geom.render() - for geom in self.onetime_geoms: + for geom in self.render_state.onetime_geoms: geom.render() self.transform.disable() pyglet.gl.glMatrixMode(pyglet.gl.GL_PROJECTION) pyglet.gl.glLoadIdentity() gluOrtho2D(0, self.width, 0, self.height) - for geom in self.text_lines: + # for geom in self.text_lines: + # geom.render() + + for geom in self.render_state.text_lines: geom.render() arr = None @@ -174,6 +198,7 @@ def render(self, return_rgb_array=False): arr = arr[::-1, :, 0:3] self.window.flip() + self.render_state.onetime_geoms = [] self.onetime_geoms = [] return arr @@ -306,9 +331,9 @@ def enable(self): class TextLine: - def __init__(self, window, idx): + def __init__(self, idx): self.idx = idx - self.window = window + # self.window = window pyglet.font.add_file(os.path.join(os.path.dirname(__file__), "secrcode.ttf")) if pyglet.font.have_font("Courier"):