From 320bac49f4a084c1292a32b1946a32d8b6e41ef8 Mon Sep 17 00:00:00 2001 From: Harshvir Sandhu <75773763+HarshvirSandhu@users.noreply.github.com> Date: Thu, 20 Jun 2024 21:18:45 +0530 Subject: [PATCH] Add initial support for PyTorch backend (#764) --- .github/workflows/test.yml | 13 +- pytensor/compile/mode.py | 15 ++ pytensor/link/basic.py | 6 +- pytensor/link/pytorch/dispatch/__init__.py | 7 + pytensor/link/pytorch/dispatch/basic.py | 60 ++++++ pytensor/link/pytorch/dispatch/elemwise.py | 36 ++++ pytensor/link/pytorch/dispatch/scalar.py | 40 ++++ pytensor/link/pytorch/linker.py | 36 ++++ tests/link/pytorch/__init__.py | 0 tests/link/pytorch/test_basic.py | 205 +++++++++++++++++++++ tests/link/pytorch/test_elemwise.py | 55 ++++++ 11 files changed, 471 insertions(+), 2 deletions(-) create mode 100644 pytensor/link/pytorch/dispatch/__init__.py create mode 100644 pytensor/link/pytorch/dispatch/basic.py create mode 100644 pytensor/link/pytorch/dispatch/elemwise.py create mode 100644 pytensor/link/pytorch/dispatch/scalar.py create mode 100644 pytensor/link/pytorch/linker.py create mode 100644 tests/link/pytorch/__init__.py create mode 100644 tests/link/pytorch/test_basic.py create mode 100644 tests/link/pytorch/test_elemwise.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 096b47ab72..023519c268 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -76,6 +76,7 @@ jobs: float32: [0, 1] install-numba: [0] install-jax: [0] + install-torch: [0] part: - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse" - "tests/scan" @@ -116,6 +117,11 @@ jobs: fast-compile: 0 float32: 0 part: "tests/link/jax" + - install-torch: 1 + python-version: "3.10" + fast-compile: 0 + float32: 0 + part: "tests/link/pytorch" steps: - uses: actions/checkout@v4 with: @@ -142,9 +148,12 @@ jobs: - name: Install dependencies shell: micromamba-shell {0} run: | + micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi + if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 -c pytorch -c nvidia; fi + pip install -e ./ micromamba list && pip freeze python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' @@ -153,6 +162,7 @@ jobs: PYTHON_VERSION: ${{ matrix.python-version }} INSTALL_NUMBA: ${{ matrix.install-numba }} INSTALL_JAX: ${{ matrix.install-jax }} + INSTALL_TORCH: ${{ matrix.install-torch}} - name: Run tests shell: micromamba-shell {0} @@ -199,7 +209,7 @@ jobs: - name: Install dependencies shell: micromamba-shell {0} run: | - micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark + micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark pytorch pytorch-cuda=12.1 -c pytorch -c nvidia pip install -e ./ micromamba list && pip freeze python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' @@ -268,3 +278,4 @@ jobs: directory: ./coverage/ fail_ci_if_error: true token: ${{ secrets.CODECOV_TOKEN }} + diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index cf8dd9e73e..16019d4187 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -28,6 +28,7 @@ from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.jax.linker import JAXLinker from pytensor.link.numba.linker import NumbaLinker +from pytensor.link.pytorch.linker import PytorchLinker from pytensor.link.vm import VMLinker @@ -47,6 +48,7 @@ "vm_nogc": VMLinker(allow_gc=False, use_cloop=False), "cvm_nogc": VMLinker(allow_gc=False, use_cloop=True), "jax": JAXLinker(), + "pytorch": PytorchLinker(), "numba": NumbaLinker(), } @@ -460,6 +462,18 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): ], ), ) +PYTORCH = Mode( + PytorchLinker(), + RewriteDatabaseQuery( + include=["fast_run"], + exclude=[ + "cxx_only", + "BlasOpt", + "fusion", + "inplace", + ], + ), +) NUMBA = Mode( NumbaLinker(), RewriteDatabaseQuery( @@ -474,6 +488,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "FAST_RUN": FAST_RUN, "JAX": JAX, "NUMBA": NUMBA, + "PYTORCH": PYTORCH, } instantiated_default_mode = None diff --git a/pytensor/link/basic.py b/pytensor/link/basic.py index 6ae4bd7682..767656e081 100644 --- a/pytensor/link/basic.py +++ b/pytensor/link/basic.py @@ -600,6 +600,10 @@ def create_thunk_inputs(self, storage_map: dict[Variable, list[Any]]) -> list[An def jit_compile(self, fn: Callable) -> Callable: """JIT compile a converted ``FunctionGraph``.""" + def input_filter(self, inp: Any) -> Any: + """Apply a filter to the data input.""" + return inp + def output_filter(self, var: Variable, out: Any) -> Any: """Apply a filter to the data output by a JITed function call.""" return out @@ -657,7 +661,7 @@ def thunk( thunk_inputs=thunk_inputs, thunk_outputs=thunk_outputs, ): - outputs = fgraph_jit(*[x[0] for x in thunk_inputs]) + outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs]) for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs): compute_map[o_var][0] = True diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py new file mode 100644 index 0000000000..b6af171995 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -0,0 +1,7 @@ +# isort: off +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify + +# # Load dispatch specializations +import pytensor.link.pytorch.dispatch.scalar +import pytensor.link.pytorch.dispatch.elemwise +# isort: on diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py new file mode 100644 index 0000000000..c74df67b5b --- /dev/null +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -0,0 +1,60 @@ +from functools import singledispatch + +import torch + +from pytensor.compile.ops import DeepCopyOp +from pytensor.graph.fg import FunctionGraph +from pytensor.link.utils import fgraph_to_python +from pytensor.raise_op import CheckAndRaise + + +@singledispatch +def pytorch_typify(data, dtype=None, **kwargs): + r"""Convert instances of PyTensor `Type`\s to PyTorch types.""" + return torch.as_tensor(data, dtype=dtype) + + +@singledispatch +def pytorch_funcify(op, node=None, storage_map=None, **kwargs): + """Create a PyTorch compatible function from an PyTensor `Op`.""" + raise NotImplementedError( + f"No PyTorch conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/821` for progress or to request we prioritize this operation" + ) + + +@pytorch_funcify.register(FunctionGraph) +def pytorch_funcify_FunctionGraph( + fgraph, + node=None, + fgraph_name="pytorch_funcified_fgraph", + **kwargs, +): + return fgraph_to_python( + fgraph, + pytorch_funcify, + type_conversion_fn=pytorch_typify, + fgraph_name=fgraph_name, + **kwargs, + ) + + +@pytorch_funcify.register(CheckAndRaise) +def pytorch_funcify_CheckAndRaise(op, **kwargs): + error = op.exc_type + msg = op.msg + + def assert_fn(x, *conditions): + for cond in conditions: + if not cond.item(): + raise error(msg) + return x + + return assert_fn + + +@pytorch_funcify.register(DeepCopyOp) +def pytorch_funcify_DeepCopyOp(op, **kwargs): + def deepcopyop(x): + return x.clone() + + return deepcopyop diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py new file mode 100644 index 0000000000..f39e108bed --- /dev/null +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -0,0 +1,36 @@ +import torch + +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.tensor.elemwise import DimShuffle, Elemwise + + +@pytorch_funcify.register(Elemwise) +def pytorch_funcify_Elemwise(op, node, **kwargs): + scalar_op = op.scalar_op + base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) + + def elemwise_fn(*inputs): + Elemwise._check_runtime_broadcast(node, inputs) + return base_fn(*inputs) + + return elemwise_fn + + +@pytorch_funcify.register(DimShuffle) +def pytorch_funcify_DimShuffle(op, **kwargs): + def dimshuffle(x): + res = torch.permute(x, op.transposition) + + shape = list(res.shape[: len(op.shuffle)]) + + for augm in op.augment: + shape.insert(augm, 1) + + res = torch.reshape(res, shape) + + if not op.inplace: + res = res.clone() + + return res + + return dimshuffle diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py new file mode 100644 index 0000000000..56ec438c9f --- /dev/null +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -0,0 +1,40 @@ +import torch + +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.scalar.basic import ( + ScalarOp, +) + + +@pytorch_funcify.register(ScalarOp) +def pytorch_funcify_ScalarOp(op, node, **kwargs): + """Return pytorch function that implements the same computation as the Scalar Op. + + This dispatch is expected to return a pytorch function that works on Array inputs as Elemwise does, + even though it's dispatched on the Scalar Op. + """ + + nfunc_spec = getattr(op, "nfunc_spec", None) + if nfunc_spec is None: + raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}") + + func_name = nfunc_spec[0] + + pytorch_func = getattr(torch, func_name) + + if len(node.inputs) > op.nfunc_spec[1]: + # Some Scalar Ops accept multiple number of inputs, behaving as a variadic function, + # even though the base Op from `func_name` is specified as a binary Op. + # This happens with `Add`, which can work as a `Sum` for multiple scalars. + pytorch_variadic_func = getattr(torch, op.nfunc_variadic, None) + if not pytorch_variadic_func: + raise NotImplementedError( + f"Dispatch not implemented for Scalar Op {op} with {len(node.inputs)} inputs" + ) + + def pytorch_func(*args): + return pytorch_variadic_func( + torch.stack(torch.broadcast_tensors(*args), axis=0), axis=0 + ) + + return pytorch_func diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py new file mode 100644 index 0000000000..035d654c83 --- /dev/null +++ b/pytensor/link/pytorch/linker.py @@ -0,0 +1,36 @@ +from typing import Any + +from pytensor.graph.basic import Variable +from pytensor.link.basic import JITLinker + + +class PytorchLinker(JITLinker): + """A `Linker` that compiles NumPy-based operations using torch.compile.""" + + def input_filter(self, inp: Any) -> Any: + from pytensor.link.pytorch.dispatch import pytorch_typify + + return pytorch_typify(inp) + + def output_filter(self, var: Variable, out: Any) -> Any: + return out.cpu() + + def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): + from pytensor.link.pytorch.dispatch import pytorch_funcify + + return pytorch_funcify( + fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs + ) + + def jit_compile(self, fn): + import torch + + return torch.compile(fn) + + def create_thunk_inputs(self, storage_map): + thunk_inputs = [] + for n in self.fgraph.inputs: + sinput = storage_map[n] + thunk_inputs.append(sinput) + + return thunk_inputs diff --git a/tests/link/pytorch/__init__.py b/tests/link/pytorch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py new file mode 100644 index 0000000000..68d937fce8 --- /dev/null +++ b/tests/link/pytorch/test_basic.py @@ -0,0 +1,205 @@ +from collections.abc import Callable, Iterable +from functools import partial + +import numpy as np +import pytest + +from pytensor.compile.function import function +from pytensor.compile.mode import get_mode +from pytensor.compile.sharedvalue import SharedVariable, shared +from pytensor.configdefaults import config +from pytensor.graph.basic import Apply +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import Op +from pytensor.raise_op import CheckAndRaise +from pytensor.tensor.type import scalar, vector + + +torch = pytest.importorskip("torch") + + +pytorch_mode = get_mode("PYTORCH") +py_mode = get_mode("FAST_COMPILE") + + +def compare_pytorch_and_py( + fgraph: FunctionGraph, + test_inputs: Iterable, + assert_fn: Callable | None = None, + must_be_device_array: bool = True, + pytorch_mode=pytorch_mode, + py_mode=py_mode, +): + """Function to compare python graph output and pytorch compiled output for testing equality + + Parameters + ---------- + fgraph: FunctionGraph + PyTensor function Graph object + test_inputs: iter + Numerical inputs for testing the function graph + assert_fn: func, opt + Assert function used to check for equality between python and pytorch. If not + provided uses np.testing.assert_allclose + must_be_device_array: Bool + Checks if torch.device.type is cuda + + + """ + if assert_fn is None: + assert_fn = partial(np.testing.assert_allclose) + + fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] + + pytensor_torch_fn = function(fn_inputs, fgraph.outputs, mode=pytorch_mode) + pytorch_res = pytensor_torch_fn(*test_inputs) + + if must_be_device_array: + if isinstance(pytorch_res, list): + assert all(isinstance(res, torch.Tensor) for res in pytorch_res) + else: + assert pytorch_res.device.type == "cuda" + + pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) + py_res = pytensor_py_fn(*test_inputs) + + if len(fgraph.outputs) > 1: + for j, p in zip(pytorch_res, py_res): + assert_fn(j.cpu(), p) + else: + assert_fn([pytorch_res[0].cpu()], py_res) + + return pytensor_torch_fn, pytorch_res + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_pytorch_FunctionGraph_once(device): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + """Make sure that an output is only computed once when it's referenced multiple times.""" + from pytensor.link.pytorch.dispatch import pytorch_funcify + + with torch.device(device): + x = vector("x") + y = vector("y") + + class TestOp(Op): + def __init__(self): + self.called = 0 + + def make_node(self, *args): + return Apply(self, list(args), [x.type() for x in args]) + + def perform(self, inputs, outputs): + for i, inp in enumerate(inputs): + outputs[i][0] = inp[0] + + @pytorch_funcify.register(TestOp) + def pytorch_funcify_TestOp(op, **kwargs): + def func(*args, op=op): + op.called += 1 + for arg in args: + assert arg.device.type == device + return list(args) + + return func + + op1 = TestOp() + op2 = TestOp() + + q, r = op1(x, y) + outs = op2(q + r, q + r) + + out_fg = FunctionGraph([x, y], outs, clone=False) + assert len(out_fg.outputs) == 2 + + out_torch = pytorch_funcify(out_fg) + + x_val = torch.tensor([1, 2]).to(getattr(torch, config.floatX)) + y_val = torch.tensor([2, 3]).to(getattr(torch, config.floatX)) + + res = out_torch(x_val, y_val) + + for output in res: + assert torch.equal( + output, torch.tensor([3, 5]).to(getattr(torch, config.floatX)) + ) + + assert len(res) == 2 + assert op1.called == 1 + assert op2.called == 1 + + res = out_torch(x_val, y_val) + + for output in res: + assert torch.equal( + output, torch.tensor([3, 5]).to(getattr(torch, config.floatX)) + ) + + assert len(res) == 2 + assert op1.called == 2 + assert op2.called == 2 + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_shared(device): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + with torch.device(device): + a = shared(np.array([1, 2, 3], dtype=config.floatX)) + pytensor_torch_fn = function([], a, mode="PYTORCH") + pytorch_res = pytensor_torch_fn() + + assert isinstance(pytorch_res, torch.Tensor) + assert isinstance(a.get_value(), np.ndarray) + np.testing.assert_allclose(pytorch_res.cpu(), a.get_value()) + + pytensor_torch_fn = function([], a * 2, mode="PYTORCH") + pytorch_res = pytensor_torch_fn() + + assert isinstance(pytorch_res, torch.Tensor) + assert isinstance(a.get_value(), np.ndarray) + np.testing.assert_allclose(pytorch_res.cpu(), a.get_value() * 2) + + new_a_value = np.array([3, 4, 5], dtype=config.floatX) + a.set_value(new_a_value) + + pytorch_res = pytensor_torch_fn() + assert isinstance(pytorch_res, torch.Tensor) + np.testing.assert_allclose(pytorch_res.cpu(), new_a_value * 2) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_shared_updates(device): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + with torch.device(device): + a = shared(0) + + pytensor_torch_fn = function([], a, updates={a: a + 1}, mode="PYTORCH") + res1, res2 = pytensor_torch_fn(), pytensor_torch_fn() + assert res1 == 0 + assert res2 == 1 + assert a.get_value() == 2 + assert isinstance(a.get_value(), np.ndarray) + + a.set_value(5) + res1, res2 = pytensor_torch_fn(), pytensor_torch_fn() + assert res1 == 5 + assert res2 == 6 + assert a.get_value() == 7 + assert isinstance(a.get_value(), np.ndarray) + + +def test_pytorch_checkandraise(): + check_and_raise = CheckAndRaise(AssertionError, "testing") + + x = scalar("x") + conds = (x > 0, x > 3) + y = check_and_raise(x, *conds) + + y_fn = function([x], y, mode="PYTORCH") + + with pytest.raises(AssertionError, match="testing"): + y_fn(0.0) + assert y_fn(4).item() == 4 diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py new file mode 100644 index 0000000000..1d843b8051 --- /dev/null +++ b/tests/link/pytorch/test_elemwise.py @@ -0,0 +1,55 @@ +import numpy as np + +import pytensor.tensor as pt +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor import elemwise as pt_elemwise +from pytensor.tensor.type import matrix, tensor, vector +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +def test_pytorch_Dimshuffle(): + a_pt = matrix("a") + + x = a_pt.T + x_fg = FunctionGraph([a_pt], [x]) + compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) + + x = a_pt.dimshuffle([0, 1, "x"]) + x_fg = FunctionGraph([a_pt], [x]) + compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) + + a_pt = tensor(dtype=config.floatX, shape=(None, 1)) + x = a_pt.dimshuffle((0,)) + x_fg = FunctionGraph([a_pt], [x]) + compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) + + a_pt = tensor(dtype=config.floatX, shape=(None, 1)) + x = pt_elemwise.DimShuffle([False, True], (0,))(a_pt) + x_fg = FunctionGraph([a_pt], [x]) + compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) + + +def test_multiple_input_output(): + x = vector("x") + y = vector("y") + out = pt.mul(x, y) + + fg = FunctionGraph(outputs=[out], clone=False) + compare_pytorch_and_py(fg, [[1.5], [2.5]]) + + x = vector("x") + y = vector("y") + div = pt.int_div(x, y) + pt_sum = pt.add(y, x) + + fg = FunctionGraph(outputs=[div, pt_sum], clone=False) + compare_pytorch_and_py(fg, [[1.5], [2.5]]) + + +def test_pytorch_elemwise(): + x = pt.vector("x") + out = pt.log(1 - x) + + fg = FunctionGraph([x], [out]) + compare_pytorch_and_py(fg, [[0.9, 0.9]])