In [None]:
import torch
import torch.autograd as autograd

# ReLU(x) = max(0, x)
def forward(ctx,x):
    ctx.save_for_backward(x)
    # return torch.max(x, torch.tensor(0.0)) # scalar, broadcasts 0.0 to perform elem wise comparison
    return torch.max(x, torch.zeros_like(x)) # auto-matches device + dtype, avoids unnecessary BT

def backward(ctx, grad_output):
    input, = ctx.saved_tensors
    grad_output = grad_output.clone()
    return grad_output * (input > 0)

# Chain Rule Conceptually:
- when loss.backward() is called, grad_output is received. 
- So, grad_output = dLoss/d(ReLU output)
- grad_ReLU_input (what we want to calculate, since this becomes the grad_output for the previous layer and "back propagates") 
    - grad_ReLU_input = dLoss/d(ReLU input) -> use this as LHS reference point

- By chain rule, dLoss/d(input) = dLoss/d(ReLU output) * d(ReLU output)/d(input)
    - we know dLoss/d(ReLU output) from grad_output, we need to calculate d(ReLU output)/d(input)
    - d(ReLU output)/d(input) = 0 if input < 0, 1 if input > 0
        - or derive it analytically using calculus (or handled by autodifferentiation). 

- The local derivative is specifically the gradient of one operation (one edge) with respect to its direct inputs.

    [x]  ← input node
      |
    [ReLU] ← operation (edge)
      |
     [y]  ← output node

[input] → [Linear] → [a] → [ReLU] → [b] → [Linear] → [Loss]

## BackProp Sequential Process
     
- The point is to get the gradients at each Tensor ([a], [b], [input]). Loss is a scalar (special case), so its gradient is always 1.
- We care about dLoss/db, dLoss/da, dLoss/dinput since they are the gradients of the corresponding tensors;
   - we use the local gradients of each operation (Linear, ReLU, Linear) to calculate them
      - Local Gradient 3 = dLinear/db
      - Local Gradient 2 = dReLU/da
      - Local Gradient 1 = dLinear/dinput
- The "gradients" that always pass through are dLoss/d(input of each operation) which becomes the dLoss/d(output of previous operation) for the previous layer.

When `loss.backward()` is called, backprop works **layer by layer, right to left**, and the multiplication happens **incrementally**, not all at once. Remember: always pass a derivative with "dLoss" as numerator to the previous layer. 

### Step-by-Step:

1. **Start**: `∂Loss/∂Loss = 1` (implicit) always.

2. **Last Linear backward** (Right to Left): 
   - Receives: `grad_output = 1` (dLoss/d(Loss)) 
   - Computes: `local grad 3 = ∂(Loss)/∂b`
   - Returns: dLoss/db; By chain rule, dLoss/db = dLoss/d(Linear) * d(Linear)/db `grad_input = 1 × local grad 3 = ∂Loss/∂b`

3. **ReLU backward** (your custom function):
   - Receives: `grad_output = ∂Loss/∂b` (db/dReLu)
   - Computes: `local grad 2 = ∂b/∂a` (dReLu/da)
   - Returns: dLoss/da = dLoss/db * db/da `grad_input = (∂Loss/∂b) × (∂b/∂a) = ∂Loss/∂a`

4. **First Linear backward**:
   - Receives: `grad_output = ∂Loss/∂a`
   - Computes: `local grad 1 = ∂a/∂input`
   - Returns: `grad_input = (∂Loss/∂a) × (∂a/∂input) = ∂Loss/∂input`