Skip to content

Commit

Permalink
Combine nargs == 1 with nargs < 1 (#136)
Browse files Browse the repository at this point in the history
The rank = 1 code path is just a loop unrolling of the general code path. That's not useful, lets not do it.

Also replaces 0 with S(0), and makes nargs == 0 work for the heck of it by using the initial argument to reduce.
  • Loading branch information
eric-wieser committed May 15, 2020
1 parent 77db548 commit ae597d4
Showing 1 changed file with 11 additions and 19 deletions.
30 changes: 11 additions & 19 deletions galgebra/lt.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,25 +752,17 @@ def __init__(self, f, Ga, nargs=None, fct=False):
self.f = None
self.nargs = nargs
Mlt.increment_slots(nargs, Ga)
if nargs > 1: # General tensor of raank > 1
t_indexes = nargs * [Mlt.extact_basis_indexes(self.Ga)]
self.fvalue = 0
for t_index, a_prod in zip(itertools.product(*t_indexes),
itertools.product(*self.Ga._mlt_pdiffs)):
if fct: # Tensor field
coef = Function(f+'_'+''.join(map(str, t_index)), real=True)(*self.Ga.coords)
else: # Constant Tensor
coef = symbols(f+'_'+''.join(map(str, t_index)), real=True)
coef *= reduce(lambda x, y: x*y, a_prod)
self.fvalue += coef
else: # General tensor of rank = 1
self.fvalue = 0
for t_index, a_prod in zip(Mlt.extact_basis_indexes(self.Ga), self.Ga._mlt_pdiffs[0]):
if fct: # Tensor field
coef = Function(f+'_'+''.join(map(str, t_index)), real=True)(*self.Ga.coords)
else: # Constant Tensor
coef = symbols(f+'_'+''.join(map(str, t_index)), real=True)
self.fvalue += coef * a_prod
t_indexes = Mlt.extact_basis_indexes(self.Ga)
self.fvalue = S(0)
for t_index, a_prod in zip(itertools.product(t_indexes, repeat=self.nargs),
itertools.product(*self.Ga._mlt_pdiffs)):
name = '{}_{}'.format(f, ''.join(map(str, t_index)))
if fct: # Tensor field
coef = Function(name, real=True)(*self.Ga.coords)
else: # Constant Tensor
coef = symbols(name, real=True)
self.fvalue += reduce(lambda x, y: x*y, a_prod, coef)

else:
if isinstance(f, types.FunctionType): # Tensor defined by general multi-linear function
args, _varargs, _kwargs, _defaults = inspect.getargspec(f)
Expand Down

0 comments on commit ae597d4

Please sign in to comment.