Skip to content

Commit

Permalink
Add complex autograd support for diag, diagflat
Browse files Browse the repository at this point in the history
  • Loading branch information
rjkilpatrick committed Jan 16, 2021
1 parent 2001f3a commit 22afebc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
8 changes: 4 additions & 4 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5022,10 +5022,10 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_',
'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh',
'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul',
'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub',
'exp', 'mean', 'inverse', 'triangular_solve', 'solve', 'addcmul',
'addcdiv', 'linalg.tensorinv', 'matrix_exp', 'qr',
'narrow', 'swapaxes', 'swapdims', 'tensor_split', 'tile'] + separate_complex_tests
'bmm', 'mv', 'ger', 'diag', 'diagflat', 'diagonal', 'atan', 'angle', 'tanh',
'fill_', 'sub', 'exp', 'mean', 'inverse', 'triangular_solve', 'solve',
'addcmul', 'addcdiv', 'linalg.tensorinv', 'matrix_exp', 'qr', 'narrow',
'swapaxes', 'swapdims', 'tensor_split', 'tile'] + separate_complex_tests

def add_test(
name,
Expand Down
11 changes: 6 additions & 5 deletions tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,12 @@
'neg', 'complex', 'select', '_s_where', 'as_strided', 'slice', 'constant_pad_nd',
'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward',
'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger',
'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_',
'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv',
'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'linalg_qr', '_svd_helper', '_fft_c2c', '_fft_r2c',
'linalg_solve', 'sqrt', 'stack', 'gather', 'index_select', 'index_add_', 'linalg_inv',
'bmm', 'diagonal', 'diag', 'diagflat', 'alias', 'atan', 'log', 'log10', 'log1p',
'log2', 'reciprocal', 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh',
'acosh', 'take', 'fill_', 'exp', 'nonzero', 'mean', 'inverse', 'solve',
'linalg_cholesky', 'addcmul', 'addcdiv', 'matrix_exp', 'linalg_eigh',
'cholesky_solve', 'linalg_qr', '_svd_helper', '_fft_c2c', '_fft_r2c', 'linalg_solve',
'sqrt', 'stack', 'gather', 'index_select', 'index_add_', 'linalg_inv',
'l1_loss_backward'
}

Expand Down

0 comments on commit 22afebc

Please sign in to comment.