Skip to content

Commit

Permalink
Batch dimension test floating point tolerance
Browse files Browse the repository at this point in the history
  • Loading branch information
smericks committed Jul 1, 2024
1 parent f046fc0 commit 0967a26
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/analysis_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,13 +1273,13 @@ def test_sanity_loss(self):
# prior = proposal, so should be same as gaussian loss
loss_diff = np.abs(gaussian_loss1.loss(truth1,output1).numpy()[0] -
snpe_c_loss1.loss(truth1,output1).numpy()[0])
self.assertTrue(loss_diff < 0.0001)
self.assertTrue(loss_diff < 0.001)

# let's try adding in a batch dimension
loss_diff_array = np.abs(gaussian_loss1.loss(truth1_batched,output1_batched).numpy() -
snpe_c_loss1.loss(truth1_batched,output1_batched).numpy())
# checking that the difference in loss is never greater than 0.0001
self.assertTrue(np.sum(loss_diff_array > 0.0001) == 0)
self.assertTrue(np.sum(loss_diff_array > 0.001) == 0)

def test_ratios_loss(self):

Expand Down

0 comments on commit 0967a26

Please sign in to comment.