Skip to content

Commit

Permalink
Update supervised.py to coerce and warn about strings in data
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahshi committed May 28, 2024
1 parent 4bdc5a1 commit 75f4d3b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
6 changes: 6 additions & 0 deletions docs/Changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
Change Log
==========


Version 0.0.0.7
===============
Update prep_nn_df function to coerce strings in data and to return warnings.


Version 0.0.0.6
===============
Update prep_nn_df function to remove mineral filter, create missing columns (whilst returning UserWarning).
Expand Down
2 changes: 1 addition & 1 deletion src/mineralML/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
# 1) we don't load dependencies by storing it in __init__.py
# 2) we can import it in setup.py for the same reason
# 3) we can import it into your module
__version__ = '0.0.0.6'
__version__ = '0.0.0.7'
27 changes: 25 additions & 2 deletions src/mineralML/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,25 @@ def prep_df_nn(df):
stacklevel=2,
)

# Convert columns to numeric, coercing errors to NaN
df[oxides] = df[oxides].apply(pd.to_numeric, errors='coerce')

# Warn if any non-numeric values were coerced to NaN
if df[oxides].isnull().any().any():
warnings.warn(
"Some non-numeric values were found in the oxides columns and have been coerced to NaN.",
UserWarning,
stacklevel=2,
)

# Drop rows with fewer than 6 non-NaN values in the oxides columns
df.dropna(subset=oxides, thresh=6, inplace=True)

# Fill remaining NaN values with 0 for oxides, keep NaN for 'Mineral'
df[oxides] = df[oxides].fillna(0)
df.loc[:, oxides] = df.loc[:, oxides].fillna(0)

# Ensure only oxides, 'Mineral', and 'SampleID' columns are kept
df = df[oxides + ["Mineral", "SampleID"]]
df = df.loc[:, oxides + ["Mineral", "SampleID"]]

# Ensure SampleID is the index
df.set_index("SampleID", inplace=True)
Expand Down Expand Up @@ -139,10 +150,22 @@ def norm_data_nn(df):
"K2O",
"Cr2O3",
]

mean, std = load_scaler("scaler_nn.npz")

# Ensure that mean and std are Series objects with indices matching the columns
if not isinstance(mean, pd.Series) or not isinstance(std, pd.Series):
raise ValueError("mean and std should be Series")

for col in oxides:
if col not in mean.index or col not in std.index:
raise ValueError(f"Missing mean or std for column: {col}")

df = df.reset_index(drop=False)
scaled_df = df[oxides].copy()

# scaled_df = df[oxides].reset_index(drop=True).copy()

if df[oxides].isnull().any().any():
df, _ = prep_df_nn(df)
else:
Expand Down

0 comments on commit 75f4d3b

Please sign in to comment.