From 1309c3573c8e7b3088dc5acc07c291065804502e Mon Sep 17 00:00:00 2001 From: Sarah Shi Date: Thu, 16 May 2024 11:27:24 -0400 Subject: [PATCH] Create new supervised.py unit tests --- UnitTests/test_supervised.py | 257 ++++++++++++++++--- docs/examples/ml_models/mineralML_docs.ipynb | 4 +- src/mineralML/supervised.py | 3 - 3 files changed, 223 insertions(+), 41 deletions(-) diff --git a/UnitTests/test_supervised.py b/UnitTests/test_supervised.py index c35681e..76c9f9c 100644 --- a/UnitTests/test_supervised.py +++ b/UnitTests/test_supervised.py @@ -1,11 +1,16 @@ import unittest import numpy as np import pandas as pd + +import torch +import torch.nn as nn +import torch.nn.functional as F +from math import sqrt + import mineralML as mm class mineralML_supervised(unittest.TestCase): - def setUp(self): self.data = { "SampleID": [72065, 72066, 31890, 31891, 59237, 59238, 37643, 37644], @@ -33,24 +38,23 @@ def setUp(self): self.df = pd.DataFrame(self.data) - def test_load_minclass_nn(self): # Load actual mineral classes min_cat, mapping = mm.load_minclass_nn() expected_mapping = { - 0: 'Amphibole', - 1: 'Biotite', - 2: 'Clinopyroxene', - 3: 'Garnet', - 4: 'Ilmenite', - 5: 'KFeldspar', - 6: 'Magnetite', - 7: 'Muscovite', - 8: 'Olivine', - 9: 'Orthopyroxene', - 10: 'Plagioclase', - 11: 'Spinel' + 0: "Amphibole", + 1: "Biotite", + 2: "Clinopyroxene", + 3: "Garnet", + 4: "Ilmenite", + 5: "KFeldspar", + 6: "Magnetite", + 7: "Muscovite", + 8: "Olivine", + 9: "Orthopyroxene", + 10: "Plagioclase", + 11: "Spinel", } expected_min_cat = list(expected_mapping.values()) @@ -59,7 +63,6 @@ def test_load_minclass_nn(self): self.assertEqual(min_cat, expected_min_cat) self.assertEqual(mapping, expected_mapping) - def test_prep_df_nn(self): df_cleaned = mm.prep_df_nn(self.df.copy()) @@ -88,7 +91,6 @@ def test_prep_df_nn(self): ) def test_norm_data_nn(self): - df_cleaned = mm.prep_df_nn(self.df.copy()) normalized_data = mm.norm_data_nn(df_cleaned) @@ -96,26 +98,209 @@ def test_norm_data_nn(self): self.assertEqual(normalized_data.shape, (8, 10)) # Expected normalized data - expected_normalized_data = np.array([ - [-0.22643322, 0.08201123, 0.37556234, -0.09862359, -0.32940736, - -0.46334456, 0.40043244, 0.33010893, -0.22369372, -0.17032885], - [-0.24528221, 0.0942531 , 0.36763161, -0.04468734, -0.29926058, - -0.50116598, 0.48428333, 0.25318957, -0.26727423, -0.17443251], - [ 0.49548345, -0.22403538, -0.69746528, -0.36938356, -0.14852665, - -0.10836359, 1.52882587, -0.765992 , -0.63770857, -0.13749953], - [ 0.47412125, -0.23260468, -0.69746528, -0.4465124 , -0.32940736, - -0.12997583, 1.76121262, -0.73907022, -0.63770857, -0.14775869], - [-0.23899921, -0.22403538, 0.88312898, 0.12251504, 0.69558336, - -0.4849568 , 0.48428333, -0.8621412 , -0.63770857, -0.18674351], - [-0.21386722, -0.22403538, 0.95450554, 0.06857879, 0.69558336, - -0.34988033, 0.14887976, -0.8621412 , -0.63770857, -0.18674351], - [-0.20695592, -0.29626238, -0.90049194, -0.10509594, -0.26911379, - 1.52606173, -0.78785449, -0.8621412 , -0.63770857, -0.17443251], - [-0.28989151, -0.28769307, -0.90128501, 0.28971741, -0.05808629, - 1.18620932, -0.76629283, -0.8621412 , -0.63770857, -0.17648434] - ]) - - np.testing.assert_almost_equal(normalized_data, expected_normalized_data, decimal=4) + expected_normalized_data = np.array( + [ + [ + -0.22643322, + 0.08201123, + 0.37556234, + -0.09862359, + -0.32940736, + -0.46334456, + 0.40043244, + 0.33010893, + -0.22369372, + -0.17032885, + ], + [ + -0.24528221, + 0.0942531, + 0.36763161, + -0.04468734, + -0.29926058, + -0.50116598, + 0.48428333, + 0.25318957, + -0.26727423, + -0.17443251, + ], + [ + 0.49548345, + -0.22403538, + -0.69746528, + -0.36938356, + -0.14852665, + -0.10836359, + 1.52882587, + -0.765992, + -0.63770857, + -0.13749953, + ], + [ + 0.47412125, + -0.23260468, + -0.69746528, + -0.4465124, + -0.32940736, + -0.12997583, + 1.76121262, + -0.73907022, + -0.63770857, + -0.14775869, + ], + [ + -0.23899921, + -0.22403538, + 0.88312898, + 0.12251504, + 0.69558336, + -0.4849568, + 0.48428333, + -0.8621412, + -0.63770857, + -0.18674351, + ], + [ + -0.21386722, + -0.22403538, + 0.95450554, + 0.06857879, + 0.69558336, + -0.34988033, + 0.14887976, + -0.8621412, + -0.63770857, + -0.18674351, + ], + [ + -0.20695592, + -0.29626238, + -0.90049194, + -0.10509594, + -0.26911379, + 1.52606173, + -0.78785449, + -0.8621412, + -0.63770857, + -0.17443251, + ], + [ + -0.28989151, + -0.28769307, + -0.90128501, + 0.28971741, + -0.05808629, + 1.18620932, + -0.76629283, + -0.8621412, + -0.63770857, + -0.17648434, + ], + ] + ) + + np.testing.assert_almost_equal( + normalized_data, expected_normalized_data, decimal=4 + ) + + def test_unique_mapping_nn(self): + pred_class = np.array([0, 2, 3, 8, 11, 0, 3, 11, -1]) + unique, valid_mapping = mm.unique_mapping_nn(pred_class) + + expected_unique = np.array([0, 2, 3, 8, 11, -1]) + expected_valid_mapping = { + 0: "Amphibole", + 2: "Clinopyroxene", + 3: "Garnet", + 8: "Olivine", + 11: "Spinel", + -1: "Unknown", + } + + # Verify expected output + np.testing.assert_array_equal(unique, expected_unique) + self.assertEqual(valid_mapping, expected_valid_mapping) + + def test_class2mineral_nn(self): + pred_class = np.array([0, 2, 3, 8, 11, 0, 3, 11, -1]) + pred_mineral = mm.class2mineral_nn(pred_class) + + expected_pred_mineral = np.array( + [ + "Amphibole", + "Clinopyroxene", + "Garnet", + "Olivine", + "Spinel", + "Amphibole", + "Garnet", + "Spinel", + "Unknown", + ] + ) + + # Verify expected output + np.testing.assert_array_equal(pred_mineral, expected_pred_mineral) + + +class mineralML_supervised_balancing(unittest.TestCase): + def setUp(self): + # Create a small, imbalanced dataset for testing + self.train_x = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]) + self.train_y = np.array([0, 0, 0, 1, 1]) + + def test_import_imblearn(self): + try: + from imblearn.over_sampling import RandomOverSampler + except ImportError: + self.fail("imbalanced-learn library not installed") + + def test_balance_function(self): + train_x_balanced, train_y_balanced = mm.balance(self.train_x, self.train_y, n=3) + # Check the shape of the output + self.assertEqual(train_x_balanced.shape[0], train_y_balanced.shape[0]) + self.assertEqual(train_x_balanced.shape[1], self.train_x.shape[1]) + + # Check that each class has the correct number of samples + unique, counts = np.unique(train_y_balanced, return_counts=True) + self.assertTrue((counts == 3).all()) + + +class test_variational_layer(unittest.TestCase): + def setUp(self): + self.input_features = 5 + self.output_features = 3 + self.layer = mm.VariationalLayer(self.input_features, self.output_features) + self.input = torch.randn(10, self.input_features) + + def test_initialization(self): + self.assertEqual( + self.layer.weight_mu.size(), (self.output_features, self.input_features) + ) + self.assertEqual( + self.layer.weight_rho.size(), (self.output_features, self.input_features) + ) + self.assertEqual(self.layer.bias_mu.size(), (self.output_features,)) + self.assertEqual(self.layer.bias_rho.size(), (self.output_features,)) + + std = 1.0 / sqrt(self.input_features) + self.assertTrue( + torch.all(self.layer.weight_mu.data <= std) + and torch.all(self.layer.weight_mu.data >= -std) + ) + self.assertTrue( + torch.all(self.layer.weight_rho.data <= std) + and torch.all(self.layer.weight_rho.data >= -std) + ) + + def test_forward_pass(self): + output = self.layer(self.input) + self.assertEqual(output.size(), (10, self.output_features)) + + def test_kl_divergence(self): + kl_div = self.layer.kl_divergence() + self.assertIsInstance(kl_div, torch.Tensor) + self.assertGreaterEqual(kl_div.item(), 0) if __name__ == "__main__": diff --git a/docs/examples/ml_models/mineralML_docs.ipynb b/docs/examples/ml_models/mineralML_docs.ipynb index 71ccad9..c5d499d 100644 --- a/docs/examples/ml_models/mineralML_docs.ipynb +++ b/docs/examples/ml_models/mineralML_docs.ipynb @@ -28,7 +28,7 @@ "\n", "We have loaded in the mineralML Python package with trained machine learning models for classifying minerals. Examples workflows working with these spectra can be found on the [ReadTheDocs](https://mineralML.readthedocs.io/en/latest/). \n", "\n", - "The Google Colab implementation here aims to get your electron microprobe compositions classified and processes. We remove degrees of freedom to simplify the process. The igneous minerals considered for this study include: amphibole, apatite, biotite, clinopyroxene, garnet, ilmenite, K-feldspar, magnetite, muscovite, olivine, orthopyroxene, plagioclase, quartz, rutile, spinel, tourmaline, and zircon. \n", + "The Google Colab implementation here aims to get your electron microprobe compositions classified and processes. We remove degrees of freedom to simplify the process. The minerals considered for this study include: amphibole, apatite, biotite, clinopyroxene, garnet, ilmenite, K-feldspar, magnetite, muscovite, olivine, orthopyroxene, plagioclase, quartz, rutile, spinel, tourmaline, and zircon. \n", "\n", "The files necessary include a CSV file containing your electron microprobe analyses in oxide weight percentages. Find an example [here](https://github.com/sarahshi/mineralML/blob/main/Validation_Data/lepr_allphases_lim.csv). The necessary oxides are $SiO_2$, $TiO_2$, $Al_2O_3$, $FeO_t$, $MnO$, $MgO$, $CaO$, $Na_2O$, $K_2O$, $Cr_2O_3$. For the oxides not analyzed for specific minerals, the preprocessing will fill in the nan values as 0. \n", "\n", @@ -58,7 +58,7 @@ "source": [ "\n", "# Read in your dataframe of mineral data, called DF.csv. \n", - "# Prepare the dataframe by removing rows with too many NaNs, filling some with zeros, and filtering to the minerals described by mineralML. \n", + "# Prepare the dataframe by removing rows with too many NaNs, and filling in zeros. \n", "\n", "df_load = mm.load_df('lepr_valid_lim.csv')\n", "df_nn = mm.prep_df_nn(df_load)\n" diff --git a/src/mineralML/supervised.py b/src/mineralML/supervised.py index 934534e..45b4312 100644 --- a/src/mineralML/supervised.py +++ b/src/mineralML/supervised.py @@ -494,9 +494,6 @@ def predict_class_prob_nn(df, n_iterations=250): output_list = np.array(output_list) probability_matrix = output_list.mean(axis=0) - # predict_class = np.argmax(probability_matrix, axis=1) - # predict_prob = np.max(probability_matrix, axis=1) - # predict_mineral = class2mineral_nn(predict_class) top_two_indices = np.argsort(probability_matrix, axis=1)[:, -2:] first_predict_prob = probability_matrix[