From 3f14759fd761f8a4ee1f5987091889a8536fe5b1 Mon Sep 17 00:00:00 2001 From: samhu1 Date: Fri, 26 Apr 2024 13:17:51 -0400 Subject: [PATCH] Added tests to test for deterministic_maxpool3d --- test/test_nn.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/test_nn.py b/test/test_nn.py index 008354ad721e..e823cc894565 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -52,6 +52,7 @@ from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, tf32_off, tf32_on from torch.types import _TensorOrTensors from torch.testing._internal.common_mkldnn import bf32_on_and_off +from torch.nn.functional import deterministic_max_pool3d AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32() @@ -7165,6 +7166,42 @@ def test_preserves_memory_format(self): nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5) self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last)) +class TestDeterministicMaxPool3D(unittest.TestCase): + def setUp(self): + # This method will be called before each test. + self.input_tensor = torch.randn(1, 1, 4, 4, 4, dtype=torch.double) + self.kernel_size = 2 + self.stride = 2 + self.padding = 0 + self.dilation = 1 + self.ceil_mode = False + + def test_basic_functionality(self): + # Test the basic functionality of the deterministic max pooling + deterministic_output = deterministic_max_pool3d( + self.input_tensor, self.kernel_size, self.stride, self.padding, + self.dilation, self.ceil_mode, deterministic=False + ) + expected_output = torch.nn.functional.max_pool3d( + self.input_tensor, self.kernel_size, self.stride, self.padding, + self.dilation, self.ceil_mode + ) + self.assertTrue(torch.allclose(deterministic_output, expected_output), + "The deterministic function does not match expected output.") + + def test_deterministic_output(self): + # Test that deterministic flag provides the same output on multiple runs + output1 = deterministic_max_pool3d( + self.input_tensor, self.kernel_size, self.stride, self.padding, + self.dilation, self.ceil_mode, deterministic=True + ) + output2 = deterministic_max_pool3d( + self.input_tensor, self.kernel_size, self.stride, self.padding, + self.dilation, self.ceil_mode, deterministic=True + ) + self.assertTrue(torch.allclose(output1, output2), + "Deterministic outputs are not identical on repeated runs.") + class TestAddRelu(TestCase): def test_add_relu(self):