Skip to content

Commit

Permalink
Modify prep_df_nn - remove mineral filter, create missing columns wit…
Browse files Browse the repository at this point in the history
…h UserWarning
  • Loading branch information
sarahshi committed May 15, 2024
1 parent 423f799 commit 3ba10bd
Show file tree
Hide file tree
Showing 7 changed files with 20,409 additions and 2,812 deletions.
14,910 changes: 14,910 additions & 0 deletions DF.csv

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mineralML_colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Read in the CSV and prepare data for analysis (fill in nans, limit to trained igneous minerals)"
"## Read in the CSV and prepare data for analysis (fill in nans, limit to trained igneous minerals). This CSV is an example file of minerals from LEPR. "
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions src/mineralML/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# %%
# %%

__author__ = 'Sarah Shi'
__author__ = "Sarah Shi"

import os
import copy
Expand Down
103 changes: 70 additions & 33 deletions src/mineralML/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# %%
# %%

import numpy as np
import pandas as pd
Expand All @@ -8,12 +8,21 @@
import matplotlib
from matplotlib import pyplot as plt

# %%

def pp_matrix(df_cm, annot=True, cmap="BuGn", fmt=".2f",
fz=12, lw=0.5, cbar=False, figsize=[10.5, 10.5],
show_null_values=0, pred_val_axis="x"): #, savefig = None,):

# %%


def pp_matrix(
df_cm,
annot=True,
cmap="BuGn",
fmt=".2f",
fz=12,
lw=0.5,
cbar=False,
figsize=[10.5, 10.5],
show_null_values=0,
pred_val_axis="x",
): # , savefig = None,):
"""
Creates and displays a confusion matrix visualization using Seaborn's heatmap function.
Expand Down Expand Up @@ -59,11 +68,21 @@ def pp_matrix(df_cm, annot=True, cmap="BuGn", fmt=".2f",
ax1 = fig1.gca() # Get Current Axis
ax1.cla() # clear existing plot

ax = sns.heatmap(df_cm, annot=annot, annot_kws={"size": fz}, linewidths=lw, ax=ax1, cbar=cbar, cmap=cmap, linecolor="w", fmt=fmt,)
ax = sns.heatmap(
df_cm,
annot=annot,
annot_kws={"size": fz},
linewidths=lw,
ax=ax1,
cbar=cbar,
cmap=cmap,
linecolor="w",
fmt=fmt,
)

# set ticklabels rotation
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, fontsize=13, ha='right')
ax.set_yticklabels(ax.get_yticklabels(), rotation=35, fontsize=13, va='top')
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, fontsize=13, ha="right")
ax.set_yticklabels(ax.get_yticklabels(), rotation=35, fontsize=13, va="top")

# Turn off all the ticks
for t in ax.xaxis.get_major_ticks():
Expand All @@ -89,7 +108,9 @@ def pp_matrix(df_cm, annot=True, cmap="BuGn", fmt=".2f",
posi += 1

# set text
txt_res = config_cell_text_and_colors(array_df, lin, col, t, facecolors, posi, fz, fmt, show_null_values)
txt_res = config_cell_text_and_colors(
array_df, lin, col, t, facecolors, posi, fz, fmt, show_null_values
)

text_add.extend(txt_res[0])
text_del.extend(txt_res[1])
Expand All @@ -106,12 +127,11 @@ def pp_matrix(df_cm, annot=True, cmap="BuGn", fmt=".2f",
ax.set_ylabel(ylbl)
plt.tight_layout() # set layout slim

# if savefig != None:
# if savefig != None:
# plt.savefig(savefig + '.pdf')


def insert_totals(df_cm):

"""
Inserts total sums for each row and column into the confusion matrix DataFrame.
Expand All @@ -125,32 +145,35 @@ def insert_totals(df_cm):
Returns:
None: The function modifies the DataFrame in place.
Note:
If 'sum_row' or 'sum_col' already exist in the DataFrame, they will be recalculated.
"""

# Check if 'sum_row' and 'sum_col' already exist and remove them if they do
if 'sum_row' in df_cm.columns:
df_cm.drop('sum_row', axis=1, inplace=True)
if 'sum_col' in df_cm.index:
df_cm.drop('sum_col', axis=0, inplace=True)
if "sum_row" in df_cm.columns:
df_cm.drop("sum_row", axis=1, inplace=True)
if "sum_col" in df_cm.index:
df_cm.drop("sum_col", axis=0, inplace=True)

# Calculate the sum of each column to create 'sum_row'
sum_col = df_cm.sum(axis=0).astype(int) # sum columns
sum_lin = df_cm.sum(axis=1).astype(int) # sum rows

# Add 'sum_row' and 'sum_col' to the dataframe
df_cm['sum_row'] = sum_lin
df_cm.loc['sum_col'] = sum_col
df_cm.at['sum_col', 'sum_row'] = sum_lin.sum() # Set the bottom right cell to the grand total

# Add 'sum_row' and 'sum_col' to the dataframe
df_cm["sum_row"] = sum_lin
df_cm.loc["sum_col"] = sum_col
df_cm.at[
"sum_col", "sum_row"
] = sum_lin.sum() # Set the bottom right cell to the grand total

def config_cell_text_and_colors(array_df, lin, col, oText, facecolors, posi, fz, fmt, show_null_values=0):

def config_cell_text_and_colors(
array_df, lin, col, oText, facecolors, posi, fz, fmt, show_null_values=0
):
"""
Configures cell text and colors for confusion matrix visualization.
Adjusts the text and background colors of cells in the confusion matrix based on their values.
Expand All @@ -173,7 +196,7 @@ def config_cell_text_and_colors(array_df, lin, col, oText, facecolors, posi, fz,
Note:
The function modifies text and background colors based on the value in each cell.
"""
"""

import matplotlib.font_manager as fm

Expand Down Expand Up @@ -208,12 +231,18 @@ def config_cell_text_and_colors(array_df, lin, col, oText, facecolors, posi, fz,

# text to DEL
text_del.append(oText)
warnings.filterwarnings("ignore", category=DeprecationWarning)

warnings.filterwarnings("ignore", category=DeprecationWarning)

# text to ADD
font_prop = fm.FontProperties(weight="bold", size=fz)
text_kwargs = dict(color="k", ha="center",va="center", gid="sum", fontproperties=font_prop,)
text_kwargs = dict(
color="k",
ha="center",
va="center",
gid="sum",
fontproperties=font_prop,
)
lis_txt = ["%d" % (cell_val), per_ok_s, "%.1f%%" % (per_err)]
lis_kwa = [text_kwargs]
dic = text_kwargs.copy()
Expand All @@ -222,9 +251,18 @@ def config_cell_text_and_colors(array_df, lin, col, oText, facecolors, posi, fz,
dic = text_kwargs.copy()
dic["color"] = "r"
lis_kwa.append(dic)
lis_pos = [(oText._x, oText._y - 0.3), (oText._x, oText._y), (oText._x, oText._y + 0.3),]
lis_pos = [
(oText._x, oText._y - 0.3),
(oText._x, oText._y),
(oText._x, oText._y + 0.3),
]
for i in range(len(lis_txt)):
newText = dict(x=lis_pos[i][0], y=lis_pos[i][1], text=lis_txt[i], kw=lis_kwa[i],)
newText = dict(
x=lis_pos[i][0],
y=lis_pos[i][1],
text=lis_txt[i],
kw=lis_kwa[i],
)
text_add.append(newText)

# set background color for sum cells (last line and last column)
Expand Down Expand Up @@ -255,4 +293,3 @@ def config_cell_text_and_colors(array_df, lin, col, oText, facecolors, posi, fz,
oText.set_color("r")

return text_add, text_del

Loading

0 comments on commit 3ba10bd

Please sign in to comment.