In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

### Simple examples of implementing Custom autograd subclasses

In [2]:
class MyReLU(torch.autograd.Function):
    
    """
    To implement a custom Autograd Function, we need to subclass the torch.autograd.Function and 
    implement forward and backward passes which operate on Tensors 
    """
    
    
    @staticmethod
    def forward(ctx, inputs):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(inputs)
        return inputs.clamp(min=0)
    
    
    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        inputs, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[inputs < 0] = 0
        return grad_input

In [3]:
from torch.autograd import gradcheck
inputs = torch.rand(2, 3, dtype = torch.double, requires_grad = True)
grad_output = torch.rand(2, 3, requires_grad = True, dtype = torch.float)
relu = MyReLU.apply
# bla = relu(inputs)
# bla.backward(grad_output)
test = gradcheck(relu, inputs, eps = 1e-6, atol=1e-4)
# print(test)

In [4]:
class MatrixSum(torch.autograd.Function):
    """
    Implements a simple matrix sum class with autograd support
    """
    @staticmethod
    def forward(ctx, inputs):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(inputs)
        return inputs.sum()

    @staticmethod
    def backward(ctx, grad_outputs):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        inputs, = ctx.saved_tensors
        dsumdx = torch.ones_like(inputs)
        dinputs =  grad_outputs * dsumdx
        return dinputs

In [5]:
class MySub(torch.autograd.Function):
    """
    Implements a custom matrix subtraction class with autograd support
    """

    # def __init__(self, )
    @staticmethod
    def forward(inp, other, alpha = 1):
        output = inp - alpha * other
        return output
        
    @staticmethod
    def setup_context(ctx, inp, output):
        inp, other, alpha = inp
        ctx.save_for_backward(inp, other)
        ctx.constant = alpha
    
    @staticmethod
    def backward(ctx, grad_out):
        inputs, other = ctx.saved_tensors
        alpha = ctx.constant
        dinputs = grad_out * torch.ones_like(inputs)
        dother = grad_out * - alpha * torch.ones_like(other)
        return dinputs, dother, None

### Simple example of implementing custom autograd for Non-differential functions

In [27]:
# Defining a non differentiable CrossEntropy loss using Numpy

def nondiff_crossentropyloss(y_true, y_pred):
    y_true_numpy = y_true.detach().numpy()
    y_pred_numpy = y_pred.detach().numpy()
    # Clipping ensures we never hit log(0) scenario
    y_pred_numpy_clipped = np.clip(y_pred_numpy, 1e-7, 1 - 1e-7)
    # Return the crossentropy loss as a tensor
    np_loss = np.sum(-y_true_numpy * np.log(y_pred_numpy_clipped))
    loss = torch.tensor(np_loss,
                        dtype = y_true.dtype, 
                        device = y_true.device) / y_true.shape[0]
    return loss

class NonDiffCrossEntropy(torch.autograd.Function):
    """
    Creating an autograd compatible wrapper class that wraps the non-differentiable crossentropy 
    function defined above to be differntiable 
    """
    @staticmethod
    def forward(ctx, y_true, y_pred):
        ctx.save_for_backward(y_true, y_pred)
        # print(nondiff_crossentropyloss(y_true, y_pred))
        return nondiff_crossentropyloss(y_true, y_pred)
        
    @staticmethod
    def backward(ctx, grad_output):
        y_true, y_pred = ctx.saved_tensors
        N = y_true.shape[0]
        dy_pred = (-y_true/y_pred) / N
        return None, dy_pred * grad_output

In [28]:
A = torch.rand(3, 3, requires_grad=False, dtype = torch.double)
B = torch.rand(3, 3, requires_grad=True, dtype = torch.double)
loss = NonDiffCrossEntropy.apply
gradcheck(loss, [A, B])

True