Skip to content

Commit

Permalink
Fix the latex representation of a 1-tuple
Browse files Browse the repository at this point in the history
Without this change, a 1-item tuple looks like a set of parentheses, which is misleading.
This adds the trailing comma (or as appropriate, semicolon) that is typical of 1-tuples in python.
  • Loading branch information
eric-wieser committed May 19, 2020
1 parent 3219093 commit aba12c8
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
4 changes: 2 additions & 2 deletions sympy/interactive/tests/test_ipythonprinting.py
Expand Up @@ -134,7 +134,7 @@ def test_builtin_containers():
([ ],)
[2] \
"""
assert app.user_ns['c']['text/latex'] == '$\\displaystyle \\left( \\left[\\begin{matrix}1\\\\2\\end{matrix}\\right]\\right)$'
assert app.user_ns['c']['text/latex'] == '$\\displaystyle \\left( \\left[\\begin{matrix}1\\\\2\\end{matrix}\\right],\\right)$'
else:
assert app.user_ns['a'][0]['text/plain'] == '(True, False)'
assert 'text/latex' not in app.user_ns['a'][0]
Expand All @@ -146,7 +146,7 @@ def test_builtin_containers():
([ ],)
[2] \
"""
assert app.user_ns['c'][0]['text/latex'] == '$\\displaystyle \\left( \\left[\\begin{matrix}1\\\\2\\end{matrix}\\right]\\right)$'
assert app.user_ns['c'][0]['text/latex'] == '$\\displaystyle \\left( \\left[\\begin{matrix}1\\\\2\\end{matrix}\\right],\\right)$'

def test_matplotlib_bad_latex():
# Initialize and setup IPython session
Expand Down
21 changes: 15 additions & 6 deletions sympy/printing/latex.py
Expand Up @@ -209,6 +209,10 @@ def __init__(self, settings=None):
def _add_parens(self, s):
return r"\left({}\right)".format(s)

# TODO: merge this with the above, which requires a lot of test changes
def _add_parens_lspace(self, s):
return r"\left( {}\right)".format(s)

def parenthesize(self, item, level, is_neg=False, strict=False):
prec_val = precedence_traditional(item)
if is_neg and strict:
Expand Down Expand Up @@ -1865,15 +1869,20 @@ def _print_frac(self, expr, exp=None):
self._print(expr.args[0]), self._print(exp))

def _print_tuple(self, expr):
if self._settings['decimal_separator'] =='comma':
return r"\left( %s\right)" % \
r"; \ ".join([self._print(i) for i in expr])
elif self._settings['decimal_separator'] =='period':
return r"\left( %s\right)" % \
r", \ ".join([self._print(i) for i in expr])
if self._settings['decimal_separator'] == 'comma':
sep = ";"
elif self._settings['decimal_separator'] == 'period':
sep = ","
else:
raise ValueError('Unknown Decimal Separator')

if len(expr) == 1:
# 1-tuple needs a trailing separator
return self._add_parens_lspace(self._print(expr[0]) + sep)
else:
return self._add_parens_lspace(
(sep + r" \ ").join([self._print(i) for i in expr]))

def _print_TensorProduct(self, expr):
elements = [self._print(a) for a in expr.args]
return r' \otimes '.join(elements)
Expand Down
3 changes: 3 additions & 0 deletions sympy/printing/tests/test_latex.py
Expand Up @@ -2567,16 +2567,19 @@ def test_latex_decimal_separator():
assert(latex([1, 2.3, 4.5], decimal_separator='comma') == r'\left[ 1; \ 2{,}3; \ 4{,}5\right]')
assert(latex(FiniteSet(1, 2.3, 4.5), decimal_separator='comma') == r'\left\{1; 2{,}3; 4{,}5\right\}')
assert(latex((1, 2.3, 4.6), decimal_separator = 'comma') == r'\left( 1; \ 2{,}3; \ 4{,}6\right)')
assert(latex((1,), decimal_separator='comma') == r'\left( 1;\right)')

# period decimal_separator
assert(latex([1, 2.3, 4.5], decimal_separator='period') == r'\left[ 1, \ 2.3, \ 4.5\right]' )
assert(latex(FiniteSet(1, 2.3, 4.5), decimal_separator='period') == r'\left\{1, 2.3, 4.5\right\}')
assert(latex((1, 2.3, 4.6), decimal_separator = 'period') == r'\left( 1, \ 2.3, \ 4.6\right)')
assert(latex((1,), decimal_separator='period') == r'\left( 1,\right)')

# default decimal_separator
assert(latex([1, 2.3, 4.5]) == r'\left[ 1, \ 2.3, \ 4.5\right]')
assert(latex(FiniteSet(1, 2.3, 4.5)) == r'\left\{1, 2.3, 4.5\right\}')
assert(latex((1, 2.3, 4.6)) == r'\left( 1, \ 2.3, \ 4.6\right)')
assert(latex((1,)) == r'\left( 1,\right)')

assert(latex(Mul(3.4,5.3), decimal_separator = 'comma') ==r'18{,}02')
assert(latex(3.4*5.3, decimal_separator = 'comma')==r'18{,}02')
Expand Down

0 comments on commit aba12c8

Please sign in to comment.