diff --git a/UnitTests/test_loading.py b/UnitTests/test_loading.py index 1265aa7..d2e8035 100644 --- a/UnitTests/test_loading.py +++ b/UnitTests/test_loading.py @@ -76,22 +76,31 @@ def test_load_df(self, mock_read_csv): pd.testing.assert_frame_equal(df, mock_df) @patch('numpy.load') - def test_load_scaler(self, mock_np_load): - # Set up the mock return values + @patch('os.path.dirname') + def test_load_scaler(self, mock_dirname, mock_np_load): + # Mock the current directory + mock_dirname.return_value = '/path/to/current/dir' + + # Set up the mock return values for numpy.load mean_array = np.random.rand(10) # Create an array with 10 elements std_array = np.random.rand(10) # Create an array with 10 elements mock_np_load.return_value = {'mean': mean_array, 'scale': std_array} - mean, std = mm.load_scaler() + # Update the function call to include scaler_path + mean, std = mm.load_scaler('scaler_ae.npz') expected_index = ['SiO2', 'TiO2', 'Al2O3', 'FeOt', 'MnO', 'MgO', 'CaO', 'Na2O', 'K2O', 'Cr2O3'] self.assertTrue((mean == pd.Series(mean_array, index=expected_index)).all()) self.assertTrue((std == pd.Series(std_array, index=expected_index)).all()) + # Assert numpy.load was called with the correct full file path + full_scaler_path = '/path/to/current/dir/scaler_ae.npz' + mock_np_load.assert_called_with(full_scaler_path) + # Test for FileNotFoundError mock_np_load.side_effect = FileNotFoundError with self.assertRaises(FileNotFoundError): - mm.load_scaler() + mm.load_scaler('non_existing_path.npz') if __name__ == '__main__':