From 0c235a4cdf2cfbc93b0db5971642795fee9045e0 Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Wed, 4 Dec 2019 17:28:03 +0000 Subject: [PATCH] Combine nargs == 1 with nargs < 1 The rank = 1 code path is just a loop unrolling of the general code path. That's not useful, lets not do it. Also replaces 0 with S(0), and makes nargs == 0 work for the heck of it by using the initial argument to reduce. --- galgebra/lt.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/galgebra/lt.py b/galgebra/lt.py index b2e4babb..6a7d3b7d 100644 --- a/galgebra/lt.py +++ b/galgebra/lt.py @@ -752,25 +752,17 @@ def __init__(self, f, Ga, nargs=None, fct=False): self.f = None self.nargs = nargs Mlt.increment_slots(nargs, Ga) - if nargs > 1: # General tensor of raank > 1 - t_indexes = nargs * [Mlt.extact_basis_indexes(self.Ga)] - self.fvalue = 0 - for t_index, a_prod in zip(itertools.product(*t_indexes), - itertools.product(*self.Ga._mlt_pdiffs)): - if fct: # Tensor field - coef = Function(f+'_'+''.join(map(str, t_index)), real=True)(*self.Ga.coords) - else: # Constant Tensor - coef = symbols(f+'_'+''.join(map(str, t_index)), real=True) - coef *= reduce(lambda x, y: x*y, a_prod) - self.fvalue += coef - else: # General tensor of rank = 1 - self.fvalue = 0 - for t_index, a_prod in zip(Mlt.extact_basis_indexes(self.Ga), self.Ga._mlt_pdiffs[0]): - if fct: # Tensor field - coef = Function(f+'_'+''.join(map(str, t_index)), real=True)(*self.Ga.coords) - else: # Constant Tensor - coef = symbols(f+'_'+''.join(map(str, t_index)), real=True) - self.fvalue += coef * a_prod + t_indexes = Mlt.extact_basis_indexes(self.Ga) + self.fvalue = S(0) + for t_index, a_prod in zip(itertools.product(t_indexes, repeat=self.nargs), + itertools.product(*self.Ga._mlt_pdiffs)): + name = '{}_{}'.format(f, ''.join(map(str, t_index))) + if fct: # Tensor field + coef = Function(name, real=True)(*self.Ga.coords) + else: # Constant Tensor + coef = symbols(name, real=True) + self.fvalue += reduce(lambda x, y: x*y, a_prod, coef) + else: if isinstance(f, types.FunctionType): # Tensor defined by general multi-linear function args, _varargs, _kwargs, _defaults = inspect.getargspec(f)