Skip to content

Commit

Permalink
Update bugs in test_core.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahshi committed Jan 25, 2024
1 parent 4aa9ebb commit 04a0da6
Showing 1 changed file with 29 additions and 26 deletions.
55 changes: 29 additions & 26 deletions UnitTests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ def setUp(self):
'Mineral2', 'Mineral1', 'Mineral2', 'Mineral1', 'Mineral2']
})

@patch('mm.load_minclass_nn')
@patch('mm.norm_data')
@patch('mineralML.load_minclass_nn')
@patch('mineralML.norm_data')
def test_create_dataloader(self, mock_norm_data, mock_load_minclass_nn):
# Mock return values for dependencies
mock_norm_data.return_value = self.df
Expand All @@ -141,34 +141,35 @@ def test_create_dataloader(self, mock_norm_data, mock_load_minclass_nn):
mock_load_minclass_nn.assert_called_once()


class MockNetwork(nn.Module):
def __init__(self):
super(MockNetwork, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.bn1 = nn.BatchNorm2d(20)
class test_NetworkWeights(unittest.TestCase):

def is_normal_tensor(tensor, mean, std):
return torch.all(torch.abs(tensor - mean) < 3 * std)
class MockNetwork(nn.Module):
def __init__(self):
super(test_NetworkWeights.MockNetwork, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.bn1 = nn.BatchNorm2d(20)

class test_weights(unittest.TestCase):
def is_normal_tensor(self, tensor, mean, std):
return torch.all(torch.abs(tensor - mean) < 3 * std).item()

def test_weights_init(self):
# Create a mock network
net = MockNetwork()
def setUp(self):
self.net = test_NetworkWeights.MockNetwork()

def test_weights_init(self):
# Apply the weights_init function
net.apply(mm.weights_init)
self.net.apply(mm.weights_init)

# Check if weights and biases of BatchNorm layers are initialized correctly
for module in net.modules():
for module in self.net.modules():
if isinstance(module, nn.BatchNorm2d):
# Check weights
self.assertTrue(is_normal_tensor(module.weight.data, 1.0, 0.02),
self.assertTrue(self.is_normal_tensor(module.weight.data, 1.0, 0.02),
"Weights of BatchNorm layer are not properly initialized")
# Check biases
self.assertTrue(torch.all(module.bias.data == 0),
self.assertTrue(torch.all(module.bias.data == 0).item(),
"Biases of BatchNorm layer are not initialized to 0")


class test_same_seeds(unittest.TestCase):

def test_reproducibility(self):
Expand All @@ -191,20 +192,22 @@ def test_reproducibility(self):
self.assertEqual(np_rand, np_rand_repeat, "NumPy random numbers do not match")
self.assertEqual(py_rand, py_rand_repeat, "Python random numbers do not match")



class MockModel(nn.Module):
def __init__(self):
super(MockModel, self).__init__()
self.conv = nn.Conv2d(1, 20, 5)


class test_SaveModel(unittest.TestCase):

def setUp(self):
self.model = MockModel()
self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)

def compare_state_dicts(self, dict1, dict2):
self.assertEqual(set(dict1.keys()), set(dict2.keys()))
for key in dict1:
self.assertTrue(torch.equal(dict1[key], dict2[key]), f"Mismatch in tensors for key: {key}")

def test_save_model_ae(self):
with TemporaryDirectory() as tmp_dir:
filepath = os.path.join(tmp_dir, "model_ae.pth")
Expand All @@ -217,7 +220,7 @@ def test_save_model_ae(self):
checkpoint = torch.load(filepath)
self.assertIn('params', checkpoint)
self.assertIn('optimizer', checkpoint)
self.assertDictEqual(checkpoint['params'], self.model.state_dict())
self.compare_state_dicts(checkpoint['params'], self.model.state_dict())

def test_save_model_nn(self):
best_model_state = self.model.state_dict()
Expand All @@ -232,19 +235,19 @@ def test_save_model_nn(self):
checkpoint = torch.load(filepath)
self.assertIn('params', checkpoint)
self.assertIn('optimizer', checkpoint)
self.assertDictEqual(checkpoint['params'], best_model_state)
self.compare_state_dicts(checkpoint['params'], best_model_state)


def save_checkpoint(model, optimizer, path):
check_point = {'params': model.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(check_point, path)

class test_LoadModel(unittest.TestCase):

def setUp(self):
self.model = MockModel()
self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)

def save_checkpoint(model, optimizer, path):
check_point = {'params': model.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(check_point, path)

def test_load_model(self):
with TemporaryDirectory() as tmp_dir:
filepath = os.path.join(tmp_dir, "model_checkpoint.pth")
Expand Down

0 comments on commit 04a0da6

Please sign in to comment.