In [4]:
from itertools import product

In [1]:
productions = [
  ['S', ['and(S,S)', 'A']], # allow conjunctions
  ['A', ['one(B)', 'two(B)', 'three(B)', # location symmetry, fix one feature (location agnostic, compare feature)
         'two(E)', 'three(E)', # location symetry, feature symmetry (location agnostic, feature agnostic)
         'two(F,E)', 'three(F,E)', # location symmetry, multiple feature symetry (location agnostic, feature agnostic)
         'one(B,C)', 'two(B,D)', # feature in location, specify feature (location given, feature compare)
         'two(E,D)'# give location, feature symmetry (location given, feature agnostic)
         ]],
  ['B', ['square', 'circle', 'triangle', 'low', 'medium', 'high', 'striped', 'plain']], # features
  ['C', ['(0)', '(1)', '(2)']], # locations
  ['D', ['(0,1)', '(0,2)', '(1,2)']], # diad locations
  ['E', ['same','unique']], # relation,
  ['F', ['1','2','3']], # number of features to compare
]

In [None]:
def generate_all_programs(productions, start_symbol='S', max_depth=3):
    """
    Generate all possible programs from a PCFG up to a specified depth.
    
    Args:
        productions: List of production rules in format [lhs, [rhs1, rhs2, ...]]
        start_symbol: Starting non-terminal symbol
        max_depth: Maximum depth of derivation tree
        
    Returns:
        set: All possible programs up to max_depth
    """
    def expand(symbol, depth):
        '''Take a symbol, and expand.'''
        # stop the search if max depth reached
        if depth > max_depth:
            return set()
            
        # If symbol is terminal (not found as LHS in any production)
        if not any(p[0] == symbol for p in productions):
            return {symbol}
            
        # Find all productions for this symbol
        results = set()
        for lhs, rhs_list in productions:
            if lhs == symbol:
                for rhs in rhs_list:
                    # Handle compound expressions like 'and(S,S)' or 'one(B)'
                    if '(' in rhs:
                        # Split into function and arguments
                        func, args = rhs.split('(')
                        args = args[:-1]  # remove closing parenthesis
                        arg_symbols = args.split(',')
                        
                        # Recursively expand each argument
                        arg_expansions = [expand(arg.strip(), depth + 1) for arg in arg_symbols]
                        
                        # Generate all combinations of argument expansions
                        for arg_combo in product(*arg_expansions):
                            results.add(f"{func}({','.join(arg_combo)})")
                    else:
                        # Direct expansion for simple rules
                        expansions = expand(rhs, depth + 1)
                        results.update(expansions)
                        
        return results

    return expand(start_symbol, 0)

# Example usage:
programs = generate_all_programs(productions, max_depth=5)

{'and(and(two(same,(0,2)),two(low)),and(two(same,(0,2)),three(low)))',
 'and(and(three(unique),one(high)),and(one(circle,(2)),two(square)))',
 'and(and(two(1,same),two(high,(1,2))),and(one(high,(2)),one(low,(1))))',
 'and(and(one(medium,(0)),one(triangle)),and(one(plain,(2)),two(medium,(0,2))))',
 'and(and(one(triangle),two(2,unique)),and(two(high),one(low)))',
 'and(and(two(circle),one(triangle,(1))),and(two(1,same),one(high,(2))))',
 'and(and(one(medium),two(square)),and(three(medium),two(low,(0,1))))',
 'and(and(one(plain,(2)),one(triangle)),and(two(2,unique),one(high)))',
 'and(and(two(square,(0,1)),two(triangle,(1,2))),and(three(1,same),two(medium,(1,2))))',
 'and(and(one(circle,(0)),one(striped,(2))),and(two(triangle,(0,2)),two(same,(1,2))))',
 'and(and(two(unique,(0,2)),two(unique)),and(three(circle),two(striped,(0,2))))',
 'and(and(one(high,(0)),two(same,(1,2))),and(two(square,(0,1)),one(high)))',
 'and(and(three(same),two(2,unique)),and(two(plain,(0,2)),three(1,same)))',
 'and