In [1]:
%matplotlib inline

from model import Schelling

import io
import warnings

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import solara
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba
from matplotlib.figure import Figure

import mesa
from mesa.experimental.cell_space import VoronoiGrid
from mesa.space import PropertyLayer
from mesa.visualization.utils import update_counter


def _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model):
    if propertylayer_portrayal:
        draw_property_layers(space_ax, space, propertylayer_portrayal, model)

    agent_data = _get_agent_data(space, agent_portrayal)

    space_ax.set_xlim(0, space.width)
    space_ax.set_ylim(0, space.height)
    _split_and_scatter(agent_data, space_ax)

    # Draw grid lines
    for x in range(space.width + 1):
        space_ax.axvline(x, color="gray", linestyle=":")
    for y in range(space.height + 1):
        space_ax.axhline(y, color="gray", linestyle=":")


def _get_agent_data(space, agent_portrayal):
    """Helper function to get agent data for visualization."""
    x, y, s, c, m = [], [], [], [], []
    for agents, pos in space.coord_iter():
        if not agents:
            continue
        if not isinstance(agents, list):
            agents = [agents]  # noqa PLW2901
        for agent in agents:
            data = agent_portrayal(agent)
            x.append(pos[0] + 0.5)  # Center the agent in the cell
            y.append(pos[1] + 0.5)  # Center the agent in the cell
            default_size = (180 / max(space.width, space.height)) ** 2
            s.append(data.get("size", default_size))
            c.append(data.get("color", "tab:blue"))
            m.append(data.get("shape", "o"))
    return {"x": x, "y": y, "s": s, "c": c, "m": m}


def _split_and_scatter(portray_data, space_ax):
    """Helper function to split and scatter agent data."""
    for marker in set(portray_data["m"]):
        mask = [m == marker for m in portray_data["m"]]
        space_ax.scatter(
            [x for x, show in zip(portray_data["x"], mask) if show],
            [y for y, show in zip(portray_data["y"], mask) if show],
            s=[s for s, show in zip(portray_data["s"], mask) if show],
            c=[c for c, show in zip(portray_data["c"], mask) if show],
            marker=marker,
        )


def agent_portrayal(agent):
    return {"color": "tab:orange" if agent.type == 0 else "tab:blue"}


model = Schelling(200, 200, 0.8, 0.2, 3)

<IPython.core.display.Javascript object>

In [2]:
@solara.component
def SpaceMatplotlibPng(
    model,
    agent_portrayal,
    propertylayer_portrayal,
    dependencies: list[any] | None = None,
):
    """Create a Matplotlib-based space visualization component."""
    update_counter.get()
    space_fig = Figure()
    space_ax = space_fig.subplots()
    space = getattr(model, "grid", None)
    if space is None:
        space = getattr(model, "space", None)

    if isinstance(space, mesa.space._Grid):
        _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)

    solara.FigureMatplotlib(
        space_fig, format="png", bbox_inches="tight", dependencies=dependencies
    )


SpaceMatplotlibPng(model, agent_portrayal, None)

In [3]:
@solara.component
def SpaceMatplotlibPng(
    model,
    agent_portrayal,
    propertylayer_portrayal,
    dependencies: list[any] | None = None,
):
    """Create a Matplotlib-based space visualization component."""
    update_counter.get()
    space_fig = Figure()
    space_ax = space_fig.subplots()
    space = getattr(model, "grid", None)
    if space is None:
        space = getattr(model, "space", None)

    if isinstance(space, mesa.space._Grid):
        _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)

    solara.FigureMatplotlib(
        space_fig, format="png", dpi=300, bbox_inches="tight", dependencies=dependencies
    )


SpaceMatplotlibPng(model, agent_portrayal, None)

In [4]:
@solara.component
def SpaceMatplotlibSvg(
    model,
    agent_portrayal,
    propertylayer_portrayal,
    dependencies: list[any] | None = None,
):
    """Create a Matplotlib-based space visualization component."""
    update_counter.get()
    space_fig = Figure()
    space_ax = space_fig.subplots()
    space = getattr(model, "grid", None)
    if space is None:
        space = getattr(model, "space", None)

    if isinstance(space, mesa.space._Grid):
        _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)

    solara.FigureMatplotlib(
        space_fig, format="svg", bbox_inches="tight", dependencies=dependencies
    )


SpaceMatplotlibSvg(model, agent_portrayal, None)

In [5]:
def SpaceMatplotlibPng(
    model,
    agent_portrayal,
    propertylayer_portrayal,
    dependencies: list[any] | None = None,
):
    """Create a Matplotlib-based space visualization component."""
    update_counter.get()
    space_fig = Figure()
    space_ax = space_fig.subplots()
    space = getattr(model, "grid", None)
    if space is None:
        space = getattr(model, "space", None)

    if isinstance(space, mesa.space._Grid):
        _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)

    f = io.BytesIO()
    space_fig.savefig(f, format="png")

In [6]:
%%timeit
SpaceMatplotlibPng(model, agent_portrayal, None)

863 ms ± 50.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
def SpaceMatplotlibPngTightBbox(
    model,
    agent_portrayal,
    propertylayer_portrayal,
    dependencies: list[any] | None = None,
):
    """Create a Matplotlib-based space visualization component."""
    update_counter.get()
    space_fig = Figure()
    space_ax = space_fig.subplots()
    space = getattr(model, "grid", None)
    if space is None:
        space = getattr(model, "space", None)

    if isinstance(space, mesa.space._Grid):
        _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)

    f = io.BytesIO()
    space_fig.savefig(f, format="png", bbox_inches="tight")

In [8]:
%%timeit
SpaceMatplotlibPngTightBbox(model, agent_portrayal, None)

886 ms ± 33.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
def SpaceMatplotlibPngDpi(
    model,
    agent_portrayal,
    propertylayer_portrayal,
    dependencies: list[any] | None = None,
):
    """Create a Matplotlib-based space visualization component."""
    update_counter.get()
    space_fig = Figure()
    space_ax = space_fig.subplots()
    space = getattr(model, "grid", None)
    if space is None:
        space = getattr(model, "space", None)

    if isinstance(space, mesa.space._Grid):
        _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)

    f = io.BytesIO()
    space_fig.savefig(f, format="png", dpi=300)

In [10]:
%%timeit
SpaceMatplotlibPngDpi(model, agent_portrayal, None)

1.37 s ± 41.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
def SpaceMatplotlibPngDpiBbox(
    model,
    agent_portrayal,
    propertylayer_portrayal,
    dependencies: list[any] | None = None,
):
    """Create a Matplotlib-based space visualization component."""
    update_counter.get()
    space_fig = Figure()
    space_ax = space_fig.subplots()
    space = getattr(model, "grid", None)
    if space is None:
        space = getattr(model, "space", None)

    if isinstance(space, mesa.space._Grid):
        _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)

    f = io.BytesIO()
    space_fig.savefig(f, format="png", dpi=300, bbox_inches="tight")

In [12]:
%%timeit
SpaceMatplotlibPngDpiBbox(model, agent_portrayal, None)

1.38 s ± 38.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
def SpaceMatplotlibSvg(
    model,
    agent_portrayal,
    propertylayer_portrayal,
    dependencies: list[any] | None = None,
):
    """Create a Matplotlib-based space visualization component."""
    update_counter.get()
    space_fig = Figure()
    space_ax = space_fig.subplots()
    space = getattr(model, "grid", None)
    if space is None:
        space = getattr(model, "space", None)

    if isinstance(space, mesa.space._Grid):
        _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)

    f = io.BytesIO()
    space_fig.savefig(f, format="svg")

In [14]:
%%timeit
SpaceMatplotlibSvg(model, agent_portrayal, None)

3.89 s ± 178 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
def SpaceMatplotlibSvgTightBbox(
    model,
    agent_portrayal,
    propertylayer_portrayal,
    dependencies: list[any] | None = None,
):
    """Create a Matplotlib-based space visualization component."""
    update_counter.get()
    space_fig = Figure()
    space_ax = space_fig.subplots()
    space = getattr(model, "grid", None)
    if space is None:
        space = getattr(model, "space", None)

    if isinstance(space, mesa.space._Grid):
        _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)

    f = io.BytesIO()
    space_fig.savefig(f, format="svg", bbox_inches="tight")

In [16]:
%%timeit
SpaceMatplotlibSvgTightBbox(model, agent_portrayal, None)

3.96 s ± 87.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
