diff --git a/src/dj_notebook/shell_plus.py b/src/dj_notebook/shell_plus.py index 361cded..228d25c 100644 --- a/src/dj_notebook/shell_plus.py +++ b/src/dj_notebook/shell_plus.py @@ -147,8 +147,11 @@ def model_graph(self, model: django_models.Model, max_nodes: int = 20) -> None: def csv_to_df(self, filepath_or_string: pathlib.Path | str) -> pd.DataFrame: """Read a CSV file into a Pandas DataFrame.""" + # Process as a Path object if isinstance(filepath_or_string, pathlib.Path): return pd.read_csv(filepath_or_string) + + # Process as a string, which we convert to a filebuffer buffer = io.StringIO(filepath_or_string) return pd.read_csv(buffer) diff --git a/tests/sample.csv b/tests/sample.csv new file mode 100644 index 0000000..1a16910 --- /dev/null +++ b/tests/sample.csv @@ -0,0 +1,3 @@ +Name,Age,Weight +A,1,100 +B,2,200 \ No newline at end of file diff --git a/tests/test_dj_notebook.py b/tests/test_dj_notebook.py index 16daacf..9a1bf84 100644 --- a/tests/test_dj_notebook.py +++ b/tests/test_dj_notebook.py @@ -3,6 +3,7 @@ from pathlib import Path from unittest.mock import patch +import pandas import django.conf import pytest from dj_notebook import Plus, activate @@ -165,6 +166,31 @@ class to ensure it properly delegates to the assert result == "Mocked DataFrame" +def test_csv_to_df(): + """ + Tests the `csv_to_df` method of the `Plus` + class to ensure it returns a CSV. + + The test mocks this function to return "Mocked DataFrame" + and checks if the `Plus` method returns this when given a mock CSV. + """ + plus_instance = Plus(helpers={}) + csv_path = Path('tests/sample.csv') + with open(csv_path) as f: + csv_string = f.read() + + result_from_string = plus_instance.csv_to_df(csv_string) + result_from_path = plus_instance.csv_to_df(csv_path) + + + # assert results are dataframes + assert isinstance(result_from_string, pandas.DataFrame) + assert isinstance(result_from_path, pandas.DataFrame) + + # assert content is correct + assert result_from_string.at[0, 'Name'] == 'A' + assert result_from_path.at[0, 'Name'] == 'A' + def test_warning_when_debug_false(capfd): """ Test if the correct warning and message are displayed when DEBUG is False.