Skip to content

Commit

Permalink
Added tests to test for deterministic_maxpool3d
Browse files Browse the repository at this point in the history
  • Loading branch information
samhu1 committed Apr 26, 2024
1 parent 93746fb commit 3f14759
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions test/test_nn.py
Expand Up @@ -52,6 +52,7 @@
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, tf32_off, tf32_on
from torch.types import _TensorOrTensors
from torch.testing._internal.common_mkldnn import bf32_on_and_off
from torch.nn.functional import deterministic_max_pool3d

AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()

Expand Down Expand Up @@ -7165,6 +7166,42 @@ def test_preserves_memory_format(self):
nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5)
self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last))

class TestDeterministicMaxPool3D(unittest.TestCase):
def setUp(self):
# This method will be called before each test.
self.input_tensor = torch.randn(1, 1, 4, 4, 4, dtype=torch.double)
self.kernel_size = 2
self.stride = 2
self.padding = 0
self.dilation = 1
self.ceil_mode = False

def test_basic_functionality(self):
# Test the basic functionality of the deterministic max pooling
deterministic_output = deterministic_max_pool3d(
self.input_tensor, self.kernel_size, self.stride, self.padding,
self.dilation, self.ceil_mode, deterministic=False
)
expected_output = torch.nn.functional.max_pool3d(
self.input_tensor, self.kernel_size, self.stride, self.padding,
self.dilation, self.ceil_mode
)
self.assertTrue(torch.allclose(deterministic_output, expected_output),
"The deterministic function does not match expected output.")

def test_deterministic_output(self):
# Test that deterministic flag provides the same output on multiple runs
output1 = deterministic_max_pool3d(
self.input_tensor, self.kernel_size, self.stride, self.padding,
self.dilation, self.ceil_mode, deterministic=True
)
output2 = deterministic_max_pool3d(
self.input_tensor, self.kernel_size, self.stride, self.padding,
self.dilation, self.ceil_mode, deterministic=True
)
self.assertTrue(torch.allclose(output1, output2),
"Deterministic outputs are not identical on repeated runs.")


class TestAddRelu(TestCase):
def test_add_relu(self):
Expand Down

0 comments on commit 3f14759

Please sign in to comment.