Skip to content

Commit

Permalink
Add Lowering for softshrink_backward
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Jul 19, 2023
1 parent 4448c78 commit e9d4e13
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 477 deletions.
10 changes: 10 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2014,6 +2014,16 @@ def fn(a):
with self.assertRaisesRegex(RuntimeError, ""):
fn(torch.randn(1, 5))

def test_softshrink_backward(self):
grad_output = torch.randn(1)
lambd = 0.5

def fn(a, grad_output, lambd):
a = a.cos()
return torch.ops.aten.softshrink_backward(grad_output, a, lambd)

self.common(fn, (torch.randn(1), grad_output, lambd))

def test_inductor_assert(self):
@torch._dynamo.optimize("inductor", dynamic=True)
def fn(a):
Expand Down

0 comments on commit e9d4e13

Please sign in to comment.