-
Notifications
You must be signed in to change notification settings - Fork 22.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pytorch-vulkan] add aten::randn_like & aten::normal_
Summary: Implemented `aten::normal_` shader and used it to create `aten::randn_like`. Ops defintions: https://pytorch.org/docs/stable/generated/torch.randn_like.html https://pytorch.org/docs/stable/generated/torch.Tensor.normal_.html Test Plan: ``` [ttingchulin@53491.od /data/sandcastle/boxes/fbsource (randn)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*<test>*" eg. -- --gtest_filter="*randn_like*" [==========] Running 2 tests from 1 test suite. [----------] Global test environment set-up. [----------] 2 tests from VulkanAPITest [ RUN ] VulkanAPITest.randn_like [ OK ] VulkanAPITest.randn_like (230 ms) [ RUN ] VulkanAPITest.randn_like_large [ OK ] VulkanAPITest.randn_like_large (570 ms) [----------] 2 tests from VulkanAPITest (801 ms total) [----------] Global test environment tear-down [==========] 2 tests from 1 test suite ran. (801 ms total) [ PASSED ] 2 tests. [ttingchulin@53491.od /data/sandcastle/boxes/fbsource (randn)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*<test>*" eg. -- --gtest_filter="*normal_*" [==========] Running 3 tests from 1 test suite. [----------] Global test environment set-up. [----------] 3 tests from VulkanAPITest [ RUN ] VulkanAPITest.normal_ [ OK ] VulkanAPITest.normal_ (222 ms) [ RUN ] VulkanAPITest.normal_large [ OK ] VulkanAPITest.normal_large (136 ms) [ RUN ] VulkanAPITest.normal_error [ OK ] VulkanAPITest.normal_error (37 ms) [----------] 3 tests from VulkanAPITest (396 ms total) [----------] Global test environment tear-down [==========] 3 tests f. ``` Reviewed By: yipjustin Differential Revision: D48814024
- Loading branch information
1 parent
405f014
commit ba6202d
Showing
4 changed files
with
186 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#version 450 core | ||
#define PRECISION $precision | ||
#define FORMAT $format | ||
|
||
#include "random.h" | ||
|
||
layout(std430) buffer; | ||
|
||
/* Qualifiers: layout - storage - precision - memory */ | ||
|
||
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict image3D uOutput; | ||
layout(set = 0, binding = 1) uniform PRECISION restrict Block { | ||
ivec3 size; | ||
float mean; | ||
float std; | ||
} uBlock; | ||
|
||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
||
void main() { | ||
ivec3 pos = ivec3(gl_GlobalInvocationID); | ||
|
||
if (all(lessThan(pos, uBlock.size))) { | ||
vec4 v = vec4( | ||
get_gaussrand(ivec4(pos, -20), uBlock.mean, uBlock.std), | ||
get_gaussrand(ivec4(pos, 40), uBlock.mean, uBlock.std), | ||
get_gaussrand(ivec4(pos, -30), uBlock.mean, uBlock.std), | ||
get_gaussrand(ivec4(pos, 15), uBlock.mean, uBlock.std)); | ||
imageStore(uOutput, pos, v); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters