In [None]:
import itertools
import heapq

# Graph Classes #

In [None]:
def power_set(lst):
        subsets = [list(itertools.combinations(lst, i)) for i in range(2, len(lst) + 1)]
        return set(itertools.chain(*subsets))
"""
Generalizability:
Note about Classic Grids:
Classic grids are 9x9
In order to extend to all square sudokus, we must have a square grid and the dimensions of that
grid must be a square number by a square number

To extend to any dimensions, we must have an n dimensional grid, and the dimensions of that grid must by (n_1 x n_2 x n_3 ... n_n)
This allows us to do 8x8x8, 27x27x27, 256x256x256... cubic puzzles, 16x16x16x16, 81x81x81x81, 1024x1024x1024x1024 hypercube puzzles, etc.




Grid Construction Algorithm:
1. Create and store all containers; all containers are empty
    a. Create a pointer back to "parent" Grid
    b. Create a list of empty lists; these represent mappings
    
2. Create and store all entries
    a. Create pointers to all "parent" Containers
    b. Assign Values; Depending on the value, remove the mapping
    c. Add Entry to containers
3. Update all solved square mappings (once they are solved)
4. Update all container mappings
5. Update all non-square mappings

"""

class Classic_Grid:
    def __init__(self, inp_squares): # let this be input path to a json file later
        self.dims = 9 # See note about generalizability

        # create containers
        self.containers = {} # we can actually just map each container to a number and make a tuple again; for conceptual reasons, helps to do strings
        self.containers["rows"] = tuple(Classic_Container(self) for i in range(self.dims))
        self.containers["cols"] = tuple(Classic_Container(self) for j in range(self.dims))
        self.containers["subgrids"] = tuple(Classic_Container(self) for k in range(self.dims))
        self.containers["all"] = self.containers["rows"] + self.containers["cols"] + self.containers["subgrids"] # helps iterate through all containers in the future
        
        
        # Create a copy of inp_squares but with squares objects
        self.entries = [[Classic_Entry(self, i, j, inp_squares[i][j]) for j in range(self.dims)] for i in range(self.dims)]
        
        
        # Update all the potential mappings for unsolved entries
        self.update_all_entry_mappings()
        
        # Finally, update all the container value-entry mappings
        self.update_all_container_mappings()
    
    def visualize(self):
        for i in range(self.dims):
            print([str(self.entries[i][j]) for j in range(self.dims)])
        print()
        
    def update_all_entry_mappings(self):
        for entries in self.entries:
            for e in entries:
                e.update_mappings()

    def update_all_container_mappings(self):
        # make sure to run update_all_entry_mappings() before running this method.
        for c in self.containers["all"]:
            c.update_mappings()
    
    
    def solve(self):
        """
        3 step algorithm
        1. Check all of the squares; if the square is solved, then we remove the square because we don't care about it anymore
        2. Check all of the containers for one to one mappings
        3. Try to run the pigeonhole
        """
        
        # Flattened entries are useful for speeding up
        self.unsolved_entries = []
        for entries in self.entries:
            for e in entries:
                if not e.get_solved():
                    self.unsolved_entries.append(e)
        
        
        self.alg1()
        self.alg2()
        self.alg3()
        
    def alg1(self):
        
        """
        Quick note on data structure:
        Given that this is the least computationally expensive step,
        we want to do this step whenever possible. Thus, our goal is to find
        the correct entries to immediately solve.
        A naive approach would be to simply iterate through all entries.
        A binary heap Priority Queue is better than the naive approach.
        
        ANOTHER THING:
        This doesn't heapify dynamically (i just can't be bothered to 
        make sure that the heap works properly)
        Instead, i just use an updated boolean. If at any point we
        update some priorities, we'll just heapify and do this again.
        
        """
        heapq.heapify(self.unsolved_entries)
        
        updated = False
        if len(self.unsolved_entries) == 0:
            return
        while self.unsolved_entries and len(self.unsolved_entries[0].get_mappings()) < 2:
            e = heapq.heappop(self.unsolved_entries)
            if len(e.get_mappings()) == 1:
                e.solve(e.get_mappings()[0])
                updated = True
        
        if updated:
            self.alg1()
            
            
    
        
    def alg2(self):
        """
        Hard to really choose the correct container here. Just go with
        a naive approach.
        
        Iterate through all containers
        If there is any value that has only one possible entry mapping, then
        solve that entry.
        """

        updated = False
        for c in self.containers["all"]:
            for val in range(1, self.dims):
                if c.get_mappings(val) and len(c.get_mappings(val)) == 1: # get the potential mappings list for the val; if it's None or empty, skip; if it's len 1 then solve it
                    c.get_mappings(val)[0].solve(val)
                    updated = True
                    
        if updated:
            self.alg1()
            self.alg2()
    
    
    
    def alg3(self):
        """
        This is the pigeonhole section; Quite weird
        
        1. Iterate through all containers to check for pigeonhole matchings
        2. To find pigeonhole matching:
            a. Create the powerset of the set of remaining unsolved values
            b. Iterate through the powerset (note that powerset starts at subsets of length 2)
               we can be certain that the length 1 stuff has already been matched
            c. Construct a list of sets that are possibile matchings for these values
            d. If the union of this list of sets is equal to the value subset, then
            there is a pigeonhole matching ie there is some bipartite matching.
        
        """
        
        updated = False
        
        for c in self.containers["all"]:
            # iterate through the power set of sets
            for e_subset in power_set(c.get_unsolved_entries()):
                all_vals = [set(e.get_mappings()) for e in e_subset]                
                vals = set.union(*all_vals)
                if len(vals) == len(e_subset):
                    c.pigeonhole(e_subset, vals) # there must be some bipartite matching
                    updated = True        
        if updated:
            self.alg1()
            self.alg2()
            self.alg3()
    
    
class Classic_Container:
    def __init__(self, grid):        
        self.grid = grid
        self.entries = []
        self.solved_vals = [None for _ in range(grid.dims + 1)] # 1-indexed
        self.mappings = [[] for i in range(grid.dims + 1)] # Note: The first list will be a dummy placeholder
        self.mappings[0] = None                            # which makes this 1-indexed; useful bc values are 1-indexed
        
    
    def get_entries(self):
        return self.entries
    
    def get_solved_val(val):
        return self.solved_vals[val]
    
    def get_unsolved_entries(self):
        return set(self.entries).difference(set(self.solved_vals))
    
    def get_mappings(self, val):
        return self.mappings[val]
    
    def add_entry(self, e):
        self.entries.append(e)
        if e.get_solved():
            self.mappings[e.get_val()] = set()
            self.solved_vals[e.val] = e
        
    def update_mappings(self):
        # NOTE: make sure that all of the entries in this container is updated, or else this won't work
        for e in self.entries:
            if e.get_solved():
                assert (self.mappings[e.get_val()] == set()), "Container mappings not set to {} for solved value" 
                assert (self.solved_vals[e.get_val()] is e), "Container solved_value array not set to correct entry"
                for mappings in self.mappings:
                    assert (mappings is None or not e in mappings), "Solved entry should not be in any container mappings"
            else:
                for val in e.get_mappings():
                    if e not in self.mappings[val]:
                        self.mappings[val].append(e)
                        
    def remove_mapping(self, e, val):
        if e in self.mappings[val]:
            self.mappings[val].remove(e)
    
    def solve_entry(self, e, val):
        """
        1. Remove the entry from bad_mappings (Ex: e.mappings = [1, 2] and e.val = 1; then remove e from self.mappings[2])
        2. For every entry that is currently in val can be mapped to, remove val from those entries' possible mappings.
        """

        for entry in self.mappings[val]:
            if not entry is e and val in entry.get_mappings():
                entry.remove_mapping(val)

        for bad_mapping in e.get_mappings():
            if bad_mapping != val and e in self.mappings[bad_mapping]:
                self.mappings[bad_mapping].remove(e)
        
        self.mappings[val] = set()
        self.solved_vals[val] = e
    
    def pigeonhole(self, entries, vals):
        """
        Iterate through all value-entry mappings
        Case 1: val is in vals:
            Iterate through all elements; remove all the bad elements and remove the val from the bad element's mappings
        Case 2: Val is not in vals:
            Iterate through all elements; remove all the 
        
        
        """
        for val in range(1, len(self.mappings)):
            if val in vals: # remove all bad entries
                for e in self.mappings[val]:
                    if e not in entries:
                        e.remove_mapping(val)
                        for c in e.get_containers():
                            c.remove_mapping(self, val)
    
            else: # remove all good entries from bad mappings
                for e in self.mappings[val]:
                    if e in entries:
                        e.remove_mapping(val)
                        for c in e.get_containers():
                            c.remove_mapping(self, val)
            
    def verify(self):
        for c in self.containers["all"]:
            c.verify()
    
    
    
    def visualize(self):
        print([str(e) for e in self.entries])
                    
        """
        1. Remove all entries that are not in the entries set:
            a. iterate through all entries in entries[val]
            b. if the entry is in the entries set, then remove any vals that are not in the vals set, then remove entry from non-vals set
            c. if the entry is not in the entries, then remove any vals that are in the vals set, then remove entry from the set
        2.Iterate through all entries that are not in mappings[val]; these do not have val in its mappings
        
        """
        
        
        
    
class Classic_Entry:
    def __init__(self, grid, row, col, val):
        assert val >= 0 and val <= grid.dims
        self.position = (row, col)
        self.grid = grid
        
        self.row = grid.containers["rows"][row]
        self.col = grid.containers["cols"][col]
        self.subgrid = grid.containers["subgrids"][(row // 3) * 3 + col // 3] # for subgrid indexing, treat as a 2 digit base 3 number
        self.containers = set([self.row, self.col, self.subgrid])
        
        self.val = val
        self.solved = bool(val)

        if self.solved:
            self.mappings = set()
        else:
            self.mappings = []

        for c in self.containers:
            c.add_entry(self)
    
    
    def get_row(self):
        return self.row
    
    def get_col(self):
        return self.col
    
    def get_subgrid(self):
        return self.subgrid
    
    def get_containers(self):
        return self.containers
    
    def get_val(self):
        return self.val
    
    def get_solved(self):
        return self.solved
    
    def get_mappings(self):
        return self.mappings
    
    def update_mappings(self):
        # this is very brute-forcey;
        # iterates through all entries that share a container; if it is solved
        # we remove the value from the possible mappings
        if self.solved:
            return
        else:
            self.mappings = list(range(1, self.grid.dims + 1))
            for c in self.containers:
                for e in c.get_entries():
                    if e.get_solved() and (e.get_val() in self.mappings):
                        self.mappings.remove(e.get_val())
    
    def __lt__(self, other): # we sort by length of mappings so that alg1 is faster
        return len(self.mappings) < len(other.get_mappings())
        
    def solve(self, val):
        for c in self.containers:
            c.solve_entry(self, val)
        
        self.val = val       
        self.mappings = set()
        self.solved = True
    
    def remove_mapping(self, val):
        if val in self.mappings:
            self.mappings.remove(val)
            
            
            
            
            

    def check_subset(subset):
        for val in subset:
            if not val in self.mappings:
                return False
        return True
        
    def __str__(self):
        return str(self.val)
        
        
        

               

In [None]:
"""
easy grids do not require alg3 pigeonholing
hard grids do require alg3
"""

easy_grids = [
Classic_Grid([
    [0, 6, 0, 0, 8, 0, 4, 2, 0],
    [0, 1, 5, 0, 6, 0, 3, 7, 8],
    [0, 0, 0, 4, 0, 0, 0, 6, 0],
    [1, 0, 0, 6, 0, 4, 8, 3, 0],
    [3, 0, 6, 0, 1, 0, 7, 0, 5],
    [0, 8, 0, 3, 5, 0, 0, 0, 0],
    [8, 3, 0, 9, 4, 0, 0, 0, 0],
    [0, 7, 2, 1, 3, 0, 9, 0, 0],
    [0, 0, 9, 0, 2, 0, 6, 1, 0]
]),
Classic_Grid([
    [0, 0, 8, 0, 0, 0, 0, 0, 0],
    [4, 9, 0, 1, 5, 7, 0, 0, 2],
    [0, 0, 3, 0, 0, 4, 1, 9, 0],
    [1, 8, 5, 0, 6, 0, 0, 2, 0],
    [0, 0, 0, 0, 2, 0, 0, 6, 0],
    [9, 6, 0, 4, 0, 5, 3, 0, 0],
    [0, 3, 0, 0, 7, 2, 0, 0, 4],
    [0, 4, 9, 0, 3, 0, 0, 5, 7],
    [8, 2, 7, 0, 0, 9, 0, 1, 3]
])]


hard_grids = [
    Classic_Grid([
    [2, 5, 0, 0, 0, 3, 0, 9, 1],
    [3, 0, 9, 0, 0, 0, 7, 2, 0],
    [0, 0, 1, 0, 0, 6, 3, 0, 0],
    [0, 0, 0, 0, 6, 8, 0, 0, 3],
    [0, 1, 0, 0, 4, 0, 0, 0, 0],
    [6, 0, 3, 0, 0, 0, 0, 5, 0],
    [1, 3, 2, 0, 0, 0, 0, 7, 0],
    [0, 0, 0, 0, 0, 4, 0, 6, 0],
    [7, 6, 4, 0, 1, 0, 0, 0, 0]
]), 
]

easy_grids[0].visualize()
easy_grids[0].solve()
easy_grids[0].visualize()

