## 1. Simple Computational Graph Example with Pytorch

### 1.1 Build the computaional graph as shown in figure

In [1]:
import torch

class SimpleGraph(torch.nn.Module):
    
    def __init__(self):
        super(SimpleGraph, self).__init__()
        # === Initialize Weights === #
        self.w1 = torch.nn.Parameter(data=torch.Tensor([5]), requires_grad=True)
        self.w2 = torch.nn.Parameter(data=torch.Tensor([6]), requires_grad=True)
    
    def forward(self, a):
        b = self.w1 * a
        c = self.w2 * a
        d = b + c
        L = d.sum()
        return L

In [2]:
net = SimpleGraph()
a = torch.Tensor([5])
out = net(a)

### Print parameters and gradients
### Gradients are None as we haven't calculated them yet

for i in net.named_parameters():
    param_name, param_tensor = i
    print(f"The Gradient of Loss w.r.t {param_name}={param_tensor.data.item()} -> {param_tensor.grad}")

The Gradient of Loss w.r.t w1=5.0 -> None
The Gradient of Loss w.r.t w2=6.0 -> None


### Compute the derivative and print params Again

In [3]:
out.backward(retain_graph=False)
for i in net.named_parameters():
    param_name, param_tensor = i
    print(f"The Gradient of Loss w.r.t {param_name}={param_tensor.data.item()} -> {param_tensor.grad.item()}")

The Gradient of Loss w.r.t w1=5.0 -> 5.0
The Gradient of Loss w.r.t w2=6.0 -> 5.0


In [5]:
## Try Doing backward again, This will give error ##
try:
    out.backward()
except RuntimeError as e:
    print(f'Error : \n{e}')

Error : 
Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.


In [6]:
# Repeating out.backward again will give error, as by default, 
# intermediate local gradients are cleared to save memory
# To disable that, we give `retain_graph=True`
net = SimpleGraph()
a = torch.Tensor([5])
out = net(a)

out.backward(retain_graph=True)
for i in net.named_parameters():
    param_name, param_tensor = i
    print(f"The Gradient of Loss w.r.t {param_name}={param_tensor.data.item()} -> {param_tensor.grad.item()}")

The Gradient of Loss w.r.t w1=5.0 -> 5.0
The Gradient of Loss w.r.t w2=6.0 -> 5.0


In [7]:
## Let's repeat backward again and see what happens
out.backward(retain_graph=True)
for i in net.named_parameters():
    param_name, param_tensor = i
    print(f"The Gradient of Loss w.r.t {param_name}={param_tensor.data.item()} -> {param_tensor.grad.item()}")

# As you can see, gradients got added

The Gradient of Loss w.r.t w1=5.0 -> 10.0
The Gradient of Loss w.r.t w2=6.0 -> 10.0


In [8]:
## Doing this again in a new-cell, gradients further get added.
out.backward(retain_graph=True)
for i in net.named_parameters():
    param_name, param_tensor = i
    print(f"The Gradient of Loss w.r.t {param_name}={param_tensor.data.item()} -> {param_tensor.grad.item()}")

The Gradient of Loss w.r.t w1=5.0 -> 15.0
The Gradient of Loss w.r.t w2=6.0 -> 15.0


### 1.2 Plot our Computational Graph

In [12]:
from torchviz import make_dot
from pathlib import Path

fig_save = Path('./test_comp_graphs/')
fig_save.mkdir(parents=True, exist_ok=True)
make_dot(out, params=dict(list(net.named_parameters()))).render( fig_save / "simple_graph" , format="png")

'test_comp_graphs/simple_graph.png'

## 2. Computaional Graph in a loop

In [26]:
import torch

class SimpleRecurrentGraph(torch.nn.Module):
    
    def __init__(self):
        super(SimpleRecurrentGraph, self).__init__()
        # === Initialize Weights === #
        self.w1 = torch.nn.Parameter(data=torch.Tensor([2]), requires_grad=True)
        # === Initialize first hidden input as 1 === #
        self.hidden = torch.tensor([1], requires_grad=False)
        # self.hidden = torch.nn.Parameter(data=torch.Tensor([1]), requires_grad=False)
    
    def forward(self, a):
        for inp in a:
            hidden_next = self.w1 * self.hidden * inp
            self.hidden = hidden_next

        return hidden_next

In [27]:
net = SimpleRecurrentGraph()
a = [ torch.Tensor([1]), torch.Tensor([2]), torch.Tensor([3]) ]
out = net(a)

print(f'Output is : {out}')
### Print parameters and gradients
### Gradients are None as we haven't calculated them yet

for i in net.named_parameters():
    param_name, param_tensor = i
    print(f"The Gradient of Loss w.r.t {param_name}={param_tensor.data.item()} -> {param_tensor.grad}")

Output is : tensor([48.], grad_fn=<MulBackward0>)
The Gradient of Loss w.r.t w1=2.0 -> None


In [28]:
out.backward()

In [29]:
for i in net.named_parameters():
    param_name, param_tensor = i
    print(f"The Gradient of Loss w.r.t {param_name}={param_tensor.data.item()} -> {param_tensor.grad}")

The Gradient of Loss w.r.t w1=2.0 -> tensor([72.])


#### plot the graph

In [31]:
from torchviz import make_dot
from pathlib import Path

fig_save = Path('./test_comp_graphs/')
fig_save.mkdir(parents=True, exist_ok=True)
make_dot(out, params=dict(list(net.named_parameters()))).render( fig_save / "simple_recurrent_graph" , format="png")

'test_comp_graphs/simple_recurrent_graph.png'