Skip to content

Commit

Permalink
remove all torch.ones from tests (#3494)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3494

Replace all torch.ones with instances of torch.randn

Reviewed By: digantdesai

Differential Revision: D56907873

fbshipit-source-id: 809b2a42ee628b6dc50b57e22ce5b7318092f8cb
  • Loading branch information
mcr229 authored and facebook-github-bot committed May 7, 2024
1 parent 2835d01 commit 0beb072
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 38 deletions.
10 changes: 5 additions & 5 deletions backends/xnnpack/test/ops/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ def _test_add(self, inputs):
)

def test_fp16_add(self):
inputs = (torch.ones(1).to(torch.float16), torch.ones(1).to(torch.float16))
inputs = (torch.randn(1).to(torch.float16), torch.randn(1).to(torch.float16))
self._test_add(inputs)

def test_fp32_add(self):
inputs = (torch.ones(1), torch.ones(1))
inputs = (torch.randn(1), torch.randn(1))
self._test_add(inputs)

def test_fp32_add_constant(self):
inputs = (torch.randn(4, 4, 4),)
(
Tester(self.AddConstant(torch.ones(4, 4, 4)), inputs)
Tester(self.AddConstant(torch.randn(4, 4, 4)), inputs)
.export()
.check_count({"torch.ops.aten.add.Tensor": 4})
.to_edge()
Expand All @@ -84,7 +84,7 @@ def test_fp32_add_constant(self):
def test_qs8_add_constant(self):
inputs = (torch.randn(4, 4, 4),)
(
Tester(self.AddConstant(torch.ones(4, 4, 4)), inputs)
Tester(self.AddConstant(torch.randn(4, 4, 4)), inputs)
.quantize()
.export()
.check_count({"torch.ops.aten.add.Tensor": 4})
Expand All @@ -95,7 +95,7 @@ def test_qs8_add_constant(self):
.check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"])
.to_executorch()
.serialize()
.run_method_compare_outputs()
.run_method_and_compare_outputs()
)

def test_qs8_add(self):
Expand Down
54 changes: 27 additions & 27 deletions backends/xnnpack/test/ops/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def test_fp16_cat2(self):
Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
"""
inputs = (
torch.ones(1, 2, 3).to(torch.float16),
torch.ones(3, 2, 3).to(torch.float16),
torch.randn(1, 2, 3).to(torch.float16),
torch.randn(3, 2, 3).to(torch.float16),
)
self._test_cat(self.Cat2(), inputs)

Expand All @@ -88,9 +88,9 @@ def test_fp16_cat3(self):
Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
"""
inputs = (
torch.ones(1, 2, 3).to(torch.float16),
torch.ones(3, 2, 3).to(torch.float16),
torch.ones(2, 2, 3).to(torch.float16),
torch.randn(1, 2, 3).to(torch.float16),
torch.randn(3, 2, 3).to(torch.float16),
torch.randn(2, 2, 3).to(torch.float16),
)
self._test_cat(self.Cat3(), inputs)

Expand All @@ -99,44 +99,44 @@ def test_fp16_cat4(self):
Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
"""
inputs = (
torch.ones(1, 2, 3).to(torch.float16),
torch.ones(3, 2, 3).to(torch.float16),
torch.ones(2, 2, 3).to(torch.float16),
torch.ones(5, 2, 3).to(torch.float16),
torch.randn(1, 2, 3).to(torch.float16),
torch.randn(3, 2, 3).to(torch.float16),
torch.randn(2, 2, 3).to(torch.float16),
torch.randn(5, 2, 3).to(torch.float16),
)
self._test_cat(self.Cat4(), inputs)

def test_fp32_cat2(self):
inputs = (torch.ones(1, 2, 3), torch.ones(3, 2, 3))
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
self._test_cat(self.Cat2(), inputs)

def test_fp32_cat3(self):
inputs = (torch.ones(1, 2, 3), torch.ones(3, 2, 3), torch.ones(2, 2, 3))
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3))
self._test_cat(self.Cat3(), inputs)

def test_fp32_cat4(self):
inputs = (
torch.ones(1, 2, 3),
torch.ones(3, 2, 3),
torch.ones(2, 2, 3),
torch.ones(5, 2, 3),
torch.randn(1, 2, 3),
torch.randn(3, 2, 3),
torch.randn(2, 2, 3),
torch.randn(5, 2, 3),
)
self._test_cat(self.Cat4(), inputs)

def test_qs8_cat2(self):
inputs = (torch.ones(1, 2, 3), torch.ones(3, 2, 3))
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
self._test_cat(self.Cat2(), inputs, cat_num=2, quant=True)

def test_qs8_cat3(self):
inputs = (torch.ones(1, 2, 3), torch.ones(3, 2, 3), torch.ones(2, 2, 3))
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3))
self._test_cat(self.Cat3(), inputs, cat_num=3, quant=True)

def test_qs8_cat4(self):
inputs = (
torch.ones(1, 2, 3),
torch.ones(3, 2, 3),
torch.ones(2, 2, 3),
torch.ones(5, 2, 3),
torch.randn(1, 2, 3),
torch.randn(3, 2, 3),
torch.randn(2, 2, 3),
torch.randn(5, 2, 3),
)
self._test_cat(self.Cat4(), inputs, cat_num=4, quant=True)

Expand All @@ -145,11 +145,11 @@ def test_fp32_cat_unsupported(self):
XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
"""
inputs = (
torch.ones(1, 2, 3),
torch.ones(3, 2, 3),
torch.ones(2, 2, 3),
torch.ones(5, 2, 3),
torch.ones(1, 2, 3),
torch.randn(1, 2, 3),
torch.randn(3, 2, 3),
torch.randn(2, 2, 3),
torch.randn(5, 2, 3),
torch.randn(1, 2, 3),
)
(
Tester(self.Cat5(), inputs)
Expand All @@ -169,7 +169,7 @@ def forward(self, x, y):
return torch.cat([x, y], -1)

def test_fp32_cat_negative_dim(self):
inputs = (torch.ones(3, 2, 3), torch.ones(3, 2, 1))
inputs = (torch.randn(3, 2, 3), torch.randn(3, 2, 1))
self._test_cat(self.CatNegativeDim(), inputs)

class CatNhwc(torch.nn.Module):
Expand Down
12 changes: 9 additions & 3 deletions backends/xnnpack/test/ops/div.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,21 @@ def _test_div(self, inputs):
)

def test_fp16_div(self):
inputs = (torch.ones(1).to(torch.float16), torch.ones(1).to(torch.float16))
# Adding 4 to move distribution away from 0, 4 Std Dev should be far enough
inputs = (
(torch.randn(1) + 4).to(torch.float16),
(torch.randn(1) + 4).to(torch.float16),
)
self._test_div(inputs)

def test_fp32_div(self):
inputs = (torch.ones(1), torch.ones(1))
# Adding 4 to move distribution away from 0, 4 Std Dev should be far enough
inputs = (torch.randn(1) + 4, torch.randn(1) + 4)
self._test_div(inputs)

def test_fp32_div_single_input(self):
inputs = (torch.ones(1),)
# Adding 4 to move distribution away from 0, 4 Std Dev should be far enough
inputs = (torch.randn(1) + 4,)
(
Tester(self.DivSingleInput(), inputs)
.export()
Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/test/ops/maxpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def forward(self, x):

# Parameter order is kernel_size, stride, padding.
for maxpool_params in [(4,), (4, 2), (4, 2, 2)]:
inputs = (torch.ones(1, 2, 8, 8),)
inputs = (torch.randn(1, 2, 8, 8),)
(
Tester(MaxPool(maxpool_params), inputs)
.quantize()
Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/test/ops/pow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_fp32_pow_unsupported(self):
attempt to delegate other powers.
"""

inputs = (torch.ones(5),)
inputs = (torch.randn(5),)
(
Tester(self.Pow(3), inputs)
.export()
Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/test/ops/sigmoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _test_sigmoid(self, inputs):
)

def test_fp16_sigmoid(self):
inputs = (torch.ones(4).to(torch.float16),)
inputs = (torch.randn(4).to(torch.float16),)
self._test_sigmoid(inputs)

def test_fp32_sigmoid(self):
Expand Down

0 comments on commit 0beb072

Please sign in to comment.