| 
 | 1 | +# How to fix an op info test.  | 
 | 2 | + | 
 | 3 | +## What is OpInfo test  | 
 | 4 | + | 
 | 5 | +PyTorch created a list of python objects (OpInfo) to keep  | 
 | 6 | +track how to test each op. This is useful to us because it  | 
 | 7 | +ensures that the ops we implement produces the same results  | 
 | 8 | +pytorch would produce.  | 
 | 9 | + | 
 | 10 | +Context:  | 
 | 11 | +* https://dev-discuss.pytorch.org/t/opinfos-in-pytorch-1-10/253  | 
 | 12 | +* https://github.com/pytorch/pytorch/issues/54261  | 
 | 13 | + | 
 | 14 | + | 
 | 15 | +## How to fix one  | 
 | 16 | + | 
 | 17 | +### Remove one op from skiplist  | 
 | 18 | + | 
 | 19 | +Open [test/test_ops.py](../test/test_ops.py) with your  | 
 | 20 | +favorite text editor.   | 
 | 21 | +Remove one line from the `skiplist` set.  | 
 | 22 | + | 
 | 23 | +i.e.  | 
 | 24 | + | 
 | 25 | +```bash  | 
 | 26 | +(base) hanq-macbookpro:torch_xla2 hanq$ git diff  | 
 | 27 | +diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py  | 
 | 28 | +index 72a39ae85..2a156cbce 100644  | 
 | 29 | +--- a/experimental/torch_xla2/test/test_ops.py  | 
 | 30 | ++++ b/experimental/torch_xla2/test/test_ops.py  | 
 | 31 | +@@ -15,7 +15,6 @@ skiplist = {  | 
 | 32 | +     "_native_batch_norm_legit",  | 
 | 33 | +     "_segment_reduce",  | 
 | 34 | +     "_upsample_bilinear2d_aa",  | 
 | 35 | +-    "addbmm",  | 
 | 36 | +     "addmm",  | 
 | 37 | +     "addmv",  | 
 | 38 | +     "addr",  | 
 | 39 | +```  | 
 | 40 | +
  | 
 | 41 | +### Run test to see what failure  | 
 | 42 | +
  | 
 | 43 | +Error gotten:  | 
 | 44 | +
  | 
 | 45 | +```  | 
 | 46 | +E         RuntimeError: ('No lowering found for\n\nTo execute this test, run the following from the base repo dir:\n     python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64\n\nThis message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0', 'aten::addbmm')  | 
 | 47 | +```  | 
 | 48 | +
  | 
 | 49 | +From here we have 2 strategies for fixing this test:  | 
 | 50 | +
  | 
 | 51 | +1. Add an implementation to `aten::addbmm` operator using Jax ops. Or,  | 
 | 52 | +2. Add an  implementation `aten::addbmm` operator using torch ops (this commonly known as "decompositions").  | 
 | 53 | +
  | 
 | 54 | +Either way works for torch_xla2. For ops that are not "Core Aten" sometimes we implement in torch ops with the goal of  | 
 | 55 | +upstreaming this decomposition to [pytorch decompositon](https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py)   | 
 | 56 | +so other projects can benefit from it.  | 
 | 57 | +
  | 
 | 58 | +For illustration purposes, let's implement this op in Jax.   | 
 | 59 | +
  | 
 | 60 | +(NOTE: this doesn't stop us from upstreaming a decomposition later if we want)  | 
 | 61 | +
  | 
 | 62 | +### First Impl  | 
 | 63 | +
  | 
 | 64 | +To implement this op using jax ops, we first find what   | 
 | 65 | +is the exact semantics in this page:  | 
 | 66 | +https://pytorch.org/docs/stable/generated/torch.addbmm.html  | 
 | 67 | +
  | 
 | 68 | +From it's math formula: we can implement it as follows.  | 
 | 69 | +
  | 
 | 70 | +```  | 
 | 71 | ++@op(torch.ops.aten.addbmm.default)  | 
 | 72 | ++def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1):  | 
 | 73 | ++  | 
 | 74 | ++  mm = jnp.einsum('bxy, byz -> xz', batch1, batch2)  | 
 | 75 | ++  return beta * input + alpha * mm  | 
 | 76 | +```  | 
 | 77 | +
  | 
 | 78 | +Now running test again:  | 
 | 79 | +
  | 
 | 80 | +```  | 
 | 81 | +python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64  | 
 | 82 | +```  | 
 | 83 | +
  | 
 | 84 | +(NOTE: the exact test command is printed out when we run   | 
 | 85 | +`pytest test/test_ops.py` so we can only run the failed test instead of running all tests.)  | 
 | 86 | +
  | 
 | 87 | +We now see this error:  | 
 | 88 | +
  | 
 | 89 | +```  | 
 | 90 | +FAIL: test_reference_eager_addbmm_cpu_int64 (__main__.TestOpInfoCPU) [torch_xla2_diff:0.001]  | 
 | 91 | +----------------------------------------------------------------------  | 
 | 92 | +Traceback (most recent call last):  | 
 | 93 | +  File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 654, in run_export_and_compare  | 
 | 94 | +    diff_output(  | 
 | 95 | +  File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 617, in diff_output  | 
 | 96 | +    testcase.assertTrue(  | 
 | 97 | +AssertionError: False is not true  | 
 | 98 | +```  | 
 | 99 | +
  | 
 | 100 | +This is telling me that our implementation did not produce   | 
 | 101 | +the same result as the ops in PyTorch.  | 
 | 102 | +
  | 
 | 103 | +To debug this, let's figure out what exact input caused this.  | 
 | 104 | +We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/test/test_ops.py#L644), right before the diff. Here we can   | 
 | 105 | +inspect values of `res` and `res2`, as well as the `sample_input`.  | 
 | 106 | +
  | 
 | 107 | +The sample input we get is   | 
 | 108 | +```  | 
 | 109 | +SampleInput(input=tensor([[-3, -3,  9,  8, -8, -3, -4,  2,  2,  2],  | 
 | 110 | +        [-5,  1, -9,  9,  1, -5,  6,  1, -4, -5],  | 
 | 111 | +        [-2, -1,  5, -2, -3,  0,  5, -4,  9, -6],  | 
 | 112 | +        [-1, -7,  6,  3,  8,  3,  8,  9, -5,  7],  | 
 | 113 | +        [-3, -4, -9,  9,  7, -3, -8,  2,  5, -3]]), args=(tensor([[[-2,  4, -2,  5,  8],  | 
 | 114 | +         [-6, -2,  5,  7,  7],  | 
 | 115 | +         [-8, -3,  2,  5, -3],  | 
 | 116 | +         [-4,  7,  0, -9,  8],  | 
 | 117 | +         [ 3,  9, -9, -2,  0]],  | 
 | 118 | +
  | 
 | 119 | +        [[-7,  1, -3,  7, -4],  | 
 | 120 | +         [ 3,  5,  4,  6,  5],  | 
 | 121 | +         [-2,  8,  3,  5,  7],  | 
 | 122 | +         [ 8, -2, -8,  2,  0],  | 
 | 123 | +         [ 6,  1, -8,  8,  0]],  | 
 | 124 | +
  | 
 | 125 | +        [[ 2, -1, -5, -8, -9],  | 
 | 126 | +         [ 5,  0, -4, -1, -6],  | 
 | 127 | +         [-6,  2, -5, -2, -5],  | 
 | 128 | +         [-5, -3, -5, -4,  9],  | 
 | 129 | +         [-3,  4, -9, -9,  7]],  | 
 | 130 | +
  | 
 | 131 | +        [[ 2,  5, -7, -3,  8],  | 
 | 132 | +         [-5, -7, -8, -4,  4],  | 
 | 133 | +         [-4, -6, -3,  0,  6],  | 
 | 134 | +         [ 8,  0, -3, -8,  2],  | 
 | 135 | +         [-4,  3, -9, -6,  7]],  | 
 | 136 | +
  | 
 | 137 | +        [[ 2,  1, -6,  2,  8],  | 
 | 138 | +         [ 2,  6,  4,  1,  8],  | 
 | 139 | +         [-9,  9, -5,  8,  3],  | 
 | 140 | +         [-5,  0, -2,  4,  0],  | 
 | 141 | +         [ 5,  8, -4,  9,  7]]]), tensor([[[-1, -8,  3,  5, -8,  2, -5,  0, -9, -5],  | 
 | 142 | +         [-4, -7,  2,  2,  1, -9,  2,  7, -1, -1],  | 
 | 143 | +         [ 1,  8, -6, -4, -6, -8, -7, -9,  7,  4],  | 
 | 144 | +         [-4,  1, -9,  3,  4,  6,  0, -2, -2, -7],  | 
 | 145 | +         [ 5,  5,  0,  8, -3,  7, -7,  8,  3,  5]],  | 
 | 146 | +
  | 
 | 147 | +        [[ 8, -4, -9,  9,  5,  0,  5,  0, -5,  5],  | 
 | 148 | +         [-5, -3, -2,  8,  1, -2,  4, -7,  5,  3],  | 
 | 149 | +         [-4,  4,  1, -4, -8,  2, -5,  2,  9, -7],  | 
 | 150 | +         [ 9,  6, -8, -3,  3,  1,  4,  6, -5, -4],  | 
 | 151 | +         [-2,  1,  5,  5,  2,  6,  7, -3, -7,  3]],  | 
 | 152 | +
  | 
 | 153 | +        [[ 9, -8,  5, -3, -1,  2, -9, -5, -1, -3],  | 
 | 154 | +         [-3,  3, -9, -7, -9, -8,  1, -3,  7, -2],  | 
 | 155 | +         [ 8, -1,  8, -8, -7,  4,  8,  8,  5, -7],  | 
 | 156 | +         [-1,  6, -8,  7, -1, -5, -8,  6, -2,  8],  | 
 | 157 | +         [-5, -5,  8,  6,  0,  1,  3, -2, -3, -9]],  | 
 | 158 | +
  | 
 | 159 | +        [[ 7, -2,  6, -8, -5,  3,  2, -1, -5,  8],  | 
 | 160 | +         [-6, -4,  3,  9, -9, -8, -7,  3,  9,  0],  | 
 | 161 | +         [ 1,  3,  4,  4, -5, -2, -4, -2,  3, -7],  | 
 | 162 | +         [-6,  9,  5, -1,  7,  7,  8, -3, -8,  0],  | 
 | 163 | +         [-1, -6, -3,  3,  3, -8, -4,  9, -5,  7]],  | 
 | 164 | +
  | 
 | 165 | +        [[-5, -3, -9,  6, -1, -7,  9, -8,  1, -8],  | 
 | 166 | +         [-8, -8, -2, -5, -7, -8,  1,  0,  0, -6],  | 
 | 167 | +         [ 7, -5,  2,  2,  0, -9, -5, -7,  1,  8],  | 
 | 168 | +         [-4,  0,  9,  6, -1, -6,  6, -6, -2, -1],  | 
 | 169 | +         [ 7,  3,  0,  1,  1, -9,  5, -8, -1, -7]]])), kwargs={'beta': 0.6, 'alpha': 0.2}, broadcasts_input=False, name='')  | 
 | 170 | +```  | 
 | 171 | +
  | 
 | 172 | +And the `res` from torch is   | 
 | 173 | +
  | 
 | 174 | +```  | 
 | 175 | +tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  | 
 | 176 | +        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  | 
 | 177 | +        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  | 
 | 178 | +        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  | 
 | 179 | +        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])  | 
 | 180 | +```  | 
 | 181 | +
  | 
 | 182 | +So few observation is:  | 
 | 183 | +1. Input tensor are of type int64  | 
 | 184 | +2. alpha and beta are both floats.  | 
 | 185 | +
  | 
 | 186 | +So one can suspect that it has to do with rounding.  | 
 | 187 | +Reading the doc more carefully, we can find this sentence  | 
 | 188 | +
  | 
 | 189 | +    For inputs of type FloatTensor or DoubleTensor, arguments beta and alpha must be real numbers, otherwise they should be integers.  | 
 | 190 | +
  | 
 | 191 | +So likely torch first casted the float alpha and beta to integer, which yields 0, then used them in math to get a matrix with all zeros.  | 
 | 192 | +
  | 
 | 193 | +### Second Impl  | 
 | 194 | +
  | 
 | 195 | +```python  | 
 | 196 | ++@op(torch.ops.aten.addbmm.default)  | 
 | 197 | ++def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1):  | 
 | 198 | ++  alpha = jnp.array(alpha).astype(batch1.dtype)  | 
 | 199 | ++  beta = jnp.array(beta).astype(batch1.dtype)  | 
 | 200 | ++  mm = jnp.einsum('bxy, byz -> xz', batch1, batch2)  | 
 | 201 | ++  return jax.lax.cond(beta == 0,  | 
 | 202 | ++           lambda: alpha * mm,  | 
 | 203 | ++           lambda: beta*input + alpha*mm)  | 
 | 204 | ++  | 
 | 205 | +```  | 
 | 206 | +
  | 
 | 207 | +Adding type casts makes the tests passes.  | 
 | 208 | +
  | 
 | 209 | +### Submit  | 
 | 210 | +Now, let's remove the pdb and prints we added, and submit the fix as a PR: https://github.com/pytorch/xla/pull/6993  | 
 | 211 | +
  | 
0 commit comments