Skip to content

Commit

Permalink
Allow subclasses to override _repr_latex_
Browse files Browse the repository at this point in the history
The previous behavior was to completely ignore this hook once `init_printing` had been called, which was surprising to the user.
  • Loading branch information
eric-wieser committed Aug 10, 2020
1 parent e572934 commit 5690f54
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
19 changes: 14 additions & 5 deletions sympy/interactive/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,58 +249,67 @@ def _result_display(self, arg):
import IPython
if V(IPython.__version__) >= '0.11':

printable_types = [Printable, float, tuple, list, set,
frozenset, dict, int]
# Printable is our own type, so we handle it with methods instead of
# the approach required by builtin types. This allows downstream
# packages to override the methods in their own classes.
printable_types = [float, tuple, list, set, frozenset, dict, int]

plaintext_formatter = ip.display_formatter.formatters['text/plain']

for cls in printable_types:
# Exception to the rule above, IPython has better dispatching rules
# for plaintext printing, and we can't use `_repr_pretty_` without
# hitting a recursion error above.
for cls in printable_types + [Printable]:
plaintext_formatter.for_type(cls, _print_plain)

svg_formatter = ip.display_formatter.formatters['image/svg+xml']
if use_latex in ('svg', ):
debug("init_printing: using svg formatter")
for cls in printable_types:
svg_formatter.for_type(cls, _print_latex_svg)
Printable._repr_svg_ = _print_latex_svg
else:
debug("init_printing: not using any svg formatter")
for cls in printable_types:
# Better way to set this, but currently does not work in IPython
#png_formatter.for_type(cls, None)
if cls in svg_formatter.type_printers:
svg_formatter.type_printers.pop(cls)
Printable._repr_svg_ = None

png_formatter = ip.display_formatter.formatters['image/png']
if use_latex in (True, 'png'):
debug("init_printing: using png formatter")
for cls in printable_types:
png_formatter.for_type(cls, _print_latex_png)
Printable._repr_png_ = _print_latex_png
elif use_latex == 'matplotlib':
debug("init_printing: using matplotlib formatter")
for cls in printable_types:
png_formatter.for_type(cls, _print_latex_matplotlib)
Printable._repr_png_ = _print_latex_matplotlib
else:
debug("init_printing: not using any png formatter")
for cls in printable_types:
# Better way to set this, but currently does not work in IPython
#png_formatter.for_type(cls, None)
if cls in png_formatter.type_printers:
png_formatter.type_printers.pop(cls)
Printable._repr_png_ = None

latex_formatter = ip.display_formatter.formatters['text/latex']
if use_latex in (True, 'mathjax'):
debug("init_printing: using mathjax formatter")
for cls in printable_types:
latex_formatter.for_type(cls, _print_latex_text)
Printable._repr_latex_ = Printable._repr_latex_orig
Printable._repr_latex_ = _print_latex_text
else:
debug("init_printing: not using text/latex formatter")
for cls in printable_types:
# Better way to set this, but currently does not work in IPython
#latex_formatter.for_type(cls, None)
if cls in latex_formatter.type_printers:
latex_formatter.type_printers.pop(cls)

Printable._repr_latex_ = None

else:
Expand Down
21 changes: 21 additions & 0 deletions sympy/interactive/tests/test_ipythonprinting.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,24 @@ def test_matplotlib_bad_latex():
# issue 9799
app.run_cell("from sympy import Piecewise, Symbol, Eq")
app.run_cell("x = Symbol('x'); pw = format(Piecewise((1, Eq(x, 0)), (0, True)))")


def test_override_repr_latex():
# Initialize and setup IPython session
app = init_ipython_session()
app.run_cell("import IPython")
app.run_cell("from sympy import init_printing")
app.run_cell("from sympy import Symbol")
app.run_cell("init_printing(use_latex=True)")
app.run_cell("""\
class SymbolWithOverload(Symbol):
def _repr_latex_(self):
return r"Hello " + super()._repr_latex_() + " world"
""")
app.run_cell("s = SymbolWithOverload('s')")

if int(ipython.__version__.split(".")[0]) < 1:
latex = app.user_ns['s']['text/latex']
else:
latex = app.user_ns['s'][0]['text/latex']
assert latex == r'Hello $\displaystyle s$ world'

0 comments on commit 5690f54

Please sign in to comment.