# Workflow diagram

In [None]:
import pkg_resources
import json
import os
from glob import glob

import matplotlib.pyplot as plt
import networkx as nx
import pydot
import numpy as np
from IPython.display import Image, display
import math

In [None]:
# PROCESSES = ["motioncor2", "ctffind", "ctfsim", "imod_align", "imod_recon", "aretomo_recon", "savu_recon", "rlf_deconv"]
PROCESSES = ["motioncor2", "ctffind", "ctfsim", "aretomo_align", "imod_align", "imod_recon", "aretomo_recon", "savu_recon", "rlf_deconv", "exclude_bad_tilts"]


In [None]:
def read_ipynb(filename):
    fn = pkg_resources.resource_filename("RepoTemp.templates", filename)

    with open(fn, 'r') as f:
        return json.load(f)

def get_processes(plist):
    node_list = []
    for idx, curr_proc in enumerate(plist):
        # Find relevant log files
        file_lookfor = f"{os.getcwd()}/o2r_{curr_proc}.log"
        file_found = len(glob(file_lookfor))==1

        if file_found:
            node_list.append(idx+1)
    
    return node_list

def create_graph(plist: list):
    proc_dict = {
        1: "Motion Corr.",
        2: "CTF estimation",
        3: "CTF simulation",
        4: "TS alignment (AreTomo)",
        5: "TS alignment (IMOD)",
        6: "Reconstruction (IMOD)",
        7: "Reconstruction (AreTomo)",
        8: "Reconstruction (Savu)",
        9: "Deconvolution (RedLionfish)",
        10: "ExcludeBadTilts"
    }

    o2r_flow = [
        (1, 2), (2, 3), (3, 9), (6, 9), # CTFSim + RLF
        (1, 4), (4, 7),                 # AreTomo align + recon
        (1, 5), (5, 6),                 # IMOD align + recon
        (5, 8),                         # IMOD align + Savu recon
        (5, 7),                         # IMOD align + Aretomo recon
        (4, 8),                         # Aretomo align + Savu recon
        (1, 10), (10, 4), (10, 5)       # Exclude bad tilts
    ]
    
    proc_names = [proc_dict[i] for i in plist]
    
    # graph stuff
    g = pydot.Dot(graph_type="digraph", rankdir="LR")
    for _, curr_node in enumerate(proc_names):
        node = pydot.Node(curr_node)
        node.set_shape("box3d")
        g.add_node(node)
        
    get_dict = lambda x: tuple([proc_dict[i] for i in x])
    proc_flow = [get_dict(i) for i in o2r_flow if set(i).issubset(plist)]
    for _, curr_edge in enumerate(proc_flow):
        edge = pydot.Edge(*curr_edge)
        g.add_edge(edge)
        
    g_plot = Image(g.create_png())
    display(g_plot)

In [None]:
nodes = get_processes(PROCESSES)
create_graph(nodes)