diff --git a/docs/migration_guide.md b/docs/migration_guide.md index 016053e84fb..193e796089f 100644 --- a/docs/migration_guide.md +++ b/docs/migration_guide.md @@ -281,15 +281,17 @@ SolaraViz(model, components=[make_space_component(agent_portrayal)]) # old from mesa.experimental import SolaraViz + def make_plot(model): ... + SolaraViz(model_cls, model_params, measures=[make_plot, "foo", ["bar", "baz"]]) # new -from mesa.visualization import SolaraViz, make_plot_measure +from mesa.visualization import SolaraViz, make_plot_component -SolaraViz(model, components=[make_plot, make_plot_measure("foo"), make_plot_measure("bar", "baz")]) +SolaraViz(model, components=[make_plot, make_plot_component("foo"), make_plot_component("bar", "baz")]) ``` #### Plotting text diff --git a/docs/overview.md b/docs/overview.md index 7d0c750ed84..92f7be06907 100644 --- a/docs/overview.md +++ b/docs/overview.md @@ -168,7 +168,7 @@ The results are returned as a list of dictionaries, which can be easily converte Mesa now uses a new browser-based visualization system called SolaraViz. This allows for interactive, customizable visualizations of your models. Here's a basic example of how to set up a visualization: ```python -from mesa.visualization import SolaraViz, make_space_component, make_plot_measure +from mesa.visualization import SolaraViz, make_space_component, make_plot_component def agent_portrayal(agent): @@ -177,20 +177,20 @@ def agent_portrayal(agent): model_params = { "N": { - "type": "SliderInt", - "value": 50, - "label": "Number of agents:", - "min": 10, - "max": 100, - "step": 1, - } + "type": "SliderInt", + "value": 50, + "label": "Number of agents:", + "min": 10, + "max": 100, + "step": 1, + } } page = SolaraViz( MyModel, [ - make_space_component(agent_portrayal), - make_plot_measure("mean_age") + make_space_component(agent_portrayal), + make_plot_component("mean_age") ], model_params=model_params ) diff --git a/docs/tutorials/visualization_tutorial.ipynb b/docs/tutorials/visualization_tutorial.ipynb index aff166e9b82..dbff7b01fa6 100644 --- a/docs/tutorials/visualization_tutorial.ipynb +++ b/docs/tutorials/visualization_tutorial.ipynb @@ -49,67 +49,43 @@ ] }, { + "metadata": {}, "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2024-10-29T19:38:46.075682Z", - "start_time": "2024-10-29T19:38:45.449918Z" - } - }, + "outputs": [], + "execution_count": null, "source": [ "import mesa\n", "print(f\"Mesa version: {mesa.__version__}\")\n", "\n", - "from mesa.visualization import SolaraViz, make_plot_measure, make_space_component\n", + "from mesa.visualization import SolaraViz, make_plot_component, make_space_component\n", "\n", "# Import the local MoneyModel.py\n", "from MoneyModel import MoneyModel\n" - ], - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mesa version: 3.0.0b2\n" - ] - } - ], - "execution_count": 1 + ] }, { + "metadata": {}, "cell_type": "code", - "metadata": { - "tags": [], - "ExecuteTime": { - "end_time": "2024-10-29T19:38:46.079286Z", - "start_time": "2024-10-29T19:38:46.076984Z" - } - }, + "outputs": [], + "execution_count": null, "source": [ "def agent_portrayal(agent):\n", " return {\n", " \"color\": \"tab:blue\",\n", " \"size\": 50,\n", " }" - ], - "outputs": [], - "execution_count": 2 + ] }, { - "cell_type": "markdown", "metadata": {}, - "source": [ - "In addition to the portrayal method, we instantiate the model parameters, some of which are modifiable by user inputs. In this case, the number of agents, N, is specified as a slider of integers." - ] + "cell_type": "markdown", + "source": "In addition to the portrayal method, we instantiate the model parameters, some of which are modifiable by user inputs. In this case, the number of agents, N, is specified as a slider of integers." }, { + "metadata": {}, "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2024-10-29T19:38:46.081662Z", - "start_time": "2024-10-29T19:38:46.079838Z" - } - }, + "outputs": [], + "execution_count": null, "source": [ "model_params = {\n", " \"n\": {\n", @@ -123,13 +99,11 @@ " \"width\": 10,\n", " \"height\": 10,\n", "}" - ], - "outputs": [], - "execution_count": 3 + ] }, { - "cell_type": "markdown", "metadata": {}, + "cell_type": "markdown", "source": [ "Next, we instantiate the visualization object which (by default) displays the grid containing the agents, and timeseries of values computed by the model's data collector. In this example, we specify the Gini coefficient.\n", "\n", @@ -142,20 +116,16 @@ ] }, { + "metadata": {}, "cell_type": "code", - "metadata": { - "tags": [], - "ExecuteTime": { - "end_time": "2024-10-29T19:38:46.864371Z", - "start_time": "2024-10-29T19:38:46.082810Z" - } - }, + "outputs": [], + "execution_count": null, "source": [ "# Create initial model instance\n", "model1 = MoneyModel(50, 10, 10)\n", "\n", "SpaceGraph = make_space_component(agent_portrayal)\n", - "GiniPlot = make_plot_measure(\"Gini\")\n", + "GiniPlot = make_plot_component(\"Gini\")\n", "\n", "page = SolaraViz(\n", " model1,\n", @@ -165,31 +135,11 @@ ")\n", "# This is required to render the visualization in the Jupyter notebook\n", "page" - ], - "outputs": [ - { - "data": { - "text/plain": [ - "Cannot show ipywidgets in text" - ], - "text/html": [ - "Cannot show widget. You probably want to rerun the code cell above (Click in the code cell, and press Shift+Enter +)." - ], - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "c9f2ef2b5a24483c92fa129213414a2c" - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "execution_count": 4 + ] }, { - "cell_type": "markdown", "metadata": {}, + "cell_type": "markdown", "source": [ "## Part 2 - Dynamic Agent Representation \n", "\n", @@ -203,40 +153,24 @@ ] }, { + "metadata": {}, "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2024-10-29T19:38:46.867576Z", - "start_time": "2024-10-29T19:38:46.865205Z" - } - }, + "outputs": [], + "execution_count": null, "source": [ "import mesa\n", "print(f\"Mesa version: {mesa.__version__}\")\n", "\n", - "from mesa.visualization import SolaraViz, make_plot_measure, make_space_component\n", + "from mesa.visualization import SolaraViz, make_plot_component, make_space_component\n", "# Import the local MoneyModel.py\n", "from MoneyModel import MoneyModel\n" - ], - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mesa version: 3.0.0b2\n" - ] - } - ], - "execution_count": 5 + ] }, { + "metadata": {}, "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2024-10-29T19:38:46.870617Z", - "start_time": "2024-10-29T19:38:46.868336Z" - } - }, + "outputs": [], + "execution_count": null, "source": [ "def agent_portrayal(agent):\n", " size = 10\n", @@ -258,24 +192,19 @@ " \"width\": 10,\n", " \"height\": 10,\n", "}" - ], - "outputs": [], - "execution_count": 6 + ] }, { + "metadata": {}, "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2024-10-29T19:38:47.881911Z", - "start_time": "2024-10-29T19:38:46.871328Z" - } - }, + "outputs": [], + "execution_count": null, "source": [ "# Create initial model instance\n", "model1 = MoneyModel(50, 10, 10)\n", "\n", "SpaceGraph = make_space_component(agent_portrayal)\n", - "GiniPlot = make_plot_measure(\"Gini\")\n", + "GiniPlot = make_plot_component(\"Gini\")\n", "\n", "page = SolaraViz(\n", " model1,\n", @@ -285,31 +214,11 @@ ")\n", "# This is required to render the visualization in the Jupyter notebook\n", "page" - ], - "outputs": [ - { - "data": { - "text/plain": [ - "Cannot show ipywidgets in text" - ], - "text/html": [ - "Cannot show widget. You probably want to rerun the code cell above (Click in the code cell, and press Shift+Enter +)." - ], - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "da8518ec9ce74c068288bec0c8d3793e" - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "execution_count": 7 + ] }, { - "cell_type": "markdown", "metadata": {}, + "cell_type": "markdown", "source": [ "## Part 3 - Custom Components \n", "\n", @@ -325,13 +234,10 @@ ] }, { + "metadata": {}, "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2024-10-29T19:38:47.885386Z", - "start_time": "2024-10-29T19:38:47.882808Z" - } - }, + "outputs": [], + "execution_count": null, "source": [ "import mesa\n", "print(f\"Mesa version: {mesa.__version__}\")\n", @@ -339,29 +245,16 @@ "from matplotlib.figure import Figure\n", "\n", "from mesa.visualization.utils import update_counter\n", - "from mesa.visualization import SolaraViz, make_plot_measure, make_space_component\n", + "from mesa.visualization import SolaraViz, make_plot_component, make_space_component\n", "# Import the local MoneyModel.py\n", "from MoneyModel import MoneyModel\n" - ], - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mesa version: 3.0.0b2\n" - ] - } - ], - "execution_count": 8 + ] }, { + "metadata": {}, "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2024-10-29T19:38:47.888491Z", - "start_time": "2024-10-29T19:38:47.886217Z" - } - }, + "outputs": [], + "execution_count": null, "source": [ "def agent_portrayal(agent):\n", " size = 10\n", @@ -383,25 +276,18 @@ " \"width\": 10,\n", " \"height\": 10,\n", "}" - ], - "outputs": [], - "execution_count": 9 + ] }, { - "cell_type": "markdown", "metadata": {}, - "source": [ - "Next, we update our solara frontend to use this new component" - ] + "cell_type": "markdown", + "source": "Next, we update our solara frontend to use this new component" }, { + "metadata": {}, "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2024-10-29T19:38:47.893643Z", - "start_time": "2024-10-29T19:38:47.891084Z" - } - }, + "outputs": [], + "execution_count": null, "source": [ "@solara.component\n", "def Histogram(model):\n", @@ -415,27 +301,20 @@ " # because plt.hist is not thread-safe.\n", " ax.hist(wealth_vals, bins=10)\n", " solara.FigureMatplotlib(fig)" - ], - "outputs": [], - "execution_count": 10 + ] }, { + "metadata": {}, "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2024-10-29T19:38:47.896565Z", - "start_time": "2024-10-29T19:38:47.894387Z" - } - }, + "outputs": [], + "execution_count": null, "source": [ "# Create initial model instance\n", "model1 = MoneyModel(50, 10, 10)\n", "\n", "SpaceGraph = make_space_component(agent_portrayal)\n", - "GiniPlot = make_plot_measure(\"Gini\")" - ], - "outputs": [], - "execution_count": 11 + "GiniPlot = make_plot_component(\"Gini\")" + ] }, { "cell_type": "code", diff --git a/mesa/examples/advanced/epstein_civil_violence/app.py b/mesa/examples/advanced/epstein_civil_violence/app.py index 99304cd4618..d3eb4643d87 100644 --- a/mesa/examples/advanced/epstein_civil_violence/app.py +++ b/mesa/examples/advanced/epstein_civil_violence/app.py @@ -7,7 +7,7 @@ from mesa.visualization import ( Slider, SolaraViz, - make_plot_measure, + make_plot_component, make_space_component, ) @@ -58,7 +58,7 @@ def post_process(ax): citizen_cop_portrayal, post_process=post_process, draw_grid=False ) -chart_component = make_plot_measure( +chart_component = make_plot_component( {state.name.lower(): agent_colors[state] for state in CitizenState} ) diff --git a/mesa/examples/advanced/pd_grid/app.py b/mesa/examples/advanced/pd_grid/app.py index d5bfd626e3c..fafedde6b09 100644 --- a/mesa/examples/advanced/pd_grid/app.py +++ b/mesa/examples/advanced/pd_grid/app.py @@ -3,7 +3,7 @@ """ from mesa.examples.advanced.pd_grid.model import PdGrid -from mesa.visualization import SolaraViz, make_plot_measure, make_space_component +from mesa.visualization import SolaraViz, make_plot_component, make_space_component from mesa.visualization.UserParam import Slider @@ -35,7 +35,7 @@ def pd_agent_portrayal(agent): grid_viz = make_space_component(agent_portrayal=pd_agent_portrayal) # Create plot for tracking cooperating agents over time -plot_component = make_plot_measure("Cooperating_Agents") +plot_component = make_plot_component("Cooperating_Agents") # Initialize model initial_model = PdGrid() diff --git a/mesa/examples/advanced/sugarscape_g1mt/app.py b/mesa/examples/advanced/sugarscape_g1mt/app.py index 39969e24079..8a5441cc54d 100644 --- a/mesa/examples/advanced/sugarscape_g1mt/app.py +++ b/mesa/examples/advanced/sugarscape_g1mt/app.py @@ -4,7 +4,7 @@ from mesa.examples.advanced.sugarscape_g1mt.agents import Trader from mesa.examples.advanced.sugarscape_g1mt.model import SugarscapeG1mt -from mesa.visualization import SolaraViz, make_plot_measure +from mesa.visualization import SolaraViz, make_plot_component def SpaceDrawer(model): @@ -55,7 +55,7 @@ def portray(g): page = SolaraViz( model1, - components=[SpaceDrawer, make_plot_measure(["Trader", "Price"])], + components=[SpaceDrawer, make_plot_component(["Trader", "Price"])], name="Sugarscape {G1, M, T}", play_interval=150, ) diff --git a/mesa/examples/advanced/wolf_sheep/app.py b/mesa/examples/advanced/wolf_sheep/app.py index a8c0a1e9c49..94261021b6a 100644 --- a/mesa/examples/advanced/wolf_sheep/app.py +++ b/mesa/examples/advanced/wolf_sheep/app.py @@ -3,7 +3,7 @@ from mesa.visualization import ( Slider, SolaraViz, - make_plot_measure, + make_plot_component, make_space_component, ) @@ -68,7 +68,7 @@ def post_process(ax): space_component = make_space_component( wolf_sheep_portrayal, draw_grid=False, post_process=post_process ) -lineplot_component = make_plot_measure( +lineplot_component = make_plot_component( {"Wolves": "tab:orange", "Sheep": "tab:cyan", "Grass": "tab:green"} ) diff --git a/mesa/examples/basic/boltzmann_wealth_model/app.py b/mesa/examples/basic/boltzmann_wealth_model/app.py index 2ab6d06bf73..ddb8933049f 100644 --- a/mesa/examples/basic/boltzmann_wealth_model/app.py +++ b/mesa/examples/basic/boltzmann_wealth_model/app.py @@ -1,7 +1,7 @@ from mesa.examples.basic.boltzmann_wealth_model.model import BoltzmannWealthModel from mesa.visualization import ( SolaraViz, - make_plot_measure, + make_plot_component, make_space_component, ) @@ -37,7 +37,7 @@ def agent_portrayal(agent): # You can also author your own visualization elements, which can also be functions # that receive the model instance and return a valid solara component. SpaceGraph = make_space_component(agent_portrayal) -GiniPlot = make_plot_measure("Gini") +GiniPlot = make_plot_component("Gini") # Create the SolaraViz page. This will automatically create a server and display the # visualization elements in a web browser. diff --git a/mesa/examples/basic/schelling/app.py b/mesa/examples/basic/schelling/app.py index 86f5a2941fd..73492fa07b6 100644 --- a/mesa/examples/basic/schelling/app.py +++ b/mesa/examples/basic/schelling/app.py @@ -4,7 +4,7 @@ from mesa.visualization import ( Slider, SolaraViz, - make_plot_measure, + make_plot_component, make_space_component, ) @@ -28,7 +28,7 @@ def agent_portrayal(agent): model1 = Schelling(20, 20, 0.8, 0.2, 3) -HappyPlot = make_plot_measure({"happy": "tab:green"}) +HappyPlot = make_plot_component({"happy": "tab:green"}) page = SolaraViz( model1, diff --git a/mesa/examples/basic/virus_on_network/app.py b/mesa/examples/basic/virus_on_network/app.py index 8e82a72830c..abe123e7e78 100644 --- a/mesa/examples/basic/virus_on_network/app.py +++ b/mesa/examples/basic/virus_on_network/app.py @@ -10,7 +10,7 @@ from mesa.visualization import ( Slider, SolaraViz, - make_plot_measure, + make_plot_component, make_space_component, ) @@ -86,9 +86,17 @@ def get_resistant_susceptible_ratio(model): ), } + +def post_process_lineplot(ax): + ax.set_ylim(ymin=0) + ax.set_ylabel("# people") + ax.legend(bbox_to_anchor=(1.05, 1.0), loc="upper left") + + SpacePlot = make_space_component(agent_portrayal) -StatePlot = make_plot_measure( - {"Infected": "tab:red", "Susceptible": "tab:green", "Resistant": "tab:gray"} +StatePlot = make_plot_component( + {"Infected": "tab:red", "Susceptible": "tab:green", "Resistant": "tab:gray"}, + post_process=post_process_lineplot, ) model1 = VirusOnNetwork() diff --git a/mesa/visualization/__init__.py b/mesa/visualization/__init__.py index 0e1875c751c..4bac98704cc 100644 --- a/mesa/visualization/__init__.py +++ b/mesa/visualization/__init__.py @@ -1,7 +1,7 @@ """Solara based visualization for Mesa models.""" from .components.altair import make_space_altair -from .components.matplotlib import make_plot_measure, make_space_component +from .components.matplotlib import make_plot_component, make_space_component from .solara_viz import JupyterViz, SolaraViz from .UserParam import Slider @@ -11,5 +11,5 @@ "Slider", "make_space_altair", "make_space_component", - "make_plot_measure", + "make_plot_component", ] diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index 7e9982a7387..09b281a3e17 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -40,6 +40,15 @@ Network = NetworkGrid | mesa.experimental.cell_space.Network +def make_space_matplotlib(*args, **kwargs): # noqa: D103 + warnings.warn( + "make_space_matplotlib has been renamed to make_space_component", + DeprecationWarning, + stacklevel=2, + ) + return make_space_component(*args, **kwargs) + + def make_space_component( agent_portrayal: Callable | None = None, propertylayer_portrayal: dict | None = None, @@ -618,30 +627,55 @@ def _scatter(ax: Axes, arguments): ) -def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]): +def make_plot_measure(*args, **kwargs): # noqa: D103 + warnings.warn( + "make_plot_measure has been renamed to make_plot_component", + DeprecationWarning, + stacklevel=2, + ) + return make_plot_component(*args, **kwargs) + + +def make_plot_component( + measure: str | dict[str, str] | list[str] | tuple[str], + post_process: Callable | None = None, + save_format="png", +): """Create a plotting function for a specified measure. Args: measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot. + post_process: a user-specified callable to do post-processing called with the Axes instance. + save_format: save format of figure in solara backend Returns: function: A function that creates a PlotMatplotlib component. """ - def MakePlotMeasure(model): - return PlotMatplotlib(model, measure) + def MakePlotMatplotlib(model): + return PlotMatplotlib( + model, measure, post_process=post_process, save_format=save_format + ) - return MakePlotMeasure + return MakePlotMatplotlib @solara.component -def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): +def PlotMatplotlib( + model, + measure, + dependencies: list[any] | None = None, + post_process: Callable | None = None, + save_format="png", +): """Create a Matplotlib-based plot for a measure or measures. Args: model (mesa.Model): The model instance. measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot. dependencies (list[any] | None): Optional dependencies for the plot. + post_process: a user-specified callable to do post-processing called with the Axes instance. + save_format: format used for saving the figure. Returns: solara.FigureMatplotlib: A component for rendering the plot. @@ -661,9 +695,13 @@ def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): for m in measure: ax.plot(df.loc[:, m], label=m) ax.legend(loc="best") + + if post_process is not None: + post_process(ax) + ax.set_xlabel("Step") # Set integer x axis ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True)) solara.FigureMatplotlib( - fig, format="png", bbox_inches="tight", dependencies=dependencies + fig, format=save_format, bbox_inches="tight", dependencies=dependencies )