<a href="https://colab.research.google.com/github/xisnu/CustomGradientTutorial/blob/master/CustomGradTF2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Import Tensorflow 2

In [2]:
import tensorflow as tf

We are trying to compute the following function in forward pass

$y=5x^2$ + 3

But we want to do it in two steps,

1. $y_{1}=x^{2}$
1. $y_{2}=5y_{1} + 3$

In [6]:
def step_1(x):
  return x**2

def step_2(x):
  return 5*x + 3

# Test them
x = tf.constant(6.0) # y = 5 * 6^2 + 3 = 183
y1 = step_1(x)
print("Step 1 Output",y1.numpy())
y2 = step_2(y1)
print("Step 2 Output",y2.numpy())

Step 1 Output 36.0
Step 2 Output 183.0


Now let's find the derivatives from each step

1. $\frac{dy_{1}}{dx} = 2x$
1. $\frac{dy_{2}}{dx} = 10x$

Test what TensorFlow computes for step 1

In [7]:
with tf.GradientTape() as tape:
  tape.watch(x)
  y1 = step_1(x)
  y2 = step_2(y1)
  grad = tape.gradient(y1,x)
print("Gradient for step 1",grad.numpy()) # should be 2x

Gradient for step 1 12.0


Test gradient for step 2

In [8]:
with tf.GradientTape() as tape:
  tape.watch(x)
  y1 = step_1(x)
  y2 = step_2(y1)
  grad = tape.gradient(y2,x)
print("Gradient for step 1",grad.numpy()) # should be 10x

Gradient for step 1 60.0


So for standard functions TF computes both the gradients perfectly. But what happens if one of the functions is not differentiable.

In [20]:
def bad_step_1(x): # Not Differentiable everywhere
  if(x>5):
    return x**2
  else:
    return tf.constant(0.0)

This is something like RELU

Now,

$y_{2} = 5x^{2} + 3 \quad if \;  x \gt 5 $

$y_{2} = 3 \quad otherwise $

Note that the left and right derivative at $x=5$ is different, therefore this function is not differentiable at $x=5$

First, test with $x>5$, 

the derivatives are $\frac{dy_{1}}{dx}=2x$ and $\frac{dy_{2}}{dx}=10x$

In [21]:
# for x > 5
x = tf.constant(7.0) # y = 5 * 7^2 + 3 = 248
with tf.GradientTape() as tape1,tf.GradientTape() as tape2:
  tape1.watch(x)
  tape2.watch(x)
  y1 = bad_step_1(x)
  y2 = step_2(y1)
  grad1 = tape1.gradient(y1,x)
  grad2 = tape2.gradient(y2,x)
print("First step output ",y1.numpy())
print("Second step output ",y2.numpy())
print("Gradient for Step 1 ",grad1.numpy())
print("Gradient for Step 2 ",grad2.numpy())

First step output  49.0
Second step output  248.0
Gradient for Step 1  14.0
Gradient for Step 2  70.0


and for $x\le5$

Automatic differentiation is not possible

In [22]:
# for x < 5
x = tf.constant(2.0) # y = 5 * 0 + 3 = 3
with tf.GradientTape() as tape1,tf.GradientTape() as tape2:
  tape1.watch(x)
  tape2.watch(x)
  y1 = bad_step_1(x)
  y2 = step_2(y1)
  grad1 = tape1.gradient(y1,x)
  grad2 = tape2.gradient(y2,x)
print("First step output ",y1.numpy())
print("Second step output ",y2.numpy())
print("Gradient for Step 1 ",grad1)
print("Gradient for Step 2 ",grad2)

First step output  0.0
Second step output  3.0
Gradient for Step 1  None
Gradient for Step 2  None


This is a problem as sometimes in our forward pass we need to include operations that are not differentiable in backpropagation. We need to define our custom gradient for those functions. This computation is defined by us. In this example 

$\frac{dy_{1}}{dx}=2x \quad if \;\; x\gt5$ and

$\frac{dy_{1}}{dx}=1 \quad otherwise$

This is not true in true mathematical sense. We are trying to pass `incoming` gradients **Straight Through**, this apporoach is similar to **Straight Through Estimator**

In [25]:
@tf.custom_gradient
def bad_step_grad(x): # custom gradient everywhere
  f = bad_step_1(x) # this is what the function will do in forward pass
  def custom_grad(dy):
    if(x>5):
      g = 2*x * dy # what automatic differentiation does
    else:
      g = dy # this is waht we are defining
    return g
  return f,custom_grad

Time to test our Custom Gradient

In [27]:
# for x < 5
x = tf.constant(2.0) # y = 5 * 0 + 3 = 3
with tf.GradientTape() as tape1,tf.GradientTape() as tape2:
  tape1.watch(x)
  tape2.watch(x)
  y1 = bad_step_grad(x)
  y2 = step_2(y1)
  grad1 = tape1.gradient(y1,x)
  grad2 = tape2.gradient(y2,x)
print("First step output ",y1.numpy())
print("Second step output ",y2.numpy())
print("Gradient for Step 1 ",grad1.numpy())
print("Gradient for Step 2 ",grad2.numpy())

First step output  0.0
Second step output  3.0
Gradient for Step 1  1.0
Gradient for Step 2  5.0


For $x\gt5$ this will behave like automatic gradient

In [28]:
# for x > 5
x = tf.constant(7.0) # y = 5 * 7^2 + 3 = 248
with tf.GradientTape() as tape1,tf.GradientTape() as tape2:
  tape1.watch(x)
  tape2.watch(x)
  y1 = bad_step_grad(x)
  y2 = step_2(y1)
  grad1 = tape1.gradient(y1,x)
  grad2 = tape2.gradient(y2,x)
print("First step output ",y1.numpy())
print("Second step output ",y2.numpy())
print("Gradient for Step 1 ",grad1.numpy())
print("Gradient for Step 2 ",grad2.numpy())

First step output  49.0
Second step output  248.0
Gradient for Step 1  14.0
Gradient for Step 2  70.0


This demonstration gives us a good insight of gradient computaion in TF2, and also illustrates the differentiability of **RELU** and **Straight Through Estimator**