In [6]:
import torch # import main library
from torch.autograd import Function # import Function to create custom activations

In [7]:
import torch
import torch._C as _C
import torch.utils.hooks as hooks
from torch._six import with_metaclass
import functools
from collections import OrderedDict


class _ContextMethodMixin(object):

    def save_for_backward(self, *tensors):
        """Saves given tensors for a future call to :func:`~Function.backward`.

        **This should be called at most once, and only from inside the**
        :func:`forward` **method.**

        Later, saved tensors can be accessed through the :attr:`saved_tensors`
        attribute; or, if the corresponding Variable is needed (e.g. for double
        backwards), those can be accessed through the :attr:`saved_variables`
        attribute.  Before returning them to the user, a check is made, to ensure
        they weren't used in any in-place operation that modified their content.

        Arguments can also be ``None``.
        """
        self.to_save = tensors

    def mark_dirty(self, *args):
        """Marks given tensors as modified in an in-place operation.

        **This should be called at most once, only from inside the**
        :func:`forward` **method, and all arguments should be inputs.**

        Every tensor that's been modified in-place in a call to :func:`forward`
        should be given to this function, to ensure correctness of our checks.
        It doesn't matter whether the function is called before or after
        modification.
        """
        self.dirty_tensors = args

    def mark_shared_storage(self, *pairs):
        """Marks that given pairs of distinct tensors are sharing storage.

        **This should be called at most once, only from inside the**
        :func:`forward` **method, and all arguments should be pairs of
        (input, output).**

        If some of the outputs are going to be tensors sharing storage with
        some of the inputs, all pairs of (input_arg, output_arg) should be
        given to this function, to ensure correctness checking of in-place
        modification. The only exception is when an output is exactly the same
        tensor as input (e.g. in-place ops). In such case it's easy to conclude
        that they're sharing data, so we don't require specifying such
        dependencies.

        This function is not needed in most functions. It's primarily used in
        indexing and transpose ops.
        """
        self.shared_pairs = pairs

    def mark_non_differentiable(self, *args):
        """Marks outputs as non-differentiable.

        **This should be called at most once, only from inside the**
        :func:`forward` **method, and all arguments should be outputs.**

        This will mark outputs as not requiring gradients, increasing the
        efficiency of backward computation. You still need to accept a gradient
        for each output in :meth:`~Function.backward`, but it's always going to
        be ``None``.

        This is used e.g. for indices returned from a max :class:`Function`.
        """
        self.non_differentiable = args

In [54]:
class SignEst(Function):

    @staticmethod
    def forward(ctx, input):

        ctx.save_for_backward(input)  # save input for backward pass

        # clone the input tensor
        output = input.clone()
        output[output >= 0] = 1.
        output[output < 0] = -1.
        
        return output

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = None  # set output to None
        #input, = ctx.saved_tensors
        input = torch.tensor([0.7, -1.2, 0., 2.3])
        grad_input = input.clone()
        grad_input[torch.abs(input)>=1.] = 0.
        grad_input[torch.abs(input)<1.] = 1.
        grad_input = grad_input*grad_output

        return grad_input

In [55]:
a = torch.tensor([0.7, -1.2, 0., 2.3])
ctx = _ContextMethodMixin()

In [56]:
b = SignEst.forward(ctx, a)
print(b)

tensor([ 1., -1.,  1.,  1.])


In [57]:
grad_output = torch.tensor([0.5, 1.5, 2., 3.])
grad_input = SignEst.backward(ctx, grad_output)
grad_input

tensor([0.5000, 0.0000, 2.0000, 0.0000])