# 2D String Wave Function Collapse

In [1]:
import random
import math

In [2]:
class Patch:
    """Represents a single sub-string for purpose of wave-collapse.
    
    Attributes
    ----------
    core : int
        The central element of the sub-string
    raw_patch : list <int>
        The complete list of elements of the sub-string
    radius : int
        The radius of the sub-string, or the number of elements to one side of the core
    length : int
        The length of the sub-string
    frequency : int
        The number of occurences of this sub-string within the base text
        """
    
    def __init__(self, core, raw_patch, radius, frequency = 1):
        """
        Parameters
        ----------
        core : int
            The central element of the sub-string
        raw_patch : list <int>
            The complete list of elements of the sub-string
        radius : int
            The radius of the sub-string, or the number of elements to one side of the core
        frequency : int, optional
            The number of occurences of this sub-string within the base text
        """
        
        self.core = core
        self.raw_patch = raw_patch
        self.radius = radius
        self.length = 2*radius + 1
        self.frequency = frequency
        
    def set_frequency(self, frequency):
        """
        Parameters
        ----------
        frequency : int
            The number of occurences of this sub-string within the base text
        """
        
        self.frequency = frequency
        
    def add_frequency(self, increment):
        """
        Parameters
        ----------
        frequency : int
            The number of new occurences to add of this sub-string within the base text
        """
        
        self.frequency = self.frequency + increment
        
    def same_pattern(self, other):
        """
        Parameters
        ----------
        other : Patch
            Check if the patterns represented are the same between this Patch and another
        """
        
        return self.core == other.core and self.raw_patch == other.raw_patch 
    
    def match_surroundings(self, surroundings):
        """Check if this Patch could exist at the center of the possible wave surrounding it.
        
        Parameters
        ----------
        surroundings : list <WaveElement>
            The possible wave surrounding this patch.
        
        Returns
        -------
        bool
            If this Patch could exist at the center of the surroundings.
            
        """
        
        for i in range(self.radius):
            #Invalid if any element of the patch is not possible in the corresponding location in the surroundings
            if self.raw_patch[i] not in surroundings[i].possible_cores:
                return False
            if self.raw_patch[-1-i] not in surroundings[-1-i].possible_cores:
                return False
        return True

In [3]:
class WaveElement:
    """Represents a single location in the wave function as a super-position of Patches.
    
    Attributes
    ----------
    normalization : int
        Weighted number of Patches still possible
    patches : dict {<int> : list <Patch>}
        The dictionary of Patches in super-position, indexed by Patch.core
    possible_cores : list <int>
        List of cores that are still possible
    collapsed : bool
        True when the WaveElement has collapsed to a core, False otherwise
    selected_core : int
        If the WaveElement has collapsed, this is the core of the remaining Patches
        """
    
    def __init__(self, all_patches = []):
        """
        Parameters
        ----------
        all_patches : list <Patch>
            List of all Patches to be put in super-position
        """
        
        self.normalization = 0
        self.patches = {}
        self.possible_cores = []
        self.collapsed = False
        self.selected_core = -1
        for p in all_patches:
            self.add_patch(p)
            
    def add_patch(self, patch):
        """Adds one Patch to the super-position, updating values accordingly
        
        Parameters
        ----------
        patch : Patch
            A single Patch
        """
        
        if self.collapsed:
            raise Exception("calling add_patch despite being collapsed");
        if patch.core not in self.patches:
            self.patches[patch.core] = []
        self.patches[patch.core].append(patch)
        self.normalization = self.normalization + patch.frequency
        if patch.core not in self.possible_cores:
            self.possible_cores.append(patch.core)
            
    def subtract_patch(self, patch):
        """Removes one Patch to the super-position, updating values accordingly
        
        Parameters
        ----------
        patch : Patch
            A single Patch
        """
        
        if self.collapsed:
            raise Exception("calling subtract_patch despite being collapsed");
        kill_list = []
        if patch.core in self.patches:
            for p in self.patches[patch.core]:
                if p.same_pattern(patch):
                    self.normalization = self.normalization - p.frequency
                    kill_list.append(p)
        
        for k in kill_list:
            self.patches[patch.core].remove(k)
        if len(self.patches[patch.core]) == 0:
            self.possible_cores.remove(patch.core)
            
    def get_collapse_quality(self):
        """Calculates the quality of the collapse of this WaveElement.
        
        Quality should mathematically be shannon entropy, but for a 2d 
        wave function collapse, normalization is used for simplicity.
        """
        
        if self.collapsed:
            raise Exception("calling get_collapse_quality despite being collapsed");
                    
        #    Shannon Entropy
        
        #sum_weight = 0
        #sum_weight_log = 0
        #for c in self.patches:
        #    part_sum = 0
        #    for p in self.patches[c]:
        #        part_sum = part_sum + p.frequency
        #    if part_sum > 0:
        #        sum_weight = sum_weight + part_sum
        #        sum_weight_log = sum_weight_log + part_sum*math.log(part_sum)
        #entropy = 100
        #if sum_weight > 0:
        #    entropy = math.log(sum_weight) - (sum_weight_log / sum_weight)
        #
        #return entropy
        
        
        return self.normalization        
    
    def fixed_collapse(self, core):
        """Collapses the WaveElement to the chosen core.
        
        Parameters
        ----------
        core : int
            The core to collapse to
        """
        
        if self.collapsed:
            raise Exception("calling fixed_collapse despite being collapsed");
        self.collapsed = True
        self.selected_core = core
        self.normalization = 1
        self.patches = {}
        self.possible_cores = [self.selected_core]
        
    def max_collapse(self):
        """Collapses the WaveElement to the core of the highest frequency Patch.
        """
        
        if self.collapsed:
            raise Exception("calling max_collapse despite being collapsed");
        if len(self.patches) == 0:
            raise Exception("Failed max_collapse : No patches")
        max_patch = None
        max_freq = -1
        for c in self.patches:
            for p in self.patches[c]:
                if p.frequency >= max_freq:
                    max_patch = p
                    max_freq = p.frequency
        self.fixed_collapse(max_patch.core)
        
    def max_core_collapse(self):
        """Collapses the WaveElement to the core with the total highest frequency.
        """
        
        if self.collapsed:
            raise Exception("calling max_core_collapse despite being collapsed");
            
        if len(self.patches) == 0:
            raise Exception("Failed max_core_collapse : No patches")
            
        max_core = -1
        max_freq = -1
        for c in self.patches:
            total = 0
            for p in self.patches[c]:
                total = total + p.frequency
            if total >= max_freq:
                max_core = c
                max_freq = total
        self.fixed_collapse(max_core)
        
    def probable_collapse(self):
        """Collapses the WaveElement via probability, weighted to each Patch by its frequency.
        """
        
        if self.collapsed:
            raise Exception("calling probable_collapse despite being collapsed");
            
        countdown = random.randint(0,self.normalization)
        for c in self.patches:
            for p in self.patches[c]:
                countdown = countdown - p.frequency
                if countdown <= 0:
                    self.fixed_collapse(p.core)
                    return
        raise Exception("Failed probable_collapse : Countdown did not end")
        
    def cull_patches(self, surroundings):
        """Removes Patches from the super-position based on surroundings.
        
        Parameters
        ----------
        surroundings : list <WaveElement>
            The possible wave surrounding this WaveElement.
        
        Returns
        -------
        bool
            If this culling changed the super-position.
        """
        
        if self.collapsed:
            raise Exception("calling cull_patches despite being collapsed");
            
        change = False
        for c in self.patches:
            kill_patches = []
            for p in self.patches[c]:
                if not p.match_surroundings(surroundings):
                    kill_patches.append(p)
            for p in kill_patches:
                change = True
                self.subtract_patch(p)
        return change
    
    def check_collapse(self):
        """Test if this WaveElement is collapse, or if it should be, and collapse if so.
        
        Returns
        -------
        bool
            If this WaveElement is now collapsed.
        """
        
        if self.collapsed:
            return True
        if len(self.possible_cores) == 1:
            self.fixed_collapse(self.possible_cores[0])
            return True
        return False
            
        

In [4]:
class Wave:
    """Represents a complete wave function in super-position.
    
    Attributes
    ----------
    waveform : list <WaveElement>
        The current state of the wave function
    patches_list : list <Patch>
        Every possible Patch
    radius : int
        The radius of interaction and Patches 
    max_size : int
        The maximum length of the wave function
    word_start_core : int
        The core value that will represent the start of the wave
    word_end_core : int
        The core value that will represent the end of the wave
    success : bool
        If the wave collapsed into a valid state
    worst_quality : int
        The highest value for quality of an element in the wave
    """
    
    def __init__(self, radius, max_size, word_start_core, word_end_core, patches_list):
        """
        Parameters
        ----------
        radius : int
        max_size : int
        word_start_core : int
        word_end_core : int
        patches_list : list <Patch>
        """
        
        self.waveform = []
        self.patches_list = patches_list
        self.radius = radius
        self.max_size = max_size
        self.word_start_core = word_start_core
        self.word_end_core = word_end_core
        self.populate()
        self.success = True
        self.worst_quality = 0
        
    def populate(self):
        """Generates the complete super-position of the wave and populates it with WaveElements
        """
        
        self.waveform = [ WaveElement(self.patches_list) for i in range(2*self.radius + self.max_size)]
        self.worst_quality = self.waveform[0].get_collapse_quality()
        for i in range(self.radius):
            self.waveform[i].fixed_collapse(self.word_start_core)
            self.waveform[-1-i].fixed_collapse(self.word_end_core)
            
    def seed_collapse(self):
        """Chooses a random index of the wave and collapses it to begin the wave function collapse.
        
        Returns
        -------
        int
            The index of the element collapsed
        """
        
        index = random.randint(0,self.max_size-1)
        self.waveform[self.radius + index].probable_collapse()
        return index
    
    def do_best_collapse(self):
        """Find and collapse the best possible WaveElement

        Returns
        -------
        int
            The index of the element collapsed
        """
        
        best_quality = self.worst_quality *2 + 1000
        i_list = []
        for i in range(len(self.waveform)):
            if not self.waveform[i].collapsed:
                if self.waveform[i].get_collapse_quality() < best_quality:
                    best_quality = self.waveform[i].get_collapse_quality()
                    i_list = [i]
                elif self.waveform[i].get_collapse_quality() == best_quality:
                    i_list.append(i)
        i = random.choice(i_list)
        self.waveform[i].probable_collapse()
        return i
    
    def cull_at(self,index):
        """Removes invalid Patches from the selected WaveElement, bringing it closer to collapse.
        
        Parameters
        ----------
        index : int
            The index of the WaveElement to cull
        
        Returns
        -------
        bool
            If the cull changed anything
        """
        
        if not self.waveform[index].collapsed:
            wave_selection = self.waveform[index-self.radius:index+self.radius+1]
            return self.waveform[index].cull_patches(wave_selection)
        return
    
    def empty_propogate(self):
        """Scans over the wave and culls WaveElements based on their neighbors.
        """
        
        max_cycles = 1
        for cycle in range(max_cycles):
            changed = False
            for i in range(self.max_size):
                j = self.radius + i
                changed = changed or self.cull_at(j)
            for i in range(self.max_size):
                j = self.radius + self.max_size - i -1
                changed = changed or self.cull_at(j)
            if not changed:
                break
    def propogate_from(self, index):
        """Scans over the surroundings of a WaveElement and culls its neighbors, then the whole wave.
        
        Parameters
        ----------
        index : int
            The index of the WaveElement to cull around
        """
        l = len(self.waveform)
        
        
        for i in range(self.radius):
            right_index = index + i + 1
            left_index = index - i - 1
            if right_index - self.radius >= 0 and right_index + self.radius < l:
                self.cull_at(right_index)
            if left_index - self.radius >= 0 and left_index + self.radius < l:
                self.cull_at(left_index)
        
        self.empty_propogate()
        
        
    def check_fully_collapsed(self):
        """Check if every element of the wave is collapsed.
        
        Returns
        -------
        bool
            If every element of the wave is collapsed
        """
        
        full_collapse = True
        for w in self.waveform:
            full_collapse = full_collapse and w.check_collapse()
        return full_collapse
    
    def check_failed_collapse(self):
        """Check if the collapse has failed to resolve correctly.
        
        This wave function collapse algorithm does not guarantee correct results,
        as that is the domain of quantum computing. Instead, it can easily identify 
        when it has reached a dead-end wave and restarts.
        
        Returns
        -------
        bool
            If the wave collapse failed
        """
        
        failed = False
        for w in self.waveform:
            failed = failed or len(w.possible_cores) == 0
        return failed
    
    def collapse(self):
        """Collapses the wave from its super-position completely

        Returns
        -------
        bool
            If the collapse was successful
        """
        
        i = self.seed_collapse()
        self.propogate_from(i)
        
        
        while not (self.check_fully_collapsed() or self.check_failed_collapse()):
            self.empty_propogate()
            if self.check_failed_collapse():
                break
            self.propogate_from(self.do_best_collapse())
        self.empty_propogate()
        return not self.check_failed_collapse()

In [12]:
class TextWaveHandler:
    """Manipulates text to work in the generic wave collapse function.
    
    Attributes
    ----------
    line_delimiter : str
        What string to split the input text by
    max_size : int
        What size to cap the wave function at
    radius : int
        What radius of interaction should be used by the wave collapse
    count_table : dict {<str> : dict {<str> : <int>} }
        Table for counting sub-strings by central character
    total_table : dict {<str> : int}
        Table for counting incidences of each character
    total_phonemes : int
        Total number of characters read from the input text
    patches_list : list <Patch>
        List of patches produced from the input text
    phoneme_list : list <str>
        List of unique characters read from the text
    word_start : int
        The value assigned to the starting character
    word_end : int
        The value assigned to the ending character
    wave : Wave
        The Wave object built to produce similar text to the input text
    padding_left : str
        The string used to pad the text to the left
    padding_right : str
        The string used to pad the text to the right
    """
    
    def __init__(self, line_delimiter = '\n',max_size=10,radius=1,padding_left="+",padding_right="-"):
        """    
        Parameters
        ----------
        line_delimiter : str
        max_size : int
        radius : int
        padding_left : str
        padding_right : str
        """
        
        self.line_delimiter = line_delimiter
        self.max_size = max_size
        self.radius = radius
        self.count_table = {}
        self.total_table = {}
        self.total_phonemes = 0
        self.patches_list = []
        self.phoneme_list = []
        self.word_start = 0
        self.word_end = 1
        self.padding_left = padding_left
        self.padding_right = padding_right
        self.wave = None
        
    def num_to_phoneme(self,num):
        """Converts from the wave's representation of a value to the original string.
        
        Parameters
        ----------
        num : int
            The value to convert
        
        Returns
        -------
        str
            The converted value
        """
        
        return self.phoneme_list[num]
    
    def phoneme_to_num(self,text):
        """Converts from the string of a value to the wave's numerical representation.
        
        Parameters
        ----------
        text : str
            The string to convert
        
        Returns
        -------
        int
            The converted value
        """
        
        return self.phoneme_list.index(text)
    
    def read_text(self,text):
        """Reads the input text and captures all values to generate the wave function super-position derived from it.
        
        Parameters
        ----------
        text : str
            The string to read
        """
        
        #Read text into tables of text
        words = text.split(self.line_delimiter)
        for word in words:
            padded_word = (self.padding_left*(self.radius+1)) + word + (self.padding_right*(self.radius+1))
            n = len(padded_word)
            for i in range(n - 2*self.radius):
                patch_text = padded_word[i:i+2*self.radius+1]
                phoneme = padded_word[i+self.radius]
                if phoneme not in self.phoneme_list:
                    self.phoneme_list.append(phoneme)
                    self.count_table[phoneme] = {}
                    self.total_table[phoneme] = 0
                if patch_text not in self.count_table[phoneme]:
                    self.count_table[phoneme][patch_text] = 0
                self.count_table[phoneme][patch_text] = self.count_table[phoneme][patch_text] + 1
                self.total_table[phoneme] = self.total_table[phoneme] + 1
                self.total_phonemes = self.total_phonemes + 1
        self.word_start = self.phoneme_list.index(self.padding_left)
        self.word_end = self.phoneme_list.index(self.padding_right)
        
        #turn tables of text in list of Patch objects for use in Wave()
        for phoneme in self.count_table:
            for patch_string in self.count_table[phoneme]:
                raw_patch = [self.phoneme_to_num(c) for c in patch_string]
                p = Patch(self.phoneme_to_num(phoneme), raw_patch, self.radius, self.count_table[phoneme][patch_string])
                self.patches_list.append(p)
                
    def generate_wave(self):
        """Creates the Wave object.
        """
        
        self.wave = Wave(self.radius, self.max_size, self.word_start, self.word_end, self.patches_list)

    def print_wave_raw(self):
        """Prints the contents of the collapsed wave as text.
        """
        
        out = ""
        for we in self.wave.waveform:
            out = out + self.num_to_phoneme(we.selected_core)
        
        print(out.strip("+-"))
        
    def mid_print(self):
        """Prints some contents of the wave function while it isn't collapsed.
        """
        
        out = "\n"
        for we in self.wave.waveform:
            var = "[" + ''.join( self.num_to_phoneme(k) for k in we.possible_cores) \
                + "|" + ((str(we.selected_core) + "|") if we.collapsed else "") + str(we.collapsed) + "]"
            out = out + var + "\n"
        print(out)

    def produce(self, n, max_out = 2000):
        """Produces random strings from the wave function collapse.
        
        Parameters
        ----------
        n : int
            The number of random strings to produce
        max_out : int, optional
            The maximum number of attempts to make before ending prematurely
        
        Returns
        -------
        int
            The converted value
        """
        
        i = 0
        while i < n:
            max_out = max_out - 1
            if max_out < 0:
                print("maxed out production loops")
                break
            self.generate_wave()
            success = self.wave.collapse()
            if success:
                self.print_wave_raw()
                i = i+1
            

In [7]:
s="""George Washington
John Adams
Thomas Jefferson
James Madison
James Monroe
John Quincy Adams
Andrew Jackson
Martin Van Buren
William H. Harrison
John Tyler
James K. Polk
Zachary Taylor
Millard Fillmore
Franklin Pierce
James Buchanan
Abraham Lincoln
Andrew Johnson
Ulysses S. Grant
Rutherford B. Hayes
James A. Garfield
Chester A. Arthur
Grover Cleveland
Benjamin Harrison
Grover Cleveland
William McKinley
Theodore Roosevelt
William H. Taft
Woodrow Wilson
Warren G. Harding
Calvin Coolidge
Herbert Hoover
Franklin D. Roosevelt
Harry S. Truman
Dwight D. Eisenhower
John F. Kennedy
Lyndon B. Johnson
Richard M. Nixon
Gerald R. Ford
Jimmy Carter
Ronald Reagan
George H. W. Bush
Bill Clinton
George W. Bush
Barack Obama
Donald J. Trump"""

In [13]:
wh = TextWaveHandler(max_size=16,radius=2)
wh.read_text(s)
wh.generate_wave()

In [14]:
wh.produce(20)

Geodoren Adams
Cary S. H. Poln
Ulysses Monroe
Abrantonroveld
John Tayes Bush
Zachard J. Poln
John F. Johnson
Ulysses Kenhower
Cald McKin Taft
Will Coodonrover
Dwight D. Nixon
Zacharriseverce
Groveld Reagan
Ulyssester Adama
Ulysses Monroe
Maran Burennedy
Johnsonankliaman
George Harrison
Benjames Kenjams
Will Clevelanton


In [15]:
w2 = TextWaveHandler(max_size=8,radius=1)
w2.read_text(s)
wh.generate_wave()

In [18]:
w2.produce(20)

Doeodrr
Jachassh
Joom Ker
And Bama
Jam Jeov
Wington
Jachnter
eon Wig
Bus Bur
Jin Jont
Foodrton
Grin Han
Gerran
Fon Harg
Wafthndy
Zamorce
Banarace
Rusharge
Gackseve
Trraler
