### A notebook for looking into visualising workflow dependencies

**Note this required additional dependencies: `networkx`, `matplotlib` and `plotly`**

In [None]:
%load_ext autoreload
%autoreload 2

import tempfile
import numpy as np
import pytest

import networkx
import matplotlib.pyplot as plt
from plotly import graph_objects as go

from hpcflow.api import (
    hpcflow, Workflow, WorkflowTemplate, TaskSchema, Task, Parameter, InputValue, 
    SchemaInput, ValueSequence, InputSource, SchemaOutput, InputSourceType, TaskSourceType,
    ElementPropagation,
)
from hpcflow.sdk.core.utils import read_YAML
from hpcflow.sdk.core.errors import MissingInputs
from hpcflow.sdk.core.zarr_io import ZarrEncodable

hpcflow.load_config(config_dir=tempfile.gettempdir())

In [None]:
param_p1 = Parameter("p1")
param_p2 = Parameter("p2")
param_p3 = Parameter("p3")
param_p4 = Parameter("p4")

param_p1 = SchemaInput(param_p1, default_value=1001)
param_p2 = SchemaInput(param_p2, default_value=np.array([2002, 2003]))
param_p3 = SchemaInput(param_p3, default_value=3001)
param_p4 = SchemaInput(param_p4)

In [None]:
s1 = TaskSchema("ts1", actions=[], inputs=[param_p1], outputs=[param_p3])
t1 = Task(
    schemas=s1,
    inputs=[InputValue(param_p1, 101)],
)

s2 = TaskSchema("ts2", actions=[], inputs=[param_p2, param_p3], outputs=[param_p4])
t2 = Task(
    schemas=s2,
    sequences=[ValueSequence('inputs.p2', values=[201, 202], nesting_order=1)],
)

s3 = TaskSchema("ts3", actions=[], inputs=[param_p3, param_p4])
t3 = Task(schemas=s3, nesting_order={'inputs.p3': 0, 'inputs.p4': 1})

wkt = WorkflowTemplate(name="w1", tasks=[t1, t2, t3])
wk = Workflow.from_template(wkt, path=tempfile.gettempdir(), name=wkt.name)

wk.tasks.ts1.add_elements(
    sequences=[ValueSequence('inputs.p1', values=[102, 103, 104], nesting_order=1)],
    propagate_to=[
        ElementPropagation(
            task=wk.tasks.ts2,
            nesting_order={'inputs.p2': 0, 'inputs.p3': 1}
        ),
        ElementPropagation(
            task=wk.tasks.ts3,
            nesting_order={'inputs.p3': 0, 'inputs.p4': 1},
        )
    ],
)

In [None]:
def build_element_graph(workflow):
    G = networkx.DiGraph()
    for task in workflow.tasks:
        for element in task.elements:
            G.add_node(element.global_index, task=task.index)
        
    for i in wk.elements:
        for j in i.dependent_elements:
            G.add_edge(i.global_index, j)
        
    return G

def _prepare_element_graph(G):
    task_colours = ['blue', 'green', 'red']
    node_colours = [task_colours[data["task"]] for v, data in G.nodes(data=True)]
    pos = networkx.multipartite_layout(G, subset_key='task')    
    return node_colours, pos

def show_element_graph_matplotlib(G):
    node_colours, pos = _prepare_element_graph(G)
    networkx.draw(
        G,
        pos,
        node_color=node_colours,
        with_labels=True,
    )

def show_element_graph_plotly(G):
    node_colours, pos = _prepare_element_graph(G)
    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)

    edge_trace = go.Scatter(
        x=edge_x,
        y=edge_y,
        hoverinfo='none',
        mode='lines'
    )

    node_x = []
    node_y = []
    for node in G.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)

    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        mode='markers',
        hoverinfo='text',
    )

    fig = go.Figure(data=[edge_trace, node_trace])
    return fig

In [None]:
G = build_element_graph(wk)

show_element_graph_matplotlib(G)
show_element_graph_plotly(G)