diff --git a/tests/__init__.py b/tests/__init__.py index e69de29bb2d..486fb178689 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +import sys +import os + +# Add the parent directory of the project to PYTHONPATH +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) diff --git a/tests/test_jupyter_viz.py b/tests/test_jupyter_viz.py index 76ee666ee8a..d5d4f82944d 100644 --- a/tests/test_jupyter_viz.py +++ b/tests/test_jupyter_viz.py @@ -1,11 +1,21 @@ import unittest -from unittest.mock import Mock +from unittest.mock import Mock, patch import ipyvuetify as vw import solara +import sys +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + import mesa -from mesa.visualization.jupyter_viz import JupyterViz, Slider, UserInputs +from mesa.visualization.jupyter_viz import ( + JupyterViz, + Slider, + UserInputs, + ModelController, + split_model_params, +) class TestMakeUserInput(unittest.TestCase): @@ -151,3 +161,108 @@ def test_slider(): assert not slider_int.is_float_slider slider_dtype_float = Slider("Homophily", 3, 0, 8, 1, dtype=float) assert slider_dtype_float.is_float_slider + + +class TestJupyterViz(unittest.TestCase): + # test for section 1. + # testing for correct init + @patch("solara.use_reactive") + @patch("solara.use_state") + def test_initialization(self, mock_use_state, mock_use_reactive): + mock_use_reactive.side_effect = [Mock(), Mock()] + mock_use_state.return_value = ({}, Mock()) + + @solara.component + def Test(): + JupyterViz(model_class=Mock(), model_params={}) + + solara.render(Test(), handle_error=False) + + # test for section 2. + # testing for solara.AppBar() condition + def test_name_parameter(self): + @solara.component + def Test(): + return JupyterViz(model_class=Mock(__name__="TestModel"), model_params={}) + + with patch("solara.AppBarTitle") as mock_app_bar_title: + solara.render(Test(), handle_error=False) + mock_app_bar_title.assert_called_with("TestModel") + + # testing for make_model + def test_make_model(self): + model_class = Mock() + model_params = {"mock_key": {"mock_value": 10}} + + @solara.component + def Test(): + return JupyterViz(model_class=model_class, model_params=model_params) + + component = Test() + + model_instance = component().make_model() + model_class.__new__.assert_called_with(model_class, mock_key=10, seed=0) + model_class.__init__.assert_called_with(mock_key=10) + + # testing for handle_change_model_params + def test_handle_change_model_params(self): + @solara.component + def Test(): + return JupyterViz( + model_class=Mock(), model_params={"mock_key": {"mock_value": 10}} + ) + + component = Test() + component().handle_change_model_params("mock_key", 20) + self.assertEqual(component().model_parameters["mock_key"], 20) + + # test for section 3. + @patch("solara.AppBar") + @patch("solara.AppBarTitle") + def test_ui_setup(self, mock_app_bar_title, mock_app_bar): + @solara.component + def Test(): + return JupyterViz( + model_class=Mock(), model_params={"mock_key": {"mock_value": 10}} + ) + + solara.render(Test(), handle_error=False) + + mock_app_bar.asser_called() + mock_app_bar_title.assert_called_with("Mock") + + @patch("solara.GridFixed") + @patch("soalra.Markdown") + def test_render_in_jupyter(self, mock_markdown, mock_grid_fixed): + @solara.component + def Test(): + return JupyterViz( + model_class=Mock(), model_params={"mock_key": {"mock_value": 10}} + ) + + mock_grid_fixed.assert_called() + mock_markdown.assert_called() + + @patch("solara.Sidebar") + @patch("solara.GridDraggable") + def test_render_in_browser(self, mock_grid_draggable, mock_sidebar): + @solara.component + def Test(): + return JupyterViz( + model_class=Mock(), model_params={"mock_key": {"mock_value": 10}} + ) + + with patch("sys.argv", ["browser"]): + solara.render(Test(), handle_error=False) + + mock_grid_draggable.assert_called() + mock_sidebar.assert_called() + + +if __name__ == "__main__": + # Make sure to remove this codeblock before submitting + loader = unittest.TestLoader() + suite = loader.discover("../mesa") + runner = unittest.TextTestRunner() + runner.run(suite()) + # unittest.main()