Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sketch of dim-ed tensors #407

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions pytensor/xtensor/__init__.py
@@ -0,0 +1,8 @@
import warnings
import pytensor.xtensor.rewriting

from pytensor.xtensor.variable import XTensorVariable, XTensorConstant, as_xtensor, as_xtensor_variable
from pytensor.xtensor.type import XTensorType


warnings.warn("xtensor module is experimental and full of bugs")
70 changes: 70 additions & 0 deletions pytensor/xtensor/basic.py
@@ -0,0 +1,70 @@
from itertools import chain

import pytensor.scalar as ps
from pytensor.graph import Apply, Op
import pytensor.xtensor as px
from pytensor.tensor import TensorType


class TensorFromXTensor(Op):

def make_node(self, x) -> Apply:
if not isinstance(x.type, px.XTensorType):
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
output = TensorType(x.type.dtype, shape=x.type.shape)()
return Apply(self, [x], [output])

def perform(self, node, inputs, output_storage) -> None:
[x] = inputs
output_storage[0][0] = x.copy()


tensor_from_xtensor = TensorFromXTensor()


class XTensorFromTensor(Op):

__props__ = ("dims",)

def __init__(self, dims):
super().__init__()
self.dims = dims

def make_node(self, x) -> Apply:
if not isinstance(x.type, TensorType):
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
output = px.XTensorType(x.type.dtype, dims=self.dims, shape=x.type.shape)()
return Apply(self, [x], [output])

def perform(self, node, inputs, output_storage) -> None:
[x] = inputs
output_storage[0][0] = x.copy()


def xtensor_from_tensor(x, dims):
return XTensorFromTensor(dims=dims)(x)


class XElemwise(Op):

__props__ = ("scalar_op",)

def __init__(self, scalar_op):
super().__init__()
self.scalar_op = scalar_op

def make_node(self, *inputs):
# TODO: Check dim lengths match
inputs = [px.as_xtensor_variable(inp) for inp in inputs]
# TODO: This ordering is different than what xarray does
unique_dims = sorted(set(chain.from_iterable(inp.type.dims for inp in inputs)))
# TODO: Fix dtype
output_type = px.XTensorType("float64", dims=unique_dims, shape=(None,) * len(unique_dims))
outputs = [output_type() for _ in range(self.scalar_op.nout)]
return Apply(self, inputs, outputs)

def perform(self, *args, **kwargs) -> None:
raise NotImplementedError("xtensor operations must be rewritten as tensor operations")


add = XElemwise(ps.add)
1 change: 1 addition & 0 deletions pytensor/xtensor/rewriting/__init__.py
@@ -0,0 +1 @@
import pytensor.xtensor.rewriting.basic
26 changes: 26 additions & 0 deletions pytensor/xtensor/rewriting/basic.py
@@ -0,0 +1,26 @@
from pytensor.graph import node_rewriter
from pytensor.tensor import expand_dims
from pytensor.tensor.elemwise import Elemwise
from pytensor.xtensor.basic import tensor_from_xtensor, XElemwise, xtensor_from_tensor
from pytensor.xtensor.rewriting.utils import register_xcanonicalize


@register_xcanonicalize
@node_rewriter(tracks=[XElemwise])
def xelemwise_to_elemwise(fgraph, node):
# Convert inputs to TensorVariables and add broadcastable dims
output_dims = node.outputs[0].type.dims

tensor_inputs = []
for inp in node.inputs:
inp_dims = inp.type.dims
axis = [i for i, dim in enumerate(output_dims) if dim not in inp_dims]
tensor_inp = tensor_from_xtensor(inp)
tensor_inp = expand_dims(tensor_inp, axis)
tensor_inputs.append(tensor_inp)

tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(*tensor_inputs, return_list=True)

# TODO: copy_stack_trace
new_outs = [xtensor_from_tensor(tensor_out, dims=output_dims) for tensor_out in tensor_outs]
return new_outs
33 changes: 33 additions & 0 deletions pytensor/xtensor/rewriting/utils.py
@@ -0,0 +1,33 @@
from typing import Union

from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import NodeRewriter
from pytensor.graph.rewriting.db import RewriteDatabase, EquilibriumDB


optdb.register(
"xcanonicalize",
EquilibriumDB(ignore_newtrees=False),
"fast_run",
"fast_compile",
"xtensor",
position=0,
)


def register_xcanonicalize(
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):

def register(inner_rewriter: Union[RewriteDatabase, NodeRewriter]):
return register_xcanonicalize(inner_rewriter, node_rewriter, *tags, **kwargs)

return register

else:
name = kwargs.pop("name", None) or node_rewriter.__name__
optdb["xtensor"].register(
name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs
)
return node_rewriter
147 changes: 147 additions & 0 deletions pytensor/xtensor/type.py
@@ -0,0 +1,147 @@
from typing import Iterable, Optional, Union, Sequence, TypeVar

import numpy as np

import pytensor
from pytensor import scalar as aes
from pytensor.graph.basic import Variable
from pytensor.graph.type import HasDataType
from pytensor.tensor.type import TensorType


_XTensorTypeType = TypeVar("_XTensorTypeType", bound=TensorType)


class XTensorType(TensorType, HasDataType):
"""A `Type` for sparse tensors.

Notes
-----
Currently, sparse tensors can only be matrices (i.e. have two dimensions).

"""

__props__ = ("dtype", "shape", "dims")

def __init__(
self,
dtype: Union[str, np.dtype],
*,
dims: Sequence[str],
shape: Optional[Iterable[Optional[Union[bool, int]]]] = None,
name: Optional[str] = None,
):
super().__init__(dtype, shape=shape, name=name)
if not isinstance(dims, (list, tuple)):
raise TypeError("dims must be a list or tuple")
dims = tuple(dims)
self.dims = dims

def clone(
self,
dtype=None,
dims=None,
shape=None,
**kwargs,
):
if dtype is None:
dtype = self.dtype
if dims is None:
dims = self.dims
if shape is None:
shape = self.shape
return type(self)(format, dtype, shape=shape, dims=dims, **kwargs)

def filter(self, value, strict=False, allow_downcast=None):
# TODO: Implement this
return value

if isinstance(value, Variable):
raise TypeError(
"Expected an array-like object, but found a Variable: "
"maybe you are trying to call a function on a (possibly "
"shared) variable instead of a numeric array?"
)

if (
isinstance(value, self.format_cls[self.format])
and value.dtype == self.dtype
):
return value

if strict:
raise TypeError(
f"{value} is not sparse, or not the right dtype (is {value.dtype}, "
f"expected {self.dtype})"
)

# The input format could be converted here
if allow_downcast:
sp = self.format_cls[self.format](value, dtype=self.dtype)
else:
data = self.format_cls[self.format](value)
up_dtype = aes.upcast(self.dtype, data.dtype)
if up_dtype != self.dtype:
raise TypeError(f"Expected {self.dtype} dtype but got {data.dtype}")
sp = data.astype(up_dtype)

assert sp.format == self.format

return sp

def convert_variable(self, var):
# TODO: Implement this
return var
res = super().convert_variable(var)

if res is None:
return res

if not isinstance(res.type, type(self)):
return None

if res.dims != self.dims:
# TODO: Does this make sense?
return None

return res

def __hash__(self):
return super().__hash__() ^ hash(self.dims)

def __repr__(self):
# TODO: Add `?` for unknown shapes like `TensorType` does
return f"XTensorType({self.dtype}, {self.dims}, {self.shape})"

def __eq__(self, other):
res = super().__eq__(other)

if isinstance(res, bool):
return res and other.dims == self.dims

return res

def is_super(self, otype):
# TODO: Implement this
return True

if not super().is_super(otype):
return False

if self.dims == otype.dims:
return True

return False


# TODO: Implement creater helper xtensor

pytensor.compile.register_view_op_c_code(
XTensorType,
"""
Py_XDECREF(%(oname)s);
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""",
1,
)