In [None]:
import networkx

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pymedphys._experimental import tree, graphviz

In [None]:
module_dependencies = tree.get_module_dependencies()
internal_modules = set(module_dependencies.keys())

In [None]:
root = 'pymedphys'

top_level_api = [item for item in module_dependencies[root] if not item[2].startswith('_')]
module_apis = [item[0] for item in top_level_api if item[0] == item[1]]

second_level_apis = {}
for module in module_apis:
    second_level_apis[module] = [item for item in module_dependencies[module] if not item[2].startswith('_')]
    
exposure_module_maps = {
    f"{root}.{item[2]}": item[1] for item in top_level_api if item[0] != item[1]
}

for module, second_level_api in second_level_apis.items():
    exposure_module_maps = {
        **exposure_module_maps,
        **{f"{module}.{item[2]}": item[1] for item in second_level_api}        
    }

exposure_module_maps

In [None]:
api_to_graph = top_level_api[0]
api_to_graph

In [None]:
# traversal of the pymedphys public API

traversal_nodes = {api_to_graph[1]}
di_graph = networkx.DiGraph()
di_graph.add_node(api_to_graph[1])

while traversal_nodes:
    node = traversal_nodes.pop()
    raw_dependencies = module_dependencies[node]
    
    for dependency in raw_dependencies:
        if (
            not dependency[2].startswith('_') and
            not dependency[1] in di_graph and
            dependency[1] in internal_modules
        ):
            traversal_nodes.add(dependency[1])
            di_graph.add_node(dependency[1])
            di_graph.add_edge(node, dependency[1])

In [None]:
networkx.draw(di_graph)

In [None]:
def create_href(text):
    return "#{}".format(text.replace("_", "-").replace(".", "-"))


def create_link(text):
    return '[URL="{}"]'.format(create_href(text))


def create_labels(label_map):
    labels = ""
    for node, label in label_map.items():
        labels += '"{}" [label="{}"] {};\n'.format(node, label, create_link(node))

    return labels


In [None]:
# graphviz.dot_string_to_svg(
#     """
#         digraph sample {
#             A -> B;
#             B -> C;
#             C -> E;
#         }
#     """, 
#     'test.svg'
# )

In [None]:
edges = ""
for edge in di_graph.edges:
    edges = edges + f'"{edge[0]}" -> "{edge[1]}";\n'
    
# print(edges)

In [None]:
graphviz.dot_string_to_svg(
    f"""
        digraph sample {{
            {edges}
        }}
    """, 
    'test.svg'
)

In [None]:
#     print(f"""
#         digraph sample {{
#             {edges}
#         }}
#     """)

In [None]:
# module_dependencies

In [None]:
def is_stable_public_api(module_name):
    not_stable_api_keys = ['._', 'beta', 'tests', 'docs', 'cli', 'experimental']
    for key in not_stable_api_keys:
        if key in module_name:
            return False

    return True

In [None]:
public_modules = [
    key for key in module_dependencies.keys()
    if is_stable_public_api(key)
]

In [None]:
# TODO:

# Get each publicly exposed API function/class.
# Draw a dependency tree for those publicly exposed APIs.
# Have a tree per API.

# The below doesn't achieve that yet.

In [None]:
public_modules

In [None]:
def append_module_dependencies_to_graph(di_graph, modules, dependency_map):
    new_modules_to_traverse = set()
    for module in modules:
        if module not in di_graph:
            di_graph.add_node(module)
        for dependency in dependency_map[module]:
            if dependency not in di_graph and dependency not in modules:
                di_graph.add_node(dependency)
                
                if dependency in dependency_map.keys():
                    new_modules_to_traverse.add(dependency)
                
            di_graph.add_edge(module, dependency)
            
    return di_graph, new_modules_to_traverse        

In [None]:
di_graph = networkx.DiGraph()
modules_to_traverse = public_modules

while modules_to_traverse:
    di_graph, modules_to_traverse = append_module_dependencies_to_graph(
        di_graph, modules_to_traverse, module_dependencies)

In [None]:
networkx.draw(di_graph)