Skip to content

Commit

Permalink
Fix confusion_matrix_df and prep_nn_df functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahshi committed Jan 25, 2024
1 parent e942c2d commit adcadbb
Show file tree
Hide file tree
Showing 11 changed files with 6,148 additions and 5,961 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ Validate_AE.ipynb
Train_CovMatch.py
Train_LinearModels.py
Validate_AE.py
.ipynb_checkpoints
.ipynb_checkpoints
/ignore/
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ Make sure that you keep up with the latest version of mineralML. To upgrade to t
pip install mineralML --upgrade
```

Mac/Linux installation will be straightforward. Windows installations will require the additional setup of WSL
Mac/Linux installation will be straightforward. Windows installations will require the additional setup of WSL.
121 changes: 113 additions & 8 deletions Train_Supervised_Variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,6 @@
opx_bad = opx[opx.Mineral != opx.Empirical_Mineral]
cpx_bad = cpx[cpx.Mineral != cpx.Empirical_Mineral]


# %%


ss = StandardScaler()
array_norm = ss.fit_transform(min_df_lim[oxides])

# %%

# npz = np.load('parametermatrix_neuralnetwork/best_model_data.npz')
Expand All @@ -107,6 +100,26 @@
# mm.pp_matrix(df_valid_cm, cmap = cmap, savefig = 'test', figsize = (11.5, 11.5))
# plt.show()


# %%

# Step 2: Read in your DataFrame, drop rows with NaN in specific oxide columns, fill NaNs, and filter minerals
petrelli_df_load = mm.load_df('Petrelli_cpx.csv')
petrelli_df, petrelli_df_ex = mm.prep_df_nn(petrelli_df_load)
petrelli_df_pred, petrelli_probability_matrix = mm.predict_class_prob_nn(petrelli_df)

petrelli_bayes_valid_report = classification_report(
petrelli_df_pred['Mineral'], petrelli_df_pred['Predict_Mineral'], zero_division=0
)
print("Petrelli Validation Report:\n", petrelli_bayes_valid_report)

petrelli_cm = mm.confusion_matrix_df(petrelli_df_pred['Mineral'], petrelli_df_pred['Predict_Mineral'])
print("Petrelli Confusion Matrix:\n", petrelli_cm)

petrelli_cm[petrelli_cm < len(petrelli_df_pred['Predict_Mineral'])*0.0005] = 0
mm.pp_matrix(petrelli_cm) # , savefig = 'none')


# %%

# Step 2: Read in your DataFrame, drop rows with NaN in specific oxide columns, fill NaNs, and filter minerals
Expand All @@ -123,7 +136,99 @@
print("LEPR Confusion Matrix:\n", lepr_cm)

lepr_cm[lepr_cm < len(lepr_df_pred['Predict_Mineral'])*0.0005] = 0
mm.pp_matrix(lepr_cm, savefig = 'none')
mm.pp_matrix(lepr_cm)

# %%


# Step 2: Read in your DataFrame, drop rows with NaN in specific oxide columns, fill NaNs, and filter minerals
petdb_df_load = mm.load_df('Validation_Data/PetDB_validationdata_Fe.csv')
petdb_df, petdb_df_ex = mm.prep_df_nn(petdb_df_load)
petdb_df_pred, petdb_probability_matrix = mm.predict_class_prob_nn(petdb_df)

petdb_bayes_valid_report = classification_report(
petdb_df_pred['Mineral'], petdb_df_pred['Predict_Mineral'], zero_division=0
)
print("PetDB Validation Report:\n", petdb_bayes_valid_report)

petdb_cm = mm.confusion_matrix_df(petdb_df_pred['Mineral'], petdb_df_pred['Predict_Mineral'])
print("PetDB Confusion Matrix:\n", petdb_cm)

petdb_cm[petdb_cm < len(petdb_df_pred['Predict_Mineral'])*0.0005] = 0
mm.pp_matrix(petdb_cm)

# %%

# %%

georoc_df_load = mm.load_df('Validation_Data/GEOROC_validationdata_Fe.csv')
georoc_df_load['Mineral'] = georoc_df_load['Mineral'].replace('(Al)Kalifeldspar', 'KFeldspar')
georoc_df, georoc_df_ex = mm.prep_df_nn(georoc_df_load)

georoc_df_pred, georoc_probability_matrix = mm.predict_class_prob_nn(georoc_df)


georoc_bayes_valid_report = classification_report(
georoc_df_pred['Mineral'], georoc_df_pred['Predict_Mineral'], zero_division=0
)
print("GEOROC Validation Report:\n", georoc_bayes_valid_report)

georoc_cm = mm.confusion_matrix_df(georoc_df_pred['Mineral'], georoc_df_pred['Predict_Mineral'])
print("GEOROC Confusion Matrix:\n", georoc_cm)

georoc_cm[georoc_cm < len(georoc_df_pred['Predict_Mineral'])*0.0005] = 0
mm.pp_matrix(georoc_cm, savefig = None)


# %%

cascades_df_load = mm.load_df('Validation_Data/Cascades_CpxAmp_NN.csv')
cascades_df, cascades_df_ex = mm.prep_df_nn(cascades_df_load)

cascades_df_pred, cascades_probability_matrix = mm.predict_class_prob_nn(cascades_df)


cascades_bayes_valid_report = classification_report(
cascades_df_pred['Mineral'], cascades_df_pred['Predict_Mineral'], zero_division=0
)
print("Cascades Validation Report:\n", cascades_bayes_valid_report)


# %%

cascades_df_pred.to_csv('Validation_Data/Cascades_CpxAmp_NN.csv')


# %%

def confusion_matrix_df_test(given_min, pred_min):

"""
Constructs a confusion matrix as a pandas DataFrame for easy visualization and
analysis. The function first finds the unique classes and maps them to their
corresponding mineral names. Then, it uses these mappings to construct the
confusion matrix, which compares the given and predicted classes.
Parameters:
given_class (array-like): The true class labels.
pred_class (array-like): The predicted class labels.
Returns:
cm_df (DataFrame): A DataFrame representing the confusion matrix, with rows
and columns labeled by the unique mineral names found in
the given and predicted class arrays.
"""

cm_matrix = confusion_matrix(given_min, pred_min)
unique, valid_mapping = mm.unique_mapping_nn(pred_min)
cm_df = pd.DataFrame(cm_matrix, index=valid_mapping, columns=valid_mapping)

return cm_df


# %%


# %%
Expand Down
Loading

0 comments on commit adcadbb

Please sign in to comment.