Skip to content

Commit

Permalink
Call self.save_checkpoint, remove create_dataloader
Browse files Browse the repository at this point in the history
create_dataloader not actually being used?
  • Loading branch information
sarahshi committed Jan 25, 2024
1 parent 04a0da6 commit c906d48
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 53 deletions.
62 changes: 31 additions & 31 deletions UnitTests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,38 +107,38 @@ def test_load_scaler(self, mock_dirname, mock_np_load):
mm.load_scaler('non_existing_path.npz')


class test_CreateDataLoader(unittest.TestCase):

def setUp(self):
self.df = pd.DataFrame({
'feature1': np.random.rand(10),
'feature2': np.random.rand(10),
'Mineral': ['Mineral1', 'Mineral2', 'Mineral1', 'Mineral2', 'Mineral1',
'Mineral2', 'Mineral1', 'Mineral2', 'Mineral1', 'Mineral2']
})

@patch('mineralML.load_minclass_nn')
@patch('mineralML.norm_data')
def test_create_dataloader(self, mock_norm_data, mock_load_minclass_nn):
# Mock return values for dependencies
mock_norm_data.return_value = self.df
mock_load_minclass_nn.return_value = (['Mineral1', 'Mineral2'], None)

# Create DataLoader
dataloader = mm.create_dataloader(self.df, batch_size=2, shuffle=False)

# Check if DataLoader is created and has correct properties
self.assertIsNotNone(dataloader, "DataLoader not created")
self.assertIsInstance(dataloader, torch.utils.data.DataLoader, "Returned object is not a DataLoader")
# class test_CreateDataLoader(unittest.TestCase):

# def setUp(self):
# self.df = pd.DataFrame({
# 'feature1': np.random.rand(10),
# 'feature2': np.random.rand(10),
# 'Mineral': ['Mineral1', 'Mineral2', 'Mineral1', 'Mineral2', 'Mineral1',
# 'Mineral2', 'Mineral1', 'Mineral2', 'Mineral1', 'Mineral2']
# })

# @patch('mineralML.load_minclass_nn')
# @patch('mineralML.norm_data')
# def test_create_dataloader(self, mock_norm_data, mock_load_minclass_nn):
# # Mock return values for dependencies
# mock_norm_data.return_value = self.df
# mock_load_minclass_nn.return_value = (['Mineral1', 'Mineral2'], None)

# # Create DataLoader
# dataloader = mm.create_dataloader(self.df, batch_size=2, shuffle=False)

# # Check if DataLoader is created and has correct properties
# self.assertIsNotNone(dataloader, "DataLoader not created")
# self.assertIsInstance(dataloader, torch.utils.data.DataLoader, "Returned object is not a DataLoader")

# Check DataLoader's batch size
for batch in dataloader:
self.assertEqual(len(batch), 2) # Assuming LabelDataset returns a tuple
break # We just need to check the first batch
# # Check DataLoader's batch size
# for batch in dataloader:
# self.assertEqual(len(batch), 2) # Assuming LabelDataset returns a tuple
# break # We just need to check the first batch

# Check if dependent functions are called
mock_norm_data.assert_called_once_with(self.df)
mock_load_minclass_nn.assert_called_once()
# # Check if dependent functions are called
# mock_norm_data.assert_called_once_with(self.df)
# mock_load_minclass_nn.assert_called_once()


class test_NetworkWeights(unittest.TestCase):
Expand Down Expand Up @@ -251,7 +251,7 @@ def save_checkpoint(model, optimizer, path):
def test_load_model(self):
with TemporaryDirectory() as tmp_dir:
filepath = os.path.join(tmp_dir, "model_checkpoint.pth")
save_checkpoint(self.model, self.optimizer, filepath)
self.save_checkpoint(self.model, self.optimizer, filepath)

# Create new model and optimizer for loading
loaded_model = MockModel()
Expand Down
44 changes: 22 additions & 22 deletions src/mineralML/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,36 +139,36 @@ def load_scaler(scaler_path):
return mean, std


def create_dataloader(df, batch_size=128, shuffle=False):
# def create_dataloader(df, batch_size=128, shuffle=False):

"""
# """

Creates a DataLoader for the given DataFrame. It normalizes the input features and converts
the 'Mineral' column to categorical codes based on predefined mineral classes. The resulting
DataLoader can be used to iterate over the dataset in batches during model training or evaluation.
The function relies on the 'load_minclass_nn' function to obtain the list of category names
for the 'Mineral' column and the 'norm_data' function to normalize the feature columns
before creating the DataLoader.
# Creates a DataLoader for the given DataFrame. It normalizes the input features and converts
# the 'Mineral' column to categorical codes based on predefined mineral classes. The resulting
# DataLoader can be used to iterate over the dataset in batches during model training or evaluation.
# The function relies on the 'load_minclass_nn' function to obtain the list of category names
# for the 'Mineral' column and the 'norm_data' function to normalize the feature columns
# before creating the DataLoader.

Parameters:
df (DataFrame): The DataFrame containing features and mineral labels to load into the DataLoader.
batch_size (int): The number of samples to load per batch. Defaults to 128.
shuffle (bool): Whether to shuffle the data before loading it. Defaults to False.
# Parameters:
# df (DataFrame): The DataFrame containing features and mineral labels to load into the DataLoader.
# batch_size (int): The number of samples to load per batch. Defaults to 128.
# shuffle (bool): Whether to shuffle the data before loading it. Defaults to False.

Returns:
dataloader (DataLoader): A PyTorch DataLoader object ready for model consumption.
# Returns:
# dataloader (DataLoader): A PyTorch DataLoader object ready for model consumption.

"""
# """

min_cat, _ = load_minclass_nn()
data_x = norm_data(df)
df['Mineral'] = pd.Categorical(df['Mineral'], categories=min_cat)
data_y = df['Mineral'].cat.codes.values
# min_cat, _ = load_minclass_nn()
# data_x = norm_data_nn(df)
# df['Mineral'] = pd.Categorical(df['Mineral'], categories=min_cat)
# data_y = df['Mineral'].cat.codes.values

label_dataset = LabelDataset(data_x, data_y)
dataloader = DataLoader(label_dataset, batch_size, shuffle)
# label_dataset = LabelDataset(data_x, data_y)
# dataloader = DataLoader(label_dataset, batch_size, shuffle)

return dataloader
# return dataloader


def weights_init(m):
Expand Down

0 comments on commit c906d48

Please sign in to comment.