In [1]:
import torch

# Graph

In [2]:
a = torch.tensor([2.], requires_grad=True)
b = torch.tensor([1.], requires_grad=True)

In [3]:
c = a + b
d = b + 1
e = c * d

In [4]:
print(type(a))
print(type(b))
print(type(c))
print(type(d))
print(type(e))

<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>


In [5]:
# grads populated for non-leaf nodes
c.retain_grad()
d.retain_grad()
e.retain_grad()

In [6]:
print(e)

tensor([6.], grad_fn=<MulBackward0>)


In [7]:
e.backward()

In [8]:
print(a.grad)

tensor([2.])


In [9]:
print(a.grad, b.grad, c.grad, d.grad, e.grad)

tensor([2.]) tensor([5.]) tensor([2.]) tensor([3.]) tensor([1.])


# require_grad vs retain_grad

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple neural network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(2, 3)  # Input layer (2 features) -> Hidden layer (3 neurons)
        self.fc2 = nn.Linear(3, 1)  # Hidden layer (3 neurons) -> Output layer (1 neuron)

    def forward(self, x):
        x = self.fc1(x)  # First layer (non-leaf tensor)
        x = torch.relu(x)  # Apply ReLU activation (non-leaf tensor)
        x = self.fc2(x)  # Second layer (non-leaf tensor)
        return x

# Create the network
net = SimpleNet()

# Define a loss function and an optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

# Create input and target tensors
inputs = torch.tensor([[1.0, 2.0]], requires_grad=True)  # Input is a leaf tensor
target = torch.tensor([[0.5]])  # Target output (no need for gradient tracking)

# Perform a forward pass
output = net(inputs)  # Forward pass through the network (output is a non-leaf tensor)
loss = criterion(output, target)  # Compute loss (non-leaf tensor)

# Perform backward pass
loss.backward()

# Let's inspect the leaf and non-leaf tensors
print("Is inputs a leaf tensor?", inputs.is_leaf)  # Should print True
print("Is output a leaf tensor?", output.is_leaf)  # Should print False
print("Is fc1's weight a leaf tensor?", net.fc1.weight.is_leaf)  # Should print True (weights are leaf tensors)
print("Gradient for inputs:", inputs.grad)  # Gradients should be calculated for leaf tensors

Is inputs a leaf tensor? True
Is output a leaf tensor? False
Is fc1's weight a leaf tensor? True
Gradient for inputs: tensor([[1.2854, 1.5133]])
