From cae26ee1b447a1e618e7476ab4cd3a8c45082fd2 Mon Sep 17 00:00:00 2001 From: zartris Date: Wed, 24 May 2023 09:17:11 +0900 Subject: [PATCH 1/2] [Feature] Text can be specified from custom scenario via extra_render. --- vmas/scenarios/debug/diff_drive.py | 9 +++++++-- vmas/simulator/environment/environment.py | 24 +++++++++++++++++------ vmas/simulator/scenario.py | 2 ++ 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/vmas/scenarios/debug/diff_drive.py b/vmas/scenarios/debug/diff_drive.py index c5af2eba..0b1d4eb6 100644 --- a/vmas/scenarios/debug/diff_drive.py +++ b/vmas/scenarios/debug/diff_drive.py @@ -75,8 +75,8 @@ def reward(self, agent: Agent): def observation(self, agent: Agent): observations = [ - agent.state.pos, - agent.state.vel, + agent.state.pos, agent.state.rot, + agent.state.vel, agent.state.ang_vel, t ] return torch.cat( observations, @@ -103,6 +103,11 @@ def extra_render(self, env_index: int = 0) -> "List[Geom]": line.set_color(*color) geoms.append(line) + # DO NOT COMMIT THIS INTO MASTER, IT IS ONLY TO SHOW TEXT RENDER EXAMPLE + self.render_text = [] # Reset the text from last round + for i in range(2): + self.render_text.append(f"{i}: This is a test of custom text rendering from scenario file") + return geoms diff --git a/vmas/simulator/environment/environment.py b/vmas/simulator/environment/environment.py index 62f5fe07..ae824444 100644 --- a/vmas/simulator/environment/environment.py +++ b/vmas/simulator/environment/environment.py @@ -524,11 +524,11 @@ def render( self.headless = False pyglet.options["headless"] = self.headless - self._init_rendering() + self._init_rendering(env_index) # Render comm messages + text_idx = 0 if self.world.dim_c > 0: - idx = 0 for agent in self.world.agents: if agent.silent: continue @@ -545,8 +545,8 @@ def render( word = ALPHABET[torch.argmax(agent.state.c[env_index]).item()] message = agent.name + " sends " + word + " " - self.viewer.text_lines[idx].set_text(message) - idx += 1 + self.viewer.text_lines[text_idx].set_text(message) + text_idx += 1 zoom = max(VIEWER_MIN_ZOOM, self.scenario.viewer_zoom) @@ -610,6 +610,11 @@ def render( self.viewer.add_onetime_list(self.scenario.extra_render(env_index)) + # Rendering the text set from extra_render method: + for message in self.scenario.render_text: + self.viewer.text_lines[text_idx].set_text(message) + text_idx += 1 + for entity in self.world.entities: self.viewer.add_onetime_list(entity.render(env_index=env_index)) @@ -636,22 +641,29 @@ def plot_function(self, f, precision, plot_range, cmap_range, cmap_alpha): ) self.viewer.add_onetime_list(geoms) - def _init_rendering(self): + def _init_rendering(self, env_index): from vmas.simulator import rendering self.viewer = rendering.Viewer( *self.scenario.viewer_size, visible=self.visible_display ) + self.viewer.text_lines = [] idx = 0 if self.world.dim_c > 0: - self.viewer.text_lines = [] for agent in self.world.agents: if not agent.silent: text_line = rendering.TextLine(self.viewer.window, idx) self.viewer.text_lines.append(text_line) idx += 1 + # Overhead in drawing (called once), but we get the number of lines expected to be rendered: + self.scenario.extra_render(env_index) + for _ in self.scenario.render_text: + text_line = rendering.TextLine(self.viewer.window, idx) + self.viewer.text_lines.append(text_line) + idx += 1 + @override(TorchVectorizedObject) def to(self, device: DEVICE_TYPING): device = torch.device(device) diff --git a/vmas/simulator/scenario.py b/vmas/simulator/scenario.py index 7e492051..e9f57693 100644 --- a/vmas/simulator/scenario.py +++ b/vmas/simulator/scenario.py @@ -27,6 +27,8 @@ def __init__(self): self.plot_grid = False # The distance between lines in the background grid self.grid_spacing = 0.1 + # Text to be rendered by environment: + self.render_text = [] @property def world(self): From c22ef47be0a6cfffc55d4d87f40119d6cd4c1713 Mon Sep 17 00:00:00 2001 From: zartris Date: Thu, 25 May 2023 09:56:05 +0900 Subject: [PATCH 2/2] [Feature] Update to use RenderStateSingleton for data management. --- vmas/interactive_rendering.py | 61 ++++++++++++----------- vmas/scenarios/debug/diff_drive.py | 31 ++++++++---- vmas/simulator/environment/environment.py | 20 +++----- vmas/simulator/rendering.py | 35 +++++++++++-- vmas/simulator/scenario.py | 2 - 5 files changed, 90 insertions(+), 59 deletions(-) diff --git a/vmas/interactive_rendering.py b/vmas/interactive_rendering.py index 384baf0d..4cde95f7 100644 --- a/vmas/interactive_rendering.py +++ b/vmas/interactive_rendering.py @@ -37,12 +37,12 @@ class InteractiveEnv: """ def __init__( - self, - env: GymWrapper, - control_two_agents: bool = False, - display_info: bool = True, - save_render: bool = False, - render_name: str = "interactive", + self, + env: GymWrapper, + control_two_agents: bool = False, + display_info: bool = True, + save_render: bool = False, + render_name: str = "interactive", ): self.env = env self.control_two_agents = control_two_agents @@ -68,7 +68,7 @@ def __init__( if self.control_two_agents: assert ( - self.n_agents >= 2 + self.n_agents >= 2 ), "Control_two_agents is true but not enough agents in scenario" self.env.render() @@ -103,17 +103,17 @@ def _cycle(self): for agent in self.agents ] action_list[self.current_agent_index] = self.u[ - : self.env.unwrapped().get_agent_action_size( - self.agents[self.current_agent_index] - ) - ] + : self.env.unwrapped().get_agent_action_size( + self.agents[self.current_agent_index] + ) + ] if self.n_agents > 1 and self.control_two_agents: action_list[self.current_agent_index2] = self.u2[ - : self.env.unwrapped().get_agent_action_size( - self.agents[self.current_agent_index2] - ) - ] + : self.env.unwrapped().get_agent_action_size( + self.agents[self.current_agent_index2] + ) + ] obs, rew, done, info = self.env.step(action_list) if self.display_info: @@ -124,7 +124,7 @@ def _cycle(self): message = f"Obs: {obs_str[:len(obs_str) // 2]}" self._write_values(self.text_idx + 1, message) - message = f"Rew: {round(rew[self.current_agent_index],3)}" + message = f"Rew: {round(rew[self.current_agent_index], 3)}" self._write_values(self.text_idx + 2, message) total_rew = list(map(add, total_rew, rew)) @@ -149,23 +149,26 @@ def _cycle(self): def _init_text(self): from vmas.simulator import rendering - - try: - self.text_idx = len(self.env.unwrapped().viewer.text_lines) - except AttributeError: - self.text_idx = 0 + state = rendering.RenderStateSingleton() + self.text_idx = len(state.text_lines) for i in range(N_TEXT_LINES_INTERACTIVE): text_line = rendering.TextLine( - self.env.unwrapped().viewer.window, self.text_idx + i + # self.env.unwrapped().viewer.window, + self.text_idx + i ) - self.env.unwrapped().viewer.text_lines.append(text_line) + state.text_lines.append(text_line) def _write_values(self, index: int, message: str, font_size: int = 15): - self.env.unwrapped().viewer.text_lines[index].set_text( + from vmas.simulator import rendering + rendering.RenderStateSingleton().text_lines[index].set_text( message, font_size=font_size ) + # self.env.unwrapped().viewer.text_lines[index].set_text( + # message, font_size=font_size + # ) + # keyboard event callbacks def _key_press(self, k, mod): from pyglet.window import key @@ -300,11 +303,11 @@ def format_obs(obs): def render_interactively( - scenario: Union[str, BaseScenario], - control_two_agents: bool = False, - display_info: bool = True, - save_render: bool = False, - **kwargs, + scenario: Union[str, BaseScenario], + control_two_agents: bool = False, + display_info: bool = True, + save_render: bool = False, + **kwargs, ): """ Use this script to interactively play with scenarios diff --git a/vmas/scenarios/debug/diff_drive.py b/vmas/scenarios/debug/diff_drive.py index 0b1d4eb6..10e7f0a5 100644 --- a/vmas/scenarios/debug/diff_drive.py +++ b/vmas/scenarios/debug/diff_drive.py @@ -1,6 +1,7 @@ # Copyright (c) 2022-2023. # ProrokLab (https://www.proroklab.org/) # All rights reserved. +import random import typing from typing import List @@ -52,8 +53,23 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): world.add_agent(agent) + self._init_text() + return world + def _init_text(self): + from vmas.simulator import rendering + + state = rendering.RenderStateSingleton() + # here the index an be customized to change the position of the text + offset = len(state.text_lines) + + # I used a list here but you could also add variables: + self.custom_obs_text_index = 0 + offset + state.text_lines.append(rendering.TextLine(self.custom_obs_text_index)) + self.custom_rew_text_index = 1 + offset + state.text_lines.append(rendering.TextLine(self.custom_rew_text_index)) + def reset_world_at(self, env_index: int = None): ScenarioUtils.spawn_entities_randomly( self.world.agents, @@ -76,7 +92,7 @@ def reward(self, agent: Agent): def observation(self, agent: Agent): observations = [ agent.state.pos, agent.state.rot, - agent.state.vel, agent.state.ang_vel, t + agent.state.vel, agent.state.ang_vel, ] return torch.cat( observations, @@ -85,8 +101,7 @@ def observation(self, agent: Agent): def extra_render(self, env_index: int = 0) -> "List[Geom]": from vmas.simulator import rendering - - geoms: List[Geom] = [] + state = rendering.RenderStateSingleton() # Agent rotation for agent in self.world.agents: @@ -101,14 +116,12 @@ def extra_render(self, env_index: int = 0) -> "List[Geom]": xform.set_translation(*agent.state.pos[env_index]) line.add_attr(xform) line.set_color(*color) - geoms.append(line) + state.onetime_geoms.append(line) # DO NOT COMMIT THIS INTO MASTER, IT IS ONLY TO SHOW TEXT RENDER EXAMPLE - self.render_text = [] # Reset the text from last round - for i in range(2): - self.render_text.append(f"{i}: This is a test of custom text rendering from scenario file") - - return geoms + state.text_lines[self.custom_obs_text_index].set_text(f"custom obs text {random.randint(0, 100)}") + state.text_lines[self.custom_rew_text_index].set_text(f"custom rew text {random.randint(0, 100)}") + return [] if __name__ == "__main__": diff --git a/vmas/simulator/environment/environment.py b/vmas/simulator/environment/environment.py index 6c26fbc8..7e2547e3 100644 --- a/vmas/simulator/environment/environment.py +++ b/vmas/simulator/environment/environment.py @@ -8,6 +8,7 @@ import numpy as np import torch +import typing from gym import spaces from torch import Tensor from vmas.simulator.core import Agent, TorchVectorizedObject @@ -23,6 +24,9 @@ override, ) +if typing.TYPE_CHECKING: + from vmas.simulator.rendering import Viewer + # environment for all agents in the multiagent world # currently code assumes that no agents will be created/destroyed at runtime! @@ -63,7 +67,7 @@ def __init__( # rendering self.render_geoms_xform = None self.render_geoms = None - self.viewer = None + self.viewer: Viewer = None self.headless = None self.visible_display = None @@ -636,11 +640,6 @@ def render( self.viewer.add_onetime_list(self.scenario.extra_render(env_index)) - # Rendering the text set from extra_render method: - for message in self.scenario.render_text: - self.viewer.text_lines[text_idx].set_text(message) - text_idx += 1 - for entity in self.world.entities: self.viewer.add_onetime_list(entity.render(env_index=env_index)) @@ -679,17 +678,10 @@ def _init_rendering(self, env_index): if self.world.dim_c > 0: for agent in self.world.agents: if not agent.silent: - text_line = rendering.TextLine(self.viewer.window, idx) + text_line = rendering.TextLine(idx) self.viewer.text_lines.append(text_line) idx += 1 - # Overhead in drawing (called once), but we get the number of lines expected to be rendered: - self.scenario.extra_render(env_index) - for _ in self.scenario.render_text: - text_line = rendering.TextLine(self.viewer.window, idx) - self.viewer.text_lines.append(text_line) - idx += 1 - @override(TorchVectorizedObject) def to(self, device: DEVICE_TYPING): device = torch.device(device) diff --git a/vmas/simulator/rendering.py b/vmas/simulator/rendering.py index b688ac22..2264d7c7 100644 --- a/vmas/simulator/rendering.py +++ b/vmas/simulator/rendering.py @@ -89,6 +89,17 @@ def get_display(spec): ) +class RenderStateSingleton(object): + + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(RenderStateSingleton, cls).__new__(cls) + cls.instance.geoms = [] + cls.instance.onetime_geoms = [] + cls.instance.text_lines = [] + return cls.instance + + class Viewer(object): def __init__(self, width, height, display=None, visible=True): display = get_display(display) @@ -104,6 +115,7 @@ def __init__(self, width, height, display=None, visible=True): self.geoms = [] self.text_lines = [] self.onetime_geoms = [] + self.render_state = RenderStateSingleton() self.transform = Transform() self.bounds = None @@ -132,12 +144,15 @@ def set_bounds(self, left, right, bottom, top): def add_geom(self, geom): self.geoms.append(geom) + self.render_state.geoms.append(geom) def add_onetime(self, geom): self.onetime_geoms.append(geom) + self.render_state.onetime_geoms.append(geom) def add_onetime_list(self, geoms): self.onetime_geoms.extend(geoms) + self.render_state.onetime_geoms.extend(geoms) def render(self, return_rgb_array=False): glClearColor(1, 1, 1, 1) @@ -146,17 +161,26 @@ def render(self, return_rgb_array=False): self.window.switch_to() self.window.dispatch_events() + # self.transform.enable() + # for geom in self.geoms: + # geom.render() + # for geom in self.onetime_geoms: + # geom.render() + # self.transform.disable() self.transform.enable() - for geom in self.geoms: + for geom in self.render_state.geoms: geom.render() - for geom in self.onetime_geoms: + for geom in self.render_state.onetime_geoms: geom.render() self.transform.disable() pyglet.gl.glMatrixMode(pyglet.gl.GL_PROJECTION) pyglet.gl.glLoadIdentity() gluOrtho2D(0, self.width, 0, self.height) - for geom in self.text_lines: + # for geom in self.text_lines: + # geom.render() + + for geom in self.render_state.text_lines: geom.render() arr = None @@ -174,6 +198,7 @@ def render(self, return_rgb_array=False): arr = arr[::-1, :, 0:3] self.window.flip() + self.render_state.onetime_geoms = [] self.onetime_geoms = [] return arr @@ -306,9 +331,9 @@ def enable(self): class TextLine: - def __init__(self, window, idx): + def __init__(self, idx): self.idx = idx - self.window = window + # self.window = window pyglet.font.add_file(os.path.join(os.path.dirname(__file__), "secrcode.ttf")) if pyglet.font.have_font("Courier"): diff --git a/vmas/simulator/scenario.py b/vmas/simulator/scenario.py index e5c2b58a..f46c0ede 100644 --- a/vmas/simulator/scenario.py +++ b/vmas/simulator/scenario.py @@ -33,8 +33,6 @@ def __init__(self): self.plot_grid = False # The distance between lines in the background grid self.grid_spacing = 0.1 - # Text to be rendered by environment: - self.render_text = [] @property def world(self):