Skip to content

Commit

Permalink
Add initial support for PyTorch backend (#764)
Browse files Browse the repository at this point in the history
  • Loading branch information
HarshvirSandhu committed Jun 20, 2024
1 parent efa845a commit 320bac4
Show file tree
Hide file tree
Showing 11 changed files with 471 additions and 2 deletions.
13 changes: 12 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ jobs:
float32: [0, 1]
install-numba: [0]
install-jax: [0]
install-torch: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests/scan"
Expand Down Expand Up @@ -116,6 +117,11 @@ jobs:
fast-compile: 0
float32: 0
part: "tests/link/jax"
- install-torch: 1
python-version: "3.10"
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
steps:
- uses: actions/checkout@v4
with:
Expand All @@ -142,9 +148,12 @@ jobs:
- name: Install dependencies
shell: micromamba-shell {0}
run: |
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 -c pytorch -c nvidia; fi
pip install -e ./
micromamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
Expand All @@ -153,6 +162,7 @@ jobs:
PYTHON_VERSION: ${{ matrix.python-version }}
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}

- name: Run tests
shell: micromamba-shell {0}
Expand Down Expand Up @@ -199,7 +209,7 @@ jobs:
- name: Install dependencies
shell: micromamba-shell {0}
run: |
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
pip install -e ./
micromamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
Expand Down Expand Up @@ -268,3 +278,4 @@ jobs:
directory: ./coverage/
fail_ci_if_error: true
token: ${{ secrets.CODECOV_TOKEN }}

15 changes: 15 additions & 0 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.link.jax.linker import JAXLinker
from pytensor.link.numba.linker import NumbaLinker
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.link.vm import VMLinker


Expand All @@ -47,6 +48,7 @@
"vm_nogc": VMLinker(allow_gc=False, use_cloop=False),
"cvm_nogc": VMLinker(allow_gc=False, use_cloop=True),
"jax": JAXLinker(),
"pytorch": PytorchLinker(),
"numba": NumbaLinker(),
}

Expand Down Expand Up @@ -460,6 +462,18 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
],
),
)
PYTORCH = Mode(
PytorchLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
],
),
)
NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
Expand All @@ -474,6 +488,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
"FAST_RUN": FAST_RUN,
"JAX": JAX,
"NUMBA": NUMBA,
"PYTORCH": PYTORCH,
}

instantiated_default_mode = None
Expand Down
6 changes: 5 additions & 1 deletion pytensor/link/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,10 @@ def create_thunk_inputs(self, storage_map: dict[Variable, list[Any]]) -> list[An
def jit_compile(self, fn: Callable) -> Callable:
"""JIT compile a converted ``FunctionGraph``."""

def input_filter(self, inp: Any) -> Any:
"""Apply a filter to the data input."""
return inp

def output_filter(self, var: Variable, out: Any) -> Any:
"""Apply a filter to the data output by a JITed function call."""
return out
Expand Down Expand Up @@ -657,7 +661,7 @@ def thunk(
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])

for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
compute_map[o_var][0] = True
Expand Down
7 changes: 7 additions & 0 deletions pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# isort: off
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify

# # Load dispatch specializations
import pytensor.link.pytorch.dispatch.scalar
import pytensor.link.pytorch.dispatch.elemwise
# isort: on
60 changes: 60 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from functools import singledispatch

import torch

from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise


@singledispatch
def pytorch_typify(data, dtype=None, **kwargs):
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
return torch.as_tensor(data, dtype=dtype)


@singledispatch
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a PyTorch compatible function from an PyTensor `Op`."""
raise NotImplementedError(
f"No PyTorch conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/821` for progress or to request we prioritize this operation"
)


@pytorch_funcify.register(FunctionGraph)
def pytorch_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="pytorch_funcified_fgraph",
**kwargs,
):
return fgraph_to_python(
fgraph,
pytorch_funcify,
type_conversion_fn=pytorch_typify,
fgraph_name=fgraph_name,
**kwargs,
)


@pytorch_funcify.register(CheckAndRaise)
def pytorch_funcify_CheckAndRaise(op, **kwargs):
error = op.exc_type
msg = op.msg

def assert_fn(x, *conditions):
for cond in conditions:
if not cond.item():
raise error(msg)
return x

return assert_fn


@pytorch_funcify.register(DeepCopyOp)
def pytorch_funcify_DeepCopyOp(op, **kwargs):
def deepcopyop(x):
return x.clone()

return deepcopyop
36 changes: 36 additions & 0 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.elemwise import DimShuffle, Elemwise


@pytorch_funcify.register(Elemwise)
def pytorch_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)

def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs)

return elemwise_fn


@pytorch_funcify.register(DimShuffle)
def pytorch_funcify_DimShuffle(op, **kwargs):
def dimshuffle(x):
res = torch.permute(x, op.transposition)

shape = list(res.shape[: len(op.shuffle)])

for augm in op.augment:
shape.insert(augm, 1)

res = torch.reshape(res, shape)

if not op.inplace:
res = res.clone()

return res

return dimshuffle
40 changes: 40 additions & 0 deletions pytensor/link/pytorch/dispatch/scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import (
ScalarOp,
)


@pytorch_funcify.register(ScalarOp)
def pytorch_funcify_ScalarOp(op, node, **kwargs):
"""Return pytorch function that implements the same computation as the Scalar Op.
This dispatch is expected to return a pytorch function that works on Array inputs as Elemwise does,
even though it's dispatched on the Scalar Op.
"""

nfunc_spec = getattr(op, "nfunc_spec", None)
if nfunc_spec is None:
raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")

func_name = nfunc_spec[0]

pytorch_func = getattr(torch, func_name)

if len(node.inputs) > op.nfunc_spec[1]:
# Some Scalar Ops accept multiple number of inputs, behaving as a variadic function,
# even though the base Op from `func_name` is specified as a binary Op.
# This happens with `Add`, which can work as a `Sum` for multiple scalars.
pytorch_variadic_func = getattr(torch, op.nfunc_variadic, None)
if not pytorch_variadic_func:
raise NotImplementedError(
f"Dispatch not implemented for Scalar Op {op} with {len(node.inputs)} inputs"
)

def pytorch_func(*args):
return pytorch_variadic_func(
torch.stack(torch.broadcast_tensors(*args), axis=0), axis=0
)

return pytorch_func
36 changes: 36 additions & 0 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Any

from pytensor.graph.basic import Variable
from pytensor.link.basic import JITLinker


class PytorchLinker(JITLinker):
"""A `Linker` that compiles NumPy-based operations using torch.compile."""

def input_filter(self, inp: Any) -> Any:
from pytensor.link.pytorch.dispatch import pytorch_typify

return pytorch_typify(inp)

def output_filter(self, var: Variable, out: Any) -> Any:
return out.cpu()

def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify

return pytorch_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
)

def jit_compile(self, fn):
import torch

return torch.compile(fn)

def create_thunk_inputs(self, storage_map):
thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
thunk_inputs.append(sinput)

return thunk_inputs
Empty file added tests/link/pytorch/__init__.py
Empty file.
Loading

0 comments on commit 320bac4

Please sign in to comment.