diff --git a/mesa/experimental/components/altair.py b/mesa/experimental/components/altair.py new file mode 100644 index 00000000000..2b08485b848 --- /dev/null +++ b/mesa/experimental/components/altair.py @@ -0,0 +1,60 @@ +import contextlib +from typing import Optional + +import solara + +with contextlib.suppress(ImportError): + import altair as alt + + +@solara.component +def SpaceAltair(model, agent_portrayal, dependencies: Optional[list[any]] = None): + space = getattr(model, "grid", None) + if space is None: + # Sometimes the space is defined as model.space instead of model.grid + space = model.space + chart = _draw_grid(space, agent_portrayal) + solara.FigureAltair(chart) + + +def _draw_grid(space, agent_portrayal): + def portray(g): + all_agent_data = [] + for content, (x, y) in space.coord_iter(): + if not content: + continue + if not hasattr(content, "__iter__"): + # Is a single grid + content = [content] # noqa: PLW2901 + for agent in content: + # use all data from agent portrayal, and add x,y coordinates + agent_data = agent_portrayal(agent) + agent_data["x"] = x + agent_data["y"] = y + all_agent_data.append(agent_data) + return all_agent_data + + all_agent_data = portray(space) + encoding_dict = { + # no x-axis label + "x": alt.X("x", axis=None, type="ordinal"), + # no y-axis label + "y": alt.Y("y", axis=None, type="ordinal"), + } + has_color = "color" in all_agent_data[0] + if has_color: + encoding_dict["color"] = alt.Color("color", type="nominal") + has_size = "size" in all_agent_data[0] + if has_size: + encoding_dict["size"] = alt.Size("size", type="quantitative") + + chart = ( + alt.Chart( + alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict) + ) + .mark_point(filled=True) + .properties(width=280, height=280) + # .configure_view(strokeOpacity=0) # hide grid/chart lines + ) + + return chart diff --git a/mesa/experimental/jupyter_viz.py b/mesa/experimental/jupyter_viz.py index a6ae318a822..9d29e08fb60 100644 --- a/mesa/experimental/jupyter_viz.py +++ b/mesa/experimental/jupyter_viz.py @@ -6,6 +6,7 @@ import solara from solara.alias import rv +import mesa.experimental.components.altair as components_altair import mesa.experimental.components.matplotlib as components_matplotlib from mesa.experimental.UserParam import Slider @@ -28,6 +29,10 @@ def Card( components_matplotlib.SpaceMatplotlib( model, agent_portrayal, dependencies=[current_step.value] ) + elif space_drawer == "altair": + components_altair.SpaceAltair( + model, agent_portrayal, dependencies=[current_step.value] + ) elif space_drawer: # if specified, draw agent space with an alternate renderer space_drawer(model, agent_portrayal) @@ -113,6 +118,10 @@ def render_in_jupyter(): components_matplotlib.SpaceMatplotlib( model, agent_portrayal, dependencies=[current_step.value] ) + elif space_drawer == "altair": + components_altair.SpaceAltair( + model, agent_portrayal, dependencies=[current_step.value] + ) elif space_drawer: # if specified, draw agent space with an alternate renderer space_drawer(model, agent_portrayal)