In [20]:
argsorted = lambda x: sorted(range(len(x)), key=x.__getitem__, reverse=True)

In [21]:
class Builder:
    def germinate(self, seed):
        pass
    
    def build(self, branch):
        passz

In [22]:
class Plucker:
    def prune(self, branches):
        pass
    
    def pluck(self, branches):
        pass
    
    def stop(self, branches):
        pass
    
    def _select_branches(self, branches, selected_idx):
        selected_idx = sorted(selected_idx, reverse=True)
        for i in range(len(branches))[::-1]:
            if i not in selected_idx:
                branches.pop(i)

In [23]:
class TreeSolver:
    def __init__(self, builder, plucker):
        self.builder = builder
        self.plucker = plucker
        
        self.branches = None
    
    def solve(self, seed):
        while True:
            self.branches = self.builder.germinate(seed) if self.branches is None else self.get_new_branches()
            self.plucker.prune(self.branches)
            if self.plucker.stop(self.branches):
                break
        self.plucker.pluck(self.branches)
        
        solution = self.branches[0]
        self.branches = None
        return solution
    
    def get_new_branches(self):
        new_branches = []
        for branch in self.branches:
            new_branches += self.builder.build(branch)
        return new_branches

In [60]:
class BeamPlucker(Plucker):
    def __init__(self, beam_size=1, max_length=1, leaf_test=lambda x: 0, to_probabilities=None):
        self.beam_size = beam_size
        self.max_length = max_length
        self.leaf_test = leaf_test
        
    def prune(self, branches):
        self.__beam_branches(branches, self.beam_size)
        
    def pluck(self, branches):
        self.__beam_branches(branches)
    
    def stop(self, branches):
        length_selected_idx = [i for i in range(len(branches)) 
                               if branches[i].size() >= self.max_length]
        leaf_selected_idx = [i for i in range(len(branches)) 
                             if self.leaf_test(branches[i].nodes[-1])]
        selected_idx = list(set(length_selected_idx).union(leaf_selected_idx))
        
        if len(selected_idx) == 0:
            return False
        
        self._select_branches(branches, selected_idx)
        return True
        
    def __beam_branches(self, branches, beam_size=1):
        preferences = [branch.preference for branch in branches]
        
        if to_probabilities is None:
            selected_idx = argsorted(preferences)[:self.beam_size]
        else:
            selected_idx = list(np.random.choice(np.arange(len(branches)), beam_size, False, to_probabilities(preferences)))
        self._select_branches(branches, selected_idx)

In [61]:
class TestBranch:
    def __init__(self, nodes=[], preference=None):
        self.nodes = nodes
        self.preference = preference
    
    def size(self):
        return len(self.nodes)

In [117]:
class TestBuilder(Builder):
    def germinate(self, seed):
        return [TestBranch([i], n) for i, n in enumerate(np.arange(10) + seed)]
    
    def build(self, branch):
        new_branches = []
        for _ in range(10):
            new_branches.append(TestBranch(branch.nodes + [branch.nodes[-1]*2], branch.preference + branch.nodes[-1]))
        return new_branches

In [11]:
from vaialgs import TreeSolver, BeamPlucker, Builder

In [63]:
def to_probabilities(preferences):
    probabilities = np.exp(np.array(preferences))
    probabilities /= probabilities.sum()
    return list(probabilities)

In [118]:
solver = TreeSolver(TestBuilder(), BeamPlucker(max_length=5, to_probabilities=to_probabilities))

In [165]:
solver.solve(10).nodes

[6, 12, 24, 48, 96]