In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
#| default_exp hierarchy

In [None]:
#| export
#import numpy as np
import uuid, json
from os import sep


In [None]:
#| export
from pct.putils import UniqueNamer
from pct.functions import BaseFunction, HPCTFUNCTION, WeightedSum
from pct.environments import EnvironmentFactory, ControlEnvironment
from pct.errors import BaseErrorCollector
from pct.putils import floatListsToString, PCTRunProperties, FunctionsData, FunctionsList, list_of_ones
from pct.nodes import PCTNode


ModuleNotFoundError: No module named 'pct'

In [None]:
#| include: false
from pct.functions import Proportional

## Defining the Hierarchy

Defining the PCTHierarchy   

In [None]:
#| export
class PCTHierarchy():
    "A hierarchical perceptual control system, of PCTNodes."
    
    def __init__(self, levels=0, cols=0, pre=None, post=None, name="pcthierarchy", clear_names=True, links="single", 
                    history=False, build=True, error_collector=None, namespace=None, **pargs):
            if namespace ==None:
                namespace = uuid.uuid1()
            self.namespace=namespace

            self.error_collector=error_collector
            self.links_built = False
            self.order=None
            self.history=history
            if clear_names:
                UniqueNamer.getInstance().clear(namespace=namespace)
            self.name=UniqueNamer.getInstance().get_name(namespace=namespace, name=name)
            if pre==None:
                self.preCollection=[]
            else:            
                self.preCollection=pre
            if post==None:
                self.postCollection=[]
            else:                 
                self.postCollection=post
            self.hierarchy = []
            self.prepost_data = None
            if history:
                self.prepost_data = FunctionsData()
                
            for r in range(levels):
                col_list=[]
                for c in range(cols):
                    if links == "dense":
                        if r > 0:
                            perc = WeightedSum(weights=list_of_ones(cols), namespace=namespace)
                        if r < levels-1:
                            ref = WeightedSum(weights=list_of_ones(cols), namespace=namespace)
                        if r == 0:
                            if levels > 1:
                                node = PCTNode(reference=ref, name=f'level{r}col{c}', history=history, namespace=namespace)      
                            else:
                                node = PCTNode(name=f'level{r}col{c}', history=history, namespace=namespace)                              
                        if r > 0 and r == levels-1:                        
                            node = PCTNode(perception=perc, name=f'level{r}col{c}', history=history, namespace=namespace)
                        if r > 0 and r < levels-1:
                            node = PCTNode(perception=perc, reference=ref, history=history, name=f'level{r}col{c}', namespace=namespace)

                    else:
                        node = PCTNode(name=f'level{r}col{c}', history=history, namespace=namespace)
                    
                    if build:
                        node.build_links()                    
                        self.handle_perception_links(node, r, c, links)
                        self.handle_reference_links(node, r, c, links)
                    col_list.append(node)
                    
                self.hierarchy.append(col_list)
        
        
    def __call__(self, verbose=False):

        for ctr in range(len(self.preCollection)):
            func = self.preCollection[ctr]
            func(verbose)
            if self.prepost_data != None:                
                self.prepost_data.add_data(func)
                if ctr == 0:
                    if hasattr(func, 'reward_sum') :
                        self.prepost_data.add_reward_sum(func)
                    if hasattr(func, 'reward') :
                        self.prepost_data.add_reward(func)
                    if hasattr(func, 'fitness'):
                        self.prepost_data.add_fitness(func)

        if verbose:
            print()

        if self.order==None:
            for level in range(len(self.hierarchy)):
                for col in range(len(self.hierarchy[level])):
                    node  = self.hierarchy[level][col]
                    if verbose:
                        print(node.get_name(), end =" ")
                    node(verbose)
        elif self.order=="Down":
            for level in range(len(self.hierarchy)-1, -1, -1):
                for col in range(len(self.hierarchy[level])-1, -1, -1):
                    node  = self.hierarchy[level][col]
                    if verbose:
                        print(node.get_name(), end =" ")
                    node(verbose)                       
        else:
            for node_name in self.order:
                if verbose:
                    print(node_name, end =" ")
                FunctionsList.getInstance().get_function(self.namespace, node_name)(verbose)
        
        for func in self.postCollection:
            func(verbose)          
            if self.prepost_data != None:
                self.prepost_data.add_data(func)
        
        output = self.get_output_function().get_value()
        
        if self.error_collector != None:
            self.error_collector.add_data(self)
        
        if verbose:
            print()
        
        return output
        

    def is_fitness_close_to_zero(self):
        return self.get_environment().is_fitness_close_to_zero()

    def get_environment_score(self):
        return self.get_environment().get_environment_score()
    
    # def is_environment_terminated(self):
    #     return self.get_environment().is_environment_terminated()

    def set_name(self, name):
        self.name=name    
    
    def get_prepost_data (self):
        return self.prepost_data 
    
    def set_order(self, order):
        self.order=order
        
    def get_output_function(self):
        if len(self.postCollection) > 0:
            return self.postCollection[-1]
        
        return self.hierarchy[-1][-1].get_output_function()
        
    def add_preprocessor(self, func):
        self.preCollection.append(func)
        
    def add_postprocessor(self, func):
        self.postCollection.append(func)

    def get_preprocessor(self):
        return self.preCollection
        
    def get_postprocessor(self):
        return self.postCollection

    def is_environment_resolved(self):
        environment = self.get_environment()
        if hasattr( environment, 'is_environment_resolved' ) and callable( environment.is_environment_resolved ):
            return self.get_environment().is_environment_resolved()

        return None

    def get_environment(self):
        return self.get_preprocessor()[0]
    
    def has_environment(self):
        return len(self.preCollection) > 0 and isinstance(self.preCollection[0], ControlEnvironment)

    def get_environment_name(self):
        return self.get_preprocessor()[0].get_name()

    def set_run_steps(self, steps):
        if self.has_environment():
            env = self.get_environment()
            if hasattr(env, 'set_run_steps'):
                env.set_run_steps(steps)

    def set_current_step(self, step):
        if self.has_environment():
            env = self.get_environment()
            if hasattr(env, 'set_current_step'):
                env.set_current_step(step)

    def run(self, steps=1, verbose=False):
        self.set_run_steps(steps)
        for i in range(steps):
            self.step = i
            try:
                if verbose:
                    print(f'[{i}]', end=' ')
                self.set_current_step(i)
                out = self(verbose)
            except Exception as ex:
                # if self.error_collector != None:
                #     print(f'<{i} {self.error_collector.error()}>')
                if ex.__str__().startswith('1000'):
                    self.error_collector.override_value()
                    if verbose:
                        print(f'Current score={self.error_collector.error()}')                    
                    return False
                elif ex.__str__().startswith('1001'):
                    return False

                raise ex

            if self.error_collector:
                if verbose:
                    print(f'Current score={self.error_collector.error()}')
            
                if self.history :
                    self.prepost_data.add_value('error', self.error_collector.error())
            
                if self.error_collector.is_terminated():
                #     print(f'<{i} {self.error_collector.error()}>')
                    return out
                    
        # if self.error_collector != None:
        #     print(f'<{i} {self.error_collector.error()}>')
        return out
    
    def last_step(self):
        return self.step
        
    def get_node(self, level, col):
        return self.hierarchy[level][col]
    
    def get_error_collector(self):
        return self.error_collector

    def set_error_collector(self, error_collector):
        self.error_collector = error_collector

    def handle_perception_links(self, node, level, col, links_type):
        if level == 0 or links_type == None:
            return
        
        if links_type == "single":
            node.add_link("perception", self.hierarchy[level-1][col].get_function("perception"))
        
        if links_type == "dense":
            for column in range(len(self.hierarchy[level-1])):
                node.add_link("perception", self.hierarchy[level-1][column].get_function("perception"))

    def handle_reference_links(self, thisnode, level, col, links_type):
        if level == 0 or links_type == None:
            return
        
        if links_type == "single":
            thatnode = self.hierarchy[level-1][col]
            thatnode.add_link("reference", thisnode.get_function("output"))
        
        if links_type == "dense":
            for column in range(len(self.hierarchy[level-1])):
                thatnode = self.hierarchy[level-1][column]
                thatnode.add_link("reference", thisnode.get_function("output"))

    def get_grid(self):
        return [self.get_columns(level) for level in range(self.get_levels())]

    def get_node_positions(self, align='horizontal'):
        graph = self.graph()
        pos = nx.multipartite_layout(graph, subset_key="layer", align=align)
        return pos

    def get_level(self, level):

        return [ self.hierarchy[level] ]

    
    def get_top_level(self):
        levels = self.get_levels()

        return [ self.hierarchy[levels-1] ]
            
    def draw(self, with_labels=True, with_edge_labels=False,  font_size=12, font_weight='bold', font_color='black', 
            color_mapping={'PL':'aqua','OL':'limegreen','CL':'goldenrod', 'RL':'red', 'I':'silver', 'A':'yellow'},
            node_size=500, arrowsize=25, align='horizontal', file=None, figsize=(8,8), move={}, draw_fig=True,
            node_color=None, layout={'r':2,'c':1,'p':2, 'o':0}, funcdata=False, interactive_mode=False, experiment=None):
        
        if not draw_fig :
            return None
        import networkx as nx
        import matplotlib.pyplot as plt
        import plotly.tools as tls
        if not interactive_mode:
            plt.switch_backend('agg')
        self.graphv = self.graph(layout=layout, funcdata=funcdata)
        if node_color==None:
            node_color = self.get_colors(self.graphv, color_mapping)

        pos = nx.multipartite_layout(self.graphv, subset_key="layer", align=align)
        
        for key in move.keys():            
            pos[key][0]+=move[key][0]
            pos[key][1]+=move[key][1]
        
        fig = plt.figure(figsize=figsize) 

        if with_edge_labels:
            edge_labels = self.get_edge_labels_wrapper(funcdata)
            nx.draw_networkx_edge_labels(self.graphv, pos=pos, edge_labels=edge_labels, font_size=font_size, 
                font_weight=font_weight, font_color='red', horizontalalignment='left')
            
        nx.draw(self.graphv, pos=pos, with_labels=with_labels, font_size=font_size, font_weight=font_weight, 
                font_color=font_color, node_color=node_color,  node_size=node_size, arrowsize=arrowsize)
        
        plt.title(self.name)
        plt.tight_layout()

        if file:
            plt.savefig(file)

        if experiment:
            experiment.log_figure(figure_name=self.name,figure=fig)
                # experiment.log_image(file)
                # plotly_fig = tls.mpl_to_plotly(fig)
                # plotly_fig.write_html(file)
                # experiment.log_html(open(file,encoding='utf-8').read()) # added ,encoding='utf-8'

        return fig


    def get_colors(self, graph, color_mapping):
        colors=[]
        for node in graph:
            color = 'darkorchid'
            for key in color_mapping.keys():                
                if node.startswith(key):
                    color = color_mapping[key]
                    break
            colors.append(color)
        return colors
            
    def reset(self):
        for func in self.preCollection:
            func.set_value(0)               

        for func in self.postCollection:
            func.reset_value()

        for level in self.hierarchy:
            for node in level:
                node.reset()

    def remove_links(self):
        # remove links with weights of 0
        for func in self.postCollection:
            func.remove_links()
                    
        for func in self.preCollection:
            func.remove_links()
            
        for level in self.hierarchy:
            for node in level:
                node.remove_links()

    def list_link_names(self):
        link_names=[]
        for func in self.postCollection:
            for link in func.links:
                if isinstance(link, str):
                    link_names.append(link)
                else:
                    link_names.append(link.get_name())
                    
        for func in self.preCollection:
            for link in func.links:
                if isinstance(link, str):
                    link_names.append(link)
                else:
                    link_names.append(link.get_name())
            
        for level in self.hierarchy:
            for node in level:
                node.list_link_names(link_names)

        return link_names


    def consolidate(self):
        self.remove_links()

        # for outputs, comparators and references
        linklist = self.list_link_names()
        for level in self.hierarchy:
            for node in level:
                node.consolidate(linklist)

        # for perceptions
        linklist = self.list_link_names()
        for level in self.hierarchy:
            for node in level:
                node.consolidate(linklist)
        
        level_ctr = 0
        for level in self.hierarchy:
            invalid_nodes = []
            ctr = 0
            for node in level:
                if node.is_empty():
                    invalid_nodes.append(ctr)
                else:   
                    if level_ctr == self.get_levels()-1 :
                        if  node.is_reference_empty():
                            invalid_nodes.append(ctr)

                ctr+=1
            for node_ctr in reversed(invalid_nodes):    
                del level[node_ctr]

            level_ctr+=1
            
    
            
    def reset_checklinks(self, val=True):
        for func in self.postCollection:
            func.reset_checklinks(val)
                    
        for func in self.preCollection:
            func.reset_checklinks(val)
            
        for level in self.hierarchy:
            for node in level:
                node.reset_checklinks(val)
                
    def get_edge_labels_wrapper(self, funcdata=False):
        if funcdata:
            return self.get_edge_labels_funcdata()
        else:
            return self.get_edge_labels()

        
    def get_edge_labels_funcdata(self):
        labels={}
    
        for func in self.postCollection:
            func.get_weights_labels_funcdata(labels)
                    
        for func in self.preCollection:
            func.get_weights_labels_funcdata(labels)
            
        for level in self.hierarchy:
            for node in level:
                node.get_edge_labels_funcdata(labels)
                
        return labels
        
        
    def get_edge_labels(self):
        labels={}
    
        for func in self.postCollection:
            func.get_weights_labels(labels)
                    
        for func in self.preCollection:
            func.get_weights_labels(labels)
            
        for level in self.hierarchy:
            for node in level:
                node.get_edge_labels(labels)
                
        return labels

    def change_namespace(self):        
        namespace = uuid.uuid1()
        self.namespace=namespace       
        self.name=UniqueNamer.getInstance().get_name(namespace=namespace, name=self.name)
        
        for func in self.postCollection:
            func.change_namespace(namespace)
                    
        for func in self.preCollection:
            func.change_namespace(namespace)
            
        for level in self.hierarchy:
            for node in level:
                node.change_namespace(namespace)
                
    
    def get_graph(self):
        return self.graphv
    
    def clear_graph(self):
        self.graphv.clear()

    def graph(self, layout={'r':2,'c':1,'p':2, 'o':0}, funcdata=False):
        import networkx as nx
        graph = nx.DiGraph()
        
        if funcdata:
            self.set_graph_data_funcdata(graph, layout=layout)
        else:
            self.set_graph_data(graph, layout=layout)
                
        return graph
    
    
    def set_graph_data(self, graph, layout={'r':2,'c':1,'p':2, 'o':0}):
        layer=0
        if len(self.preCollection)>0 or len(self.postCollection)>0:
            layer=1
            
        for func in self.postCollection:
            func.set_graph_data(graph, layer=0)  

        for func in self.preCollection:
            func.set_graph_data(graph, layer=0)   
                    
        for level in range(len(self.hierarchy)):
            for col in range(len(self.hierarchy[level])-1, -1, -1):
            #for col in range(len(self.hierarchy[level])):
                self.hierarchy[level][col].set_graph_data(graph, layer=layer, layout=layout)
            layer+=3

            
            
    def set_graph_data_funcdata(self, graph, layout={'r':2,'c':1,'p':2, 'o':0}):
        layer=0
        if len(self.preCollection)>0 or len(self.postCollection)>0:
            layer=1
            
        for func in self.postCollection:
            func.set_graph_data_funcdata(graph, layer=0)  

        for func in self.preCollection:
            func.set_graph_data_funcdata(graph, layer=0)   
                    
        for level in range(len(self.hierarchy)):
            for col in range(len(self.hierarchy[level])-1, -1, -1):
            #for col in range(len(self.hierarchy[level])):
                self.hierarchy[level][col].set_graph_data_funcdata(graph, layer=layer, layout=layout)
            layer+=3
            
            
            
#     def draw_nodes(self, with_labels=True, with_edge_labels=False,  font_size=12, font_weight='bold', node_color=None,  
#          color_mapping={'L':'red', 'I':'silver', 'A':'yellow'},
#          node_size=500, arrowsize=25, align='horizontal', file=None, figsize=(8,8), move={}):
#         graph = self.graph_nodes()
#         if node_color==None:
#             node_color = self.get_colors(graph, color_mapping)

#         pos = nx.multipartite_layout(graph, subset_key="layer", align=align)

#         for key in move.keys():            
#             pos[key][0]+=move[key][0]
#             pos[key][1]+=move[key][1]

#         plt.figure(figsize=figsize) 
#         if with_edge_labels:
#             edge_labels = self.get_edge_labels_nodes()
#             nx.draw_networkx_edge_labels(graph, pos=pos, edge_labels=edge_labels, font_size=font_size, font_weight=font_weight, 
#                 font_color='red')
#         nx.draw(graph, pos=pos, with_labels=with_labels, font_size=font_size, font_weight=font_weight, 
#                 node_color=node_color,  node_size=node_size, arrowsize=arrowsize)

#         if file != None:
#             plt.title(self.name)
#             plt.savefig(file)

    def get_edge_labels_nodes(self, node_list):
        labels={}

        for func in self.postCollection:
            func.get_weights_labels_nodes(labels, node_list)

        for func in self.preCollection:
            func.get_weights_labels_nodes(labels, node_list)

        for level in self.hierarchy:
            for node in level:
                node.get_edge_labels(labels)

        return labels

    def change_link_name(self, old_name, new_name):
        for func in self.postCollection:
            func.links = [new_name if i==old_name else i for i in func.links ]

        for func in self.preCollection:
            func.links = [new_name if i==old_name else i for i in func.links ]

        for level in range(len(self.hierarchy)):
            for col in range(len(self.hierarchy[level])):
                self.hierarchy[level][col].change_link_name(old_name, new_name)
        
    def set_suffixes(self):
        functionsList = FunctionsList.getInstance()
        # change names
        for key in functionsList.functions[self.namespace].keys():
            func = functionsList.get_function(self.namespace, key)
            if isinstance (func, BaseFunction):
                name = func.get_name()
                #print(name)
                suffix = func.get_suffix()
                if len(suffix)>0:
                    func.name = name+suffix
                    self.change_link_name(key, func.name)

        keys = list(functionsList.functions[self.namespace].keys())
        for key in keys:
            func = functionsList.get_function(self.namespace,key)
            if isinstance (func, BaseFunction):
                name = func.get_name()
                #print(key, name)
                if key != name:
                    popped = functionsList.functions[self.namespace].pop(key)
                    functionsList.functions[self.namespace][name] = popped


    def get_levels(self):
        return len(self.hierarchy)
    
    def get_columns(self, level):
        return len(self.hierarchy[level])

#     def graph_nodes(self):
#         graph = nx.DiGraph()

#         self.set_graph_data_nodes(graph)

#         return graph

#     def set_graph_data_nodes(self, graph):
#         layer=0
#         if len(self.preCollection)>0 or len(self.postCollection)>0:
#             layer=1

#         node_list={}
#         for level in range(len(self.hierarchy)):
#             for col in range(len(self.hierarchy[level])-1, -1, -1):
#                 node = self.hierarchy[level][col]
#                 node.get_node_list(node_list)

#         for func in self.preCollection:
#             node_list[func.get_name()] = func.get_name()

#         for func in self.postCollection:
#             node_list[func.get_name()] = func.get_name()

#         for func in self.postCollection:
#             func.set_graph_data_node(graph, layer=0, node_list=node_list)

#         for func in self.preCollection:
#             func.set_graph_data_node(graph, layer=0, node_list=node_list)

#         edges = []
#         for level in range(len(self.hierarchy)):
#             for col in range(len(self.hierarchy[level])-1, -1, -1):
#                 node = self.hierarchy[level][col]
#                 graph.add_node(node.get_name(), layer=level+layer)

#                 for func in node.referenceCollection:
#                     for link in func.links:
#                         if isinstance(link, str):
#                             name=link
#                         else:
#                             name = link.get_name()                            
#                         edges.append((node_list[name],node.get_name()))

#                 for func in node.perceptionCollection:
#                     for link in func.links:
#                         if isinstance(link, str):
#                             name=link
#                         else:
#                             name = link.get_name()                            
#                         edges.append((node_list[name],node.get_name()))
                        
#         graph.add_edges_from(edges)

    def validate_links(self):
        for level in range(len(self.hierarchy)):
            for col in range(len(self.hierarchy[level])):
                node = self.hierarchy[level][col]
                ref = node.get_function_from_collection(HPCTFUNCTION.REFERENCE)
                ref_name = ref.get_name()
                target_level = level+1
                links = ref.get_links()
                for link in links:
                    if isinstance(link, str):
                        link_name = link
                    else:
                        link_name = link.get_name()
                    link_level = link_name[2:3]
                    if target_level != eval(link_level):
                        msg = f'Ref {ref_name} link level for {link_name} different to {target_level}'
                        print(msg)
                        self.summary()
                        raise Exception(msg)
                    
        return True
                    
    
    def build_links(self):
        for level in range(len(self.hierarchy)):
            for col in range(len(self.hierarchy[level])):
                self.hierarchy[level][col].build_links()
        

    def clear_values(self):
        for func in self.postCollection:
            func.value = 0

        for func in self.preCollection:
            func.value = 0
                    
        for level in range(len(self.hierarchy)):
            for col in range(len(self.hierarchy[level])):
                self.hierarchy[level][col].clear_values()

    def error(self):
        error = 0
        for level in range(len(self.hierarchy)):
            for col in range(len(self.hierarchy[level])):
                    error += self.hierarchy[level][col].get_function("comparator").get_value()
        return error

    def insert_level(self, level):
        cols_list=[]
        self.hierarchy.insert(level, cols_list)
        
    def remove_level(self, lvl):
        level = self.hierarchy.pop(lvl)
        for node in level:
            node.delete_contents()
            del node
        del level


    def remove_nodes(self, level, num_nodes):        
        for _ in range(num_nodes):
            del self.hierarchy[level][-1]

    def remove_node(self, level, num):        
        del self.hierarchy[level][num]
            
    def get_summary(self):
        str = "**************************\nPRE: "
        
        for func in self.preCollection:
            str += func.output_string() + " "
        str += "\n"
        
        if self.order==None:
            for level in range(len(self.hierarchy)):
                for col in range(len(self.hierarchy[level])):
                    str += f'L{level}C{col} '
                    node  = self.hierarchy[level][col]
                    str += node.get_summary() + " \n"
        elif self.order=="Down":
            for level in range(len(self.hierarchy)-1, -1, -1):
                for col in range(len(self.hierarchy[level])-1, -1, -1):
                    str += f'L{level}C{col} '
                    node  = self.hierarchy[level][col]
                    str += node.get_summary() + " \n"
        
        str += "POST: "

        for func in self.postCollection:
            str += func.output_string()+ " "
        str += "\n"
        
        return str

                                    
    def get_namespace(self):
        return self.namespace

    def check_namespace(self):

        namespace = self.namespace
        for func in self.preCollection:
            func.check_namespace(higher_namespace=namespace)   

        if self.order==None:
            for level in range(len(self.hierarchy)):
                for col in range(len(self.hierarchy[level])):
                    self.hierarchy[level][col].check_namespace(higher_namespace=namespace)
        elif self.order=="Down":
            for level in range(len(self.hierarchy)-1, -1, -1):
                for col in range(len(self.hierarchy[level])-1, -1, -1):
                    self.hierarchy[level][col].check_namespace(higher_namespace=namespace)
                                            
        for func in self.postCollection:
            func.check_namespace(higher_namespace=namespace) 



            
    def summary(self, build=False, extra=False, check_namespace=False):
        print("**************************")
        print(self.name, type(self).__name__, self.get_grid(), self.namespace)                
        print("--------------------------")
        print("PRE:", end=" ")
        if check_namespace:
            namespace = self.namespace
        else:
            namespace = None
        if len(self.preCollection) == 0:
            print("None")
        for func in self.preCollection:
            func.summary(extra=extra, higher_namespace=namespace)   
        
            
        if self.order==None:
            for level in range(len(self.hierarchy)):
                print(f'Level {level} Cols {self.get_columns(level)}')
                for col in range(len(self.hierarchy[level])):
                    self.hierarchy[level][col].summary(build=build, extra=extra, higher_namespace=namespace)
        elif self.order=="Down":
            for level in range(len(self.hierarchy)-1, -1, -1):
                print(f'Level {level} Cols {self.get_columns(level)}')
                for col in range(len(self.hierarchy[level])-1, -1, -1):
                    self.hierarchy[level][col].summary(build=build, extra=extra, higher_namespace=namespace)
                                            
        print("POST:", end=" ")
        if len(self.postCollection) == 0:
            print("None")
        for func in self.postCollection:
            func.summary(extra=extra, higher_namespace=namespace)   


        print("**************************")
            
    def save(self, file=None, indent=4):
        import json
        jsondict = json.dumps(self.get_config(), indent=indent)
        f = open(file, "w")
        f.write(jsondict)
        f.close()
        
    @classmethod
    def load(cls, file, clear=True, namespace=None):
        if clear:
            FunctionsList.getInstance().clear()

        with open(file) as f:
            config = json.load(f)
        return cls.from_config(config, namespace=namespace)
                
    def get_config(self, zero=1):
        config = {"type": type(self).__name__,
                    "name": self.name}        
        
        pre = {}
        for i in range(len(self.preCollection)):
            pre[f'pre{i}']=self.preCollection[i].get_config(zero=zero)
        config['pre']=pre

        
        levels = {}
        for lvl in range(len(self.hierarchy)):
            level ={'level':lvl}
            columns={}
            for col in range(len(self.hierarchy[lvl])):
                column={'col':col}
                if not self.hierarchy[lvl][col].is_empty():
                    nodeconfig = self.hierarchy[lvl][col].get_config(zero=zero)
                    #print(nodeconfig)
                    column['node']=nodeconfig
                    #print(column)

                columns[f'col{col}']=column
            level['nodes']=columns
            levels[f'level{lvl}']=level
        config['levels']=levels
        
        post = {}
        for i in range(len(self.postCollection)):
            post[f'post{i}']=self.postCollection[i].get_config(zero=zero)
        config['post']=post
        return config       

    
    @classmethod
    def from_config(cls, config, history=False, namespace=None):
        hpct = PCTHierarchy(name=config['name'], namespace=namespace)
        namespace = hpct.namespace
        preCollection = []        
        coll_dict = config['pre']
        PCTNode.collection_from_config(preCollection, coll_dict, namespace=namespace)
        
        postCollection = []        
        coll_dict = config['post']
        PCTNode.collection_from_config(postCollection, coll_dict, namespace=namespace)
    
        hpct.preCollection=preCollection
        hpct.postCollection=postCollection
                
        hpct.hierarchy=[]

        # for level_key in config['levels'].keys():
        #     cols = []
        #     for nodes_key in config['levels'][level_key]['nodes'].keys():
        #         node = PCTNode.from_config(config['levels'][level_key]['nodes'][nodes_key]['node'], namespace=namespace, reference=True, comparator=True, perception=True, output=True)
        #         cols.append(node)
        #     hpct.hierarchy.append(cols)
        
        # do in order of perceptions from bottom 
        # then from top references, comparator and output

        for level_key in config['levels']:
            cols = []
            for nodes_key in config['levels'][level_key]['nodes']:
                node = PCTNode.from_config(config['levels'][level_key]['nodes'][nodes_key]['node'], namespace=namespace, perception=True, history=history)
                cols.append(node)
            hpct.hierarchy.append(cols)

        for level_key, level_value in dict(reversed(list(config['levels'].items()))).items():
            cols = []
            for nodes_key, nodes_value in dict(reversed(list(level_value['nodes'].items()))).items():
                node = hpct.get_node(level_value['level'], nodes_value['col'])
                PCTNode.from_config(config=nodes_value['node'], namespace=namespace, reference=True, comparator=True,  output=True, node=node, history=history)

    
        return hpct
    

    def add_node(self, node, level=-1, col=-1):
        
        if len(self.hierarchy)==0:
            self.hierarchy.append([])

        if level<0 and col<0:
            self.hierarchy[0].append(node)
        else:
            levels = len(self.hierarchy)
            if level == levels:
                self.hierarchy.append([])      
            self.hierarchy[level].insert(col, node)
        
    def insert_function(self, level=None, col=None, collection=None, function=None, position=-1):
        self.hierarchy[level][col].insert_function(collection, function, position)

    def replace_function(self, level=None, col=None, collection=None, function=None, position=-1):
        self.hierarchy[level][col].replace_function(collection, function, position)

    def get_function(self, level=None, col=None, collection=None, position=-1):
        return self.hierarchy[level][col].get_function(collection, position)

    def set_links(self, func_name, *link_names):
        func = FunctionsList.getInstance().get_function(self.namespace, func_name)
        func.clear_links()
        for link_name in link_names:
            func.add_link(FunctionsList.getInstance().get_function(self.namespace, link_name))
            
    def add_links(self, func_name, *link_names):
        for link_name in link_names:
            FunctionsList.getInstance().get_function(self.namespace, func_name).add_link(FunctionsList.getInstance().get_function(self.namespace, link_name))
            
            
    def get_history_data(self):
        history_data = self.get_prepost_data()
        #for key in history_data.data.keys():
        #    print(key)

        for level in range(len(self.hierarchy)):
            for col in range(len(self.hierarchy[level])):
                node = self.get_node(level,col)
                for key in node.history.data['refcoll'].keys():
                    #print(key)
                    history_data.add_list(key,node.history.data['refcoll'][key])
                for key in node.history.data['percoll'].keys():
                    #print(key)
                    history_data.add_list(key,node.history.data['percoll'][key])
                for key in node.history.data['comcoll'].keys():
                    #print(key)
                    history_data.add_list(key,node.history.data['comcoll'][key])
                for key in node.history.data['outcoll'].keys():
                    #print(key)
                    history_data.add_list(key,node.history.data['outcoll'][key])

        return history_data.data            
    
    def hierarchy_plots(self, title='plot', plot_items={}, figsize=(15,4), file=None, experiment=None, history=None):
        from matplotlib import style
        import matplotlib.pyplot as plt
        if history == None:
            history = self.get_history_data()

        num_items = len(history[list(history.keys())[0]])
        #x = np.linspace(0, num_items-1, num_items)
        x =  [i for i in range(num_items)]
        style.use('fivethirtyeight')

        fig = plt.figure(figsize=figsize)
        ax1 = fig.add_subplot(1,1,1)

        for key in plot_items.keys():    
            ax1.plot(x, history[key], label=plot_items[key])

        # if experiment or file:
        plt.title(title)
        plt.legend()

        if file != None:
            plt.savefig(file)
            
        # plt.show()
        if experiment:
            experiment.log_figure(figure_name=title,figure=fig)

        return fig

    def hierarchy_plot(self, title='plot', plot_items={}, figsize=(15,4), file=None, experiment=None, history=None):
        from matplotlib import style
        import matplotlib.pyplot as plt
        if history == None:
            history = self.get_history_data()

        num_items = len(history[list(history.keys())[0]])
        #x = np.linspace(0, num_items-1, num_items)
        x =  [i for i in range(num_items)]
        style.use('fivethirtyeight')

        fig = plt.figure(figsize=figsize)
        ax1 = fig.add_subplot(1,1,1)

        for key in plot_items.keys():    
            ax1.plot(x, history[key], label=plot_items[key])

        # if experiment or file:
        plt.title(title)
        plt.legend()

        if file != None:
            plt.savefig(file)
            
        # plt.show()
        if experiment:
            experiment.log_figure(figure_name=title,figure=fig)

        return fig


    def get_parameters_list(self):

        pre = []
        post = []

        for func in self.preCollection:
            pre.append(func.get_parameters_list())
        
        for func in self.postCollection:
            post.append(func.get_parameters_list())
                    
        lowest = [pre, post]

        hpct=[lowest]
        
        for level in self.hierarchy:
            level_list=[]
            for node in level:
                level_list.append(node.get_parameters_list())
            hpct.append(level_list)
                
        return hpct
    
    @classmethod
    def from_config_with_environment(cls, config, seed=None, history=False, suffixes=False, environment_properties=None):
        "Create an individual from a provided configuration."
        hpct = PCTHierarchy(history=history)
        namespace = hpct.namespace
        #print(namespace)
        preCollection = []        
        coll_dict = config['pre']
        env_dict = coll_dict.pop('pre0')

        env = EnvironmentFactory.createEnvironment(env_dict['type'], namespace=namespace, seed=seed, gym_name=env_dict['env_name'])
        env.set_properties(environment_properties)
        for key, link in env_dict['links'].items():
            env.add_link(link)
        preCollection.append(env)
        PCTNode.collection_from_config(preCollection, coll_dict, namespace=namespace)
        
        hpct.preCollection=preCollection
                
        hpct.hierarchy=[]

        # do in order of perceptions from bottom 
        # then from top references, comparator and output

        for level_key in config['levels']:
            cols = []
            for nodes_key in config['levels'][level_key]['nodes']:
                if 'node' in config['levels'][level_key]['nodes'][nodes_key]:
                    node = PCTNode.from_config(config['levels'][level_key]['nodes'][nodes_key]['node'], namespace=namespace, perception=True, history=history)
                else:
                    node = PCTNode(default=False, namespace=namespace, history=history)    
                cols.append(node)

            hpct.hierarchy.append(cols)

        for level_key, level_value in dict(reversed(list(config['levels'].items()))).items():
            cols = []
            for nodes_key, nodes_value in dict(reversed(list(level_value['nodes'].items()))).items():
                node = hpct.get_node(level_value['level'], nodes_value['col'])
                if 'node' in nodes_value:
                    PCTNode.from_config(config=nodes_value['node'], namespace=namespace, reference=True, comparator=True,  output=True, node=node, history=history)
                                
        postCollection = []        
        coll_dict = config['post']
        PCTNode.collection_from_config(postCollection, coll_dict, namespace=namespace)
        hpct.postCollection=postCollection

        if suffixes:
            hpct.set_suffixes()
        return hpct
    
    def formatted_config(self, places=3):
        str_list=[]
        hpct = self.get_parameters_list()
        levels = len(hpct)
        level = 0
        str_list.append(f'grid: {self.get_grid()}\n')
        for lvl in hpct:
            #print(lvl)
            if level==0:
                str_list.append(f'env: {lvl[0]} act: ')
                str_list.append(floatListsToString(lvl[1],places))                
                str_list.append('\n')                #str_list.append(f'env: {lvl[0]} act: {lvl[1]:0.3f}\n')
            else:
                str_list.append(f'level{level-1} \n')
                column = 0
                for col in lvl:
                    str_list.append(f'col: {column} ')
                    str_list.append(f'ref: ')
                    str_list.append(floatListsToString(col[0], places))
                    str_list.append(f' per: ')
                    str_list.append(floatListsToString(col[1], places))
                    str_list.append(f' out: ')
                    str_list.append(floatListsToString(col[2], places))
                    if level < levels-1:
                        str_list.append('\n')
                    column = column + 1
            level=level+1
            
        return ''.join(str_list)


    def get_plots_config(self, plots, title_prefix=""):
        if isinstance(plots, list):
            return plots

        def create_named_plot_item(name=None, ptitle=None):
            plot_item = {}
            signals = {}
            signals[name] = name
            plot_item['plot_items'] = signals
            title = ptitle if ptitle else name
            plot_item['title'] =  f'{title_prefix}{title}'
            return plot_item

        def create_plot_item(func1, func2=None, ptitle=None):
            plot_item = {}
            signals = {}
            if func2:
                signals[func2.get_name()] = func2.get_name()
                signals[func1.get_name()] = func1.get_name()
            else:
                signals[func1.get_name()] = func1.get_name()
            plot_item['plot_items'] = signals
            title = ptitle if ptitle else func1.get_name()
            plot_item['title'] =  f'{title_prefix}{title}'
            return plot_item

        plots_list = []
        top_done=False
        plot_items = plots.split(',')
        for plot_item in plot_items:
            if plot_item == 'scTop':
                top_done=True
                for level in self.get_top_level():
                    if isinstance(level, list):
                        for node in level:
                            plots_list.append(create_plot_item(node.get_reference_function(), node.get_perception_function(), node.get_name()))
                    else:
                        plots_list.append(create_plot_item(level.get_reference_function(), level.get_perception_function(), level.get_name()))

            if plot_item == 'scEdges':
                top_done=True
                for func in self.get_preprocessor()[1:]:
                    plots_list.append(create_plot_item(func))

                for func in self.get_postprocessor():
                    plots_list.append(create_plot_item(func))

                for level in self.get_top_level():
                    if isinstance(level, list):
                        for node in level:
                            plots_list.append(create_plot_item(node.get_reference_function(), node.get_perception_function(), node.get_name()))
                    else:
                        plots_list.append(create_plot_item(level.get_reference_function(), level.get_perception_function(), level.get_name()))

            if plot_item == 'scZero':
                if self.get_levels() == 1 and top_done:
                    pass
                else:
                    for level in self.get_level(0):
                        if isinstance(level, list):
                            for node in level:
                                if node.has_reference_function():
                                    plots_list.append(create_plot_item(node.get_reference_function(), node.get_perception_function(), node.get_name()))
                        else:
                            if level.has_reference_function():
                                plots_list.append(create_plot_item(level.get_reference_function(), level.get_perception_function(), level.get_name()))

            if plot_item == 'scFitness':
                plots_list.append(create_named_plot_item('fitness', 'Fitness'))

            if plot_item == 'scReward':
                plots_list.append(create_named_plot_item('reward', 'Reward'))
                plots_list.append(create_named_plot_item('reward_sum', 'RewardSum'))

            if plot_item == 'scError':
                plots_list.append(create_named_plot_item('error', 'Error'))


        return plots_list

    


                    
    ## run_from_file
    @classmethod
    def run_from_file(cls, filename, min=None, env_props=None, seed=None, render=False, history=False, move=None, plots=None, hpct_verbose= False, 
                    runs=None, plots_dir=None, early_termination = None, draw_file=None, experiment=None, log_experiment_figure=False, suffixes=False,
                    enhanced_environment_properties=None, title_prefix="", video=False):
        
        prp = PCTRunProperties()
        prp.load_db(filename)
        if experiment:
            config = prp.db.pop('config')
            experiment.log_parameters(prp.db)
            prp.db['config'] = config
            if 'environment_properties' in prp.db:
                if 'history' in  prp.db['environment_properties']:
                    ep = eval(prp.db['environment_properties'])
                    experiment.log_metric('history', ep['history'])

        error_collector_type = prp.db['error_collector_type'].strip()
        error_response_type = prp.db['error_response_type']
        error_limit = eval(prp.db['error_limit'])
        if env_props is None:
            environment_properties = eval(prp.db['environment_properties'])
        else:
            environment_properties = env_props    
        error_properties = prp.get_error_properties()

        if enhanced_environment_properties is not None:
            environment_properties = environment_properties | enhanced_environment_properties 
            if 'dataset' in environment_properties:
                if environment_properties['dataset'] == 'test':
                    error_properties[1][1] = environment_properties['initial']
        # print(environment_properties)
        if runs==None:
            runs = eval(prp.db['runs'])
        config = eval(prp.db['config'])
        if seed is None:
            seed = eval(prp.db['seed'])
        # print(f'Seed={seed}')
        if early_termination is None:
            early_termination = eval(prp.db['early_termination'])

        if video:
            environment_properties['video'] = video

        hierarchy, score = cls.run_from_config(config, min=min, render=render,  error_collector_type=error_collector_type, error_response_type=error_response_type, 
                                                error_properties=error_properties, error_limit=error_limit, steps=runs, hpct_verbose=hpct_verbose, history=history, 
                                                environment_properties=environment_properties, seed=seed, early_termination=early_termination, move=move, plots=plots, 
                                                suffixes=suffixes, plots_dir=plots_dir, draw_file=draw_file, experiment=experiment, log_experiment_figure=log_experiment_figure, 
                                                title_prefix = title_prefix)
        
        return hierarchy, score 

    def run_hierarchy(self, render=False, steps=500, hpct_verbose=False):
        env = self.get_preprocessor()[0]
        env.set_render(render)
        if hpct_verbose:
            self.summary()
            print(self.formatted_config())
        self.run(steps, hpct_verbose)
        env.close()

        score = self.get_environment_score() if self.get_environment_score() is not None else self.get_error_collector().error()

        return score    
 
    def draw_hierarchy(self,  draw_file=False, move=None, with_edge_labels=True, font_size=6, node_size=100, experiment=None, log_experiment_figure=False):
        # draw network file
        move = {} if move == None else move
        if experiment or draw_file:
            if log_experiment_figure:
                self.draw(file=draw_file, move=move, with_edge_labels=with_edge_labels, font_size=font_size, node_size=node_size, experiment=experiment)
            else:
                self.draw(file=draw_file, move=move, with_edge_labels=with_edge_labels, font_size=font_size, node_size=node_size)
            if draw_file:
                print(draw_file)

    def plot_hierarchy(self, plots=None,
        history=False, plots_figsize=(15,4), plots_dir=None, experiment=None, title_prefix=""):
        if history:
            if plots:
                plots = self.get_plots_config(plots, title_prefix)
                
                for plot in plots:
                    plotfile=None
                    if plots_dir:
                        plotfile = plots_dir + sep + plot['title'] + '-' + str(self.get_namespace()) + '.png'
                    fig = self.hierarchy_plots(title=plot['title'], plot_items=plot['plot_items'], figsize=plots_figsize, file=plotfile, experiment=experiment)
                    import matplotlib.pyplot as plt
                    plt.close(fig)  # Close the figure here

    @classmethod
    def run_from_config(cls, config, min=None, render=False,  error_collector_type=None, error_response_type=None, 
        error_properties=None, error_limit=100, steps=500, hpct_verbose=False, early_termination=False, 
        seed=None, draw_file=False, move=None, with_edge_labels=True, font_size=6, node_size=100, plots=None,
        history=False, suffixes=False, plots_figsize=(15,4), plots_dir=None, flip_error_response=False, 
        environment_properties=None, experiment=None, log_experiment_figure=False, title_prefix="", nevals=1    ):
        "Run an individual from a provided configuration."

        if callable(min):
            raise Exception("min must not be a function")


        hierarchy = cls.from_config_with_environment(config, seed=seed, history=history, suffixes=suffixes, environment_properties=environment_properties)
        env = hierarchy.get_preprocessor()[0]
        env.set_render(render)
        env.early_termination = early_termination
        env.reset(full=False, seed=seed)
        if error_collector_type is not None:
            error_collector = BaseErrorCollector.collector(error_response_type, error_collector_type, error_limit, min, properties=error_properties, flip_error_response=flip_error_response)
            hierarchy.set_error_collector(error_collector)
        if hpct_verbose:
            hierarchy.summary()
            print(hierarchy.formatted_config())
        score = 0

        for i in range(nevals):
            hierarchy.reset()
            env.reset(full=False, seed=seed+i)
            hierarchy.get_error_collector().reset()

            hierarchy.run(steps, hpct_verbose)
    
            # error_score=hierarchy.get_error_collector().error()
            # environment_score = hierarchy.get_environment_score()
            # current_error = environment_score if environment_score is not None else error_score
            current_error = hierarchy.get_environment_score() if hierarchy.get_environment_score() is not None else hierarchy.get_error_collector().error()

            score += current_error


        env.close()
        score = score / nevals
        # draw network file
        move = {} if move == None else move
        if experiment or draw_file:
            if log_experiment_figure:
                hierarchy.draw(file=draw_file, move=move, with_edge_labels=with_edge_labels, font_size=font_size, node_size=node_size, experiment=experiment)
            else:
                hierarchy.draw(file=draw_file, move=move, with_edge_labels=with_edge_labels, font_size=font_size, node_size=node_size)
            if draw_file:
                print(draw_file)
        
        if history:
            if plots:
                plots = hierarchy.get_plots_config(plots, title_prefix)
                
                for plot in plots:
                    plotfile=None
                    if plots_dir:
                        plotfile = plots_dir + sep + plot['title'] + '-' + str(hierarchy.get_namespace()) + '.png'
                    fig = hierarchy.hierarchy_plots(title=plot['title'], plot_items=plot['plot_items'], figsize=plots_figsize, file=plotfile, experiment=experiment)
                    import matplotlib.pyplot as plt
                    plt.close(fig)  # Close the figure here



        return hierarchy, score    
    

    @classmethod
    def load_from_file(cls, filename, min=None, env_props=None, seed=None, render=False, runs=None, early_termination = False, experiment=None,  hpct_verbose= False, history=False, additional_props=None):
        
        prp = PCTRunProperties()
        prp.load_db(filename)
        if experiment:
            config = prp.db.pop('config')
            experiment.log_parameters(prp.db)
            prp.db['config'] = config
            if 'environment_properties' in prp.db:
                if 'history' in  prp.db['environment_properties']:
                    ep = eval(prp.db['environment_properties'])
                    experiment.log_metric('history', ep['history'])

        error_collector_type = prp.db['error_collector_type'].strip()
        error_response_type = prp.db['error_response_type']
        error_limit = eval(prp.db['error_limit'])
        if env_props is None:
            environment_properties = eval(prp.db['environment_properties'])
        else:
            environment_properties = env_props    
        error_properties = prp.get_error_properties()
        if additional_props:
            environment_properties.update(additional_props)
        if runs==None:
            runs = eval(prp.db['runs'])
        config = eval(prp.db['config'])
        if seed is None:
            seed = eval(prp.db['seed'])
        # print(f'Seed={seed}')
        if early_termination is None:
            early_termination = eval(prp.db['early_termination'])

        hierarchy, env = cls.load_from_config(config, min=min, render=render,  error_collector_type=error_collector_type, error_response_type=error_response_type, 
                                                error_properties=error_properties, error_limit=error_limit, hpct_verbose=hpct_verbose,  history=history,
                                                environment_properties=environment_properties, seed=seed, early_termination=early_termination)
        
        return hierarchy, env, environment_properties


    @classmethod
    def load_from_config(cls, config, min=None, render=False,  error_collector_type=None, error_response_type=None, 
        error_properties=None, error_limit=100, hpct_verbose=False, early_termination=None, 
        seed=None, history=False, suffixes=False, flip_error_response=False, environment_properties=None):
        "Load an individual from a provided configuration."

        if callable(min):
            raise Exception("min must not be a function")

        hierarchy = cls.from_config_with_environment(config, seed=seed, history=history, suffixes=suffixes, environment_properties=environment_properties)
        env = hierarchy.get_preprocessor()[0]
        env.set_render(render)
        env.early_termination = early_termination
        env.reset(full=False, seed=seed)
        if error_collector_type is not None:
            error_collector = BaseErrorCollector.collector(error_response_type, error_collector_type, error_limit, min, properties=error_properties, flip_error_response=flip_error_response)
            hierarchy.set_error_collector(error_collector)
        if hpct_verbose:
            hierarchy.summary()
            print(hierarchy.formatted_config())
        
        return hierarchy, env    


    @classmethod
    def run_and_draw_hierarchy(cls, hierarchy, env, steps=500, hpct_verbose=False, draw_file=False, draw_filename=None, move=None, with_edge_labels=True, font_size=6, node_size=100, plots=None,
        history=False, plots_figsize=(15,4), plots_dir=None, experiment=None, log_experiment_figure=False, funcdata=True, draw_figsize=(8,8)):

        hierarchy.run(steps, hpct_verbose)
        env.close()
        
        # draw network file
        move = {} if move == None else move
        dfig = None
        if experiment or draw_file:
            if log_experiment_figure:
                dfig = hierarchy.draw(file=draw_filename, move=move, with_edge_labels=with_edge_labels, font_size=font_size, node_size=node_size, experiment=experiment, funcdata=funcdata, figsize=draw_figsize)
            else:
                dfig = hierarchy.draw(file=draw_filename, move=move, with_edge_labels=with_edge_labels, font_size=font_size, node_size=node_size, funcdata=funcdata, figsize=draw_figsize)
            if draw_filename:
                print(draw_filename)
        
        pfigs=None
        if history:
            if plots:
                plots = hierarchy.get_plots_config(plots)
                pfigs = []
                for plot in plots:
                    plotfile=None
                    if plots_dir:
                        plotfile = plots_dir+ sep +plot['title']+'.png'
                    pfig = hierarchy.hierarchy_plots(title=plot['title'], plot_items=plot['plot_items'], figsize=plots_figsize, file=plotfile, experiment=experiment)
                    pfigs.append(pfig)

        error_score=hierarchy.get_error_collector().error()
        environment_score = hierarchy.get_environment_score()
        score = environment_score if environment_score is not None else error_score

        return score, dfig, pfigs


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()