Skip to content

Commit

Permalink
[pytorch-vulkan] aten::.rand_like (#108086)
Browse files Browse the repository at this point in the history
Summary:

Before implementing `aten::.randn_like` as requested (T152843033), I think it worth to extend `aten::rand_like` from existing `aten::uniform`, since they're so similar.

Test Plan:
```
[ttingchulin@6945.od /data/sandcastle/boxes/fbsource (rand_like)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin  -- --gtest_filter="*<test>*" eg.  -- --gtest_filter="*rand_like*"
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from VulkanAPITest
[ RUN      ] VulkanAPITest.rand_like
[       OK ] VulkanAPITest.rand_like (136 ms)
[----------] 1 test from VulkanAPITest (136 ms total)

[----------] Global test environment tear-down
[==========] 1 test from 1 test suite ran. (136 ms total)
[  PASSED  ] 1 test.

[ttingchulin@6945.od /data/sandcastle/boxes/fbsource (rand_like)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin  -- --gtest_filter="*<test>*" eg.  -- --gtest_filter="*uniform*"
Building: finished in 0.1 sec (100%) 329/329 jobs, 0/329 updated
  Total time: 0.1 sec
BUILD SUCCEEDED
Running main() from xplat/third-party/gmock/googletest-1.12.1/googletest/src/gtest_main.cc
Note: Google Test filter = *uniform*
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from VulkanAPITest
[ RUN      ] VulkanAPITest.uniform
[       OK ] VulkanAPITest.uniform (131 ms)
[----------] 1 test from VulkanAPITest (131 ms total)

[----------] Global test environment tear-down
[==========] 1 test from 1 test suite ran. (131 ms total)
[  PASSED  ] 1 test.

[ttingchulin@6945.od /data/sandcastle/boxes/fbsource (rand_like)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin
ALL PASS
```

Reviewed By: yipjustin

Differential Revision: D48710273
  • Loading branch information
tina134 authored and facebook-github-bot committed Sep 5, 2023
1 parent 64ad16a commit 524d037
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 12 deletions.
21 changes: 18 additions & 3 deletions aten/src/ATen/native/vulkan/ops/Random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ using namespace api::utils;

Tensor& uniform_(
Tensor& self,
double from,
double to,
c10::optional<at::Generator> /* not implemented */) {
const double from,
const double to,
const c10::optional<at::Generator> /* not implemented */) {
TORCH_CHECK(
self.is_vulkan(),
"Vulkan: In-place operator is only supported on Vulkan tensors.");
Expand Down Expand Up @@ -57,10 +57,25 @@ Tensor& uniform_(
return self;
}

Tensor rand_like(
const at::Tensor& input_arg,
const c10::optional<c10::ScalarType> /* not implemented */,
const c10::optional<c10::Layout> /* not implemented */,
const c10::optional<c10::Device> /* not implemented */,
const c10::optional<bool> /* not implemented */,
const c10::optional<c10::MemoryFormat> /* not implemented */) {
// Returns a tensor with the same size as input that is filled with random
// numbers from a uniform distribution on the interval [0,1). To match the CPU
// implementation, we simplify the range to [0,1] and tolerate the small
// chance of 1 being sampled.
return input_arg.clone().detach().uniform_(0.0, 1.0);
}

#ifdef USE_VULKAN_API

TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
m.impl(TORCH_SELECTIVE_NAME("aten::uniform_"), TORCH_FN(uniform_));
m.impl(TORCH_SELECTIVE_NAME("aten::rand_like"), TORCH_FN(rand_like));
}

#endif /* USE_VULKAN_API */
Expand Down
31 changes: 22 additions & 9 deletions aten/src/ATen/test/vulkan_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3921,16 +3921,9 @@ TEST_F(VulkanAPITest, sum_dim_keepdim_4d) {
test_sum_dim({9, 5, 7, 11}, {-2, -3, -4}, true);
}

TEST_F(VulkanAPITest, uniform) {
float a_min = -8.2f;
float a_max = -1.4f;

auto a_vulkan =
at::rand({8, 7, 12, 10}, at::device(at::kCPU).dtype(at::kFloat)).vulkan();
a_vulkan.uniform_(a_min, a_max);
void test_uniform(at::Tensor a_vulkan, const float a_min, const float a_max) {
auto a_cpu = a_vulkan.cpu();

ASSERT_TRUE(a_cpu.max().item<float>() < a_max);
ASSERT_TRUE(a_cpu.max().item<float>() <= a_max);
ASSERT_TRUE(a_cpu.min().item<float>() >= a_min);

// Verify range, also perform a loose check with on histogram distribution.
Expand All @@ -3956,6 +3949,26 @@ TEST_F(VulkanAPITest, uniform) {
(expected_per_bin * 0.05));
}

TEST_F(VulkanAPITest, uniform) {
float a_min = -8.2f;
float a_max = -1.4f;
auto a_vulkan =
at::rand({8, 7, 12, 10}, at::device(at::kCPU).dtype(at::kFloat)).vulkan();
a_vulkan.uniform_(a_min, a_max);
test_uniform(a_vulkan, a_min, a_max);
}

TEST_F(VulkanAPITest, rand_like) {
float a_min = 0.0f;
float a_max = 1.0f;
auto a_vulkan =
at::zeros({8, 7, 12, 10}, at::device(at::kCPU).dtype(at::kFloat)).vulkan();
const auto out_vulkan = at::rand_like(a_vulkan);
// verify that the input are still all zeros (not in-place)
ASSERT_TRUE(at::mean(a_vulkan.cpu()).item<float>() == 0.0);
test_uniform(out_vulkan, a_min, a_max);
}

void test_t(const at::IntArrayRef input_shape) {
const auto in_cpu = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat));
const auto out_cpu = at::t(in_cpu);
Expand Down

0 comments on commit 524d037

Please sign in to comment.