Skip to content

Commit

Permalink
Merge pull request #25186 from Kishore96in/add_TensorHead_comm_to_arg…
Browse files Browse the repository at this point in the history
…s_onmaster

Fix caching of TensorHead
  • Loading branch information
Upabjojr committed Apr 8, 2024
2 parents 8e00716 + c88a6d8 commit 08b6c58
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions sympy/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from sympy.combinatorics.tensor_can import get_symmetric_group_sgs, \
bsgs_direct_product, canonicalize, riemann_bsgs
from sympy.core import Basic, Expr, sympify, Add, Mul, S
from sympy.core.cache import clear_cache
from sympy.core.containers import Tuple, Dict
from sympy.core.sorting import default_sort_key
from sympy.core.symbol import Symbol, symbols
Expand Down Expand Up @@ -902,6 +903,9 @@ def set_comm(self, i, j, c):
if c not in (0, 1, None):
raise ValueError('`c` can assume only the values 0, 1 or None')

i = sympify(i)
j = sympify(j)

if i not in self._comm_symbols2i:
n = len(self._comm)
self._comm.append({})
Expand All @@ -921,6 +925,14 @@ def set_comm(self, i, j, c):
self._comm[ni][nj] = c
self._comm[nj][ni] = c

"""
Cached sympy functions (e.g. expand) may have cached the results of
expressions involving tensors, but those results may not be valid after
changing the commutation properties. To stay on the safe side, we clear
the cache of all functions.
"""
clear_cache()

def set_comms(self, *args):
"""
Set the commutation group numbers ``c`` for symbols ``i, j``.
Expand Down Expand Up @@ -1805,8 +1817,7 @@ def __new__(cls, name, index_types, symmetry=None, comm=0):
else:
assert symmetry.rank == len(index_types)

obj = Basic.__new__(cls, name_symbol, Tuple(*index_types), symmetry)
obj.comm = TensorManager.comm_symbols2i(comm)
obj = Basic.__new__(cls, name_symbol, Tuple(*index_types), symmetry, sympify(comm))
return obj

@property
Expand All @@ -1821,6 +1832,10 @@ def index_types(self):
def symmetry(self):
return self.args[2]

@property
def comm(self):
return TensorManager.comm_symbols2i(self.args[3])

@property
def rank(self):
return len(self.index_types)
Expand Down Expand Up @@ -4262,11 +4277,13 @@ def __new__(cls, name, index_types=None, symmetry=None, comm=0, unordered_indic
raise NotImplementedError("Wild matching based on symmetry is not implemented.")

obj = Basic.__new__(cls, name_symbol, Tuple(*index_types), sympify(symmetry), sympify(comm), sympify(unordered_indices))
obj.comm = TensorManager.comm_symbols2i(comm)
obj.unordered_indices = unordered_indices

return obj

@property
def unordered_indices(self):
return self.args[4]

def __call__(self, *indices, **kwargs):
tensor = WildTensor(self, indices, **kwargs)
return tensor.doit()
Expand Down

0 comments on commit 08b6c58

Please sign in to comment.