diff --git a/src/sage/combinat/boltzmann_sampling/generator.pyx b/src/sage/combinat/boltzmann_sampling/generator.pyx index e5ffbb707ec..61a1f799f58 100644 --- a/src/sage/combinat/boltzmann_sampling/generator.pyx +++ b/src/sage/combinat/boltzmann_sampling/generator.pyx @@ -80,7 +80,7 @@ from functools import reduce from sage.libs.gmp.random cimport gmp_randinit_set, gmp_randinit_default from sage.libs.gmp.types cimport gmp_randstate_t -from sage.misc.randstate cimport randstate, current_randstate +from sage.misc.randstate cimport randstate, current_randstate, SAGE_RAND_MAX from .grammar import Atom, Product, Ref, Union from .oracle import SimpleOracle, find_singularity @@ -231,6 +231,70 @@ cdef c_generate(first_rule, rules, builders, randstate rstate): obj, = generated return obj +cdef int rand_int(int bound, randstate rstate): + r = rstate.c_random() + v = r % bound + if r - v > SAGE_RAND_MAX - bound + 1: + return rand_int(bound, rstate) + else: + return v + +cdef select(int r, sizes): + r = r + for i in range(len(sizes)): + r -= sizes[i] + if r < 0: + sizes[i] -= 1 + return i + +cdef shuffle(bag, sizes, total_size, rstate): + total_size = total_size + bags = [[] for __ in sizes] + for total_size in range(total_size, 0, -1): + r = rand_int(total_size, rstate) + i = select(r, sizes) + label = bag.pop() + bags[i].append(label) + return bags + +cdef tree_size(tree): + __, __, size = tree + return size + +cdef labelling(tree, builders, randstate rstate): + bag = list(range(tree_size(tree))) + generated = [] + todo = [(tree, bag)] + + while todo: + (type, content, size), bag = todo.pop() + if type == REF: + rule_id, children = content + todo.append((FUNCTION, rule_id, size)) + todo.append(children) + elif type == ATOM: + name = content + assert len(bag) == size + generated.append((name, bag)) + elif type == PRODUCT: + sizes = [tree_size(t) for t in content] + bags = shuffle(bag, sizes, size, rstate) + todo.append((TUPLE, len(content), size)) + for i in range(len(content)): + todo.append((content[i], bags[i])) + elif type == TUPLE: + nargs = content + t = tuple(generated[-nargs:]) + generated = generated[:-nargs] + generated.append(t) + elif type == FUNCTION: + func = builders[content] + x = generated.pop() + generated.append(func(x)) + + tree, = generated + return tree + cdef c_gen(first_rule, rules, int size_min, int size_max, int max_retry, builders): """Search for a tree in a given size window. Wrapper around c_simulate and c_generate.""" @@ -329,6 +393,39 @@ cdef make_default_builder(rule): subbuilders = [make_default_builder(component) for component in rule.args] return ProductBuilder(subbuilders) +# --- +# Generic tree builders with size annotations +# --- + +cdef size_ref_builder(id): + def build(x): + __, __, size = x + return (REF, (id, x), size) + return build + +cdef size_atom_builder(x): + name, size = x + return (ATOM, name, size) + +cdef size_product_builder(builders): + def build(terms): + t = tuple(builders[i](terms[i]) for i in range(len(terms))) + size = sum([s for __, __, s in t]) + return (PRODUCT, t, size) + return build + +cdef size_builder(name_to_id, rule): + if isinstance(rule, Ref): + return size_ref_builder(name_to_id[rule.name]) + elif isinstance(rule, Atom): + return size_atom_builder + elif isinstance(rule, Union): + subbuilders = [size_builder(name_to_id, component) for component in rule.args] + return UnionBuilder(*subbuilders) + elif isinstance(rule, Product): + subbuilders = [size_builder(name_to_id, component) for component in rule.args] + return size_product_builder(subbuilders) + # --- # High level interface # ---