Skip to content

Commit

Permalink
Update scalers for release
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahshi committed Dec 10, 2023
1 parent c249b75 commit 9f909d5
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/mineralML/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
# 1) we don't load dependencies by storing it in __init__.py
# 2) we can import it in setup.py for the same reason
# 3) we can import it into your module
__version__ = '0.0.0.2'
__version__ = '0.0.0.3'
8 changes: 4 additions & 4 deletions src/mineralML/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,25 +102,25 @@ def load_df(filepath):
return df


def load_scaler():
def load_scaler(scaler_path):

"""
Loads a pre-fitted scaler's mean and std from a .npz file. This scaler is a StandardScaler
for normalizing or standardizing input data before passing it to a machine learning model.
Returns:
mean, std (pandas Series): The mean and std from the scaler object 'scaler.npz'.
mean, std (pandas Series): The mean and std from the scaler object 'scaler_ae/nn.npz'.
Raises:
FileNotFoundError: If 'scaler.npz' is not found in the expected directory.
FileNotFoundError: If 'scaler_ae/nn.npz' is not found in the expected directory.
Exception: Propagates any exception raised during the scaler loading process.
"""

# Define the path to the scaler relative to this file's location.
current_dir = os.path.dirname(__file__)
scaler_path = os.path.join(current_dir, 'scaler.npz') # Note the .joblib extension
scaler_path = os.path.join(current_dir, scaler_path) # Note the .joblib extension

oxides = ['SiO2', 'TiO2', 'Al2O3', 'FeOt', 'MnO', 'MgO', 'CaO', 'Na2O', 'K2O', 'Cr2O3']

Expand Down
Binary file added src/mineralML/scaler_ae.npz
Binary file not shown.
File renamed without changes.
2 changes: 1 addition & 1 deletion src/mineralML/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def norm_data_nn(df):
"""

oxides = ['SiO2', 'TiO2', 'Al2O3', 'FeOt', 'MnO', 'MgO', 'CaO', 'Na2O', 'K2O', 'Cr2O3']
mean, std = load_scaler()
mean, std = load_scaler('scaler_nn.npz')

scaled_df = df[oxides].copy()

Expand Down
6 changes: 3 additions & 3 deletions src/mineralML/unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def norm_data_ae(df):
"""

oxides = ['SiO2', 'TiO2', 'Al2O3', 'FeOt', 'MnO', 'MgO', 'CaO', 'Na2O', 'K2O', 'Cr2O3']
mean, std = load_scaler()
mean, std = load_scaler('scaler_ae.npz')

if df[oxides].isnull().any().any():
df, _ = prep_df_ae(df)
Expand Down Expand Up @@ -543,8 +543,8 @@ def get_latent_space(df):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

current_dir = os.path.dirname(__file__)
# model_path = os.path.join(current_dir, 'ae_best_model.pt')
model_path = os.path.join(current_dir, 'ae_best_model_noP_tanh.pt')
model_path = os.path.join(current_dir, 'ae_best_model.pt')
# model_path = os.path.join(current_dir, 'ae_best_model_noP_tanh.pt')
model = Tanh_Autoencoder(input_dim=10, hidden_layer_sizes=(256, 64, 16)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=0)
load_model(model, optimizer, model_path)
Expand Down

0 comments on commit 9f909d5

Please sign in to comment.