diff --git a/torch_scatter/div.py b/torch_scatter/div.py index 13d76f70..ef8da970 100644 --- a/torch_scatter/div.py +++ b/torch_scatter/div.py @@ -18,7 +18,7 @@ def forward(ctx, out, src, index, dim): @staticmethod def backward(ctx, grad_out): - out, src, index = ctx.saved_variables + out, src, index = ctx.saved_tensors grad_src = None if ctx.needs_input_grad[1]: diff --git a/torch_scatter/max.py b/torch_scatter/max.py index b437ac70..6cf10703 100644 --- a/torch_scatter/max.py +++ b/torch_scatter/max.py @@ -19,7 +19,7 @@ def forward(ctx, out, src, index, dim): @staticmethod def backward(ctx, grad_out, grad_arg): - index, arg = ctx.saved_variables + index, arg = ctx.saved_tensors grad_src = None if ctx.needs_input_grad[1]: diff --git a/torch_scatter/min.py b/torch_scatter/min.py index dd387ed7..ad2670a2 100644 --- a/torch_scatter/min.py +++ b/torch_scatter/min.py @@ -19,7 +19,7 @@ def forward(ctx, out, src, index, dim): @staticmethod def backward(ctx, grad_out, grad_arg): - index, arg = ctx.saved_variables + index, arg = ctx.saved_tensors grad_src = None if ctx.needs_input_grad[1]: diff --git a/torch_scatter/mul.py b/torch_scatter/mul.py index 1ad19c63..9ba65ec6 100644 --- a/torch_scatter/mul.py +++ b/torch_scatter/mul.py @@ -18,7 +18,7 @@ def forward(ctx, out, src, index, dim): @staticmethod def backward(ctx, grad_out): - out, src, index = ctx.saved_variables + out, src, index = ctx.saved_tensors grad_src = None if ctx.needs_input_grad[1]: