Skip to content

Commit

Permalink
Update test_loading.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahshi committed Dec 10, 2023
1 parent 9f909d5 commit 437bf13
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions UnitTests/test_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,31 @@ def test_load_df(self, mock_read_csv):
pd.testing.assert_frame_equal(df, mock_df)

@patch('numpy.load')
def test_load_scaler(self, mock_np_load):
# Set up the mock return values
@patch('os.path.dirname')
def test_load_scaler(self, mock_dirname, mock_np_load):
# Mock the current directory
mock_dirname.return_value = '/path/to/current/dir'

# Set up the mock return values for numpy.load
mean_array = np.random.rand(10) # Create an array with 10 elements
std_array = np.random.rand(10) # Create an array with 10 elements
mock_np_load.return_value = {'mean': mean_array, 'scale': std_array}

mean, std = mm.load_scaler()
# Update the function call to include scaler_path
mean, std = mm.load_scaler('scaler_ae.npz')

expected_index = ['SiO2', 'TiO2', 'Al2O3', 'FeOt', 'MnO', 'MgO', 'CaO', 'Na2O', 'K2O', 'Cr2O3']
self.assertTrue((mean == pd.Series(mean_array, index=expected_index)).all())
self.assertTrue((std == pd.Series(std_array, index=expected_index)).all())

# Assert numpy.load was called with the correct full file path
full_scaler_path = '/path/to/current/dir/scaler_ae.npz'
mock_np_load.assert_called_with(full_scaler_path)

# Test for FileNotFoundError
mock_np_load.side_effect = FileNotFoundError
with self.assertRaises(FileNotFoundError):
mm.load_scaler()
mm.load_scaler('non_existing_path.npz')


if __name__ == '__main__':
Expand Down

0 comments on commit 437bf13

Please sign in to comment.