In [None]:
import torch

In [None]:
phonies = {}

In [None]:
phonies

{}

In [None]:
def get_phony(device, requires_grad):
    key = (device, requires_grad)
    
    try:
        # if already exist
        phony = phonies[key]
    except KeyError:
        phony = torch.empty(0, requires_grad=requires_grad)
        phonies[key] = phony
    
    return phony

In [None]:
device = torch.cuda.device("cuda:0")

In [None]:
get_phony(device, True)

tensor([], requires_grad=True)

In [None]:
phonies[(device, True)]

tensor([], requires_grad=True)

In [None]:
class Fork(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        phony = get_phony(input.device, requires_grad=False)
        return input.detach(), phony.detach()

    @staticmethod
    def backward(ctx, grad_input, grad_grad):
        return grad_input

In [None]:
class Join(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, phony):
        return input.detach()
    
    @staticmethod
    def backward(ctx, grad_input):
        return grad_input, None

In [None]:
def fork(input):
    if torch.is_grad_enabled() and input.requires_grad:
        input, phony = Fork.apply(input)
    else:
        phony = get_phony(input.device, requires_grad=False)
    return input, phony

In [None]:
def join(input, phony):
    if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
        input = Join.apply(input, phony)
    return input

In [None]:
def depend(fork_from, join_to):
    fork_from, phony = fork(fork_from)
    join_to = join(join_to, phony)

In [None]:
x = torch.tensor([1., 2., 3.], requires_grad=True)
y = torch.tensor([4., 5., 6.], requires_grad=True)

In [None]:
depend(x, y)

In [None]:
y.sum().backward()

In [None]:
y.grad

tensor([1., 1., 1.])

In [None]:
x.grad

In [None]:
class Operation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input + 1
    
    @staticmethod
    def backward(ctx, grad_input):
        return grad_input, None

In [None]:
batch_1 = torch.randn(2, 3, requires_grad=True)

In [None]:
batch_1.grad

In [None]:
output = Operation.apply(batch_1)

In [None]:
output.sum().backward()