In [16]:
from functools import reduce
from operator import add
from numpy import inf
import numpy as np

class CompPair(object):
    def __init__(self, torder, xorder):
        self.torder = torder
        self.xorder = xorder
        
    def __repr__(self):
        return f'({self.torder}, {self.xorder})'
    
    def __eq__(self, other):
        return (self.torder==other.torder and self.xorder==other.xorder)
    
    def __ge__(self, other):
        return (self.torder>other.torder or (self.torder==other.torder and self.xorder>=other.xorder))
    
    def __hash__(self):
        return hash((self.torder, self.xorder))
    
class CompList(object):
    def __init__(self, in_list):
        self.in_list = in_list
        
    def __repr__(self):
        return self.in_list.__repr__()
    
    def __eq__(self, other):
        return self.in_list==other.in_list
    
    def __ge__(self, other):
        for x, y in zip(self.in_list, other.in_list):
            if x>y:
                return True
            elif x<y:
                return False
        return True

class DerivativeOrder(CompPair):
    pass
    
class Observable(object):
  def __init__(self, string, rank):
    self.string = string
    self.rank = rank
    
  def __repr__(self):
    return self.string

  def __str__(self):
    return self.string

  def __equals__(self, other):
        return (self.string==other.string and self.rank==other.rank)
    
  def __hash__(self):
        return hash((self.string, self.rank))

class LibraryPrimitive(object):
    def __init__(self, dorder, observable): 
        self.simple = True
        self.dorder = dorder
        self.observable = observable
        self.rank = dorder.xorder + observable.rank
        
    def __repr__(self):
        torder = self.dorder.torder
        xorder = self.dorder.xorder
        if torder==0:
            tstring = ""
        elif torder==1:
            tstring = "dt "
        else:
            tstring = f"dt^{torder} "
        if xorder==0:
            xstring = ""
        elif xorder==1:
            xstring = "dx "
        else:
            xstring = f"dx^{xorder} "
        return f'{tstring}{xstring}{self.observable}'
    
    def __str__(self):
        return self.__repr__()
    
    def __eq__(self, other):
        return (self.dorder==other.dorder and self.observable==other.observable)
    
    def __hash__(self):
        return hash((self.dorder, self.observable))
    
class IndexedPrimitive(LibraryPrimitive):
    dim_to_let = {0: 'x', 1: 'y', 2: 'z'}
    
    def __init__(self, prim, space_orders=None, obs_dim=None, newords=None): #parity = 1
        self.simple = True
        self.dorder = prim.dorder
        self.observable = prim.observable
        self.rank = prim.rank
        if newords is None: # normal constructor
            self.dimorders = space_orders+[self.dorder.torder]
            self.obs_dim = obs_dim
        else: # modifying constructor
            self.dimorders = newords
            self.obs_dim = prim.obs_dim
        self.ndims = len(self.dimorders)
        self.nderivs = sum(self.dimorders)
        #self.parity = parity
        
    def __repr__(self):
        torder = self.dimorders[-1]
        xstring = ""
        for i in range(len(self.dimorders)-1):
            let = self.dim_to_let[i]
            xorder = self.dimorders[i]
            if xorder==0:
                xstring += ""
            elif xorder==1:
                xstring += f"d{let} "
            else:
                xstring += f"d{let}^{xorder} "
        if torder==0:
            tstring = ""
        elif torder==1:
            tstring = "dt "
        else:
            tstring = f"dt^{torder} "
        if self.obs_dim is None:
            dimstring = ""
        else:
            let = self.dim_to_let[self.obs_dim]
            dimstring = f"_{let}"
        #if self.parity == -1:
        #    pstring = ""
        #else:
        #    pstring = "-"
        #return f'{pstring}{tstring}{xstring}{self.observable}{dimstring}'
        return f'{tstring}{xstring}{self.observable}{dimstring}'
    
    def __eq__(self, other):
        return (self.dimorders==other.dimorders and self.observable==other.observable \
                and self.obs_dim==other.obs_dim)
    
    def __mul__(self, other):
        if isinstance(other, IndexedTerm):
            return IndexedTerm(observable_list=[self]+other.observable_list)
        else:
            return IndexedTerm(observable_list=[self, other])
    
    def succeeds(self, other, dim):
        copyorders = self.dimorders.copy()
        copyorders[dim] += 1
        return copyorders==other.dimorders and self.observable==other.observable and self.obs_dim==other.obs_dim
    
    def diff(self, dim):
        newords = self.dimorders.copy()
        newords[dim] += 1
        return IndexedPrimitive(self, newords=newords)
    
class LibraryTensor(object):
    def __init__(self, observables):
        if isinstance(observables, LibraryPrimitive):  # constructor for library terms consisting of an observable with some derivatives
            self.simple = True
            self.observable_list = [observables]
            self.rank = observables.rank
        else:  # constructor for library terms consisting of a product
            self.simple = False
            self.observable_list = observables
            self.rank = sum([obs.rank for obs in observables])
        
    def __mul__(self, other):
        if isinstance(other, LibraryTensor):
            return LibraryTensor(self.observable_list + other.observable_list)
        elif other==1:
            return self
        else:
            raise ValueError(f"Cannot multiply {self}, {other}")
    
    def __rmul__(self, other):
        return __mul__(self, other)
    
    def __repr__(self):
        repstr = [str(obs)+' * ' for obs in self.observable_list]
        return reduce(add, repstr)[:-3]
    
    def __str__(self):
        return self.__repr__()
    
class LibraryTerm(object):
    num_to_let = {0: 'i', 1: 'j', 2: 'k', 3: 'l', 4: 'm', 5: 'n', 6: 'p'}
    
    def __init__(self, libtensor, labels):
        self.observable_list = libtensor.observable_list
        self.rank = (libtensor.rank % 2)
        self.labels = labels
        self.index_list = [list() for i in range(len(self.observable_list)*2)]
        for key in labels.keys():
            letter = self.num_to_let[key]
            for a in labels[key]:
                self.index_list[a].append(letter)
                
    def __add__(self, other):
        if isinstance(other, LibraryTerm):
            return TermSum([self, other])
        else:
            return TermSum([self] + other.term_list)
        
    def __repr__(self):
        repstr = [label_repr(obs, ind1, ind2)+' * ' for (obs, ind1, ind2) in zip(self.observable_list, self.index_list[0::2], self.index_list[1::2])]
        return reduce(add, repstr)[:-3]
    
class IndexedTerm(object):
    def __init__(self, libterm=None, space_orders=None, obs_dims=None, observable_list=None): #indterm=None, neworders=None,
        if observable_list is None: # normal "from scratch" constructor
            self.rank = libterm.rank
            #self.obs_dims = obs_dims
            nterms = len(libterm.observable_list)
            self.observable_list = libterm.observable_list.copy()
            for i, obs, sp_ord, obs_dim in zip(range(nterms), libterm.observable_list, space_orders, obs_dims):
                self.observable_list[i] = IndexedPrimitive(obs, sp_ord, obs_dim)
            self.ndims = len(space_orders[0])+1
            self.nderivs = np.max([p.nderivs for p in self.observable_list])
        #elif indterm is not None: # integrate by parts constructor
        #    self.rank = indterm.rank
        #    #self.obs_dims = indterm.obs_dims
        #    self.observable_list = indterm.observable_list.copy()
        #    for prim, ords, obs_dim in zip(indterm.observable_list, neworders, obs_dims):
        #        self.observable_list[i] = IndexedPrimitive(obs, obs_dim=obs_dim, newords=newords)
        #    self.ndims = len(neworders[0])
        else: # direct constructor from observable list
            #print(observable_list)
            if len(observable_list)>0: # if term is not simply equal to 1
                self.rank = observable_list[0].rank
                self.ndims = observable_list[0].ndims
                self.observable_list = observable_list
                self.nderivs = np.max([p.nderivs for p in self.observable_list])
            else:
                self.observable_list = []
                self.ndims = 0
                self.nderivs = 0
            
    def __repr__(self):
        repstr = [str(obs)+' * ' for obs in self.observable_list]
        return reduce(add, repstr)[:-3]
    
    def __mul__(self, other):
        if isinstance(other, IndexedTerm):
            return IndexedTerm(observable_list=self.observable_list+other.observable_list)
        else:
            return IndexedTerm(observable_list=self.observable_list+[other])
    
    def drop(self, obs):
        #print(self.observable_list)
        obs_list_copy = self.observable_list.copy()
        if len(obs_list_copy)>1:
            obs_list_copy.remove(obs)
        else:
            obs_list_copy = []
        return IndexedTerm(observable_list=obs_list_copy)
    
    def diff(self, dim):
        for i, obs in enumerate(self.observable_list):
            yield obs.diff(dim)*self.drop(obs)
    
def label_repr(prim, ind1, ind2):
    torder = prim.dorder.torder
    xorder = prim.dorder.xorder
    obs = prim.observable
    if torder==0:
        tstring = ""
    elif torder==1:
        tstring = "dt "
    else:
        tstring = f"dt^{torder} "
    if xorder==0:
        xstring = ""
    else:
        ind1 = compress(ind1)
        xlist = [f"d{letter} " for letter in ind1]
        xstring = reduce(add, xlist)
    if obs.rank == 1:
        obstring = obs.string+"_"+ind2[0]
    else:
        obstring = obs.string
    return tstring+xstring+obstring

def compress(labels):
    copy = []
    skip = False
    for i in range(len(labels)):
        if i<len(labels)-1 and labels[i]==labels[i+1]:
            copy.append(labels[i]+'^2')
            skip = True
        elif not skip:
            copy.append(labels[i])
        else:
            skip = False
    return copy

rho = Observable('rho', 0)
v = Observable('v', 1)
    
#def raw_library_tensors(a, b, c, d, max_order=DerivativeOrder(inf, inf)):
#    #print(a, b, c, d, max_order)
#    if a==0 and b==1:
#        do = DerivativeOrder(c, d)
#        if max_order>=do:
#            prim = LibraryPrimitive(do, v)
#            yield LibraryTensor(prim)
#        return
#    if a==1 and b==0:
#        do = DerivativeOrder(c, d)
#        if max_order>=do:
#            prim = LibraryPrimitive(do, rho)
#            yield LibraryTensor(prim)
#        return
#    for i in range(c+1):
#        for j in range(d+1):
#            if max_order>=DerivativeOrder(i, j):
#                do = DerivativeOrder(i, j) 
#                if a>0:
#                    prim = LibraryPrimitive(do, rho) 
#                    term1 = LibraryTensor(prim)
#                    if a==1: # reset max_order since we are going to b terms
#                        do = DerivativeOrder(inf, inf)
#                    for term2 in raw_library_tensors(a-1, b, c-i, d-j, max_order=do):
#                        yield term1*term2
#                else:
#                    prim = LibraryPrimitive(do, v)
#                    term1 = LibraryTensor(prim)
#                    for term2 in raw_library_tensors(a, b-1, c-i, d-j, max_order=do):
#                        yield term1*term2
                        
def raw_library_tensors(observables, obs_orders, nt, nx, max_order=DerivativeOrder(inf, inf), zeroidx=0):
    #print(obs_orders, nt, nx, max_order)
    while obs_orders[zeroidx]==0:
        zeroidx += 1
        if zeroidx==len(observables):
            return
    if sum(obs_orders)==1:
        i = obs_orders.index(1)
        do = DerivativeOrder(nt, nx)
        if max_order>=do:
            prim = LibraryPrimitive(do, observables[i])
            yield LibraryTensor(prim)
        return
    for i in range(nt+1):
        for j in range(nx+1):
            if max_order>=DerivativeOrder(i, j):
                do = DerivativeOrder(i, j) 
                prim = LibraryPrimitive(do, observables[zeroidx]) 
                term1 = LibraryTensor(prim)
                new_orders = list(obs_orders)
                new_orders[zeroidx] -= 1
                if obs_orders[zeroidx]==1: # reset max_order since we are going to next terms
                    do = DerivativeOrder(inf, inf)
                for term2 in raw_library_tensors(observables, new_orders, nt-i, nx-j, max_order=do):
                    yield term1*term2
                        
# make a dictionary of how paired indices are placed
def place_pairs(*rank_array, min_ind2=0, curr_ind=1, start=0, answer_dict=dict()):
    while rank_array[start]<=0:
        start += 1
        min_ind2 = 0
        if start>=len(rank_array):
            yield answer_dict
            return
    ind1 = start
    for ind2 in range(min_ind2, len(rank_array)):
        if (ind1==ind2 and rank_array[ind1]==1) or rank_array[ind2]==0:
            continue
        min_ind2 = ind2
        dict1 = answer_dict.copy()
        dict1.update({curr_ind: (ind1, ind2)})
        copy_array = np.array(rank_array)
        copy_array[ind1] -= 1
        copy_array[ind2] -= 1
        for new_dict in place_pairs(*copy_array, min_ind2=min_ind2, curr_ind=curr_ind+1, start=start, answer_dict=dict1):
            yield new_dict
            
def place_indices(*rank_array):
    # only paired indices allowed
    if sum(rank_array) % 2 == 0:
        for new_dict in place_pairs(*rank_array):
            yield new_dict
    # one single index
    else:
        for single_ind in range(len(rank_array)):
            if rank_array[single_ind]>0:
                copy_array = np.array(rank_array)
                copy_array[single_ind] -= 1
                for new_dict in place_pairs(*copy_array, answer_dict={0: [single_ind]}):
                    yield new_dict

def list_labels(tensor):
    rank_array = []
    for term in tensor.observable_list:
        rank_array.append(term.dorder.xorder)
        rank_array.append(term.observable.rank)
    return [output_dict for output_dict in place_indices(*rank_array) if test_valid_label(output_dict, tensor.observable_list)]

# the lexicographic ordering rule can fail for N>=5.
# e.g. dj v_i * dk v_j * v_k = dj v_k * dk v_i * v_j
def test_valid_label(output_dict, obs_list): # check if index labeling is valid (i.e. in non-decreasing order among identical terms)
    if len(output_dict.keys())<2: # not enough indices for something to be invalid
        return True
    # this can be implemented more efficiently, but the cost is negligible for reasonably small N
    bins = [] # bin observations according to equality
    for obs in obs_list:
        found_match = False
        for bi in bins:
            if bi is not None and obs==bi[0]:
                bi.append(obs)
                found_match = True
        if not found_match:
            bins.append([obs])
    if len(bins)==len(obs_list):
        return True # no repeated values
    # else need to check more carefully
    index_list = [list() for i in range(len(obs_list)*2)]
    for key in output_dict.keys():
        for a in output_dict[key]:
            index_list[a].append(key)
    #index_dict = dict()
    #for i in range(len(obs_list)):
    #    index_dict[obs_list[i]] = (index_list[2*i], index_list[2*i+1])
    for i in range(len(obs_list)):
        for j in range(i+1, len(obs_list)):
            if obs_list[i] == obs_list[j]:
                clist1 = CompList(index_list[2*i]+index_list[2*i+1])
                clist2 = CompList(index_list[2*j]+index_list[2*j+1])
                if not clist1<=clist2: # if (lexicographic) order decreases
                    return False
    
    # this is only guaranteed to work if there is only 1 index per identical term
    #for key1 in output_dict.keys():
    #    for key2 in output_dict.keys():
    #        if key1<key2: # we only test each pair once
    #            for val1 in output_dict[key1]:
    #                for val2 in output_dict[key2]:
    #                    if obs_list[val1//2]==obs_list[val2//2] and val1>val2:
    #                        return False # violation: decreasing label
    
    # if we got this far, the labeling is valid
    return True

#def generate_terms_to(order, max_observables=999):
#    libterms = list()
#    N = order # max number of "blocks" to include
#    for a in range(min(N, max_observables)+1): # number of rhos
#        for b in range(min(N, max_observables)-a+1): # number of vs
#            for c in range(N-b-a+1): # number of dts
#                for d in range(N-c-b-a+1): # number of dxs
#                    if a+b>0: # not a valid term of no rho or v
#                        for tensor in raw_library_tensors(a, b, c, d):
#                            for label in list_labels(tensor):
#                                libterms.append(LibraryTerm(tensor, label))
#    return libterms

def generate_terms_to(order, observables=[rho, v], max_observables=999):
    libterms = list()
    N = order # max number of "blocks" to include
    K = len(observables)
    part = partition(N, K+2) # K observables + 2 derivative dimensions
    #maxs = [max_observables]*K+[np.inf]*2
    # not a valid term if no observables or max exceeded
    for part in partition(N, K+2):
        #print(part)
        if sum(part[:K])>0 and sum(part[:K])<max_observables:
            nt, nx = part[-2:]
            obs_orders = part[:-2]
            for tensor in raw_library_tensors(observables, obs_orders, nt, nx):
                for label in list_labels(tensor):
                    libterms.append(LibraryTerm(tensor, label))
    return libterms

def partition(n,k):
    '''n is the integer to partition, k is the length of partitions, l is the min partition element size'''
    if k < 1:
        return
    if k == 1:
        for i in range(n+1):
            yield (i,)
        return
    for i in range(n+1):
        for result in partition(n-i,k-1):
            yield (i,)+result

class TermSum(object):
    def __init__(self, term_list):
        self.term_list = term_list
        self.rank = term_list[0].rank
      
    def __add__(self, other):
        if isinstance(other, TermSum):
            return TermSum(self.term_list + other.term_list)
        else:
            return TermSum(self.term_list + [other])
    
    def __repr__(self):
        repstr = [str(term)+' + ' for term in self.term_list]
        return reduce(add, repstr)[:-3]