In [31]:
import torch
import math


def f(x, i):
    print(f"f[{i}] func")
    return torch.sin(math.pi * x) + torch.sum(x**2)
def g(x, f, i):
    print(f"g[{i}] func")
    return torch.cos(math.pi * x) + 2 * f + torch.mean(x)


class fn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x ,i):
        print(f'fn[{i}] forward')
        ctx.save_for_backward(x)
        y = f(x, i)
        ctx.i = i
        return y

    @staticmethod
    def backward(ctx, grad_output):
        print(f'fn[{ctx.i}] backward')
        x, = ctx.saved_tensors
        i = ctx.i
        
        x_ = x.detach().requires_grad_()
        _ = torch.set_grad_enabled(True)
        y = f(x_, i)
        grad = torch.autograd.grad(y, x_, grad_outputs=grad_output, retain_graph=True)[0]
        _ = torch.set_grad_enabled(False)
        return grad, None
    
class gn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, f, i):
        print(f'gn[{i}] forward')
        ctx.save_for_backward(x, f)
        y = g(x, f, i)
        ctx.i = i
        return y

    @staticmethod
    def backward(ctx, grad_output):
        print(f'gn[{ctx.i}] backward')
        x, f = ctx.saved_tensors
        i = ctx.i
        
        _ = torch.set_grad_enabled(True)
        x_ = x.detach().requires_grad_()
        f_ = f.detach().requires_grad_()
        y = g(x_, f_, i)
        grad = torch.autograd.grad(y, (x_, f_), grad_outputs=grad_output, retain_graph=True)

        _ = torch.set_grad_enabled(False)
        return grad[0], grad[1], None


fn_ = fn.apply
gn_ = gn.apply

In [32]:

x = torch.linspace(-1, 1, 15, requires_grad=True, dtype=torch.float64)
# fyi = fn_(x, 0)
# sum_f = torch.sum(fyi)
# sum_f.backward()

gacc = torch.zeros_like(x)
for i in range(8):
    fyi = fyi if i % 4 != 0 else fn_(x, i)
    gacc = gacc + gn_(x, fyi, i)

gy = torch.sum(gacc)
gy.backward()

fn[0] forward
f[0] func
gn[0] forward
g[0] func
gn[1] forward
g[1] func
gn[2] forward
g[2] func
gn[3] forward
g[3] func
fn[4] forward
f[4] func
gn[4] forward
g[4] func
gn[5] forward
g[5] func
gn[6] forward
g[6] func
gn[7] forward
g[7] func
gn[7] backward
g[7] func
gn[6] backward
g[6] func
gn[5] backward
g[5] func
gn[4] backward
g[4] func
fn[4] backward
f[4] func
gn[3] backward
g[3] func
gn[2] backward
g[2] func
gn[1] backward
g[1] func
gn[0] backward
g[0] func
fn[0] backward
f[0] func


In [7]:
torch.autograd.gradcheck(
    f,
    (x,),
    eps=1e-6,
    atol=1e-4,
    rtol=1e-2,
    raise_exception=True,
)

[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] forward
[func]
[fn_a] for

True

[fn_a] forward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] backward
[fn_a] ba