Skip to content

Commit

Permalink
feat(polys): Add DomainMatrix based on poly elements
Browse files Browse the repository at this point in the history
Try using sfield(..., extension=True) to construct the domain for
DomainMatrix.
  • Loading branch information
oscarbenjamin committed Jun 5, 2020
1 parent 28abe43 commit f0b7996
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
8 changes: 5 additions & 3 deletions sympy/polys/polymatrix.py
Expand Up @@ -8,6 +8,7 @@

from sympy.matrices.dense import MutableDenseMatrix

from sympy.polys.fields import sfield
from sympy.polys.polytools import Poly
from sympy.polys.domains import EX, QQ, ZZ, AlgebraicField, PolynomialRing

Expand Down Expand Up @@ -114,9 +115,10 @@ def from_list_sympy(cls, rows):
assert len(rows) == nrows
assert all(len(row) == ncols for row in rows)

rows_sympy = [[_sympify(item) for item in row] for row in rows]
domain = cls.get_domain([item for row in rows_sympy for item in row])
domain_rows = [[domain.from_sympy(item) for item in row] for row in rows_sympy]
items_sympy = [_sympify(item) for row in rows for item in row]
K, items_K = sfield(items_sympy, extension=True)
domain_rows = [[items_K[ncols*r + c] for c in range(ncols)] for r in range(nrows)]
domain = K.to_domain()

return DomainMatrix(domain_rows, (nrows, ncols), domain)

Expand Down
8 changes: 5 additions & 3 deletions sympy/solvers/solvers.py
Expand Up @@ -45,8 +45,9 @@
from sympy.matrices.common import NonInvertibleMatrixError
from sympy.matrices import Matrix, zeros
from sympy.polys import roots, cancel, factor, Poly, degree
from sympy.polys.polyerrors import GeneratorsNeeded, PolynomialError
from sympy.polys.polymatrix import DomainMatrixDomainError, linsolve_domain
from sympy.polys.polyerrors import (GeneratorsNeeded, PolynomialError,
NotInvertible)
from sympy.polys.polymatrix import linsolve_domain
from sympy.functions.elementary.piecewise import piecewise_fold, Piecewise

from sympy.utilities.lambdify import lambdify
Expand Down Expand Up @@ -2238,7 +2239,8 @@ def solve_linear_system(system, *symbols, **flags):
# Try to use DomainMatrix
try:
sol = linsolve_domain(system, symbols)
except DomainMatrixDomainError:
except NotInvertible:
# https://github.com/sympy/sympy/issues/18874
pass
else:
if sol is not None:
Expand Down
10 changes: 5 additions & 5 deletions sympy/solvers/tests/test_recurr.py
@@ -1,5 +1,5 @@
from sympy import Eq, factorial, Function, Lambda, rf, S, sqrt, symbols, I, \
expand_func, binomial, gamma, Rational, Symbol, cos, sin, Abs
from sympy import Eq, factor, factorial, Function, Lambda, rf, S, sqrt, symbols, I, \
expand, binomial, Rational, Symbol, cos, sin, Abs
from sympy.solvers.recurr import rsolve, rsolve_hyper, rsolve_poly, rsolve_ratio
from sympy.testing.pytest import raises, slow
from sympy.abc import a, b
Expand Down Expand Up @@ -176,9 +176,9 @@ def test_rsolve():

f = (-16*n**2 + 32*n - 12)*y(n - 1) + (4*n**2 - 12*n + 9)*y(n)

assert expand_func(rsolve(f, y(n), \
{y(1): binomial(2*n + 1, 3)}).rewrite(gamma)).simplify() == \
2**(2*n)*n*(2*n - 1)*(4*n**2 - 1)/12
yn = rsolve(f, y(n), {y(1): binomial(2*n + 1, 3)})
sol = 2**(2*n)*n*(2*n - 1)**2*(2*n + 1)/12
assert factor(expand(yn, func=True)) == sol

assert (rsolve(y(n) + a*(y(n + 1) + y(n - 1))/2, y(n)) -
(C0*((sqrt(-a**2 + 1) - 1)/a)**n +
Expand Down

0 comments on commit f0b7996

Please sign in to comment.