In [1]:
import json
import os

In [3]:
class TLCState:
    def __init__(self, key, state_s) -> None:
        self.key = key
        self.state_s = state_s
        self.parse()

    def parse(self):
        pass

    def __repr__(self) -> str:
        return self.state_s


In [58]:
# TODO: 
# 1. Define structure for graph
# 2. Define parser to parse string and interpret the state variables
# 3. Define abstraction to transform graph and collapse states
class Graph:
    def __init__(self, nodes) -> None:
        self.nodes = dict([(key,GraphNode(node)) for key,node in nodes.items()])
        self.start_nodes = []
        for n, node in self.nodes.items():
            if not node.have_prev():
                self.start_nodes.append(n)
        
        def assign_depth(cur, nodes):
            cur.depth = min([n.depth+1 for n in [nodes[p] for p in cur.prev]]) if cur.have_prev() else 0
        self.traverse(assign_depth)

        depths = {}
        for key, node in self.nodes.items():
            if node.depth not in depths:
                depths[node.depth] = []
            depths[node.depth].append(key)
        
        for _, values in depths.items():
            values.sort()
            for (i,key) in enumerate(values):
                self.nodes[key].sibling = i

    def get_next(self, key):
        return [str(self.nodes[k]) for k in self.nodes[key].next]
    
    def max_depth(self):
        depth = 0
        for _, node in self.nodes.items():
            if node.depth > depth:
                depth = node.depth
        return depth

    # BFS traversal of the graph
    def traverse(self, visit_func):
        q = list(self.start_nodes)
        visited = set()
        while len(q) > 0:
            cur_node = self.nodes[q.pop(0)]
            if cur_node.key in visited:
                continue
            visited.add(cur_node.key)
            visit_func(cur_node, self.nodes)
            q+=list(cur_node.next)
            

class GraphNode:
    def __init__(self, node) -> None:
        self.key = node["Key"]
        self.state = TLCState(node["Key"], node["State"])
        self.visits = node["Visits"]
        self.next = set(node["Next"].keys()) if "Next" in node else set()
        self.prev = set(node["Prev"].keys()) if "Prev" in node else set()
        self.depth = -1
        self.sibling = -1
    
    def have_prev(self):
        return len(self.prev) != 0

    def __repr__(self) -> str:
        return str({"Key": self.key, "State": str(self.state), "Visits": self.visits})
    

In [59]:
def read_data(graph_file_path, name=""):
    if name == "":
        name = os.path.basename(graph_file_path)
    graph = {}
    with open(graph_file_path) as f:
        graph = json.load(f)
    return Graph(graph["Nodes"])

In [60]:
def next(node_key, graph):
    if node_key not in graph["Nodes"] or "Next" not in graph["Nodes"][node_key]:
        return []
    next_nodes = [graph["Nodes"][n] for n in graph["Nodes"][node_key]["Next"].keys()]
    return [{"Key": n["Key"], "State": n["State"], "Visits": n["Visits"]} for n in next_nodes]

In [61]:
def filter_visits(min_threshold, graph):
    return [graph["Nodes"][n] for n in graph["Nodes"].keys() if graph["Nodes"][n]["Visits"]> min_threshold]
        

In [62]:
def max_depth(graph):
    depth = 0
    for n in graph["Nodes"].keys():
        if graph["Nodes"][n]["Depth"] > depth:
            depth = graph["Nodes"][n]["Depth"]
    return depth

In [63]:
def at_depth(depth, graph):
    filtered_nodes = [graph["Nodes"][n] for n in graph["Nodes"].keys() if graph["Nodes"][n]["Depth"] == depth]
    return [{"Key": n["Key"], "State": n["State"], "Visits": n["Visits"]} for n in filtered_nodes]

In [64]:
import matplotlib.pyplot as plt
import numpy as np

def default_filter(visits):
    return visits

def compare_visit_hists(graphs, filter=default_filter):
    fig, ax = plt.subplots()
    for graph in graphs:
        visits = [graph["Nodes"][node]["Visits"] for node in graph["Nodes"]]
        visits = filter(visits)
        ax.hist(visits, bins=30,linewidth=0.5, edgecolor="white", label=graph["Name"])

    ax.legend()
    plt.show()

def min_visits(min_visit):
    return lambda visits: [v for v in visits if v >= min_visit]

def max_visits(max_visit):
    return lambda visits: [v for v in visits if v < max_visit]

def between_visits(min_v, max_v):
    return lambda visits: [v for v in visits if v >= min_v and v < max_v]
    

In [66]:
graph_random = read_data("/Users/srinidhin/random_10k_75_3_3/visit_graph_random.json")
graph_swapint = read_data("/Users/srinidhin/random_10k_75_3_3/visit_graph_scaleUpInt.json")

graph_random.max_depth()

36