In [71]:
from random import random
from functools import reduce
from collections import namedtuple
from queue import PriorityQueue, SimpleQueue, LifoQueue
import numpy as np
from tqdm.auto import tqdm

In [72]:

def getTrueElements(single_set):
    return np.sum(single_set)

def count_true_elements(frequencies):
    counts = {}
    
    for num_el, _  in frequencies:
        if num_el not in counts:
            counts[num_el] = 1
        else:
            counts[num_el] += 1
    return counts


PROBLEM_SIZE = 10
NUM_SETS = 20
SETS = tuple(
    np.array([random() < 0.2 for _ in range(PROBLEM_SIZE)])
    for _ in range(NUM_SETS)
)

frequencies = [(getTrueElements(set), 1) for set in SETS]
counts = count_true_elements(frequencies)


State = namedtuple('State', ['taken', 'not_taken'])

In [73]:
def goal_check(state):
    return np.all(reduce(
        np.logical_or,
        [SETS[i] for i in state.taken],
        np.array([False for _ in range(PROBLEM_SIZE)]),
    ))

def distance(state):
    return len(state.taken)

def heuristic1(state): # heuristic that return the estimate to the goal
    dicts=counts.copy()
    
 
    for el in state.taken: 
        num_elements=getTrueElements(SETS[el])
        dicts[num_elements] -= 1
    
    elements_to_take = PROBLEM_SIZE - np.sum(reduce(
        np.logical_or,
        [SETS[i] for i in state.taken],
        np.array([False for _ in range(PROBLEM_SIZE)]),
    ))
    
   
    max_num_elements = max(dicts.keys())
    
   
    h = 0 # our heuristic
    while elements_to_take>0:
        elements_to_take -= dicts.get(max_num_elements)
        dicts[max_num_elements] -= 1
       
        dicts = dict(filter(lambda elem: elem[1] > 0, dicts.items()))
        max_num_elements = max(dicts.keys())
        h+=1

    return h

In [74]:
assert goal_check(
    State(set(range(NUM_SETS)), set())
), "Probelm not solvable"

In [75]:
frontier = PriorityQueue()
state = State(set(), set(range(NUM_SETS)))
frontier.put((distance(state) + heuristic1(state), state))
counter = 0
_, current_state = frontier.get()
with tqdm(total=None) as pbar:
    while not goal_check(current_state):

        counter += 1
        for action in current_state[1]:
            new_state = State(
                current_state.taken ^ {action},
                current_state.not_taken ^ {action},
            )
            
            curr_tuple = (distance(new_state) + heuristic1(new_state), new_state)
            frontier.put(curr_tuple)
            pbar.update(1) 

        _, current_state = frontier.get()
    
print(
    f"Solved in {counter:,} steps ({len(current_state.taken)} tiles)"
)


75it [00:00, 72784.08it/s]

Solved in 4 steps (3 tiles)



