Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Named tensors with typed spaces #477

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

michaelosthege
Copy link
Member

@michaelosthege michaelosthege commented Oct 15, 2023

I took the branch from #407 and added a pytensor.xtensor.spaces module that defines types to distinguish between "unordered spaces" (BaseSpace) and "ordered spaces" (OrderedSpace).

BaseSpace and OrderedSpace are similar to sets & tuples, but do not implement some operations that would mess up interpreting them as dims.

One idea here is to apply the mathematical operations not only to the variables, but also to their spaces.

For example:

# Addition between two variables uses bilateral broadcasting
Space(["a", "b"]) + Space({"c"}) -> Space({"a", "b", "c"})

This matches broadcasting in xarray:

a = xarray.DataArray([[1,2,3]], dims=["a", "b"])
b = xarray.DataArray([1,2,3,4], dims=["c"])
assert set((b + a).dims) == {"a", "b", "c"}

However, xarray.DataArray.dims are tuples, and the commutative rule does not apply to addition of xarray.DataArray variables' dims:

assert (a + b).dims == (b + a).dims  # AssertionError

In contrast, with this PR the resulting dims become an unordered space, and the resulting XTensorType are equal:

xa = ptx.as_xtensor(a)
xb = ptx.as_xtensor(b)
xc = xa + xb

xa.type  # XTensorType(int32, OrderedSpace('a', 'b'), (1, 3))
xb.type  # XTensorType(int32, OrderedSpace('c'), (4,))
xc.type  # XTensorType(float64, Space{'c', 'a', 'b'}, (None, None, None))

assert (xa + xb).type == (xb + xa).type

This was basic math, but we could introduce XOps with XOp.infer_space methods that can implement broadcasting rules for any operation:

class XOp(Op):
    def infer_space(self, fgraph, node, input_spaces) -> BaseSpace:
        raise NotImplementedError()


class SumOverTime(XOp):
    def infer_space(self, fgraph, node, input_spaces) -> BaseSpace:
        [s] = input_spaces
        if "time" not in s:
            raise ValueError("No time dim to sum over.")
        return Space(s - {"time"})

Similarly, this should allow us to implement dot products requiring OrderedSpace inputs to produce an OrderedSpace output, or a specify_dimorder XOp that orders a BaseSpace into an OrderedSpace.

Looking at the (None, None, None) shape from the code block above, I wonder if we should type XTensorType.shape as a Mapping[DimLike, int | ScalarVariable] 馃

@ricardoV94
Copy link
Member

Looking at the (None, None, None) shape from the code block above, I wonder if we should type XTensorType.shape as a Mapping[DimLike, int | ScalarVariable]

PyTensor variables shouldn't show up in the attributes of Variable types.

@michaelosthege
Copy link
Member Author

had a few more thoughts on this, and found that also for unordered spaces we need to know which index a dimension has in the underlying array. With that information one can then index into shape as well.

Now the question is where this information should be kept. Either the XTensorType keeps it, or we don't actually make the BaseSpace unordered no, then space math would not work.

Maybe it's enough to keep a is_ordered: bool and a dims: tuple? A .space property could create the corresponding BaseSpace/OrderedSpace if needed 馃

class XTensorFromTensor(Op):
__props__ = ("dims",)

def __init__(self, dims: Iterable[DimLike]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dims should be Sequence since Iterable can exhaust...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by "exhaust" here?

class XElemwise(Op):
__props__ = ("scalar_op",)

def __init__(self, scalar_op):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing type hints

...


class Dim(DimLike):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I remember, it is a bad practice to inherit a base class from the protocol

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dim can be just a dataclass instead

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the explanation in the PEP I would disagree: https://peps.python.org/pep-0544/#explicitly-declaring-implementation

By inheriting the protocol, we enable type checkers to warn about incomplete/incorrect implementations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants