# **THOAD USER GUIDE**

### **What is thoad**

thoad (Torch High Order Auto-Differentiation) package contains a full-Python implementation of a reverse-mode auto-differentiation engine for PyTorch.
It is developed in Python 3.12 and uses PyTorch (2.4+) as its only dependency.

## **1. Computing Arbitrary Order Derivatives**

In [2]:
from typing import Callable, Optional, Tuple, Union

### **1.1 Setting up dependencies**

Import **PyTorch** and **thoad**.

In [3]:
import torch

device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
from thoad import backward, Controller

> In as much thoad relies heavily on torch einsum operator, enabeling `opt_einsum` backend can benefit performance.

In [5]:
import torch.backends.opt_einsum as opt_einsum

if opt_einsum.is_available():
    opt_einsum.enabled = True
    opt_einsum.strategy = "greedy"
    print("opt_einsum backend enabled")
else:
    print("opt_einsum backend is not available")

opt_einsum backend enabled


### **1.2 Calling thoad auto-differentiation via the backward function**

Define the dynamic computational graph as is normally done in PyTorch. Just define a series of traceable tensors and feed them into a composition of operators.

In [6]:
X1: torch.Tensor = torch.rand(size=(2,3), requires_grad=True, device=device)
X2: torch.Tensor = torch.rand(size=(3,4), requires_grad=True, device=device)

def forward(X1: torch.Tensor, X2: torch.Tensor) -> torch.Tensor:
    return torch.sigmoid(input=(X1 @ X2))

Y: torch.Tensor = forward(X1=X1, X2=X2)

Call the `thoad.backward` function to compute the partial derivatives of order *o* w.r.t. inputs $\normalsize \left( \text{i.e. } \;\; \frac{\normalsize \delta^{\otimes o} \; Y}{\normalsize \bigotimes_{i=1}^{o} \delta X_i} \right)$. This will aggregate 2 new attributes to the input tensors, `hgrad`and `hdata`.


`thoad.backward` takes the following argumntes.

* `tensor: torch.Tensor`: Output tensor (must have `requires_grad=True` and a `grad_fn`); root to differentiate from.
* `order: int`: Positive derivative order $o \ge 1$.
* `gradient: Optional[torch.Tensor] = None`: (Optional) Upstream seed with same shape as `tensor`; for $o=1$ performs a VJP; if `None`, uses the identity seed.
* `crossings: bool = False`: (Optional) If `True`, computes mixed partials across different terminal tensors (cross-terminal derivatives).
* `groups: Optional[Iterable[Iterable[torch.Tensor]]] = None`: (Optional) Terminal groups that allow mixed partials only **within** each group; mutually exclusive with `crossings=True`.
* `keep_batch: bool = False`: (Optional) Controls whether mutually independent dimensions (detected via **Indeps**) are explicitly regularized. If `True`, unified batch-independent dimensions remain unified in the resulting `hgrad`/`hdata` shapes.
* `keep_schwarz: bool = False`: (Optional) Controls whether variable-order symmetries (detected via **VPerm**) are resolved or preserved. If `True`, derivatives that are symmetric under Schwarz theorem remain stored as permuted references rather than expanded tensors.

In [7]:
# execute auto-differentiation from Y tensor
o: int = 2
backward(tensor=Y, order=o)

<thoad.user.interface.Controller at 0x19058e23bc0>

Access the newly included attributes to get the derivatives w.r.t. input tensors.

- `Tensor.hgrad` contains a tuple of lenght *o* - 1 with the derivative tensors from order 1 to *o*
- `Tensor.hdata` contains a tuple of lenght *o* - 1 with nested tuples of lenght 3 containing matadata (**shapes**, **indeps**, **variable permutations**) about its correstponding tensor.

hgrad tensors have the following shape:

  - (output_flattened_size, *(shape_input_tensor_0), ..., *(shape_input_tensor_o-1))

If thoad.backward is called with default configuration hdata information has no practical value


In [8]:
# torch.autograd.functional.hessian only support autodifferentiation from scalar tensors
fwd_sum: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
fwd_sum = lambda X1, X2: torch.sum(forward(X1=X1, X2=X2))

In [9]:
# obtain input gradients
grad_X1: torch.Tensor = X1.hgrad[0]
grad_X1X1: torch.Tensor = X1.hgrad[1]
grad_X2: torch.Tensor = X2.hgrad[0]
grad_X2X2: torch.Tensor = X2.hgrad[1]

In [10]:
# check X1 derivatives
assert grad_X1.shape == (Y.numel(), *X1.shape)
assert grad_X1X1.shape == (Y.numel(), *X1.shape, *X1.shape)
assert torch.allclose(
    grad_X1.flatten(),
    torch.autograd.functional.jacobian(func=forward, inputs=(X1, X2))[0].flatten(),
)
assert torch.allclose(
    grad_X1X1.sum(0).flatten(),
    torch.autograd.functional.hessian(func=fwd_sum, inputs=(X1, X2))[0][0].flatten(),
)

# check X2 derivatives
assert grad_X2.shape == (Y.numel(), *X2.shape)
assert grad_X2X2.shape == (Y.numel(), *X2.shape, *X2.shape)
assert torch.allclose(
    grad_X2.flatten(),
    torch.autograd.functional.jacobian(func=forward, inputs=(X1, X2))[1].flatten(),
)
assert torch.allclose(
    grad_X2X2.sum(0).flatten(),
    torch.autograd.functional.hessian(func=fwd_sum, inputs=(X1, X2))[1][1].flatten(),
)

To obtain the cross derivatives:
1. First use the *crossings* argument of `backward` to indicate thoad to compute them
2. Then collect the `Controller` returned by backward and use its `Controller.fetch_grad` method indicating the desired variables

> Note. setting *crossings=True* will have a significant impact on performance.

> Note. thoad overrides, not accumulates, the gradients of repeated AD calls

In [11]:
# execute auto-differentiation from Y tensor
o: int = 2
ctrl: Controller = backward(tensor=Y, order=o, crossings=True)

In [12]:
# obtain input gradients
grad_X1X2: torch.Tensor
grad_X1X2, _ = ctrl.fetch_hgrad(variables=(X1, X2))
grad_X2X1: torch.Tensor
grad_X2X1, _ = ctrl.fetch_hgrad(variables=(X2, X1))

In [13]:
assert torch.allclose(
    grad_X1X2.sum(0).flatten(),
    torch.autograd.functional.hessian(func=fwd_sum, inputs=(X1, X2))[0][1].flatten(),
)
assert torch.allclose(
    grad_X2X1.sum(0).flatten(),
    torch.autograd.functional.hessian(func=fwd_sum, inputs=(X1, X2))[1][0].flatten(),
)

Alternatively to *crossings*, **thoad** also provides the *groups* argument, which lets the user explicitly specify among which subsets of Tensors it must compute cross derivatives. Cross derivatives among tensors which do not appear in a common specified group will not be computed.

In [14]:
X1: torch.Tensor = torch.rand(size=(2,3), requires_grad=True, device=device)
X2: torch.Tensor = torch.rand(size=(3,4), requires_grad=True, device=device)
X3: torch.Tensor = torch.rand(size=(2,4), requires_grad=True, device=device)

def forward(X1: torch.Tensor, X2: torch.Tensor, X3: torch.Tensor) -> torch.Tensor:
    return torch.sigmoid(input=(X1 @ X2 + X3))

Y: torch.Tensor = forward(X1=X1, X2=X2, X3=X3)

In [15]:
# execute auto-differentiation from Y tensor
o: int = 2
ctrl: Controller = backward(tensor=Y, order=o, groups=({X1, X2}, {X1, X3}))

In [16]:
ctrl.fetch_hgrad(variables=(X1, X2))
ctrl.fetch_hgrad(variables=(X2, X1))
ctrl.fetch_hgrad(variables=(X1, X3))
ctrl.fetch_hgrad(variables=(X3, X1))

for pair in [(X2, X3), (X3, X2)]:
    try:
        ctrl.fetch_hgrad(variables=pair)
    except Exception as e:
        print(e)

'No gradient saved for given key'
'No gradient saved for given key'


### **1.3 Calling thoad auto-differentiation via the Controller interface**

Define the **PyTorch** graph the same way as before.

In [17]:
X1: torch.Tensor = torch.rand(size=(2,3), requires_grad=True, device=device)
X2: torch.Tensor = torch.rand(size=(3,4), requires_grad=True, device=device)

def forward(X1: torch.Tensor, X2: torch.Tensor) -> torch.Tensor:
    return torch.sigmoid(input=(X1 @ X2))

Y: torch.Tensor = forward(X1=X1, X2=X2)

Instantiate the Controller interface passing as argument the output tensor and call its backward method equivalently (but without the tensor argument).

In [18]:
o: int = 2
ctrl: Controller = Controller(tensor=Y)
ctrl.backward(order=o)

Finally obtain the derivatives using the `Controller.fetch_hgrad` method as before

### **1.4 Using derivative metadata**

#### **About Tensor Metadata (shapes, indeps, variable permutations)**

**Shapes = Tuple[Shape, ...] = Tuple[Tuple[int, ...], ...]**  
Indicate the form of each of the variables with respect to which the derivative tensor is differentiated.
```
ctrl = thoad.backward(tensor=Y, order=3)  
hgrad, hdata = ctrl.fetch_hgrad(variables=(X1, X2, X1))  
assert hgrad.shape == (hgrad.shape[0], *hdata[0])  
```  

**Indeps = Tuple[Indep, ...] = Tuple[Tuple[Union[None, int], ...], ...]**  
Mutually independent dimensions are those that have all crossed terms null, thus, only their nd-diagonal contains meaningful elements. THOAD back-propagation leverages this by unifying each group of mutually independent dimensions into a single one.  
If thoad.backward is executed with `keep_batch=True` mutually independent dimensions unified will be kept unified.
- (output_flattened_size, *(independent_dimensions), *(non_independent_shape_input_tensor_0), ..., *(non_independent_shape_input_tensor_o-1))  

Each Indep will share length with **independent_dimensions**, and will indicate which index in corresponding differentiation is unified into that independent_dimensions position. Unification of no dimension into position will be marked as None.
*Example:*
- *shapes=( (2,4),(3,5),(2,4) ) + indeps=( (None,None),(None,None),(None,None) ) -> hgrad.shape=(Y.numel(), 1,1, 2,4, 3,5, 2,4)*
- *shapes=( (2,4),(3,5),(2,4) ) + indeps=( (0,None),(None,1),(0,None) ) -> hgrad.shape=(Y.numel(), 2,5, 4, 3, 4)*

**VPerm = Tuple[int, ...]**  
In subgraphs of operators satisfying scharz theorem, derivatives w.r.t. permutations of the same variable set are symmetric. thoad back-propagation leverages this symmetries by avoiding their computation when a symmetric one has already been computed. Instead it expresses it as that already symetric computed derivative wrapped by a permutation.  
If thoad.backward is executed with `keep_schwarz=True` derivatives expressed as a permuted others will be kept expressed like that. This permuation can be found in VPerm.

In [19]:
X1: torch.Tensor = torch.rand(size=(2,3), requires_grad=True, device=device)
X2: torch.Tensor = torch.rand(size=(3,4), requires_grad=True, device=device)

def forward(X1: torch.Tensor, X2: torch.Tensor) -> torch.Tensor:
    return torch.sigmoid(input=(X1 @ X2))

Y: torch.Tensor = forward(X1=X1, X2=X2)

In [20]:
o: int = 2
ctrl: Controller = backward(
    tensor=Y,
    order=o,
    crossings=True,
    keep_batch=True,
    keep_schwarz=True,
)

In [21]:
from thoad.typing import Shape, Indep, VPerm

grad_X1X2: torch.Tensor
metadata_X1X2: Tuple[Tuple[Shape, ...], Tuple[Indep, ...], VPerm]
grad_X1X2, metadata_X1X2 = ctrl.fetch_hgrad(variables=(X1, X2))

## **2. Other Utils**

### **2.1 Checking operator compatibility**

By date of 2025 **PyTorch** has more than 2000 operators. **thoad** does not support all of them, but it provides 2 different tools to let users check whether a given **PyTorch** graph is fully composed by compatible operators.

- the `Controller.compatible` attribute
- the `Controller.display_graph` method

The `Controller.compatible` attribute simply returns a boolean value indicating if all the operators of the subgraph pending from the tensor passed in `Controller` instantiation are compatible

In [22]:
X: torch.Tensor = torch.rand(size=(1,), requires_grad=True, device=device)

In [23]:
Y: torch.Tensor = torch.sigmoid(input=X)
ctrl: Controller = Controller(tensor=Y)
print(ctrl.compatible)

True


In [24]:
Y: torch.Tensor = torch.special.ndtr(input=X)
ctrl: Controller = Controller(tensor=Y)
print(ctrl.compatible)

False


The `Controller.display_graph` method displays a diagram of the full subgraph pending from the tensor passed in `Controller` instantiation, aggregating a *not supported* flag next to not supported **PyTorch** `grad_fn` functions.

In [25]:
X1: torch.Tensor = torch.rand(size=(1,), requires_grad=True, device=device)
X2: torch.Tensor = torch.rand(size=(1,), requires_grad=True, device=device)

In [26]:
Y: torch.Tensor = torch.special.erf(input=(X1 @ X2))
ctrl: Controller = Controller(tensor=Y)
ctrl.display_graph()

┬[31m·<ErfBackward0 object at 0x000001906F116200> (not supported)[0m
└─·<DotBackward0 object at 0x000001906F116020>
  ├─·<AccumulateGrad object at 0x000001906F116320>
  └─·<AccumulateGrad object at 0x000001906F115FF0>


### **2.2 Saving intermediate gradients** (experimental feature)

**thoad**'s auto-differentiation only saves derivatives of graph leaf tensors. To activate the retention of the derivative of some combination of intermediate tensors use the `Controller.require_grad_` method.

In [27]:
X: torch.Tensor = torch.rand(size=(1,), requires_grad=True, device=device)
Y: torch.Tensor = torch.sigmoid(input=X)
Z: torch.Tensor = torch.softmax(input=Y, dim=0)

In [28]:
ctrl: Controller = Controller(tensor=Z)
ctrl.require_grad_(variables=(Y,))

In [29]:
# execute auto-differentiation
o: int = 2
ctrl.backward(order=o)

# check that Y derivatives are indeed saved
ctrl.fetch_hgrad(variables=(Y,))


(tensor([[1.4211e-14]], grad_fn=<SumBackward1>), (((1,),), ((None,),), (0,)))

### **2.3 Removing dynamic tensor attributes**

As seen before, running the auto-differentiation process will aggregate 2 derivative-related attributes. This dynamic aggregation of attributes can be an undesirable behaviour to some users. Conveniently, calling the `Controller.clear` method will remove these attributes from them.

In [30]:
ctrl.clear()

assert "hgrad" not in dir(X)
assert "hdata" not in dir(X)
assert "hgrad" not in dir(Y)
assert "hdata" not in dir(Y)

### **2.4 Accessing high order backward functions**

Throughout the auto-differentiation execution **thoad** provides the supported operators their own implementation of ExtendedAutogradFunction(s), equivalent to corresponding torch.autograd.Function(s), but capable of computing internal derivatives of arbitrary order. This is done using an index that maps them one to one, and is exposed to the user through the readeable / writeable `Controller.index` attibute. This attribute can be used to override or aggregate new ExtendedAutogradFunction(s).

In [31]:
X: torch.Tensor = torch.rand(size=(1,), requires_grad=True, device=device)
Y: torch.Tensor = torch.sigmoid(input=X)
ctrl: Controller = Controller(tensor=Y)

In [32]:
from thoad.typing import AutogradFunction
from thoad.differentiation import ExtendedAutogradFunction

xbackwards: dict[type[AutogradFunction], type[ExtendedAutogradFunction]] = ctrl.index
grad_fn: Optional[AutogradFunction] = Y.grad_fn
assert grad_fn is not None
print(type(Y.grad_fn), " -> ", xbackwards[type(grad_fn)])
ctrl.index = xbackwards

<class 'SigmoidBackward0'>  ->  <class 'thoad.differentiation.internals.mathematic.sigmoid.SigmoidXBackward0'>


### **2.5 Modifying config settings**

**thoad** executes the auto-differentiation process following 3 settings with a default configuration. The value of this settings can be modified dynamically throughout the program. The settings are:

- `config.DEBUG`: Controls the activation of a series of internal checks to catch bugs early. (defauts to *False*)
- `config.BATCH_OPTIMIZATION`: Controls the activation of batch dimensions unification (subsets of dimensions with pair to pair null non-diagonal elements). (defauts to *True*)
- `config.SCHWARZ_OPTIMIZATION`: Controls the activation of the usage of variable symetries to avoid repeated derivarive computations. (defauts to *True*)

In [33]:
import thoad.config as config

config.DEBUG = False
config.BATCH_OPTIMIZATION = True
config.SCHWARZ_OPTIMIZATION = True

## **3. Registering Hooks**

Pursuing to some degree the flexibility of **PyTorch**'s autograd, **thoad** also allows the registration of hooks for the auto-differentiation process execution. However, due to internal differences in design, **thoad**'s hooks are not attached to backward functions, but to combinations of internal variables (i.e. graph nodes, i.e tensors). To register a backward hook use the `Controller.register_backward_hook` method.

Hook functions must expect 2 arguments:
1. `grad_data`: *Tuple[torch.Tensor, Tuple[Shape, ...], Tuple[Indep, ...], VPerm]*  
   All external derivative data. Returned grad_data must keep shapes as they are.
2. `context`: *dict[AutogradFunction, set[torch.Tensor]]*  
   Dictionary mapping all hook backward functions (function pointing to any of hook's key tensors) with their pointed tensors. *AutogradFunction* generally contains operator context info saved in forward pass. *set[torch.Tensor]* purpose is to help the user to identify each backward function in case of repeated operators.

In [34]:
X: torch.Tensor = torch.rand(size=(1,), requires_grad=True, device=device)
Y: torch.Tensor = torch.sigmoid(input=X)
Z: torch.Tensor = torch.softmax(input=Y, dim=0)

In [35]:
def hook(
    grad_data: Tuple[
        Optional[torch.Tensor],
        Optional[Tuple[Shape, ...]],
        Optional[Tuple[Indep, ...]],
        Optional[VPerm]
    ],
    context: dict[AutogradFunction, set[torch.Tensor]],
    ) -> Tuple[
        Optional[torch.Tensor],
        Optional[Tuple[Shape, ...]],
        Optional[Tuple[Indep, ...]],
        Optional[VPerm]
    ]:
    return grad_data

In [36]:
o: int = 2
ctrl: Controller = Controller(tensor=Z)
ctrl.register_backward_hook(variables=(Y, Y), hook=hook)
ctrl.backward(order=o)

## **4. Registering New High Order Backward Functions**

### **4.1 Registering new high‑order backward functions**

thoad allows to extend the engine with custom higher‑order rules for PyTorch ops by registering an *extended* class mapped from a concrete `grad_fn` type.

**How to access the index.**
- The *Function Transcoder* maintains a dictionary `index: dict[type[AutogradFunction], type[ExtendedAutogradFunction]]`.
- You can read/override it from a `Controller` via the property `controller.index`.
- If your op is missing, add a mapping from the concrete `grad_fn` type to your extended class.

**Two families you can inherit from.**
1. `ContractiveFunction` — you *produce internal derivatives* of the operator with respect to an **indexed** output (`out_id`) and a tuple of **indexed** inputs (`inp_ids`). This is the most general path.
2. `DirectFunction` — you *directly transform* a provided derivative into the external derivative layout. Only valid for operators whose internal derivatives are **first‑order only** (no internal higher‑order accumulation).

In short: use **Contractive** when you must *compute* internal derivatives per `(out_id, inp_ids)`; use **Direct** when you only need to *relabel/route* already available 1st‑order internals into the external differentiation.

In [None]:
from typing import Type
from thoad.typing import AutogradFunction
from thoad.differentiation import ExtendedAutogradFunction

T: torch.Tensor = torch.rand(3, requires_grad=True)
G: torch.Tensor = (T + 1).sum()
ctrl = Controller(G)

# Read current map
func_index: dict[Type[AutogradFunction], Type[ExtendedAutogradFunction]] = ctrl.index

# Add/override a mapping
ctrl.index = func_index

Registered ops: 72
Custom mapping installed.


### **4.2 ContractiveFunction**

`ContractiveFunction` computes **internal derivatives** indexed by `(out_id, inp_ids)` and returns them with an **Einstein‑style notation** the engine knows how to contract.

**You must provide:**
- `check_shape(out_id, inp_id, shape, indep, crossed) -> (Shape, Indep)`
  - Validate/normalize the internal *external* derivative slot shape and independence mask; set `self._shape` as needed.
- `_extract_context()`
  - Pull raw data from the wrapped `grad_fn` (e.g., saved tensors/scalars).
- `_process_context()`
  - Turn raw context into ready‑to‑use buffers/constants.
- `compute_internal(out_id: int, inp_id: tuple[int, ...]) -> IDData`
  - Produce the internal derivative **tensor** and its **notation**.
  - Convention: optionally factor logic by defining private helpers per case, e.g. `_compute_internal_0_1_2()`.

**Returned data:**
- `IDData = tuple[Tensor, Notation]`
- `Notation` encodes (a) external indices; (b) internal indices per variable; and (c) meta row with `([shape_dims], [batch/schwarz flags])`.

**Notes.**
- `ContractiveFunction` supports multi‑output ops (`out_id`) and multi‑input selections (`inp_ids`), exactly as many entries as the *internal derivative order* you are producing.
- Set class attribute `schwarz: bool` if you want symmetric accumulation across internal slots (when applicable).

In [None]:
from thoad.typing import IDData, Notation
from thoad.differentiation import ContractiveFunction

class MyOpXBackward(ContractiveFunction):
    
    schwarz: bool = True

    def check_shape(
        self,
        out_id: int,
        inp_id: int,
        shape: Shape,
        indep: Indep,
        crossed: bool,
        ) -> Tuple[Shape, Indep]:
        self._shape = shape
        return (shape, indep)

    def _extract_context(self) -> None:
        self._context = {}
        self._process_context()

    def _process_context(self) -> None:
        assert self._context is not None
        self._processed_context = {}

    def _compute_internal_0_0(self) -> IDData:
        assert self._shape is not None
        derivative: torch.Tensor = torch.ones(
            self._shape,
            dtype=self._dtype,
            device=self._device,
        )
        notation: Notation = [
            (tuple(range(len(self._shape))), tuple(range(len(self._shape)))),
            (tuple(range(len(self._shape))),),
            (tuple(self._shape), tuple(False for _ in self._shape)),
        ]
        return (derivative, notation)

    def compute_internal(self, out_id: int, inp_id: Tuple[int, ...]) -> IDData:
        if (out_id, tuple(inp_id)) == (0, (0,)):
            return self._compute_internal_0_0()
        raise NotImplementedError((out_id, tuple(inp_id)))

### **4.3 DirectFunction**

`DirectFunction` **does not** compute new internal tensors. Instead, it **routes** existing (first‑order) internals into the external derivative layout, hence it only works when the underlying op has **1st‑order‑only** internal derivatives.

**You must provide:**
- Maintain `self._indeps: list[Indep]` of length = number of *unique* inputs that are differentiable for the op; fill it in `check_shape`.
- `check_shape(out_id, inp_id, shape, indep, crossed)`
  - Validate/normalize the external slot and cache `self._shape` and the appropriate entry in `self._indeps`.
- `_extract_context()` and `_process_context()`
  - As in the contractive case; prepare anything you need for reshaping/validation.
- `transform(derivative, shapes, indeps, out_id, inp_id) -> EDData`
  - Receive a derivative **already computed elsewhere** plus shape/independence metadata for each external slot.
  - Validate alignments (see `_check_transform`) and return the re‑packed `(derivative, shapes, indeps)` (and optionally a variable permutation if your engine expects it).

**Constraints and alignment.**
- `out_id[i]` is `None` iff `inp_id[i]` is `None`.
- When `inp_id[i] = j`, the `indeps[i]` must equal `self._indeps[j]` (same independence pattern).

In [None]:
from thoad.typing import StaticEDData
from thoad.differentiation import DirectFunction

class MyPointwiseXBackward(DirectFunction):
    schwarz: bool = True

    def __init__(
            self,
            grad_fn: AutogradFunction,
            order: int,
            dtype: torch.dtype,
            device: torch.device,
        ) -> None:
        super().__init__(
            grad_fn=grad_fn,
            order=order,
            dtype=dtype,
            device=device
        )
        self._indeps = [None]

    def check_shape(
            self,
            out_id: int,
            inp_id: int,
            shape: Shape,
            indep: Indep,
            crossed: bool,
        ) -> Tuple[Shape, Indep]:
        self._shape = shape
        self._indeps[0] = indep
        return (shape, indep)

    def _extract_context(self) -> None:
        self._context = {}
        self._process_context()

    def _process_context(self) -> None:
        assert self._context is not None
        self._processed_context = {}

    def _transform_0_0(
            self,
            derivative: torch.Tensor,
            shapes: Tuple[Shape, ...],
            indeps: Tuple[Indep, ...],
            variables: Tuple[int, ...],
        ) -> StaticEDData:
        # TODO: Transform derivative differentiations
        return (derivative, shapes, indeps)

    def transform(
            self,
            derivative: torch.Tensor,
            shapes: Tuple[Shape, ...],
            indeps: Tuple[Indep, ...],
            out_id: Tuple[Union[None, int], ...],
            inp_id: Tuple[Union[None, int], ...],
        ) -> StaticEDData:
        if bool(getattr(config, "DEBUG", False)):
            self._check_transform(
                derivative=derivative,
                shapes=shapes,
                indeps=indeps,
                out_id=out_id,
                inp_id=inp_id,
            )
        assert all(oo in (None, 0) for oo in out_id)
        assert all(ii in (None, 0) for ii in inp_id)
        variables = tuple(i for i, ii in enumerate(inp_id) if ii == 0)
        return self._transform_0_0(
            derivative=derivative,
            shapes=shapes,
            indeps=indeps,
            variables=variables,
        )