Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import sys
import os

# Add the parent directory of the project to PYTHONPATH
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
119 changes: 117 additions & 2 deletions tests/test_jupyter_viz.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import unittest
from unittest.mock import Mock
from unittest.mock import Mock, patch

import ipyvuetify as vw
import solara

import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import mesa
from mesa.visualization.jupyter_viz import JupyterViz, Slider, UserInputs
from mesa.visualization.jupyter_viz import (
JupyterViz,
Slider,
UserInputs,
ModelController,
split_model_params,
)


class TestMakeUserInput(unittest.TestCase):
Expand Down Expand Up @@ -151,3 +161,108 @@ def test_slider():
assert not slider_int.is_float_slider
slider_dtype_float = Slider("Homophily", 3, 0, 8, 1, dtype=float)
assert slider_dtype_float.is_float_slider


class TestJupyterViz(unittest.TestCase):
# test for section 1.
# testing for correct init
@patch("solara.use_reactive")
@patch("solara.use_state")
def test_initialization(self, mock_use_state, mock_use_reactive):
mock_use_reactive.side_effect = [Mock(), Mock()]
mock_use_state.return_value = ({}, Mock())

@solara.component
def Test():
JupyterViz(model_class=Mock(), model_params={})

solara.render(Test(), handle_error=False)

# test for section 2.
# testing for solara.AppBar() condition
def test_name_parameter(self):
@solara.component
def Test():
return JupyterViz(model_class=Mock(__name__="TestModel"), model_params={})

with patch("solara.AppBarTitle") as mock_app_bar_title:
solara.render(Test(), handle_error=False)
mock_app_bar_title.assert_called_with("TestModel")

# testing for make_model
def test_make_model(self):
model_class = Mock()
model_params = {"mock_key": {"mock_value": 10}}

@solara.component
def Test():
return JupyterViz(model_class=model_class, model_params=model_params)

component = Test()

model_instance = component().make_model()
model_class.__new__.assert_called_with(model_class, mock_key=10, seed=0)
model_class.__init__.assert_called_with(mock_key=10)

# testing for handle_change_model_params
def test_handle_change_model_params(self):
@solara.component
def Test():
return JupyterViz(
model_class=Mock(), model_params={"mock_key": {"mock_value": 10}}
)

component = Test()
component().handle_change_model_params("mock_key", 20)
self.assertEqual(component().model_parameters["mock_key"], 20)

# test for section 3.
@patch("solara.AppBar")
@patch("solara.AppBarTitle")
def test_ui_setup(self, mock_app_bar_title, mock_app_bar):
@solara.component
def Test():
return JupyterViz(
model_class=Mock(), model_params={"mock_key": {"mock_value": 10}}
)

solara.render(Test(), handle_error=False)

mock_app_bar.asser_called()
mock_app_bar_title.assert_called_with("Mock")

@patch("solara.GridFixed")
@patch("soalra.Markdown")
def test_render_in_jupyter(self, mock_markdown, mock_grid_fixed):
@solara.component
def Test():
return JupyterViz(
model_class=Mock(), model_params={"mock_key": {"mock_value": 10}}
)

mock_grid_fixed.assert_called()
mock_markdown.assert_called()

@patch("solara.Sidebar")
@patch("solara.GridDraggable")
def test_render_in_browser(self, mock_grid_draggable, mock_sidebar):
@solara.component
def Test():
return JupyterViz(
model_class=Mock(), model_params={"mock_key": {"mock_value": 10}}
)

with patch("sys.argv", ["browser"]):
solara.render(Test(), handle_error=False)

mock_grid_draggable.assert_called()
mock_sidebar.assert_called()


if __name__ == "__main__":
# Make sure to remove this codeblock before submitting
loader = unittest.TestLoader()
suite = loader.discover("../mesa")
runner = unittest.TextTestRunner()
runner.run(suite())
# unittest.main()