diff --git a/pipestat/pipestat.py b/pipestat/pipestat.py index dce08dd7..8a689781 100644 --- a/pipestat/pipestat.py +++ b/pipestat/pipestat.py @@ -682,7 +682,7 @@ def retrieve_one( :param str result_identifier: single record_identifier or list of result identifiers :return: Dict[str, any]: a mapping with filtered results reported for the record """ - r_id = record_identifier or self.record_identifier + record_identifier = record_identifier or self.record_identifier filter_conditions = [ { diff --git a/tests/test_pipestat.py b/tests/test_pipestat.py index 716fd487..c874f36c 100644 --- a/tests/test_pipestat.py +++ b/tests/test_pipestat.py @@ -589,6 +589,39 @@ def test_retrieve_basic( # Test Retrieve Whole Record assert isinstance(psm.retrieve_one(record_identifier=rec_id), Mapping) + @pytest.mark.parametrize( + ["rec_id", "val"], + [ + ("sample1", {"name_of_something": "test_name"}), + ("sample1", {"number_of_things": 2}), + ], + ) + @pytest.mark.parametrize("backend", ["file", "db"]) + def test_retrieve_basic_no_record_identifier( + self, + rec_id, + val, + config_file_path, + results_file_path, + schema_file_path, + backend, + ): + with NamedTemporaryFile() as f, ContextManagerDBTesting(DB_URL): + results_file_path = f.name + args = dict(schema_path=schema_file_path, database_only=False) + backend_data = ( + {"config_file": config_file_path} + if backend == "db" + else {"results_file_path": results_file_path} + ) + args.update(backend_data) + args.update(record_identifier=rec_id) + psm = SamplePipestatManager(**args) + psm.report(record_identifier=rec_id, values=val, force_overwrite=True) + assert ( + psm.retrieve_one(result_identifier=list(val.keys())[0]) == list(val.items())[0][1] + ) + @pytest.mark.parametrize( ["rec_id", "val"], [