## PyTorch Python and Native overrides

This notebook demonstrates `__torch_function__` and `__torch_dispatch__` and its
interaction with XLA:

- `__torch_function__` may be activated by a `TorchFunctionMode` context manager,
  and intercepts `torch.` function calls. It won't see what operations are run in
  a `loss.backward()`, for example.
  
- `__torch_dispatch__` may be activated by a `TorchDispatchMode` context manager,
  and intercepts `aten` operations right before they get dispatched to a backend.
  It will see what `aten` operations are run by `loss.backward()`.
  
For more details, see https://pytorch.org/docs/stable/notes/extending.html

In [14]:
%env XLA_IR_DEBUG=1

# Turn on INFO logs.
%env TF_CPP_MIN_LOG_LEVEL=0

# Turn on verbose INFO logs for these `.cc` files.
# %env TF_CPP_MAX_VLOG_LEVEL=8
%env TF_CPP_VMODULE=xla_graph_executor=8,ir=8
# %env TORCH_SHOW_DISPATCH_TRACE=1

env: XLA_IR_DEBUG=1
env: TF_CPP_MIN_LOG_LEVEL=0
env: TF_CPP_VMODULE=xla_graph_executor=8,ir=8


In [15]:
from typing import Any, List
import torch
from torch.overrides import TorchFunctionMode, resolve_name
from torch.utils._python_dispatch import TorchDispatchMode

def arg_shapes(args: List[Any]) -> str:
    """
    Inspects each input argument and prints its shape if it's a PyTorch tensor, 
    otherwise prints the repr (unchanged).

    Args:
        *args: Variable number of arguments of any type.

    Returns:
        str: Concatenated string of argument representations.
    """
    result = []
    for arg in args:
        if isinstance(arg, torch.Tensor):
            result.append(f"tensor(shape={tuple(list(arg.shape))})")  # Get shape, not data
        else:
            result.append(repr(arg))
    return ', '.join(result)


class FunctionLog(TorchFunctionMode):
    def __torch_function__(self, func, types, args, kwargs=None):
        print(f"Function Log: {resolve_name(func)}({arg_shapes(args)}, **{kwargs})")
        return func(*args, **(kwargs or {}))

class DispatchLog(TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args, kwargs=None):
        print(f"Dispatch Log: {func}({arg_shapes(args)}, **{kwargs})")
        return func(*args, **(kwargs or {}))

def f():
    a = torch.rand(10, requires_grad=True)
    b = torch.sin(a * 2)
    loss = b.sum()
    print("Backward")
    loss.backward()


In [16]:

print("TorchFunctionMode logging:")
with FunctionLog():
    f()


TorchFunctionMode logging:
Function Log: torch.rand(10, **{'requires_grad': True})
Function Log: torch.Tensor.mul(tensor(shape=(10,)), 2, **None)
Function Log: torch.sin(tensor(shape=(10,)), **None)
Function Log: torch.Tensor.sum(tensor(shape=(10,)), **None)
Backward
Function Log: torch.Tensor.backward(tensor(shape=()), **{'gradient': None, 'retain_graph': None, 'create_graph': False, 'inputs': None})


In [17]:

print("TorchDispatchMode logging:")
with DispatchLog():
    f()


TorchDispatchMode logging:
Dispatch Log: aten.rand.default([10], **{'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.mul.Tensor(tensor(shape=(10,)), 2, **{})
Dispatch Log: aten.sin.default(tensor(shape=(10,)), **{})
Dispatch Log: aten.sum.default(tensor(shape=(10,)), **{})
Backward
Dispatch Log: aten.ones_like.default(tensor(shape=()), **{'pin_memory': False, 'memory_format': torch.preserve_format})
Dispatch Log: aten.expand.default(tensor(shape=()), [10], **{})
Dispatch Log: aten.cos.default(tensor(shape=(10,)), **{})
Dispatch Log: aten.mul.Tensor(tensor(shape=(10,)), tensor(shape=(10,)), **{})
Dispatch Log: aten.mul.Tensor(tensor(shape=(10,)), 2, **{})
Dispatch Log: aten.detach.default(tensor(shape=(10,)), **{})
Dispatch Log: aten.detach.default(tensor(shape=(10,)), **{})


In [8]:
def g():
    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    import time

    a = torch.rand((1000, 1000), requires_grad=True, device=torch_xla.device())
    time.sleep(1)
    b = torch.rand((1000, 1000), requires_grad=True, device=torch_xla.device())
    time.sleep(1)
    c = a @ b @ a @ b
    time.sleep(1)
    d = c.sum()
    time.sleep(1)
    print("Mark step")
    time.sleep(1)
    xm.mark_step()
    time.sleep(1)
    print(d)
    time.sleep(1)
    return d
    

In [9]:
print("TorchFunctionMode logging:")
with FunctionLog():
    g()


TorchFunctionMode logging:
Function Log: torch.device('xla:0', **None)
Function Log: torch.rand((1000, 1000), **{'requires_grad': True, 'device': device(type='xla', index=0)})


2024-08-05 01:58:22.469995: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::expand
2024-08-05 01:58:22.471368: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::mul
2024-08-05 01:58:22.471457: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::add
2024-08-05 01:58:22.471607: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::uniform


Function Log: torch.device('xla:0', **None)
Function Log: torch.rand((1000, 1000), **{'requires_grad': True, 'device': device(type='xla', index=0)})


2024-08-05 01:58:23.473611: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::expand
2024-08-05 01:58:23.474479: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::mul
2024-08-05 01:58:23.474558: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::add
2024-08-05 01:58:23.474705: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::uniform


Function Log: torch.Tensor.matmul(tensor(shape=(1000, 1000)), tensor(shape=(1000, 1000)), **None)
Function Log: torch.Tensor.matmul(tensor(shape=(1000, 1000)), tensor(shape=(1000, 1000)), **None)
Function Log: torch.Tensor.matmul(tensor(shape=(1000, 1000)), tensor(shape=(1000, 1000)), **None)


2024-08-05 01:58:24.476417: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::mm
2024-08-05 01:58:24.476969: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::mm
2024-08-05 01:58:24.477298: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::mm


Function Log: torch.Tensor.sum(tensor(shape=(1000, 1000)), **None)


2024-08-05 01:58:25.479016: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::sum


Mark step


2024-08-05 01:58:27.481775: I torch_xla/csrc/xla_graph_executor.cpp:422] 6 live tensors: devices=()
2024-08-05 01:58:27.481826: I torch_xla/csrc/xla_graph_executor.cpp:400] Trying to sync the value of 6 tensor(s)
2024-08-05 01:58:27.481936: I torch_xla/csrc/xla_graph_executor.cpp:714] Tensors graph hash 1a89380de694b4919304f231485b7de1 on device TPU:0
2024-08-05 01:58:27.482877: I torch_xla/csrc/xla_graph_executor.cpp:1514] Parameter sequence graph hash 90ec8fd148146797d064fba641daf419
2024-08-05 01:58:27.484687: I torch_xla/csrc/xla_graph_executor.cpp:1206] Graph hash 90ec8fd148146797d064fba641daf419 is computation hash d664d514aa8275302560d0223f778dd6
2024-08-05 01:58:27.484712: I torch_xla/csrc/xla_graph_executor.cpp:1232] TensorsGraphSize=19
2024-08-05 01:58:27.484742: I torch_xla/csrc/xla_graph_executor.cpp:722] waiting barrier for device TPU:0 start
2024-08-05 01:58:27.484783: I torch_xla/csrc/xla_graph_executor.cpp:725] waiting barrier for device TPU:0 done
2024-08-05 01:58:27.4

Function Log: torch.device('xla:0', **None)
Function Log: torch.Tensor.__repr__(tensor(shape=()), **{'tensor_contents': None})
tensor(6.2506e+13, device='xla:0', grad_fn=<SumBackward0>)


In [10]:

print("TorchDispatchMode logging:")
with DispatchLog():
    g()


TorchDispatchMode logging:
Dispatch Log: aten.rand.default([1000, 1000], **{'device': device(type='xla', index=0), 'pin_memory': False})


2024-08-05 01:58:29.506924: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::expand
2024-08-05 01:58:29.507616: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::mul
2024-08-05 01:58:29.507701: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::add
2024-08-05 01:58:29.507869: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::uniform


Dispatch Log: aten.rand.default([1000, 1000], **{'device': device(type='xla', index=0), 'pin_memory': False})


2024-08-05 01:58:30.510177: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::expand
2024-08-05 01:58:30.510617: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::mul
2024-08-05 01:58:30.510706: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::add
2024-08-05 01:58:30.510848: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::uniform


Dispatch Log: aten.mm.default(tensor(shape=(1000, 1000)), tensor(shape=(1000, 1000)), **{})
Dispatch Log: aten.mm.default(tensor(shape=(1000, 1000)), tensor(shape=(1000, 1000)), **{})
Dispatch Log: aten.mm.default(tensor(shape=(1000, 1000)), tensor(shape=(1000, 1000)), **{})


2024-08-05 01:58:31.512893: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::mm
2024-08-05 01:58:31.513474: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::mm
2024-08-05 01:58:31.513821: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::mm


Dispatch Log: aten.sum.default(tensor(shape=(1000, 1000)), **{})


2024-08-05 01:58:32.516048: I torch_xla/csrc/ir.cpp:53] Create XlaNode for aten::sum


Mark step


2024-08-05 01:58:34.518880: I torch_xla/csrc/xla_graph_executor.cpp:422] 6 live tensors: devices=()
2024-08-05 01:58:34.518910: I torch_xla/csrc/xla_graph_executor.cpp:400] Trying to sync the value of 6 tensor(s)
2024-08-05 01:58:34.518982: I torch_xla/csrc/xla_graph_executor.cpp:714] Tensors graph hash 1a89380de694b4919304f231485b7de1 on device TPU:0
2024-08-05 01:58:34.519587: I torch_xla/csrc/xla_graph_executor.cpp:1514] Parameter sequence graph hash 90ec8fd148146797d064fba641daf419
2024-08-05 01:58:34.521405: I torch_xla/csrc/xla_graph_executor.cpp:1206] Graph hash 90ec8fd148146797d064fba641daf419 is computation hash d664d514aa8275302560d0223f778dd6
2024-08-05 01:58:34.521428: I torch_xla/csrc/xla_graph_executor.cpp:1232] TensorsGraphSize=19
2024-08-05 01:58:34.521454: I torch_xla/csrc/xla_graph_executor.cpp:722] waiting barrier for device TPU:0 start
2024-08-05 01:58:34.521488: I torch_xla/csrc/xla_graph_executor.cpp:725] waiting barrier for device TPU:0 done
2024-08-05 01:58:34.5

tensor(6.2446e+13, device='xla:0', grad_fn=<SumBackward0>)
