Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 32 additions & 29 deletions vmas/interactive_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ class InteractiveEnv:
"""

def __init__(
self,
env: GymWrapper,
control_two_agents: bool = False,
display_info: bool = True,
save_render: bool = False,
render_name: str = "interactive",
self,
env: GymWrapper,
control_two_agents: bool = False,
display_info: bool = True,
save_render: bool = False,
render_name: str = "interactive",
):
self.env = env
self.control_two_agents = control_two_agents
Expand All @@ -68,7 +68,7 @@ def __init__(

if self.control_two_agents:
assert (
self.n_agents >= 2
self.n_agents >= 2
), "Control_two_agents is true but not enough agents in scenario"

self.env.render()
Expand Down Expand Up @@ -103,17 +103,17 @@ def _cycle(self):
for agent in self.agents
]
action_list[self.current_agent_index] = self.u[
: self.env.unwrapped().get_agent_action_size(
self.agents[self.current_agent_index]
)
]
: self.env.unwrapped().get_agent_action_size(
self.agents[self.current_agent_index]
)
]

if self.n_agents > 1 and self.control_two_agents:
action_list[self.current_agent_index2] = self.u2[
: self.env.unwrapped().get_agent_action_size(
self.agents[self.current_agent_index2]
)
]
: self.env.unwrapped().get_agent_action_size(
self.agents[self.current_agent_index2]
)
]
obs, rew, done, info = self.env.step(action_list)

if self.display_info:
Expand All @@ -124,7 +124,7 @@ def _cycle(self):
message = f"Obs: {obs_str[:len(obs_str) // 2]}"
self._write_values(self.text_idx + 1, message)

message = f"Rew: {round(rew[self.current_agent_index],3)}"
message = f"Rew: {round(rew[self.current_agent_index], 3)}"
self._write_values(self.text_idx + 2, message)

total_rew = list(map(add, total_rew, rew))
Expand All @@ -149,23 +149,26 @@ def _cycle(self):

def _init_text(self):
from vmas.simulator import rendering

try:
self.text_idx = len(self.env.unwrapped().viewer.text_lines)
except AttributeError:
self.text_idx = 0
state = rendering.RenderStateSingleton()
self.text_idx = len(state.text_lines)

for i in range(N_TEXT_LINES_INTERACTIVE):
text_line = rendering.TextLine(
self.env.unwrapped().viewer.window, self.text_idx + i
# self.env.unwrapped().viewer.window,
self.text_idx + i
)
self.env.unwrapped().viewer.text_lines.append(text_line)
state.text_lines.append(text_line)

def _write_values(self, index: int, message: str, font_size: int = 15):
self.env.unwrapped().viewer.text_lines[index].set_text(
from vmas.simulator import rendering
rendering.RenderStateSingleton().text_lines[index].set_text(
message, font_size=font_size
)

# self.env.unwrapped().viewer.text_lines[index].set_text(
# message, font_size=font_size
# )

# keyboard event callbacks
def _key_press(self, k, mod):
from pyglet.window import key
Expand Down Expand Up @@ -300,11 +303,11 @@ def format_obs(obs):


def render_interactively(
scenario: Union[str, BaseScenario],
control_two_agents: bool = False,
display_info: bool = True,
save_render: bool = False,
**kwargs,
scenario: Union[str, BaseScenario],
control_two_agents: bool = False,
display_info: bool = True,
save_render: bool = False,
**kwargs,
):
"""
Use this script to interactively play with scenarios
Expand Down
30 changes: 24 additions & 6 deletions vmas/scenarios/debug/diff_drive.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2022-2023.
# ProrokLab (https://www.proroklab.org/)
# All rights reserved.
import random
import typing
from typing import List

Expand Down Expand Up @@ -52,8 +53,23 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):

world.add_agent(agent)

self._init_text()

return world

def _init_text(self):
from vmas.simulator import rendering

state = rendering.RenderStateSingleton()
# here the index an be customized to change the position of the text
offset = len(state.text_lines)

# I used a list here but you could also add variables:
self.custom_obs_text_index = 0 + offset
state.text_lines.append(rendering.TextLine(self.custom_obs_text_index))
self.custom_rew_text_index = 1 + offset
state.text_lines.append(rendering.TextLine(self.custom_rew_text_index))

def reset_world_at(self, env_index: int = None):
ScenarioUtils.spawn_entities_randomly(
self.world.agents,
Expand All @@ -75,8 +91,8 @@ def reward(self, agent: Agent):

def observation(self, agent: Agent):
observations = [
agent.state.pos,
agent.state.vel,
agent.state.pos, agent.state.rot,
agent.state.vel, agent.state.ang_vel,
]
return torch.cat(
observations,
Expand All @@ -85,8 +101,7 @@ def observation(self, agent: Agent):

def extra_render(self, env_index: int = 0) -> "List[Geom]":
from vmas.simulator import rendering

geoms: List[Geom] = []
state = rendering.RenderStateSingleton()

# Agent rotation
for agent in self.world.agents:
Expand All @@ -101,9 +116,12 @@ def extra_render(self, env_index: int = 0) -> "List[Geom]":
xform.set_translation(*agent.state.pos[env_index])
line.add_attr(xform)
line.set_color(*color)
geoms.append(line)
state.onetime_geoms.append(line)

return geoms
# DO NOT COMMIT THIS INTO MASTER, IT IS ONLY TO SHOW TEXT RENDER EXAMPLE
state.text_lines[self.custom_obs_text_index].set_text(f"custom obs text {random.randint(0, 100)}")
state.text_lines[self.custom_rew_text_index].set_text(f"custom rew text {random.randint(0, 100)}")
return []


if __name__ == "__main__":
Expand Down
20 changes: 12 additions & 8 deletions vmas/simulator/environment/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import torch
import typing
from gym import spaces
from torch import Tensor
from vmas.simulator.core import Agent, TorchVectorizedObject
Expand All @@ -23,6 +24,9 @@
override,
)

if typing.TYPE_CHECKING:
from vmas.simulator.rendering import Viewer


# environment for all agents in the multiagent world
# currently code assumes that no agents will be created/destroyed at runtime!
Expand Down Expand Up @@ -63,7 +67,7 @@ def __init__(
# rendering
self.render_geoms_xform = None
self.render_geoms = None
self.viewer = None
self.viewer: Viewer = None
self.headless = None
self.visible_display = None

Expand Down Expand Up @@ -550,11 +554,11 @@ def render(
self.headless = False
pyglet.options["headless"] = self.headless

self._init_rendering()
self._init_rendering(env_index)

# Render comm messages
text_idx = 0
if self.world.dim_c > 0:
idx = 0
for agent in self.world.agents:
if agent.silent:
continue
Expand All @@ -571,8 +575,8 @@ def render(
word = ALPHABET[torch.argmax(agent.state.c[env_index]).item()]

message = agent.name + " sends " + word + " "
self.viewer.text_lines[idx].set_text(message)
idx += 1
self.viewer.text_lines[text_idx].set_text(message)
text_idx += 1

zoom = max(VIEWER_MIN_ZOOM, self.scenario.viewer_zoom)

Expand Down Expand Up @@ -662,19 +666,19 @@ def plot_function(self, f, precision, plot_range, cmap_range, cmap_alpha):
)
self.viewer.add_onetime_list(geoms)

def _init_rendering(self):
def _init_rendering(self, env_index):
from vmas.simulator import rendering

self.viewer = rendering.Viewer(
*self.scenario.viewer_size, visible=self.visible_display
)

self.viewer.text_lines = []
idx = 0
if self.world.dim_c > 0:
self.viewer.text_lines = []
for agent in self.world.agents:
if not agent.silent:
text_line = rendering.TextLine(self.viewer.window, idx)
text_line = rendering.TextLine(idx)
self.viewer.text_lines.append(text_line)
idx += 1

Expand Down
35 changes: 30 additions & 5 deletions vmas/simulator/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ def get_display(spec):
)


class RenderStateSingleton(object):

def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(RenderStateSingleton, cls).__new__(cls)
cls.instance.geoms = []
cls.instance.onetime_geoms = []
cls.instance.text_lines = []
return cls.instance


class Viewer(object):
def __init__(self, width, height, display=None, visible=True):
display = get_display(display)
Expand All @@ -104,6 +115,7 @@ def __init__(self, width, height, display=None, visible=True):
self.geoms = []
self.text_lines = []
self.onetime_geoms = []
self.render_state = RenderStateSingleton()
self.transform = Transform()
self.bounds = None

Expand Down Expand Up @@ -132,12 +144,15 @@ def set_bounds(self, left, right, bottom, top):

def add_geom(self, geom):
self.geoms.append(geom)
self.render_state.geoms.append(geom)

def add_onetime(self, geom):
self.onetime_geoms.append(geom)
self.render_state.onetime_geoms.append(geom)

def add_onetime_list(self, geoms):
self.onetime_geoms.extend(geoms)
self.render_state.onetime_geoms.extend(geoms)

def render(self, return_rgb_array=False):
glClearColor(1, 1, 1, 1)
Expand All @@ -146,17 +161,26 @@ def render(self, return_rgb_array=False):
self.window.switch_to()
self.window.dispatch_events()

# self.transform.enable()
# for geom in self.geoms:
# geom.render()
# for geom in self.onetime_geoms:
# geom.render()
# self.transform.disable()
self.transform.enable()
for geom in self.geoms:
for geom in self.render_state.geoms:
geom.render()
for geom in self.onetime_geoms:
for geom in self.render_state.onetime_geoms:
geom.render()
self.transform.disable()

pyglet.gl.glMatrixMode(pyglet.gl.GL_PROJECTION)
pyglet.gl.glLoadIdentity()
gluOrtho2D(0, self.width, 0, self.height)
for geom in self.text_lines:
# for geom in self.text_lines:
# geom.render()

for geom in self.render_state.text_lines:
geom.render()

arr = None
Expand All @@ -174,6 +198,7 @@ def render(self, return_rgb_array=False):
arr = arr[::-1, :, 0:3]

self.window.flip()
self.render_state.onetime_geoms = []
self.onetime_geoms = []
return arr

Expand Down Expand Up @@ -306,9 +331,9 @@ def enable(self):


class TextLine:
def __init__(self, window, idx):
def __init__(self, idx):
self.idx = idx
self.window = window
# self.window = window

pyglet.font.add_file(os.path.join(os.path.dirname(__file__), "secrcode.ttf"))
if pyglet.font.have_font("Courier"):
Expand Down