# Extending PyTorch differentiable functions

In this notebook you'll see how to add your custom differentiable function for which you need to specify `forward` and `backward` passes.

In [1]:
!pip3 install http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp36-cp36m-linux_x86_64.whl

Collecting torch==0.3.0.post4
[?25l  Downloading http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp36-cp36m-linux_x86_64.whl (592.3MB)
[K     |████████████████████████████████| 592.3MB 1.1MB/s 
[31mERROR: torchvision 0.7.0+cu101 has requirement torch==1.6.0, but you'll have torch 0.3.0.post4 which is incompatible.[0m
[31mERROR: fastai 1.0.61 has requirement torch>=1.0.0, but you'll have torch 0.3.0.post4 which is incompatible.[0m
Installing collected packages: torch
  Found existing installation: torch 1.6.0+cu101
    Uninstalling torch-1.6.0+cu101:
      Successfully uninstalled torch-1.6.0+cu101
Successfully installed torch-0.3.0.post4


In [2]:
!pip3 install torchvision

Collecting torch==1.6.0
[?25l  Downloading https://files.pythonhosted.org/packages/38/53/914885a93a44b96c0dd1c36f36ff10afe341f091230aad68f7228d61db1e/torch-1.6.0-cp36-cp36m-manylinux1_x86_64.whl (748.8MB)
[K     |████████████████████████████████| 748.8MB 22kB/s 
Installing collected packages: torch
  Found existing installation: torch 0.3.0.post4
    Uninstalling torch-0.3.0.post4:
      Successfully uninstalled torch-0.3.0.post4
Successfully installed torch-1.6.0


In [3]:
# Import some libraries
import torch
import numpy

In [4]:
# Custom addition module
class MyAdd(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x1, x2):
        # ctx is a context where we can save
        # computations for backward.
        ctx.save_for_backward(x1, x2)
        return x1 + x2

    @staticmethod
    def backward(ctx, grad_output):
        x1, x2 = ctx.saved_tensors
        grad_x1 = grad_output * torch.ones_like(x1)
        grad_x2 = grad_output * torch.ones_like(x2)
        # need to return grads in order 
        # of inputs to forward (excluding ctx)
        return grad_x1, grad_x2

In [6]:
# Let's try out the addition module
x1 = torch.randn((3), requires_grad=True)
x2 = torch.randn((3), requires_grad=True)
print(f'x1: {x1}')
print(f'x2: {x2}')
myadd = MyAdd.apply  # aliasing the apply method
y = myadd(x1, x2)
print(f' y: {y}')
z = y.mean()
print(f' z: {z}, z.grad_fn: {z.grad_fn}')
z.backward()
print(f'x1.grad: {x1.grad}')
print(f'x2.grad: {x2.grad}')

x1: tensor([ 1.0212, -1.1056,  0.4566], requires_grad=True)
x2: tensor([0.1209, 1.1487, 0.9463], requires_grad=True)
 y: tensor([1.1421, 0.0431, 1.4029], grad_fn=<MyAddBackward>)
 z: 0.8626823425292969, z.grad_fn: <MeanBackward0 object at 0x7ff71288a828>
x1.grad: tensor([0.3333, 0.3333, 0.3333])
x2.grad: tensor([0.3333, 0.3333, 0.3333])


In [7]:
# Custom split module
class MySplit(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        x1 = x.clone()
        x2 = x.clone()
        return x1, x2
        
    @staticmethod
    def backward(ctx, grad_x1, grad_x2):
        x = ctx.saved_tensors[0]
        print(f'grad_x1: {grad_x1}')
        print(f'grad_x2: {grad_x2}')
        return grad_x1 + grad_x2

In [8]:
# Let's try out the split module
x = torch.randn((4), requires_grad=True)
print(f' x: {x}')
split = MySplit.apply
x1, x2 = split(x)
print(f'x1: {x1}')
print(f'x2: {x2}')
y = x1 + x2
print(f' y: {y}')
z = y.mean()
print(f' z: {z}, z.grad_fn: {z.grad_fn}')
z.backward()
print(f' x.grad: {x.grad}')

 x: tensor([-0.0619, -0.1381, -0.1914, -0.4372], requires_grad=True)
x1: tensor([-0.0619, -0.1381, -0.1914, -0.4372], grad_fn=<MySplitBackward>)
x2: tensor([-0.0619, -0.1381, -0.1914, -0.4372], grad_fn=<MySplitBackward>)
 y: tensor([-0.1238, -0.2762, -0.3828, -0.8744], grad_fn=<AddBackward0>)
 z: -0.4143010079860687, z.grad_fn: <MeanBackward0 object at 0x7ff712894908>
grad_x1: tensor([0.2500, 0.2500, 0.2500, 0.2500])
grad_x2: tensor([0.2500, 0.2500, 0.2500, 0.2500])
 x.grad: tensor([0.5000, 0.5000, 0.5000, 0.5000])


In [9]:
# Custom max module
class MyMax(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # example where we explicitly use non-torch code
        maximum = x.detach().numpy().max()
        argmax = x.detach().eq(maximum).float()
        ctx.save_for_backward(argmax)
        return torch.tensor(maximum)
    @staticmethod
    def backward(ctx, grad_output):
        argmax = ctx.saved_tensors[0]
        return grad_output * argmax

In [10]:
# Let's try out the max module
x = torch.randn((5), requires_grad=True)
print(f'x: {x}')
mymax = MyMax.apply
y = mymax(x)
print(f'y: {y}, y.grad_fn: {y.grad_fn}')
y.backward()
print(f'x.grad: {x.grad}')

x: tensor([ 0.1203, -0.9562,  1.1946, -0.5337,  0.0887], requires_grad=True)
y: 1.1945806741714478, y.grad_fn: <torch.autograd.function.MyMaxBackward object at 0x7ff7128df668>
x.grad: tensor([0., 0., 1., 0., 0.])
