-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add initial support for PyTorch backend (#764)
- Loading branch information
1 parent
efa845a
commit 320bac4
Showing
11 changed files
with
471 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# isort: off | ||
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify | ||
|
||
# # Load dispatch specializations | ||
import pytensor.link.pytorch.dispatch.scalar | ||
import pytensor.link.pytorch.dispatch.elemwise | ||
# isort: on |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from functools import singledispatch | ||
|
||
import torch | ||
|
||
from pytensor.compile.ops import DeepCopyOp | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.link.utils import fgraph_to_python | ||
from pytensor.raise_op import CheckAndRaise | ||
|
||
|
||
@singledispatch | ||
def pytorch_typify(data, dtype=None, **kwargs): | ||
r"""Convert instances of PyTensor `Type`\s to PyTorch types.""" | ||
return torch.as_tensor(data, dtype=dtype) | ||
|
||
|
||
@singledispatch | ||
def pytorch_funcify(op, node=None, storage_map=None, **kwargs): | ||
"""Create a PyTorch compatible function from an PyTensor `Op`.""" | ||
raise NotImplementedError( | ||
f"No PyTorch conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/821` for progress or to request we prioritize this operation" | ||
) | ||
|
||
|
||
@pytorch_funcify.register(FunctionGraph) | ||
def pytorch_funcify_FunctionGraph( | ||
fgraph, | ||
node=None, | ||
fgraph_name="pytorch_funcified_fgraph", | ||
**kwargs, | ||
): | ||
return fgraph_to_python( | ||
fgraph, | ||
pytorch_funcify, | ||
type_conversion_fn=pytorch_typify, | ||
fgraph_name=fgraph_name, | ||
**kwargs, | ||
) | ||
|
||
|
||
@pytorch_funcify.register(CheckAndRaise) | ||
def pytorch_funcify_CheckAndRaise(op, **kwargs): | ||
error = op.exc_type | ||
msg = op.msg | ||
|
||
def assert_fn(x, *conditions): | ||
for cond in conditions: | ||
if not cond.item(): | ||
raise error(msg) | ||
return x | ||
|
||
return assert_fn | ||
|
||
|
||
@pytorch_funcify.register(DeepCopyOp) | ||
def pytorch_funcify_DeepCopyOp(op, **kwargs): | ||
def deepcopyop(x): | ||
return x.clone() | ||
|
||
return deepcopyop |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
|
||
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify | ||
from pytensor.tensor.elemwise import DimShuffle, Elemwise | ||
|
||
|
||
@pytorch_funcify.register(Elemwise) | ||
def pytorch_funcify_Elemwise(op, node, **kwargs): | ||
scalar_op = op.scalar_op | ||
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) | ||
|
||
def elemwise_fn(*inputs): | ||
Elemwise._check_runtime_broadcast(node, inputs) | ||
return base_fn(*inputs) | ||
|
||
return elemwise_fn | ||
|
||
|
||
@pytorch_funcify.register(DimShuffle) | ||
def pytorch_funcify_DimShuffle(op, **kwargs): | ||
def dimshuffle(x): | ||
res = torch.permute(x, op.transposition) | ||
|
||
shape = list(res.shape[: len(op.shuffle)]) | ||
|
||
for augm in op.augment: | ||
shape.insert(augm, 1) | ||
|
||
res = torch.reshape(res, shape) | ||
|
||
if not op.inplace: | ||
res = res.clone() | ||
|
||
return res | ||
|
||
return dimshuffle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import torch | ||
|
||
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify | ||
from pytensor.scalar.basic import ( | ||
ScalarOp, | ||
) | ||
|
||
|
||
@pytorch_funcify.register(ScalarOp) | ||
def pytorch_funcify_ScalarOp(op, node, **kwargs): | ||
"""Return pytorch function that implements the same computation as the Scalar Op. | ||
This dispatch is expected to return a pytorch function that works on Array inputs as Elemwise does, | ||
even though it's dispatched on the Scalar Op. | ||
""" | ||
|
||
nfunc_spec = getattr(op, "nfunc_spec", None) | ||
if nfunc_spec is None: | ||
raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}") | ||
|
||
func_name = nfunc_spec[0] | ||
|
||
pytorch_func = getattr(torch, func_name) | ||
|
||
if len(node.inputs) > op.nfunc_spec[1]: | ||
# Some Scalar Ops accept multiple number of inputs, behaving as a variadic function, | ||
# even though the base Op from `func_name` is specified as a binary Op. | ||
# This happens with `Add`, which can work as a `Sum` for multiple scalars. | ||
pytorch_variadic_func = getattr(torch, op.nfunc_variadic, None) | ||
if not pytorch_variadic_func: | ||
raise NotImplementedError( | ||
f"Dispatch not implemented for Scalar Op {op} with {len(node.inputs)} inputs" | ||
) | ||
|
||
def pytorch_func(*args): | ||
return pytorch_variadic_func( | ||
torch.stack(torch.broadcast_tensors(*args), axis=0), axis=0 | ||
) | ||
|
||
return pytorch_func |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Any | ||
|
||
from pytensor.graph.basic import Variable | ||
from pytensor.link.basic import JITLinker | ||
|
||
|
||
class PytorchLinker(JITLinker): | ||
"""A `Linker` that compiles NumPy-based operations using torch.compile.""" | ||
|
||
def input_filter(self, inp: Any) -> Any: | ||
from pytensor.link.pytorch.dispatch import pytorch_typify | ||
|
||
return pytorch_typify(inp) | ||
|
||
def output_filter(self, var: Variable, out: Any) -> Any: | ||
return out.cpu() | ||
|
||
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): | ||
from pytensor.link.pytorch.dispatch import pytorch_funcify | ||
|
||
return pytorch_funcify( | ||
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs | ||
) | ||
|
||
def jit_compile(self, fn): | ||
import torch | ||
|
||
return torch.compile(fn) | ||
|
||
def create_thunk_inputs(self, storage_map): | ||
thunk_inputs = [] | ||
for n in self.fgraph.inputs: | ||
sinput = storage_map[n] | ||
thunk_inputs.append(sinput) | ||
|
||
return thunk_inputs |
Empty file.
Oops, something went wrong.