Skip to content

Commit f76da37

Browse files
committed
fix addbmm opinfo
1 parent 6443e59 commit f76da37

File tree

3 files changed

+220
-1
lines changed

3 files changed

+220
-1
lines changed
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# How to fix an op info test.
2+
3+
## What is OpInfo test
4+
5+
PyTorch created a list of python objects (OpInfo) to keep
6+
track how to test each op. This is useful to us because it
7+
ensures that the ops we implement produces the same results
8+
pytorch would produce.
9+
10+
Context:
11+
* https://dev-discuss.pytorch.org/t/opinfos-in-pytorch-1-10/253
12+
* https://github.com/pytorch/pytorch/issues/54261
13+
14+
15+
## How to fix one
16+
17+
### Remove one op from skiplist
18+
19+
Open [test/test_ops.py](../test/test_ops.py) with your
20+
favorite text editor.
21+
Remove one line from the `skiplist` set.
22+
23+
i.e.
24+
25+
```bash
26+
(base) hanq-macbookpro:torch_xla2 hanq$ git diff
27+
diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py
28+
index 72a39ae85..2a156cbce 100644
29+
--- a/experimental/torch_xla2/test/test_ops.py
30+
+++ b/experimental/torch_xla2/test/test_ops.py
31+
@@ -15,7 +15,6 @@ skiplist = {
32+
"_native_batch_norm_legit",
33+
"_segment_reduce",
34+
"_upsample_bilinear2d_aa",
35+
- "addbmm",
36+
"addmm",
37+
"addmv",
38+
"addr",
39+
```
40+
41+
### Run test to see what failure
42+
43+
Error gotten:
44+
45+
```
46+
E RuntimeError: ('No lowering found for\n\nTo execute this test, run the following from the base repo dir:\n python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64\n\nThis message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0', 'aten::addbmm')
47+
```
48+
49+
From here we have 2 strategies for fixing this test:
50+
51+
1. Add an implementation to `aten::addbmm` operator using Jax ops. Or,
52+
2. Add an implementation `aten::addbmm` operator using torch ops (this commonly known as "decompositions").
53+
54+
Either way works for torch_xla2. For ops that are not "Core Aten" sometimes we implement in torch ops with the goal of
55+
upstreaming this decomposition to [pytorch decompositon](https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py)
56+
so other projects can benefit from it.
57+
58+
For illustration purposes, let's implement this op in Jax.
59+
60+
(NOTE: this doesn't stop us from upstreaming a decomposition later if we want)
61+
62+
### First Impl
63+
64+
To implement this op using jax ops, we first find what
65+
is the exact semantics in this page:
66+
https://pytorch.org/docs/stable/generated/torch.addbmm.html
67+
68+
From it's math formula: we can implement it as follows.
69+
70+
```
71+
+@op(torch.ops.aten.addbmm.default)
72+
+def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1):
73+
+
74+
+ mm = jnp.einsum('bxy, byz -> xz', batch1, batch2)
75+
+ return beta * input + alpha * mm
76+
```
77+
78+
Now running test again:
79+
80+
```
81+
python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64
82+
```
83+
84+
(NOTE: the exact test command is printed out when we run
85+
`pytest test/test_ops.py` so we can only run the failed test instead of running all tests.)
86+
87+
We now see this error:
88+
89+
```
90+
FAIL: test_reference_eager_addbmm_cpu_int64 (__main__.TestOpInfoCPU) [torch_xla2_diff:0.001]
91+
----------------------------------------------------------------------
92+
Traceback (most recent call last):
93+
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 654, in run_export_and_compare
94+
diff_output(
95+
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 617, in diff_output
96+
testcase.assertTrue(
97+
AssertionError: False is not true
98+
```
99+
100+
This is telling me that our implementation did not produce
101+
the same result as the ops in PyTorch.
102+
103+
To debug this, let's figure out what exact input caused this.
104+
We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/test/test_ops.py#L644), right before the diff. Here we can
105+
inspect values of `res` and `res2`, as well as the `sample_input`.
106+
107+
The sample input we get is
108+
```
109+
SampleInput(input=tensor([[-3, -3, 9, 8, -8, -3, -4, 2, 2, 2],
110+
[-5, 1, -9, 9, 1, -5, 6, 1, -4, -5],
111+
[-2, -1, 5, -2, -3, 0, 5, -4, 9, -6],
112+
[-1, -7, 6, 3, 8, 3, 8, 9, -5, 7],
113+
[-3, -4, -9, 9, 7, -3, -8, 2, 5, -3]]), args=(tensor([[[-2, 4, -2, 5, 8],
114+
[-6, -2, 5, 7, 7],
115+
[-8, -3, 2, 5, -3],
116+
[-4, 7, 0, -9, 8],
117+
[ 3, 9, -9, -2, 0]],
118+
119+
[[-7, 1, -3, 7, -4],
120+
[ 3, 5, 4, 6, 5],
121+
[-2, 8, 3, 5, 7],
122+
[ 8, -2, -8, 2, 0],
123+
[ 6, 1, -8, 8, 0]],
124+
125+
[[ 2, -1, -5, -8, -9],
126+
[ 5, 0, -4, -1, -6],
127+
[-6, 2, -5, -2, -5],
128+
[-5, -3, -5, -4, 9],
129+
[-3, 4, -9, -9, 7]],
130+
131+
[[ 2, 5, -7, -3, 8],
132+
[-5, -7, -8, -4, 4],
133+
[-4, -6, -3, 0, 6],
134+
[ 8, 0, -3, -8, 2],
135+
[-4, 3, -9, -6, 7]],
136+
137+
[[ 2, 1, -6, 2, 8],
138+
[ 2, 6, 4, 1, 8],
139+
[-9, 9, -5, 8, 3],
140+
[-5, 0, -2, 4, 0],
141+
[ 5, 8, -4, 9, 7]]]), tensor([[[-1, -8, 3, 5, -8, 2, -5, 0, -9, -5],
142+
[-4, -7, 2, 2, 1, -9, 2, 7, -1, -1],
143+
[ 1, 8, -6, -4, -6, -8, -7, -9, 7, 4],
144+
[-4, 1, -9, 3, 4, 6, 0, -2, -2, -7],
145+
[ 5, 5, 0, 8, -3, 7, -7, 8, 3, 5]],
146+
147+
[[ 8, -4, -9, 9, 5, 0, 5, 0, -5, 5],
148+
[-5, -3, -2, 8, 1, -2, 4, -7, 5, 3],
149+
[-4, 4, 1, -4, -8, 2, -5, 2, 9, -7],
150+
[ 9, 6, -8, -3, 3, 1, 4, 6, -5, -4],
151+
[-2, 1, 5, 5, 2, 6, 7, -3, -7, 3]],
152+
153+
[[ 9, -8, 5, -3, -1, 2, -9, -5, -1, -3],
154+
[-3, 3, -9, -7, -9, -8, 1, -3, 7, -2],
155+
[ 8, -1, 8, -8, -7, 4, 8, 8, 5, -7],
156+
[-1, 6, -8, 7, -1, -5, -8, 6, -2, 8],
157+
[-5, -5, 8, 6, 0, 1, 3, -2, -3, -9]],
158+
159+
[[ 7, -2, 6, -8, -5, 3, 2, -1, -5, 8],
160+
[-6, -4, 3, 9, -9, -8, -7, 3, 9, 0],
161+
[ 1, 3, 4, 4, -5, -2, -4, -2, 3, -7],
162+
[-6, 9, 5, -1, 7, 7, 8, -3, -8, 0],
163+
[-1, -6, -3, 3, 3, -8, -4, 9, -5, 7]],
164+
165+
[[-5, -3, -9, 6, -1, -7, 9, -8, 1, -8],
166+
[-8, -8, -2, -5, -7, -8, 1, 0, 0, -6],
167+
[ 7, -5, 2, 2, 0, -9, -5, -7, 1, 8],
168+
[-4, 0, 9, 6, -1, -6, 6, -6, -2, -1],
169+
[ 7, 3, 0, 1, 1, -9, 5, -8, -1, -7]]])), kwargs={'beta': 0.6, 'alpha': 0.2}, broadcasts_input=False, name='')
170+
```
171+
172+
And the `res` from torch is
173+
174+
```
175+
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
176+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
177+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
178+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
179+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
180+
```
181+
182+
So few observation is:
183+
1. Input tensor are of type int64
184+
2. alpha and beta are both floats.
185+
186+
So one can suspect that it has to do with rounding.
187+
Reading the doc more carefully, we can find this sentence
188+
189+
For inputs of type FloatTensor or DoubleTensor, arguments beta and alpha must be real numbers, otherwise they should be integers.
190+
191+
So likely torch first casted the float alpha and beta to integer, which yields 0, then used them in math to get a matrix with all zeros.
192+
193+
### Second Impl
194+
195+
```python
196+
+@op(torch.ops.aten.addbmm.default)
197+
+def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1):
198+
+ alpha = jnp.array(alpha).astype(batch1.dtype)
199+
+ beta = jnp.array(beta).astype(batch1.dtype)
200+
+ mm = jnp.einsum('bxy, byz -> xz', batch1, batch2)
201+
+ return jax.lax.cond(beta == 0,
202+
+ lambda: alpha * mm,
203+
+ lambda: beta*input + alpha*mm)
204+
+
205+
```
206+
207+
Adding type casts makes the tests passes.
208+
209+
### Submit
210+
Now, let's remove the pdb and prints we added, and submit the fix as a PR: https://github.com/pytorch/xla/pull/6993
211+

experimental/torch_xla2/test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"_native_batch_norm_legit",
1616
"_segment_reduce",
1717
"_upsample_bilinear2d_aa",
18-
"addbmm",
1918
"addmm",
2019
"addmv",
2120
"addr",

experimental/torch_xla2/torch_xla2/_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,15 @@ def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0):
415415
self += alpha * jnp.matmul(mat1, mat2)
416416
return self
417417

418+
@op(torch.ops.aten.addbmm.default)
419+
def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1):
420+
alpha = jnp.array(alpha).astype(batch1.dtype)
421+
beta = jnp.array(beta).astype(batch1.dtype)
422+
mm = jnp.einsum('bxy, byz -> xz', batch1, batch2)
423+
return jax.lax.cond(beta == 0,
424+
lambda: alpha * mm,
425+
lambda: beta*input + alpha*mm)
426+
418427

419428
@op(torch.ops.aten.gelu)
420429
def _aten_gelu(self, *, approximate="none"):

0 commit comments

Comments
 (0)