In [1]:
import torch
import numpy

In [2]:
# Custom addition module
class MyAdd(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x1, x2):
        # ctx is a context where we can save
        # computations for backward.
        ctx.save_for_backward(x1, x2)
        return x1 + x2

    @staticmethod
    def backward(ctx, grad_output):
        x1, x2 = ctx.saved_tensors
        grad_x1 = grad_output * torch.ones_like(x1)
        grad_x2 = grad_output * torch.ones_like(x2)
        # need to return grads in order 
        # of inputs to forward (excluding ctx)
        return grad_x1, grad_x2

In [3]:
# Let's try out the addition module
x1 = torch.randn((3), requires_grad=True)
x2 = torch.randn((3), requires_grad=True)
print(f'x1: {x1}')
print(f'x2: {x2}')
myadd = MyAdd.apply  # aliasing the apply method
y = myadd(x1, x2)
print(f' y: {y}')
z = y.mean()
print(f' z: {z}, z.grad_fn: {z.grad_fn}')
z.backward()
print(f'x1.grad: {x1.grad}')
print(f'x2.grad: {x2.grad}')

x1: tensor([-0.2656,  0.7681, -0.5370], requires_grad=True)
x2: tensor([1.7043, 3.0458, 0.8894], requires_grad=True)
 y: tensor([1.4387, 3.8139, 0.3524], grad_fn=<MyAddBackward>)
 z: 1.8683284521102905, z.grad_fn: <MeanBackward0 object at 0x7fadb31eb8e0>
x1.grad: tensor([0.3333, 0.3333, 0.3333])
x2.grad: tensor([0.3333, 0.3333, 0.3333])


In [4]:
# Custom split module
class MySplit(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        x1 = x.clone()
        x2 = x.clone()
        return x1, x2
        
    @staticmethod
    def backward(ctx, grad_x1, grad_x2):
        x = ctx.saved_tensors[0]
        print(f'grad_x1: {grad_x1}')
        print(f'grad_x2: {grad_x2}')
        return grad_x1 + grad_x2

In [5]:
# Let's try out the split module
x = torch.randn((4), requires_grad=True)
print(f' x: {x}')
split = MySplit.apply
x1, x2 = split(x)
print(f'x1: {x1}')
print(f'x2: {x2}')
y = x1 + x2
print(f' y: {y}')
z = y.mean()
print(f' z: {z}, z.grad_fn: {z.grad_fn}')
z.backward()
print(f' x.grad: {x.grad}')

 x: tensor([-2.3724,  0.4210,  0.6205,  1.0377], requires_grad=True)
x1: tensor([-2.3724,  0.4210,  0.6205,  1.0377], grad_fn=<MySplitBackward>)
x2: tensor([-2.3724,  0.4210,  0.6205,  1.0377], grad_fn=<MySplitBackward>)
 y: tensor([-4.7448,  0.8420,  1.2410,  2.0755], grad_fn=<AddBackward0>)
 z: -0.1465766429901123, z.grad_fn: <MeanBackward0 object at 0x7fadb4f739d0>
grad_x1: tensor([0.2500, 0.2500, 0.2500, 0.2500])
grad_x2: tensor([0.2500, 0.2500, 0.2500, 0.2500])
 x.grad: tensor([0.5000, 0.5000, 0.5000, 0.5000])


In [8]:
# Custom max module
class MyMax(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # example where we explicitly use non-torch code
        maximum = x.detach().numpy().max()
        argmax = x.detach().eq(maximum).float()
        ctx.save_for_backward(argmax)
        return torch.tensor(maximum)
    @staticmethod
    def backward(ctx, grad_output):
        argmax = ctx.saved_tensors[0]
        print(f'grad_output: {grad_output}')
        return grad_output * argmax

In [14]:
# Let's try out the max module
x = torch.randn((5), requires_grad=True)
print(f'x: {x}')
mymax = MyMax.apply
y = mymax(x)
print(f'y: {y}, y.grad_fn: {y.grad_fn}')
y.backward()
print(f'x.grad: {x.grad}')

x: tensor([ 0.1040,  2.2136,  1.9183, -0.3757,  1.1987], requires_grad=True)
y: 2.2136030197143555, y.grad_fn: <torch.autograd.function.MyMaxBackward object at 0x7fadb4f5c5f0>
grad_output: 1.0
x.grad: tensor([0., 1., 0., 0., 0.])
