Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kernels/portable/cpu/op_convolution_backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> convolution_backward_out(

constexpr auto name = "convolution_backward.out";

ET_SWITCH_FLOATH_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() {
ET_SWITCH_FLOATHBF16_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() {
conv2d_backward_impl<CTYPE>(
grad_output,
input,
Expand Down
1 change: 1 addition & 0 deletions kernels/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ set(all_test_sources
"op_clamp_test.cpp"
"op_clone_test.cpp"
"op_constant_pad_nd_test.cpp"
"op_convolution_backward_test.cpp"
"op_convolution_test.cpp"
"op_copy_test.cpp"
"op_cos_test.cpp"
Expand Down
274 changes: 145 additions & 129 deletions kernels/test/op_convolution_backward_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,139 +62,155 @@ class OpConvolutionBackwardOutTest : public OperatorTest {
grad_weight,
grad_bias);
}
};

TEST_F(OpConvolutionBackwardOutTest, SmokeTest) {
TensorFactory<ScalarType::Float> tf;
template <ScalarType DTYPE>
void test_dtype() {
TensorFactory<DTYPE> tf;

std::vector<float> grad_output_data = {
10, 12, 87, 13, 34, 87, 55, 22, 48, 33, 29, 38, 60, 49, 88, 30,
99, 19, 42, 37, 61, 31, 33, 58, 38, 23, 2, 33, 3, 21, 32, 2,
30, 72, 10, 67, 92, 19, 11, 16, 65, 37, 60, 74, 4, 19, 45, 37};
std::vector<float> input_data = {
9, 89, 45, 39, 25, 2, 97, 55, 80, 24, 18, 33, 28, 89, 19, 16, 19, 33,
69, 61, 34, 84, 58, 30, 33, 18, 75, 30, 6, 33, 42, 10, 80, 41, 66, 64,
47, 51, 67, 62, 58, 10, 97, 71, 24, 44, 84, 34, 33, 54, 8, 73, 90, 15,
21, 92, 55, 22, 56, 12, 10, 63, 32, 76, 65, 38, 95, 92, 22, 15, 37, 12,
67, 14, 60, 44, 73, 74, 23, 4, 56, 64, 88, 90, 82, 32, 91, 3, 6, 87,
55, 95, 7, 14, 24, 69, 52, 44, 14, 37, 75, 52, 37, 40, 25, 54, 4, 15,
97, 51, 46, 28, 65, 95, 50, 82, 23, 39, 50, 55, 97, 52, 91, 16, 19, 49,
61, 50, 42, 47, 87, 99, 9, 60, 22, 71, 47, 17, 0, 80, 28, 88, 93, 43,
65, 25, 88, 67, 21, 89, 24, 81, 3, 71, 20, 34, 17, 17, 94, 10, 82, 25,
10, 11, 7, 28, 77, 39, 74, 79, 17, 40, 67, 54, 49, 54, 21, 89, 17, 7,
52, 64, 68, 80, 7, 72, 44, 35, 92, 47, 4, 13, 10, 43, 64, 66, 83, 49,
81, 78, 58, 22, 86, 48, 35, 64, 98, 79, 8, 52, 56, 23, 38, 74, 16, 63,
51, 70, 44, 28, 43, 13, 51, 85, 42, 29, 64, 26, 54, 91, 9, 96, 41, 56,
7, 52, 27, 22, 69, 13, 8, 20, 22, 49, 66, 98, 77, 42, 54, 38, 70, 83,
13, 8, 21, 56, 78, 37, 28, 69, 42, 30, 91, 5, 28, 15, 20, 14, 16, 39,
95, 66, 4, 72, 52, 35, 54, 93, 87, 77, 3, 49, 82, 70, 84, 3, 73, 99,
32, 95, 58, 65, 32, 75, 34, 22, 12, 84, 63, 72, 85, 66, 63, 27, 3, 73,
45, 37, 61, 52, 41, 16, 37, 14, 80, 17, 48, 8, 87, 98, 69, 63, 92, 68,
42, 63, 5, 22, 66, 91, 74, 11, 17, 45, 45, 33, 40, 85, 26, 75, 73, 81,
54, 27, 80, 1, 44, 66, 10, 21, 15, 10, 76, 96, 0, 43, 39, 3, 57, 79,
45, 64, 58, 92, 44, 42, 7, 28, 94, 4, 8, 22, 22, 31, 75, 44, 3, 70,
83, 72, 87, 12, 20, 55, 84, 31, 50, 34, 25, 49, 29, 71, 57, 97, 25, 82,
84, 42, 86, 41, 54, 92, 34, 30, 52, 34, 84, 25, 54, 37, 38, 26, 76, 82,
34, 14, 85, 28, 93, 9};
std::vector<float> weight_data = {
2, 54, 9, 37, 0, 47, 70, 9, 84, 69, 56, 79, 25, 35, 54, 13,
65, 46, 38, 28, 74, 27, 66, 61, 20, 60, 62, 58, 15, 44, 75, 55,
7, 52, 13, 36, 39, 64, 62, 45, 100, 6, 79, 63, 63, 52, 37, 60,
78, 12, 69, 2, 74, 56, 93, 39, 62, 22, 55, 67, 68, 74, 12, 69,
15, 73, 28, 70, 86, 20, 90, 49, 52, 26, 58, 2, 82, 17, 70, 55,
54, 83, 70, 11, 27, 9, 5, 42, 34, 62, 29, 94, 69, 81, 54, 4};
std::vector<float> expected_grad_input_data = {
1134, 7578, 686, 2682, 0, 4148, 7136, 2406, 8698, 0,
3759, 6003, 2163, 2395, 0, 2929, 5830, 3469, 6955, 0,
720, 6201, 495, 2063, 0, 5260, 5989, 3060, 7079, 0,
9690, 3423, 3385, 1932, 0, 7644, 8499, 1323, 2613, 0,
4334, 6624, 8532, 9719, 0, 5496, 8601, 1157, 2215, 0,
4676, 7600, 6524, 10069, 0, 4047, 6117, 1612, 2567, 0,
5931, 5651, 5669, 6623, 0, 7674, 3291, 2748, 1654, 0,
10455, 4290, 4145, 796, 0, 9835, 5483, 11649, 5952, 0,
7098, 5460, 3101, 2443, 0, 7788, 5909, 8582, 6298, 0,
9462, 4845, 3041, 2067, 0, 7038, 6336, 10438, 6377, 0,
7518, 8187, 2079, 2773, 0, 10036, 2642, 3952, 1166, 0,
16014, 2250, 10025, 1908, 0, 9610, 298, 3868, 122, 0,
16629, 4338, 11335, 3527, 0, 11514, 5965, 4762, 2207, 0,
18552, 10755, 13309, 5996, 0, 12454, 6787, 4960, 2875, 0,
8750, 6999, 3534, 3233, 0, 14160, 9399, 9595, 8922, 0,
9110, 6567, 3820, 2351, 0, 12969, 11814, 9436, 5870, 0,
7631, 7061, 2877, 2499, 0, 8553, 13527, 3631, 6863, 0,
1361, 8634, 515, 3372, 0, 3394, 10206, 1504, 4112, 0,
5505, 17421, 4702, 11891, 0, 4233, 11894, 1739, 5014, 0,
11787, 14634, 8981, 10759, 0, 11777, 6701, 4719, 3111, 0,
18459, 7761, 12044, 7627, 0, 11214, 4556, 4374, 1594, 0,
604, 1908, 1506, 6102, 0, 2532, 4024, 1713, 6121, 0,
1878, 1814, 4761, 5397, 0, 1127, 3885, 4373, 5832, 0,
450, 1414, 1080, 4719, 0, 5210, 2683, 2765, 4252, 0,
2390, 1668, 7710, 4257, 0, 378, 1698, 3276, 6021, 0,
2866, 4881, 3547, 6822, 0, 502, 1238, 2784, 5199, 0,
2496, 3975, 2700, 5004, 0, 1220, 1990, 3633, 5763, 0,
4501, 2679, 4504, 5412, 0, 1968, 1376, 6246, 3669, 0,
3130, 272, 9345, 1950, 0, 5167, 3278, 9097, 2138, 0,
2446, 1946, 6942, 5460, 0, 5732, 3404, 7919, 5534, 0,
2038, 1614, 6978, 4635, 0, 4544, 4839, 7367, 5574, 0,
1242, 1922, 4842, 6333, 0, 1066, 236, 2236, 686, 0,
17238, 2254, 10413, 1592, 0, 991, 30, 2206, 70, 0,
18823, 6392, 12173, 2470, 0, 1142, 684, 2742, 1219, 0,
21256, 11293, 12719, 7512, 0, 1303, 649, 2818, 1669, 0,
898, 574, 2018, 1929, 0, 15720, 11989, 10517, 5972, 0,
885, 781, 2210, 1281, 0, 14601, 12198, 7915, 4958, 0,
856, 850, 1601, 1355, 0, 7039, 14083, 4113, 7490, 0,
152, 927, 287, 1902, 0, 301, 1051, 886, 2346, 0,
6821, 19615, 4491, 13281, 0, 424, 1146, 999, 2906, 0,
15177, 15480, 8849, 12442, 0, 1222, 544, 2687, 1859, 0,
20215, 9693, 11441, 4964, 0, 1206, 555, 2466, 860, 0};
std::vector<float> expected_grad_weight_data = {
9246, 22073, 12431, 19714, 11179, 19032, 8458, 6495, 18707, 13830,
20445, 17089, 17124, 18710, 11827, 17236, 16824, 9008, 14086, 18834,
17419, 16759, 13152, 9339, 13801, 20888, 13976, 27277, 13010, 23949,
9838, 11220, 17658, 15019, 25337, 17583, 13270, 21754, 16908, 20563,
20732, 13413, 20868, 27521, 19537, 21170, 15888, 10034, 19195, 16370,
40243, 25890, 40472, 30460, 21228, 21625, 13289, 24435, 19876, 29816,
24188, 23619, 13752, 16251, 18741, 19368, 24517, 34261, 27054, 31257,
21238, 18909, 15776, 16881, 34604, 22534, 28101, 23834, 18479, 16469,
12852, 16551, 14204, 29983, 20167, 24150, 14281, 17501, 15897, 16019,
21661, 32765, 23874, 26527, 20463, 18661};
std::vector<float> expected_grad_bias_data = {363, 438, 585, 501};
using CTYPE = typename decltype(tf)::ctype;
std::vector<CTYPE> grad_output_data = {
10, 12, 87, 13, 34, 87, 55, 22, 48, 33, 29, 38, 60, 49, 88, 30,
99, 19, 42, 37, 61, 31, 33, 58, 38, 23, 2, 33, 3, 21, 32, 2,
30, 72, 10, 67, 92, 19, 11, 16, 65, 37, 60, 74, 4, 19, 45, 37};
std::vector<CTYPE> input_data = {
9, 89, 45, 39, 25, 2, 97, 55, 80, 24, 18, 33, 28, 89, 19, 16, 19, 33,
69, 61, 34, 84, 58, 30, 33, 18, 75, 30, 6, 33, 42, 10, 80, 41, 66, 64,
47, 51, 67, 62, 58, 10, 97, 71, 24, 44, 84, 34, 33, 54, 8, 73, 90, 15,
21, 92, 55, 22, 56, 12, 10, 63, 32, 76, 65, 38, 95, 92, 22, 15, 37, 12,
67, 14, 60, 44, 73, 74, 23, 4, 56, 64, 88, 90, 82, 32, 91, 3, 6, 87,
55, 95, 7, 14, 24, 69, 52, 44, 14, 37, 75, 52, 37, 40, 25, 54, 4, 15,
97, 51, 46, 28, 65, 95, 50, 82, 23, 39, 50, 55, 97, 52, 91, 16, 19, 49,
61, 50, 42, 47, 87, 99, 9, 60, 22, 71, 47, 17, 0, 80, 28, 88, 93, 43,
65, 25, 88, 67, 21, 89, 24, 81, 3, 71, 20, 34, 17, 17, 94, 10, 82, 25,
10, 11, 7, 28, 77, 39, 74, 79, 17, 40, 67, 54, 49, 54, 21, 89, 17, 7,
52, 64, 68, 80, 7, 72, 44, 35, 92, 47, 4, 13, 10, 43, 64, 66, 83, 49,
81, 78, 58, 22, 86, 48, 35, 64, 98, 79, 8, 52, 56, 23, 38, 74, 16, 63,
51, 70, 44, 28, 43, 13, 51, 85, 42, 29, 64, 26, 54, 91, 9, 96, 41, 56,
7, 52, 27, 22, 69, 13, 8, 20, 22, 49, 66, 98, 77, 42, 54, 38, 70, 83,
13, 8, 21, 56, 78, 37, 28, 69, 42, 30, 91, 5, 28, 15, 20, 14, 16, 39,
95, 66, 4, 72, 52, 35, 54, 93, 87, 77, 3, 49, 82, 70, 84, 3, 73, 99,
32, 95, 58, 65, 32, 75, 34, 22, 12, 84, 63, 72, 85, 66, 63, 27, 3, 73,
45, 37, 61, 52, 41, 16, 37, 14, 80, 17, 48, 8, 87, 98, 69, 63, 92, 68,
42, 63, 5, 22, 66, 91, 74, 11, 17, 45, 45, 33, 40, 85, 26, 75, 73, 81,
54, 27, 80, 1, 44, 66, 10, 21, 15, 10, 76, 96, 0, 43, 39, 3, 57, 79,
45, 64, 58, 92, 44, 42, 7, 28, 94, 4, 8, 22, 22, 31, 75, 44, 3, 70,
83, 72, 87, 12, 20, 55, 84, 31, 50, 34, 25, 49, 29, 71, 57, 97, 25, 82,
84, 42, 86, 41, 54, 92, 34, 30, 52, 34, 84, 25, 54, 37, 38, 26, 76, 82,
34, 14, 85, 28, 93, 9};
std::vector<CTYPE> weight_data = {
2, 54, 9, 37, 0, 47, 70, 9, 84, 69, 56, 79, 25, 35, 54, 13,
65, 46, 38, 28, 74, 27, 66, 61, 20, 60, 62, 58, 15, 44, 75, 55,
7, 52, 13, 36, 39, 64, 62, 45, 100, 6, 79, 63, 63, 52, 37, 60,
78, 12, 69, 2, 74, 56, 93, 39, 62, 22, 55, 67, 68, 74, 12, 69,
15, 73, 28, 70, 86, 20, 90, 49, 52, 26, 58, 2, 82, 17, 70, 55,
54, 83, 70, 11, 27, 9, 5, 42, 34, 62, 29, 94, 69, 81, 54, 4};
std::vector<CTYPE> expected_grad_input_data = {
1134, 7578, 686, 2682, 0, 4148, 7136, 2406, 8698, 0,
3759, 6003, 2163, 2395, 0, 2929, 5830, 3469, 6955, 0,
720, 6201, 495, 2063, 0, 5260, 5989, 3060, 7079, 0,
9690, 3423, 3385, 1932, 0, 7644, 8499, 1323, 2613, 0,
4334, 6624, 8532, 9719, 0, 5496, 8601, 1157, 2215, 0,
4676, 7600, 6524, 10069, 0, 4047, 6117, 1612, 2567, 0,
5931, 5651, 5669, 6623, 0, 7674, 3291, 2748, 1654, 0,
10455, 4290, 4145, 796, 0, 9835, 5483, 11649, 5952, 0,
7098, 5460, 3101, 2443, 0, 7788, 5909, 8582, 6298, 0,
9462, 4845, 3041, 2067, 0, 7038, 6336, 10438, 6377, 0,
7518, 8187, 2079, 2773, 0, 10036, 2642, 3952, 1166, 0,
16014, 2250, 10025, 1908, 0, 9610, 298, 3868, 122, 0,
16629, 4338, 11335, 3527, 0, 11514, 5965, 4762, 2207, 0,
18552, 10755, 13309, 5996, 0, 12454, 6787, 4960, 2875, 0,
8750, 6999, 3534, 3233, 0, 14160, 9399, 9595, 8922, 0,
9110, 6567, 3820, 2351, 0, 12969, 11814, 9436, 5870, 0,
7631, 7061, 2877, 2499, 0, 8553, 13527, 3631, 6863, 0,
1361, 8634, 515, 3372, 0, 3394, 10206, 1504, 4112, 0,
5505, 17421, 4702, 11891, 0, 4233, 11894, 1739, 5014, 0,
11787, 14634, 8981, 10759, 0, 11777, 6701, 4719, 3111, 0,
18459, 7761, 12044, 7627, 0, 11214, 4556, 4374, 1594, 0,
604, 1908, 1506, 6102, 0, 2532, 4024, 1713, 6121, 0,
1878, 1814, 4761, 5397, 0, 1127, 3885, 4373, 5832, 0,
450, 1414, 1080, 4719, 0, 5210, 2683, 2765, 4252, 0,
2390, 1668, 7710, 4257, 0, 378, 1698, 3276, 6021, 0,
2866, 4881, 3547, 6822, 0, 502, 1238, 2784, 5199, 0,
2496, 3975, 2700, 5004, 0, 1220, 1990, 3633, 5763, 0,
4501, 2679, 4504, 5412, 0, 1968, 1376, 6246, 3669, 0,
3130, 272, 9345, 1950, 0, 5167, 3278, 9097, 2138, 0,
2446, 1946, 6942, 5460, 0, 5732, 3404, 7919, 5534, 0,
2038, 1614, 6978, 4635, 0, 4544, 4839, 7367, 5574, 0,
1242, 1922, 4842, 6333, 0, 1066, 236, 2236, 686, 0,
17238, 2254, 10413, 1592, 0, 991, 30, 2206, 70, 0,
18823, 6392, 12173, 2470, 0, 1142, 684, 2742, 1219, 0,
21256, 11293, 12719, 7512, 0, 1303, 649, 2818, 1669, 0,
898, 574, 2018, 1929, 0, 15720, 11989, 10517, 5972, 0,
885, 781, 2210, 1281, 0, 14601, 12198, 7915, 4958, 0,
856, 850, 1601, 1355, 0, 7039, 14083, 4113, 7490, 0,
152, 927, 287, 1902, 0, 301, 1051, 886, 2346, 0,
6821, 19615, 4491, 13281, 0, 424, 1146, 999, 2906, 0,
15177, 15480, 8849, 12442, 0, 1222, 544, 2687, 1859, 0,
20215, 9693, 11441, 4964, 0, 1206, 555, 2466, 860, 0};
std::vector<CTYPE> expected_grad_weight_data = {
9246, 22073, 12431, 19714, 11179, 19032, 8458, 6495, 18707, 13830,
20445, 17089, 17124, 18710, 11827, 17236, 16824, 9008, 14086, 18834,
17419, 16759, 13152, 9339, 13801, 20888, 13976, 27277, 13010, 23949,
9838, 11220, 17658, 15019, 25337, 17583, 13270, 21754, 16908, 20563,
20732, 13413, 20868, 27521, 19537, 21170, 15888, 10034, 19195, 16370,
40243, 25890, 40472, 30460, 21228, 21625, 13289, 24435, 19876, 29816,
24188, 23619, 13752, 16251, 18741, 19368, 24517, 34261, 27054, 31257,
21238, 18909, 15776, 16881, 34604, 22534, 28101, 23834, 18479, 16469,
12852, 16551, 14204, 29983, 20167, 24150, 14281, 17501, 15897, 16019,
21661, 32765, 23874, 26527, 20463, 18661};
std::vector<CTYPE> expected_grad_bias_data = {363, 438, 585, 501};

auto grad_output = tf.make({2, 4, 3, 2}, grad_output_data);
auto input = tf.make({2, 6, 7, 5}, input_data);
auto weight = tf.make({4, 3, 4, 2}, weight_data);
int64_t bias_sizes[1] = {4};
int64_t stride[2] = {1, 2};
int64_t padding[2] = {1, 0};
int64_t dilation[2] = {2, 1};
bool transposed = false;
int64_t output_padding[2] = {0, 0};
int64_t groups = 2;
std::array<bool, 3> output_mask_a = {true, true, true};
auto grad_input = tf.zeros({2, 6, 7, 5});
auto grad_weight = tf.zeros({4, 3, 4, 2});
auto grad_bias = tf.zeros({4});
auto grad_output = tf.make({2, 4, 3, 2}, grad_output_data);
auto input = tf.make({2, 6, 7, 5}, input_data);
auto weight = tf.make({4, 3, 4, 2}, weight_data);
int64_t bias_sizes[1] = {4};
int64_t stride[2] = {1, 2};
int64_t padding[2] = {1, 0};
int64_t dilation[2] = {2, 1};
bool transposed = false;
int64_t output_padding[2] = {0, 0};
int64_t groups = 2;
std::array<bool, 3> output_mask_a = {true, true, true};
auto grad_input = tf.zeros({2, 6, 7, 5});
auto grad_weight = tf.zeros({4, 3, 4, 2});
auto grad_bias = tf.zeros({4});

op_convolution_backward_out(
grad_output,
input,
weight,
IntArrayRef{bias_sizes, 1},
IntArrayRef{stride, 2},
IntArrayRef{padding, 2},
IntArrayRef{dilation, 2},
transposed,
IntArrayRef{output_padding, 2},
groups,
output_mask_a,
grad_input,
grad_weight,
grad_bias);

op_convolution_backward_out(
grad_output,
input,
weight,
IntArrayRef{bias_sizes, 1},
IntArrayRef{stride, 2},
IntArrayRef{padding, 2},
IntArrayRef{dilation, 2},
transposed,
IntArrayRef{output_padding, 2},
groups,
output_mask_a,
grad_input,
grad_weight,
grad_bias);
auto expected_grad_input = tf.make({2, 6, 7, 5}, expected_grad_input_data);
auto expected_grad_weight =
tf.make({4, 3, 4, 2}, expected_grad_weight_data);
auto expected_grad_bias = tf.make({4}, expected_grad_bias_data);

auto expected_grad_input = tf.make({2, 6, 7, 5}, expected_grad_input_data);
auto expected_grad_weight = tf.make({4, 3, 4, 2}, expected_grad_weight_data);
auto expected_grad_bias = tf.make({4}, expected_grad_bias_data);
if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) {
EXPECT_TENSOR_CLOSE_WITH_TOL(grad_input, expected_grad_input, 1e-2, 1e-8);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use defaults here? EXPECT_TENSOR_CLOSE_WITH_TOL should apply the right tolerance given the type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the default rtol is 1e-5; rtol and atol are different

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, but in the same way that we have kDefaultHalfAtol and kDefaultBFloat16Atol I think we should have kDefaultHalfRtol and kDefaultBFloat16Rtol and set it to a proper value.
You seem to be using 1e-2 for most of these tests. Why not introduced kDefaultHalfRtol and kDefaultBFloat16Rtol with value 1e-2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not introduced kDefaultHalfRtol and kDefaultBFloat16Rtol with value 1e-2?

Because not all operators require the higher rtol.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not particularly uncommon to need to set rtol in pytorch core: https://github.com/search?q=repo%3Apytorch%2Fpytorch+%2Frtol%3D%5B1-9%5D%2F&type=code

EXPECT_TENSOR_CLOSE_WITH_TOL(
grad_weight, expected_grad_weight, 2e-2, 1e-8);
EXPECT_TENSOR_CLOSE_WITH_TOL(grad_bias, expected_grad_bias, 1e-2, 1e-8);
} else {
EXPECT_TENSOR_CLOSE(grad_input, expected_grad_input);
EXPECT_TENSOR_CLOSE(grad_weight, expected_grad_weight);
EXPECT_TENSOR_CLOSE(grad_bias, expected_grad_bias);
}
}
};

EXPECT_TENSOR_CLOSE(grad_input, expected_grad_input);
EXPECT_TENSOR_CLOSE(grad_weight, expected_grad_weight);
EXPECT_TENSOR_CLOSE(grad_bias, expected_grad_bias);
TEST_F(OpConvolutionBackwardOutTest, SmokeTest) {
#define TEST_ENTRY(ctype, dtype) test_dtype<ScalarType::dtype>();
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
}
Loading