Skip to content

Commit

Permalink
Implement basic Alloc Ops in PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 20, 2024
1 parent 320bac4 commit 60246ad
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
31 changes: 31 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange


@singledispatch
Expand Down Expand Up @@ -58,3 +59,33 @@ def deepcopyop(x):
return x.clone()

return deepcopyop


@pytorch_funcify.register(AllocEmpty)
def pytorch_funcify_AllocEmpty(op, **kwargs):
dtype = getattr(torch, op.dtype)

def alloc_empty(*shape):
return torch.empty(shape, dtype=dtype)

return alloc_empty


@pytorch_funcify.register(Alloc)
def pytorch_funcify_alloc(op, **kwargs):
def alloc(value, *shape):
out = torch.empty(shape, dtype=value.dtype)
out[...] = value # broadcast value to shape of out
return out

return alloc


@pytorch_funcify.register(ARange)
def pytorch_funcify_arange(op, **kwargs):
dtype = getattr(torch, op.dtype)

def arange(start, stop, step):
return torch.arange(start, stop, step, dtype=dtype)

return arange
34 changes: 33 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty
from pytensor.tensor.type import scalar, vector


Expand Down Expand Up @@ -191,7 +192,7 @@ def test_shared_updates(device):
assert isinstance(a.get_value(), np.ndarray)


def test_pytorch_checkandraise():
def test_checkandraise():
check_and_raise = CheckAndRaise(AssertionError, "testing")

x = scalar("x")
Expand All @@ -203,3 +204,34 @@ def test_pytorch_checkandraise():
with pytest.raises(AssertionError, match="testing"):
y_fn(0.0)
assert y_fn(4).item() == 4


def test_alloc_and_empty():
dim0 = as_tensor(5, dtype="int64")
dim1 = scalar("dim1", dtype="int64")

out = empty((dim0, dim1, 3), dtype="float32")
fn = function([dim1], out, mode=pytorch_mode)
res = fn(7)
assert res.shape == (5, 7, 3)
assert res.dtype == torch.float32

v = vector("v", shape=(3,), dtype="float64")
out = alloc(v, (dim0, dim1, 3))
compare_pytorch_and_py(
FunctionGraph([v, dim1], [out]),
[np.array([1, 2, 3]), np.array(7)],
)


def test_arange():
start = scalar("start", dtype="int64")
stop = scalar("stop", dtype="int64")
step = scalar("step", dtype="int64")

out = arange(start, stop, step, dtype="int16")

compare_pytorch_and_py(
FunctionGraph([start, stop, step], [out]),
[np.array(1), np.array(10), np.array(2)],
)

0 comments on commit 60246ad

Please sign in to comment.