diff --git a/aten/src/ATen/native/vulkan/ops/Random.cpp b/aten/src/ATen/native/vulkan/ops/Random.cpp index fe4d41af6d28b..a0a16538abea4 100644 --- a/aten/src/ATen/native/vulkan/ops/Random.cpp +++ b/aten/src/ATen/native/vulkan/ops/Random.cpp @@ -14,9 +14,9 @@ using namespace api::utils; Tensor& uniform_( Tensor& self, - double from, - double to, - c10::optional /* not implemented */) { + const double from, + const double to, + const c10::optional /* not implemented */) { TORCH_CHECK( self.is_vulkan(), "Vulkan: In-place operator is only supported on Vulkan tensors."); @@ -57,10 +57,25 @@ Tensor& uniform_( return self; } +Tensor rand_like( + const at::Tensor& input_arg, + const c10::optional /* not implemented */, + const c10::optional /* not implemented */, + const c10::optional /* not implemented */, + const c10::optional /* not implemented */, + const c10::optional /* not implemented */) { + // Returns a tensor with the same size as input that is filled with random + // numbers from a uniform distribution on the interval [0,1). To match the CPU + // implementation, we simplify the range to [0,1] and tolerate the small + // chance of 1 being sampled. + return input_arg.clone().detach().uniform_(0.0, 1.0); +} + #ifdef USE_VULKAN_API TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl(TORCH_SELECTIVE_NAME("aten::uniform_"), TORCH_FN(uniform_)); + m.impl(TORCH_SELECTIVE_NAME("aten::rand_like"), TORCH_FN(rand_like)); } #endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index 7a3072d4d4f94..fa21380ac3b2d 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -3921,16 +3921,9 @@ TEST_F(VulkanAPITest, sum_dim_keepdim_4d) { test_sum_dim({9, 5, 7, 11}, {-2, -3, -4}, true); } -TEST_F(VulkanAPITest, uniform) { - float a_min = -8.2f; - float a_max = -1.4f; - - auto a_vulkan = - at::rand({8, 7, 12, 10}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(); - a_vulkan.uniform_(a_min, a_max); +void test_uniform(at::Tensor a_vulkan, const float a_min, const float a_max) { auto a_cpu = a_vulkan.cpu(); - - ASSERT_TRUE(a_cpu.max().item() < a_max); + ASSERT_TRUE(a_cpu.max().item() <= a_max); ASSERT_TRUE(a_cpu.min().item() >= a_min); // Verify range, also perform a loose check with on histogram distribution. @@ -3956,6 +3949,26 @@ TEST_F(VulkanAPITest, uniform) { (expected_per_bin * 0.05)); } +TEST_F(VulkanAPITest, uniform) { + float a_min = -8.2f; + float a_max = -1.4f; + auto a_vulkan = + at::rand({8, 7, 12, 10}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(); + a_vulkan.uniform_(a_min, a_max); + test_uniform(a_vulkan, a_min, a_max); +} + +TEST_F(VulkanAPITest, rand_like) { + float a_min = 0.0f; + float a_max = 1.0f; + auto a_vulkan = + at::zeros({8, 7, 12, 10}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(); + const auto out_vulkan = at::rand_like(a_vulkan); + // verify that the input are still all zeros (not in-place) + ASSERT_TRUE(at::mean(a_vulkan.cpu()).item() == 0.0); + test_uniform(out_vulkan, a_min, a_max); +} + void test_t(const at::IntArrayRef input_shape) { const auto in_cpu = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)); const auto out_cpu = at::t(in_cpu);