In [13]:
import torch
import torch_xla


@torch.library.custom_op("xla::name_tensor", mutates_args=())
def name_tensor(t: torch.Tensor, name: str) -> torch.Tensor:
  if t is None:
    return None
  return t.clone()


@name_tensor.register_fake
def _(t: torch.Tensor, name: str) -> torch.Tensor:
  if t is None:
    return None
  return torch.empty_like(t)


def name_tensor_backward(ctx, grad):
  return grad, None


name_tensor.register_autograd(name_tensor_backward)


In [14]:
a = torch.zeros(10)
a = name_tensor(a, "foo")
a

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [15]:
from functorch.compile import aot_function, make_boxed_func  # type: ignore
a = torch.ones(10, requires_grad=True)
def my_fn(x):
  return name_tensor(x, "foo")
graphs = []
def get_graph(gm: torch.fx.GraphModule, _):
  graphs.append(gm)
  return make_boxed_func(gm)

c = aot_function(my_fn, get_graph)(a)
c.sum().backward()
assert len(graphs) == 2
fw, bw = graphs

In [16]:
print(fw.code)




def forward(self, primals_1):
    name_tensor = torch.ops.xla.name_tensor.default(primals_1, 'foo');  primals_1 = None
    return (name_tensor,)
    


In [17]:
fw.graph.print_tabular()

opcode         name         target                   args                kwargs
-------------  -----------  -----------------------  ------------------  --------
placeholder    primals_1    primals_1                ()                  {}
call_function  name_tensor  xla.name_tensor.default  (primals_1, 'foo')  {}
output         output       output                   ((name_tensor,),)   {}


In [18]:
fw: torch.fx.GraphModule = fw
for node in fw.graph.nodes:
  print(node.name, node.meta)


primals_1 {'val': FakeTensor(..., size=(10,)), 'tensor_meta': TensorMetadata(shape=torch.Size([10]), dtype=torch.float32, requires_grad=True, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}
name_tensor {'original_aten': <OpOverload(op='xla.name_tensor', overload='default')>, 'seq_nr': 27, 'val': FakeTensor(..., size=(10,)), 'tensor_meta': TensorMetadata(shape=torch.Size([10]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}
output {}


In [19]:
print(bw.code)




def forward(self, tangents_1):
    return (tangents_1,)
    


## Now what?

Similar to the `checkpoint_name` in JAX:

```
x = checkpoint_name(g(W1, x), name='a')
```

We can write a graph parser that figures out the name of each tensor (by walking
the graph), then in the graph partitioner, decide whether to offload that tensor.

In [20]:
from functorch.compile import min_cut_rematerialization_partition, default_partition, make_boxed_func  # type:ignore

# Replicate regular torch checkpointing here. The low budget forces the partitioner
# to recompute tensors instead of saving them.
import torch._functorch.config
torch._functorch.config.activation_memory_budget = 0.0
torch._functorch.config.aggressive_recomputation = True
torch._functorch.config.recompute_views = True
torch._functorch.config.ban_recompute_reductions = False
torch._functorch.config.ban_recompute_not_in_allowlist = False
torch._functorch.config.ban_recompute_materialized_backward = False
torch._functorch.config.ban_recompute_long_fusible_chains = False
torch._functorch.config.ban_recompute_used_far_apart = False


In [21]:
from functorch.compile import aot_function, make_boxed_func  # type: ignore
a = torch.ones(10, requires_grad=True)

def my_fn(x):
  x = name_tensor(x, "foo")
  y = torch.sin(x)
  z = y * y
  w = z + 3
  return w

graphs = []
def get_graph(gm: torch.fx.GraphModule, _):
  graphs.append(gm)
  return make_boxed_func(gm)

c = aot_function(my_fn, get_graph, partition_fn=min_cut_rematerialization_partition)(a)
c.sum().backward()
assert len(graphs) == 2
fw, bw = graphs
fw: torch.fx.GraphModule = fw
bw: torch.fx.GraphModule = bw

In [22]:
print(fw.code)




def forward(self, primals_1):
    name_tensor = torch.ops.xla.name_tensor.default(primals_1, 'foo')
    sin = torch.ops.aten.sin.default(name_tensor);  name_tensor = None
    mul = torch.ops.aten.mul.Tensor(sin, sin);  sin = None
    add = torch.ops.aten.add.Tensor(mul, 3);  mul = None
    return (add, primals_1)
    


In [23]:
print(bw.code)




def forward(self, primals_1, tangents_1):
    name_tensor = torch.ops.xla.name_tensor.default(primals_1, 'foo');  primals_1 = None
    sin = torch.ops.aten.sin.default(name_tensor)
    mul_1 = torch.ops.aten.mul.Tensor(tangents_1, sin);  tangents_1 = sin = None
    add_1 = torch.ops.aten.add.Tensor(mul_1, mul_1);  mul_1 = None
    cos = torch.ops.aten.cos.default(name_tensor);  name_tensor = None
    mul_3 = torch.ops.aten.mul.Tensor(add_1, cos);  add_1 = cos = None
    return (mul_3,)
    


In [34]:
def get_named_nodes(gm: torch.fx.GraphModule):
  named_nodes = {}

  for node in gm.graph.nodes:
    if node.op == "call_function":
      if hasattr(node.target, "name"):
          if node.target.name() == name_tensor._qualname:  # type: ignore
            named_nodes[node.args[0]] = node.args[1]
  
  return named_nodes

named_nodes = get_named_nodes(fw)
print(named_nodes)

def get_name_in_output_indices(gm: torch.fx.GraphModule):
  named_nodes = get_named_nodes(gm)
  name_in_output_indices = {}

  for node in gm.graph.nodes:
    if node.op == "output":
      assert len(node.args) <= 1
      if len(node.args) == 0:
        continue
      for i, arg in enumerate(next(iter(node.args))): # type: ignore
        if arg in named_nodes:
          name_in_output_indices[named_nodes[arg]] = i

  return name_in_output_indices

name_in_output_indices = get_name_in_output_indices(fw)
print(name_in_output_indices)

{primals_1: 'foo'}
{'foo': 1}


In [35]:
named_nodes = get_named_nodes(bw)
print(named_nodes)

{primals_1: 'foo'}


In [None]:
name_in_input_names = {}

for node in bw.graph.nodes:
  if node.op == "placeholder":
    if node in named_nodes:
      name_in_input_names[named_nodes[node]] = node.target

print(name_in_input_names)

{'foo': 'primals_1'}
