Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docs on implementing Pytorch Ops (and CumOp) #837

Merged
merged 8 commits into from
Jul 4, 2024

Conversation

HarshvirSandhu
Copy link
Contributor

Description

This PR can be used as an example for implementing Ops in PyTorch

Related Issue

Checklist

  • Checked that the pre-commit linting/style checks pass
  • Included tests that prove the fix is effective or that the new feature works
  • Added necessary documentation (docstrings and/or example notebooks)

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

cc @ricardoV94

dim = op.axis
mode = op.mode

def cumop(x, dim=dim, mode=mode):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed, the returned functions are never called by the user

Suggested change
def cumop(x, dim=dim, mode=mode):
def cumop(x):

Comment on lines 16 to 17
# Create test value tag for a
a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for test values and tags. We're planning to deprecate that functionality as well

# For the second mode of CumOp
out = pt.cumprod(a, axis=1)
fgraph = FunctionGraph([a], [out])
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here just pass the test values (instead of adding them as tags and then retrieving them)

a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))

# Create the output variable
out = pt.cumsum(a, axis=0)
Copy link
Member

@ricardoV94 ricardoV94 Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test axis=None and axis=tuple(...) if supported by the original Op. If tuple is allowed make sure you have more dimensions (say 3) and only ask for a subset (say 2) of them in the axis. This is to make sure you test something that is different than axis=None or axis=int.

The axis can be parametrized (prod and add as well) instead of adding more conditions inside the test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried this on the original Op. axis=tuple(...) does not work and gives a TypeError
axis=None gives the output as a 1-D array

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Op __init__ doesn't seem to check explicitly for axes but it does assume it is either None or an int. Can we add a check and raise an explicit ValueError if it's not either?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked again, there is no error if we use axis=(0), pytorch also returns the same output.
The error only comes when there are more than 1 elements in the tuple (Even np.cumsum gives TypeError in this case).

We could try adding a check and raise, but would that be needed in other Op implementations?
Since this would be used as an example, it might be complicated if a check and raise is not needed for other implementations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(0) is 0, not a tuple with a 0 inside it, it would have to be (0,) to be a tuple with a single element inside. Does it work with (0,)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it gives a TypeError

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which is fine but probably gives a typeerror in an obscure place. We should raise already in the init method of the Op to save people time

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 20, 2024

Can you extend the example in the documentation page on implementing custom JAX/NUMBA Ops to mention PyTorch and include this example as well?

Perhaps you can use some fancy tab to select among the different modes in the same documentation page. Is that supported @OriolAbril ?

@OriolAbril
Copy link
Member

Not here as of now, you'd have to add an extra extension for tabs. If you'll only want tabs, then it is probably best to use https://sphinx-tabs.readthedocs.io/en/latest/, if using things like grids, dropdowns, icons... somewhere else in addition to tabs here seems a future possibility then https://sphinx-design.readthedocs.io/en/sbt-theme/ is probably best. Both should only require being added as dependencies to the doc env and adding them to the extensions varialbe in conf.py, no further configuration

@ricardoV94
Copy link
Member

Thanks @OriolAbril either of those seems perfect. Any preference?

Copy link

codecov bot commented Jun 22, 2024

Codecov Report

Attention: Patch coverage is 63.15789% with 7 lines in your changes missing coverage. Please review.

Project coverage is 80.97%. Comparing base (320bac4) to head (a6e6bd8).
Report is 22 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #837      +/-   ##
==========================================
+ Coverage   80.87%   80.97%   +0.10%     
==========================================
  Files         168      170       +2     
  Lines       46950    47044      +94     
  Branches    11472    11504      +32     
==========================================
+ Hits        37972    38096     +124     
+ Misses       6766     6734      -32     
- Partials     2212     2214       +2     
Files Coverage Δ
pytensor/link/pytorch/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/tensor/extra_ops.py 88.63% <100.00%> (+0.03%) ⬆️
pytensor/link/pytorch/dispatch/extra_ops.py 56.25% <56.25%> (ø)

... and 20 files with indirect coverage changes

dim = op.axis
mode = op.mode

def cumop(x, dim=dim):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just no need for any kwargs. The function will only ever receive the node inputs

Suggested change
def cumop(x, dim=dim):
def cumop(x):

@OriolAbril
Copy link
Member

Thanks @OriolAbril either of those seems perfect. Any preference?

I use sphinx-design more because I use its other features

# Create a symbolic input for the first input of `CumOp`
a = pt.matrix("a")

# Create test value tag for a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Create test value tag for a
# Create test value

a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))

# Create the output variable
out = pt.cumsum(a, axis=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which is fine but probably gives a typeerror in an obscure place. We should raise already in the init method of the Op to save people time

@@ -283,8 +283,11 @@ class CumOp(COp):
def __init__(self, axis: int | None = None, mode="add"):
if mode not in ("add", "mul"):
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"')
self.axis = axis
self.mode = mode
if isinstance(axis, int) or axis is None:
Copy link
Member

@ricardoV94 ricardoV94 Jun 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, it's more common to just check and raise than indenting the "correct code" and raising otherwise

Suggested change
if isinstance(axis, int) or axis is None:
if not (isinstance(axis, int) or axis is None):
# raise error
# usual code

That's how the error check above for the mode is structured as well

return res if n_outs > 1 else res[0]
.. tab-set::

.. tab-item:: JAX/Numba
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct for Numba, can you leave it as a separate tab with [in progress] text (and open an issue) or check the source code of the Numba implementation if you want to do it correctly?

This probably applies to all the tabbed sections, no reason to combine jax and numba, and the pre-existing snippets were JAX specific

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was adding a separate tab for numba and found this comment. Is there anything that should be changed in numba_funcify_DimShuffle?

# FIXME: Numba's `array.reshape` only accepts C arrays.
res_reshape = np.reshape(np.ascontiguousarray(x), new_shape)

Copy link
Member

@ricardoV94 ricardoV94 Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in the context of this PR, but we should open an issue here to check if that's still a problem in the newer versions of numba. Could you do that?

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all tabs look rendered correctly, only left a comment so cross references to other libraries actually work

function that performs exactly the same computations as the :class:`Op`. For
example, the :class:`Eye` operator has a JAX equivalent: :func:`jax.numpy.eye`
(see `the documentation <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.eye.html?highlight=eye>`_).
(see `the documentation <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.eye.html?highlight=eye>`_) and a Pytorch equivalent :func:`torch.eye` (see `documentation <https://pytorch.org/docs/stable/generated/torch.eye.html>`_).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like this:

imatge

which is quite the weird pattern for docs, especially given jax.numpy.eye and torch.eye are already using the correct cross-referencing syntax. I would remove the manual links and use the cross-references. That is, leaving only this:

Suggested change
(see `the documentation <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.eye.html?highlight=eye>`_) and a Pytorch equivalent :func:`torch.eye` (see `documentation <https://pytorch.org/docs/stable/generated/torch.eye.html>`_).
and a Pytorch equivalent :func:`torch.eye`.

And doing two more changes to conf.py. First add sphinx.ext.intersphinx to the list of extensions. It is part of the main sphinx library so no need to add any extra dependency to the env file. Add

intersphinx_mapping = {
    "jax": ("https://jax.readthedocs.io/en/latest", None),
    "numpy": ("https://numpy.org/doc/stable", None),
    "torch": ("https://pytorch.org/docs/stable", None),
}

with that, the jax.numpy.eye and torch.eye will still be formatted as monospaced text but no longer be pink, they'll be blue and be clickable links to their respective API pages.

@ricardoV94 ricardoV94 added enhancement New feature or request torch PyTorch backend labels Jun 28, 2024
@ricardoV94 ricardoV94 merged commit 781073b into pymc-devs:main Jul 4, 2024
56 of 57 checks passed
@ricardoV94 ricardoV94 changed the title Add Pytorch support for Cum Op Add docs on implementing Pytorch Ops (and CumOp) Jul 4, 2024
@ricardoV94 ricardoV94 added the docs label Jul 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
docs enhancement New feature or request torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants