Skip to content

Commit

Permalink
Update on "Generate type match guard for torch.Size input"
Browse files Browse the repository at this point in the history
I suppose hypothetically, if the user code ends up working
polymorphically over the SizeVariable, in such a way that a tuple would
work, this type match is not necessary.  But we do not carefully test
for this.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
  • Loading branch information
ezyang committed Mar 10, 2023
2 parents aab2e4c + 10ca914 commit 3c5fa0e
Show file tree
Hide file tree
Showing 42 changed files with 1,208 additions and 386 deletions.
19 changes: 6 additions & 13 deletions aten/src/ATen/native/mps/operations/Normalization.mm
Original file line number Diff line number Diff line change
Expand Up @@ -254,20 +254,16 @@ Check if running mean exists (maybe do this check before making graph)
// Update saved mean and inverse std tensor
MPSGraphTensor *epsilonTensor = [mpsGraph constantWithScalar:(double)epsilon
shape:@[@1]
dataType:MPSDataTypeFloat32];
dataType:input_mps_dtype];

MPSGraphTensor *varianceEps = [mpsGraph additionWithPrimaryTensor:batchVarianceTensor
secondaryTensor:epsilonTensor
name:@"varianceEps"];

MPSGraphTensor *sqrtVariance = [mpsGraph squareRootWithTensor:varianceEps
name:@"sqrtVariance"];
float primary = 1.0f;
MPSGraphTensor *primaryTensor = [mpsGraph constantWithScalar:primary dataType:MPSDataTypeFloat32];

scaledInverseSqrtVariance = [mpsGraph divisionWithPrimaryTensor:primaryTensor
secondaryTensor:sqrtVariance
name:nil];
scaledInverseSqrtVariance = [mpsGraph reciprocalWithTensor:sqrtVariance
name:nil];
// Update saved mean and inverse std tensor
saveMeanTensor = batchMeanTensor;
saveVarTensor = scaledInverseSqrtVariance;
Expand Down Expand Up @@ -678,13 +674,10 @@ string get_mem_string(c10::MemoryFormat memory_format) {

if(train) {
// Use save_mean and save_var
float primary = 1.0f;
MPSGraphTensor *primaryTensor = [mpsGraph constantWithScalar:primary dataType:MPSDataTypeFloat32];
MPSGraphTensor *epsilonTensor = [mpsGraph constantWithScalar:(float)epsilon dataType:MPSDataTypeFloat32];
MPSGraphTensor *epsilonTensor = [mpsGraph constantWithScalar:(float)epsilon dataType:input_mps_dtype];
MPSGraphTensor *revertSaveVarTensor = saveVarTensor;
revertSaveVarTensor = [mpsGraph divisionWithPrimaryTensor: primaryTensor
secondaryTensor: revertSaveVarTensor
name: nil];
revertSaveVarTensor = [mpsGraph reciprocalWithTensor: revertSaveVarTensor
name: nil];
revertSaveVarTensor = [mpsGraph multiplicationWithPrimaryTensor: revertSaveVarTensor
secondaryTensor: revertSaveVarTensor
name: nil];
Expand Down
39 changes: 39 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/select_batch_4d.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */

/*
* Output Image
*/
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;

/*
* Input Buffer
*/
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;

/*
* Params Buffer
*/
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// batch_info.x: number of texels per batch
// batch_info.y: index along batch dim to select
ivec2 batch_info;
}
uBlock;

/*
* Local Work Group Size
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const uint src_pos_z = (uBlock.batch_info.y * uBlock.batch_info.x) + pos.z;
imageStore(
uOutput, pos, texelFetch(uInput, ivec3(pos.x, pos.y, src_pos_z), 0));
}
31 changes: 0 additions & 31 deletions aten/src/ATen/native/vulkan/glsl/select_depth.glsl

This file was deleted.

32 changes: 32 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/select_depth_3d.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */

layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// depth_info.x: output texture x extent
// depth_info.y: output texture y extent
// depth_info.z: output texture z extent
// depth_info.w: output texture w extent
ivec4 depth_info;
}
uBlock;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (all(lessThan(pos, uBlock.depth_info.xyz))) {
const int tex = uBlock.depth_info.w / 4;
const int ind = uBlock.depth_info.w % 4;
const float v = texelFetch(uInput, ivec3(pos.x, pos.y, tex), 0)[ind];

imageStore(uOutput, ivec3(pos.x, pos.y, 0), vec4(v, 0, 0, 0));
}
}
53 changes: 53 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/select_depth_4d.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */

/*
* Output Image
*/
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;

/*
* Input Buffer
*/
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;

/*
* Params Buffer
*/
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// depth_info.x: number of batches
// depth_info.y: number of texels per batch
// depth_info.z: index along channel dim to select
// depth_info.w: zero pad for alignment
ivec4 depth_info;
}
uBlock;

/*
* Local Work Group Size
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
// read in the same channel from 4 separate batches
vec4 out_texel = vec4(0, 0, 0, 0);
for (int k = 0; k < 4; k++) {
if ((k + pos.z * 4) >=
uBlock.depth_info.x) { // < 4 batches for this texel, exit early
break;
}
const uint src_pos_z = (4 * uBlock.depth_info.y * pos.z) +
(k * uBlock.depth_info.y) + (uBlock.depth_info.z / 4);
const uint src_pos_t = uBlock.depth_info.z % 4;
out_texel[k] =
texelFetch(uInput, ivec3(pos.x, pos.y, src_pos_z), 0)[src_pos_t];
}

imageStore(uOutput, pos, out_texel);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@ layout(std430) buffer;
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
ivec3 size;
int index;
} uBlock;
// height_info.x: output texture x extent
// height_info.y: output texture y extent
// height_info.z: output texture z extent
// height_info.w: output texture w extent
ivec4 height_info;
}
uBlock;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

Expand All @@ -21,7 +25,7 @@ void main() {
// w
const int src_x = pos.x;
// h
const int src_y = uBlock.index;
const int src_y = uBlock.height_info.w;
// c
const int src_z = pos.y;

Expand All @@ -31,7 +35,7 @@ void main() {
ivec3 new_pos = ivec3(pos.x, pos.y * 4 + i, 0);

// When the C-channel exceeds original block size, exit early
if (new_pos.y >= uBlock.size.y) {
if (new_pos.y >= uBlock.height_info.y) {
return;
}

Expand Down
51 changes: 51 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/select_height_4d.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */

/*
* Output Image
*/
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;

/*
* Input Buffer
*/
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;

/*
* Params Buffer
*/
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// height_info.x: number of batches
// height_info.y: number of texels per batch
// height_info.z: index along height dim to select
// height_info.w: zero pad for alignment
ivec4 height_info;
}
uBlock;

/*
* Local Work Group Size
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
vec4 out_texel = vec4(0, 0, 0, 0);
// read in the same channel from 4 separate batches
for (int k = 0; k < 4; k++) {
if ((k + pos.z * 4) >=
uBlock.height_info.x) { // < 4 batches for this texel, exit early
break;
}
const uint src_pos_z = (pos.z * uBlock.height_info.y * 4) +
k * uBlock.height_info.y + (pos.y / 4);
out_texel[k] = texelFetch(
uInput, ivec3(pos.x, uBlock.height_info.z, src_pos_z), 0)[pos.y % 4];
}
imageStore(uOutput, pos, out_texel);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,21 @@ layout(std430) buffer;
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
ivec3 size;
int index;
} uBlock;
// width_info.x: output texture x extent
// width_info.y: output texture y extent
// width_info.z: output texture z extent
// width_info.w: output texture w extent
ivec4 width_info;
}
uBlock;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

// w
const int src_x = uBlock.index;
const int src_x = uBlock.width_info.w;
// h
const int src_y = pos.x;
// c
Expand All @@ -31,7 +35,7 @@ void main() {
ivec3 new_pos = ivec3(pos.x, pos.y * 4 + i, 0);

// When the C-channel exceeds original block size, exit early
if (new_pos.y >= uBlock.size.y) {
if (new_pos.y >= uBlock.width_info.y) {
return;
}

Expand Down
51 changes: 51 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/select_width_4d.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */

/*
* Output Image
*/
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;

/*
* Input Buffer
*/
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;

/*
* Params Buffer
*/
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// width_info.x: number of batches
// width_info.y: number of texels per batch
// width_info.z: index along width dim to select
// width_info.w: zero pad for alignment
ivec4 width_info;
}
uBlock;

/*
* Local Work Group Size
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
vec4 out_texel = vec4(0, 0, 0, 0);
// read in the same channel from 4 separate batches
for (int k = 0; k < 4; k++) {
if ((k + pos.z * 4) >=
uBlock.width_info.x) { // < 4 batches for this texel, exit early
break;
}
const uint src_pos_z = (pos.z * uBlock.width_info.y * 4) +
k * uBlock.width_info.y + (pos.y / 4);
out_texel[k] = texelFetch(
uInput, ivec3(uBlock.width_info.z, pos.x, src_pos_z), 0)[pos.y % 4];
}
imageStore(uOutput, pos, out_texel);
}

0 comments on commit 3c5fa0e

Please sign in to comment.