Skip to content

Commit

Permalink
Create new supervised.py unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahshi committed May 16, 2024
1 parent c0e02e0 commit 1309c35
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 41 deletions.
257 changes: 221 additions & 36 deletions UnitTests/test_supervised.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import unittest
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt

import mineralML as mm


class mineralML_supervised(unittest.TestCase):

def setUp(self):
self.data = {
"SampleID": [72065, 72066, 31890, 31891, 59237, 59238, 37643, 37644],
Expand Down Expand Up @@ -33,24 +38,23 @@ def setUp(self):

self.df = pd.DataFrame(self.data)


def test_load_minclass_nn(self):
# Load actual mineral classes
min_cat, mapping = mm.load_minclass_nn()

expected_mapping = {
0: 'Amphibole',
1: 'Biotite',
2: 'Clinopyroxene',
3: 'Garnet',
4: 'Ilmenite',
5: 'KFeldspar',
6: 'Magnetite',
7: 'Muscovite',
8: 'Olivine',
9: 'Orthopyroxene',
10: 'Plagioclase',
11: 'Spinel'
0: "Amphibole",
1: "Biotite",
2: "Clinopyroxene",
3: "Garnet",
4: "Ilmenite",
5: "KFeldspar",
6: "Magnetite",
7: "Muscovite",
8: "Olivine",
9: "Orthopyroxene",
10: "Plagioclase",
11: "Spinel",
}

expected_min_cat = list(expected_mapping.values())
Expand All @@ -59,7 +63,6 @@ def test_load_minclass_nn(self):
self.assertEqual(min_cat, expected_min_cat)
self.assertEqual(mapping, expected_mapping)


def test_prep_df_nn(self):
df_cleaned = mm.prep_df_nn(self.df.copy())

Expand Down Expand Up @@ -88,34 +91,216 @@ def test_prep_df_nn(self):
)

def test_norm_data_nn(self):

df_cleaned = mm.prep_df_nn(self.df.copy())
normalized_data = mm.norm_data_nn(df_cleaned)

# Check the shape of the output
self.assertEqual(normalized_data.shape, (8, 10))

# Expected normalized data
expected_normalized_data = np.array([
[-0.22643322, 0.08201123, 0.37556234, -0.09862359, -0.32940736,
-0.46334456, 0.40043244, 0.33010893, -0.22369372, -0.17032885],
[-0.24528221, 0.0942531 , 0.36763161, -0.04468734, -0.29926058,
-0.50116598, 0.48428333, 0.25318957, -0.26727423, -0.17443251],
[ 0.49548345, -0.22403538, -0.69746528, -0.36938356, -0.14852665,
-0.10836359, 1.52882587, -0.765992 , -0.63770857, -0.13749953],
[ 0.47412125, -0.23260468, -0.69746528, -0.4465124 , -0.32940736,
-0.12997583, 1.76121262, -0.73907022, -0.63770857, -0.14775869],
[-0.23899921, -0.22403538, 0.88312898, 0.12251504, 0.69558336,
-0.4849568 , 0.48428333, -0.8621412 , -0.63770857, -0.18674351],
[-0.21386722, -0.22403538, 0.95450554, 0.06857879, 0.69558336,
-0.34988033, 0.14887976, -0.8621412 , -0.63770857, -0.18674351],
[-0.20695592, -0.29626238, -0.90049194, -0.10509594, -0.26911379,
1.52606173, -0.78785449, -0.8621412 , -0.63770857, -0.17443251],
[-0.28989151, -0.28769307, -0.90128501, 0.28971741, -0.05808629,
1.18620932, -0.76629283, -0.8621412 , -0.63770857, -0.17648434]
])

np.testing.assert_almost_equal(normalized_data, expected_normalized_data, decimal=4)
expected_normalized_data = np.array(
[
[
-0.22643322,
0.08201123,
0.37556234,
-0.09862359,
-0.32940736,
-0.46334456,
0.40043244,
0.33010893,
-0.22369372,
-0.17032885,
],
[
-0.24528221,
0.0942531,
0.36763161,
-0.04468734,
-0.29926058,
-0.50116598,
0.48428333,
0.25318957,
-0.26727423,
-0.17443251,
],
[
0.49548345,
-0.22403538,
-0.69746528,
-0.36938356,
-0.14852665,
-0.10836359,
1.52882587,
-0.765992,
-0.63770857,
-0.13749953,
],
[
0.47412125,
-0.23260468,
-0.69746528,
-0.4465124,
-0.32940736,
-0.12997583,
1.76121262,
-0.73907022,
-0.63770857,
-0.14775869,
],
[
-0.23899921,
-0.22403538,
0.88312898,
0.12251504,
0.69558336,
-0.4849568,
0.48428333,
-0.8621412,
-0.63770857,
-0.18674351,
],
[
-0.21386722,
-0.22403538,
0.95450554,
0.06857879,
0.69558336,
-0.34988033,
0.14887976,
-0.8621412,
-0.63770857,
-0.18674351,
],
[
-0.20695592,
-0.29626238,
-0.90049194,
-0.10509594,
-0.26911379,
1.52606173,
-0.78785449,
-0.8621412,
-0.63770857,
-0.17443251,
],
[
-0.28989151,
-0.28769307,
-0.90128501,
0.28971741,
-0.05808629,
1.18620932,
-0.76629283,
-0.8621412,
-0.63770857,
-0.17648434,
],
]
)

np.testing.assert_almost_equal(
normalized_data, expected_normalized_data, decimal=4
)

def test_unique_mapping_nn(self):
pred_class = np.array([0, 2, 3, 8, 11, 0, 3, 11, -1])
unique, valid_mapping = mm.unique_mapping_nn(pred_class)

expected_unique = np.array([0, 2, 3, 8, 11, -1])
expected_valid_mapping = {
0: "Amphibole",
2: "Clinopyroxene",
3: "Garnet",
8: "Olivine",
11: "Spinel",
-1: "Unknown",
}

# Verify expected output
np.testing.assert_array_equal(unique, expected_unique)
self.assertEqual(valid_mapping, expected_valid_mapping)

def test_class2mineral_nn(self):
pred_class = np.array([0, 2, 3, 8, 11, 0, 3, 11, -1])
pred_mineral = mm.class2mineral_nn(pred_class)

expected_pred_mineral = np.array(
[
"Amphibole",
"Clinopyroxene",
"Garnet",
"Olivine",
"Spinel",
"Amphibole",
"Garnet",
"Spinel",
"Unknown",
]
)

# Verify expected output
np.testing.assert_array_equal(pred_mineral, expected_pred_mineral)


class mineralML_supervised_balancing(unittest.TestCase):
def setUp(self):
# Create a small, imbalanced dataset for testing
self.train_x = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
self.train_y = np.array([0, 0, 0, 1, 1])

def test_import_imblearn(self):
try:
from imblearn.over_sampling import RandomOverSampler
except ImportError:
self.fail("imbalanced-learn library not installed")

def test_balance_function(self):
train_x_balanced, train_y_balanced = mm.balance(self.train_x, self.train_y, n=3)
# Check the shape of the output
self.assertEqual(train_x_balanced.shape[0], train_y_balanced.shape[0])
self.assertEqual(train_x_balanced.shape[1], self.train_x.shape[1])

# Check that each class has the correct number of samples
unique, counts = np.unique(train_y_balanced, return_counts=True)
self.assertTrue((counts == 3).all())


class test_variational_layer(unittest.TestCase):
def setUp(self):
self.input_features = 5
self.output_features = 3
self.layer = mm.VariationalLayer(self.input_features, self.output_features)
self.input = torch.randn(10, self.input_features)

def test_initialization(self):
self.assertEqual(
self.layer.weight_mu.size(), (self.output_features, self.input_features)
)
self.assertEqual(
self.layer.weight_rho.size(), (self.output_features, self.input_features)
)
self.assertEqual(self.layer.bias_mu.size(), (self.output_features,))
self.assertEqual(self.layer.bias_rho.size(), (self.output_features,))

std = 1.0 / sqrt(self.input_features)
self.assertTrue(
torch.all(self.layer.weight_mu.data <= std)
and torch.all(self.layer.weight_mu.data >= -std)
)
self.assertTrue(
torch.all(self.layer.weight_rho.data <= std)
and torch.all(self.layer.weight_rho.data >= -std)
)

def test_forward_pass(self):
output = self.layer(self.input)
self.assertEqual(output.size(), (10, self.output_features))

def test_kl_divergence(self):
kl_div = self.layer.kl_divergence()
self.assertIsInstance(kl_div, torch.Tensor)
self.assertGreaterEqual(kl_div.item(), 0)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/ml_models/mineralML_docs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"\n",
"We have loaded in the mineralML Python package with trained machine learning models for classifying minerals. Examples workflows working with these spectra can be found on the [ReadTheDocs](https://mineralML.readthedocs.io/en/latest/). \n",
"\n",
"The Google Colab implementation here aims to get your electron microprobe compositions classified and processes. We remove degrees of freedom to simplify the process. The igneous minerals considered for this study include: amphibole, apatite, biotite, clinopyroxene, garnet, ilmenite, K-feldspar, magnetite, muscovite, olivine, orthopyroxene, plagioclase, quartz, rutile, spinel, tourmaline, and zircon. \n",
"The Google Colab implementation here aims to get your electron microprobe compositions classified and processes. We remove degrees of freedom to simplify the process. The minerals considered for this study include: amphibole, apatite, biotite, clinopyroxene, garnet, ilmenite, K-feldspar, magnetite, muscovite, olivine, orthopyroxene, plagioclase, quartz, rutile, spinel, tourmaline, and zircon. \n",
"\n",
"The files necessary include a CSV file containing your electron microprobe analyses in oxide weight percentages. Find an example [here](https://github.com/sarahshi/mineralML/blob/main/Validation_Data/lepr_allphases_lim.csv). The necessary oxides are $SiO_2$, $TiO_2$, $Al_2O_3$, $FeO_t$, $MnO$, $MgO$, $CaO$, $Na_2O$, $K_2O$, $Cr_2O_3$. For the oxides not analyzed for specific minerals, the preprocessing will fill in the nan values as 0. \n",
"\n",
Expand Down Expand Up @@ -58,7 +58,7 @@
"source": [
"\n",
"# Read in your dataframe of mineral data, called DF.csv. \n",
"# Prepare the dataframe by removing rows with too many NaNs, filling some with zeros, and filtering to the minerals described by mineralML. \n",
"# Prepare the dataframe by removing rows with too many NaNs, and filling in zeros. \n",
"\n",
"df_load = mm.load_df('lepr_valid_lim.csv')\n",
"df_nn = mm.prep_df_nn(df_load)\n"
Expand Down
3 changes: 0 additions & 3 deletions src/mineralML/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,6 @@ def predict_class_prob_nn(df, n_iterations=250):

output_list = np.array(output_list)
probability_matrix = output_list.mean(axis=0)
# predict_class = np.argmax(probability_matrix, axis=1)
# predict_prob = np.max(probability_matrix, axis=1)
# predict_mineral = class2mineral_nn(predict_class)

top_two_indices = np.argsort(probability_matrix, axis=1)[:, -2:]
first_predict_prob = probability_matrix[
Expand Down

0 comments on commit 1309c35

Please sign in to comment.