Skip to content

Commit

Permalink
Update supervised.py unit tests to remove -1
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahshi committed May 16, 2024
1 parent 1309c35 commit 45a3325
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions UnitTests/test_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,25 +204,24 @@ def test_norm_data_nn(self):
)

def test_unique_mapping_nn(self):
pred_class = np.array([0, 2, 3, 8, 11, 0, 3, 11, -1])
pred_class = np.array([0, 2, 3, 8, 11, 0, 3, 11])
unique, valid_mapping = mm.unique_mapping_nn(pred_class)

expected_unique = np.array([0, 2, 3, 8, 11, -1])
expected_unique = np.array([0, 2, 3, 8, 11])
expected_valid_mapping = {
0: "Amphibole",
2: "Clinopyroxene",
3: "Garnet",
8: "Olivine",
11: "Spinel",
-1: "Unknown",
}

# Verify expected output
np.testing.assert_array_equal(unique, expected_unique)
self.assertEqual(valid_mapping, expected_valid_mapping)

def test_class2mineral_nn(self):
pred_class = np.array([0, 2, 3, 8, 11, 0, 3, 11, -1])
pred_class = np.array([0, 2, 3, 8, 11, 0, 3, 11])
pred_mineral = mm.class2mineral_nn(pred_class)

expected_pred_mineral = np.array(
Expand All @@ -235,7 +234,6 @@ def test_class2mineral_nn(self):
"Amphibole",
"Garnet",
"Spinel",
"Unknown",
]
)

Expand All @@ -249,12 +247,6 @@ def setUp(self):
self.train_x = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
self.train_y = np.array([0, 0, 0, 1, 1])

def test_import_imblearn(self):
try:
from imblearn.over_sampling import RandomOverSampler
except ImportError:
self.fail("imbalanced-learn library not installed")

def test_balance_function(self):
train_x_balanced, train_y_balanced = mm.balance(self.train_x, self.train_y, n=3)
# Check the shape of the output
Expand Down

0 comments on commit 45a3325

Please sign in to comment.