In [1]:
import numpy as np
import pandas as pd
import copy
import xml.etree.ElementTree as ET

## Setup: Base template

This creates a default, blank cdxml file

In [2]:
root = ET.fromstring('''<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE CDXML SYSTEM "http://www.cambridgesoft.com/xml/cdxml.dtd" >
<CDXML
 CreationProgram="ChemDraw 22.0.0.22"
 Name="blank.cdxml"
 BoundingBox="0 0 0 0"
 WindowPosition="0 0"
 WindowSize="-2147483648 0"
 WindowIsZoomed="yes"
 FractionalWidths="yes"
 InterpretChemically="yes"
 ShowAtomQuery="yes"
 ShowAtomStereo="no"
 ShowAtomEnhancedStereo="yes"
 ShowAtomNumber="no"
 ShowResidueID="no"
 ShowBondQuery="yes"
 ShowBondRxn="yes"
 ShowBondStereo="no"
 ShowTerminalCarbonLabels="no"
 ShowNonTerminalCarbonLabels="no"
 HideImplicitHydrogens="no"
 LabelFont="3"
 LabelSize="7"
 LabelFace="96"
 CaptionFont="3"
 CaptionSize="7"
 HashSpacing="1.69"
 MarginWidth="1.19"
 LineWidth="1.08"
 BoldWidth="1.55"
 BondLength="10.80"
 BondSpacing="18"
 ChainAngle="120"
 LabelJustification="Auto"
 CaptionJustification="Left"
 AminoAcidTermini="HOH"
 ShowSequenceTermini="yes"
 ShowSequenceBonds="yes"
 ShowSequenceUnlinkedBranches="no"
 ResidueWrapCount="40"
 ResidueBlockCount="10"
 CaptionLineHeight="7"
 PrintMargins="36 36 36 36"
 MacPrintInfo="00030000004800480000000003000240FFF4FFEE030C02520367052803FC00020000004800480000000003000240000100000064000000010001010100000001270F000100010000000000000000000000000002001901900000000000600000000000000000000100000000000000000000000000000000"
 ChemPropName=""
 ChemPropFormula="Chemical Formula: "
 ChemPropExactMass="Exact Mass: "
 ChemPropMolWt="Molecular Weight: "
 ChemPropMOverZ="m/z: "
 ChemPropAnalysis="Elemental Analysis: "
 ChemPropBoilingPt="Boiling Point: "
 ChemPropMeltingPt="Melting Point: "
 ChemPropCritTemp="Critical Temp: "
 ChemPropCritPres="Critical Pres: "
 ChemPropCritVol="Critical Vol: "
 ChemPropGibbs="Gibbs Energy: "
 ChemPropLogP="Log P: "
 ChemPropMR="MR: "
 ChemPropHenry="Henry&apos;s Law: "
 ChemPropEForm="Heat of Form: "
 ChemProptPSA="tPSA: "
 ChemPropCLogP="CLogP: "
 ChemPropCMR="CMR: "
 ChemPropLogS="LogS: "
 ChemPropPKa="pKa: "
 ChemPropID=""
 ChemPropFragmentLabel=""
 color="0"
 bgcolor="1"
 RxnAutonumberStart="1"
 RxnAutonumberConditions="no"
 RxnAutonumberStyle="Roman"
 RxnAutonumberFormat="(#)"
><colortable>
<color r="1" g="1" b="1"/>
<color r="0" g="0" b="0"/>
<color r="0.9020" g="0.2157" b="0.2706"/>
<color r="0" g="0" b="1"/>
<color r="0.4902" g="0.4353" b="0.6706"/>
<color r="0.9451" g="0.5608" b="0.6000"/>
<color r="0.7608" g="0.8314" b="0.9373"/>
<color r="0.4000" g="0.3569" b="0.6000"/>
<color r="0.1059" g="0.1059" b="0.1059"/>
<color r="0.2353" g="0.2353" b="0.2353"/>
<color r="0.5059" g="0.5059" b="0.5059"/>
<color r="0.6667" g="0.6667" b="0.6667"/>
<color r="0.8157" g="0.8157" b="0.8157"/>
<color r="0.8667" g="0.8667" b="0.8667"/>
<color r="0.9020" g="0.9020" b="0.9020"/>
<color r="0.9451" g="0.9451" b="0.9451"/>
<color r="0.0510" g="0.2706" b="0.6000"/>
<color r="0.1216" g="0.3451" b="0.6902"/>
<color r="0.1255" g="0.4314" b="0.6902"/>
<color r="0.2275" g="0.5608" b="1"/>
<color r="0.4902" g="0.7451" b="1"/>
<color r="0.5961" g="0.7412" b="0.9608"/>
<color r="0.7059" g="0.7059" b="0.7059"/>
<color r="0.2353" g="0.2471" b="0.2667"/>
<color r="0.6902" g="0.6902" b="0.6902"/>
<color r="0.2000" g="0.2000" b="0.2000"/>
<color r="0.5020" g="0.5020" b="0.5020"/>
<color r="0.5922" g="0.6510" b="0.9608"/>
<color r="0.7961" g="0.8314" b="0.9059"/>
<color r="0.4353" g="0.4353" b="0.4353"/>
<color r="0.8863" g="0.8863" b="0.8863"/>
<color r="0.8510" g="0.6000" b="0.9608"/>
<color r="0.9804" g="0.6000" b="0.6510"/>
<color r="1" g="0.8157" b="0.8510"/>
<color r="0.8353" g="0.8510" b="0.9922"/>
<color r="0.0902" g="0.0902" b="0.0902"/>
<color r="1" g="0" b="0"/>
<color r="0.9059" g="0.9255" b="0.9608"/>
<color r="0.0902" g="0.0902" b="0.1059"/>
<color r="0.6000" g="0.6000" b="0.6000"/>
<color r="0.9451" g="0.8157" b="0.8157"/>
<color r="0.6353" g="0.0706" b="0.1255"/>
<color r="0.6314" g="0.0706" b="0.1216"/>
<color r="0.6353" g="0.0706" b="0.1216"/>
<color r="0.3255" g="0.3255" b="0.3255"/>
<color r="0.5373" g="0.6118" b="0.9451"/>
<color r="0.7804" g="0.8157" b="0.9961"/>
</colortable><fonttable>
<font id="3" charset="iso-8859-1" name="Arial"/>
</fonttable><page
 id="2093301810"
 BoundingBox="0 0 3240 4320"
 HeaderPosition="36"
 FooterPosition="36"
 PrintTrimMarks="yes"
 HeightPages="6"
 WidthPages="6"
><annotation
 Keyword="Catalytic Cycle"
 Content="Catalytic Cycle template"
/><annotation
 Keyword="GUID"
 Content="841B63F5-73DC-4D2E-89E0-791ECB0A5A66"
/><annotation
 Keyword="Name"
 Content="Ass"
/><annotation
 Keyword="Type"
 Content="stnd."
/></page></CDXML>''')


#get page element of root


## Helper functions to generate CDXML

In [3]:
def arrow_cdxml(head_pos, tail_pos,z, color=1, arrowhead_head = False, arrowhead_tail = False, dashed = False, hash_spacing = 4, arrowhead_props = (716, 626, 261), linewidth=1.08):
    """
    Generates an arrow CDXML element with the specified properties.

    Parameters:
    - head_pos (tuple): The position of the arrowhead's head in 2D space.
    - tail_pos (tuple): The position of the arrowhead's tail in 2D space.
    - z (int): The Z-depth of the arrow.
    - color (int): The color of the arrow.
    - arrowhead_head (bool): Whether to include an arrowhead at the head of the arrow.
    - arrowhead_tail (bool): Whether to include an arrowhead at the tail of the arrow.
    - dashed (bool): Whether the arrow should be dashed.
    - hash_spacing (int): The spacing between dashes in a dashed arrow.
    - arrowhead_props (tuple): The properties of the arrowhead (head size, center size, width).
    - linewidth (float): The width of the arrow's line.

    Returns:
    - ET.Element: The arrow CDXML element.
    """

    arrowhead_head_str = f'ArrowheadHead="{arrowhead_head}"' if arrowhead_head else ''
    arrowhead_tail_str = f'ArrowheadTail="{arrowhead_tail}"' if arrowhead_tail else ''
    dashed_str = f'LineType="Dashed"' if dashed else ''
    dash_style = f'HashSpacing="{hash_spacing}"' if dashed else ''

    return ET.fromstring(f'''<arrow
    color="{color}"
    FillType="None"
    LineWidth="{linewidth}"
    ArrowheadType="Solid"
    HeadSize="{arrowhead_props[0]}"
    ArrowheadCenterSize="{arrowhead_props[1]}"
    ArrowheadWidth="{arrowhead_props[2]}"
    {dash_style}
    {arrowhead_head_str}
    {arrowhead_tail_str}
    {dashed_str}
    Head3D="{head_pos[0]} {head_pos[1]} 0"
    Tail3D="{tail_pos[0]} {tail_pos[1]} 0"
    Z="{z}"
    />''' )

def networkbox_cdxml(head_pos, tail_pos, z, color):
    """
    Generate an XML element representing a side of the box. This is done with an arrow element.

    Args:
        head_pos (tuple): The position of the arrow head in 2D space.
        tail_pos (tuple): The position of the arrow tail in 2D space.
        z (float): The z-coordinate of the arrow in 3D space.
        color (str): The color of the arrow.

    Returns:
        xml.etree.ElementTree.Element: The XML element representing the arrow.
    """
    return arrow_cdxml(head_pos, tail_pos,z, color=color, arrowhead_head = False, arrowhead_tail = False, dashed = False,  arrowhead_props = (716, 626, 261), linewidth=1.08)

def whitebox_cdxml(head_pos, tail_pos,z, color=1):
    """
    Generate a CDXML representation of a whitebox (used for margins).

    Parameters:
    - head_pos (tuple): The position of the arrowhead in 2D space.
    - tail_pos (tuple): The position of the tail in 2D space.
    - z (float): The z-coordinate of the arrow in 3D space.
    - color (int): The color of the arrow.

    Returns:
    - xml.etree.ElementTree.Element: The CDXML representation of the arrow.
    """
    return ET.fromstring(f'''<arrow
    id="2093172180"
    color="{color}"
    BoldWidth="10"
    LineType="Bold"
    FillType="None"
    ArrowheadType="Solid"
    HeadSize="2000"
    ArrowheadCenterSize="1750"
    ArrowheadWidth="500"
    Head3D="{head_pos[0]} {head_pos[1]} 0"
    Tail3D="{tail_pos[0]} {tail_pos[1]} 0"
    Center3D="290 297.47 0"
    MajorAxisEnd3D="369.89 297.47 0"
    MinorAxisEnd3D="290 377.36 0"
    Z="{z}"
    />''')

def text_cdxml(text, position, z, color=0, y_offset=0, text_lowercase=False, size=7, vertical_shift=True):
    """
    Generate CDXML XML element for a text label.

    Args:
        text (str): The text content of the label.
        position (tuple): The x and y coordinates of the label.
        z (int): The z-index of the label.
        color (int, optional): The color of the label. Defaults to 0.
        y_offset (int, optional): The vertical offset of the label. Defaults to 0.
        text_lowercase (bool, optional): Whether to convert the text to lowercase. Defaults to False.
        size (int, optional): The font size of the label. Defaults to 7.
        vertical_shift (bool, optional): Whether to apply vertical shift to the label. Defaults to True.

    Returns:
        xml.etree.ElementTree.Element: The CDXML XML element representing the text label.
    """
    if text_lowercase:
        text = text.lower()

    if vertical_shift and size is not None:
        y_offset = size / 4

    return ET.fromstring(f'''<t
    id="2093301862"
    p="{position[0]} {position[1] + y_offset} 0"
    Warning="Chemical Interpretation is not possible for this label"
    CaptionJustification="Center"
    Justification="Center"
    LineHeight="7"
    Z="{z}"
    ><s font="3" size="{size}" color="{color}" face="1">{text}</s></t>''')


def circle_cdxml(center, radius, color, z):
    """
    Generate a CDXML representation of a filled circle.

    Parameters:
    - center: tuple of two floats representing the x and y coordinates of the center of the circle
    - radius: float representing the radius of the circle
    - color: string representing the color of the circle
    - z: float representing the z-coordinate of the circle

    Returns:
    - xml.etree.ElementTree.Element: the CDXML representation of the circle
    """
    return ET.fromstring(f'''<graphic
    id="2093172235"
    color="{color}"
    GraphicType="Oval"
    OvalType="Circle Filled"
    Center3D="{center[0]} {center[1]} 0"
    MajorAxisEnd3D="{center[0] + radius} {center[1]} 0"
    MinorAxisEnd3D="{center[0]} {center[1] + radius} 0"
    Z="{z}"
/>''')

## Helper math functions to generate network

In [4]:
def midpoint(a, b):
    """
    Calculate the midpoint between two points.
    
    Parameters:
    a (tuple): The first point as a tuple of (x, y).
    b (tuple): The second point as a tuple of (x, y).
    
    Returns:
    tuple: The midpoint between the two points.
    """
    return (a[0] + b[0])/2, (a[1] + b[1])/2

def shift(x,y,slope, distance):
    """
    Shift a point along a line defined by a slope.
    
    Parameters:
    x (float): The x-coordinate of the point.
    y (float): The y-coordinate of the point.
    slope (float): The slope of the line.
    distance (float): The distance to shift the point.
    
    Returns:
    tuple: The new coordinates of the point.
    """
    if np.isnan(slope):
        return x, y + distance
    else:
        return x + distance / np.sqrt(1 + slope**2), y + slope * distance / np.sqrt(1 + slope**2)

def slope(x1,y1,x2,y2):
    """
    Calculate the slope of a line defined by two points.
    
    Parameters:
    x1, y1 (float): The coordinates of the first point.
    x2, y2 (float): The coordinates of the second point.
    
    Returns:
    float: The slope of the line. Returns NaN if the line is vertical.
    """
    if x2-x1 != 0:
        return (y1-y2)/(x1-x2)
    else:
        return float('NaN')

def distance(x1,y1,x2,y2):
    """
    Calculate the Euclidean distance between two points.
    
    Parameters:
    x1, y1 (float): The coordinates of the first point.
    x2, y2 (float): The coordinates of the second point.
    
    Returns:
    float: The distance between the two points.
    """
    return np.sqrt((x1-x2)**2 + (y1-y2)**2)

def intersection(x1,y1,x2,y2,x3,y3,x4,y4):
    """
    Calculate the intersection point of two lines, each defined by two points.
    
    Parameters:
    x1, y1 (float): The coordinates of the first point on the first line.
    x2, y2 (float): The coordinates of the second point on the first line.
    x3, y3 (float): The coordinates of the first point on the second line.
    x4, y4 (float): The coordinates of the second point on the second line.
    
    Returns:
    tuple: The coordinates of the intersection point. Returns (NaN, NaN) if the lines are parallel.
    """
    m1 = slope(x1,y1,x2,y2)
    m2 = slope(x3,y3,x4,y4)
    b1 = y1 - m1 * x1
    b2 = y3 - m2 * x3
    if np.isnan(m1):
        return x1, m2 * x1 + b2
    elif np.isnan(m2):
        return x3, m1 * x3 + b1
    elif m1 == m2:
        return float('NaN'), float('NaN')
    else:
        x = (b2 - b1) / (m1 - m2)
        y = m1 * x + b1
        return x, y

## Function to create a CDXML file
A rate parameter dataframe is passed to the function, which contains the following columns (example .csv files are provided for the alpha and beta networks):
- 's1' (str): The first sugar in the reaction.
- 's2' (str): The second sugar in the reaction.
- 'rate' (float): The rate of the reaction.
- 'sel' (float): The selectivity of the reaction. (ratio term, positive favoring s2, negative favoring s1)
- 'arrow_color' (str): (Optional) The color of the reaction arrow.
- 'box_color' (str): (Optional) The color of the box for the corresponding edge.
- 'dashed_arrow' (bool): (Optional) Whether the arrow is reaction arrow is dashed.
- 'direction' (str): (Optional) Whether to include only the forward, backward, or both directions. Valid options are "forward", "backward", or "both".


A settings dictionary is passed to the function, which contains the following keys:
- 'coordinates' (tuple): The x and y coordinates for the network diagram. Default is (700, 700).
- 'length' (float): The length of the network diagram. Default is 100.
- 'max_length_part' (float): The maximum length part of the network diagram. Default is the value of 'length'.
- 'min_length_part' (float): The minimum length part of the network diagram. Default is None.
- 'spacing' (float): The spacing for the network diagram. Default is 5.5.
- 'margin' (float): The margin for the network diagram. Default is 5.
- 'white_circle_size' (float): The size of the white circle in the network diagram. Default is 0.2 times the 'length'.
- 'var_arrowhead' (bool): Whether to vary the arrowhead. Default is False.
- 'arrowhead_props_max' (tuple): The maximum properties of the arrowhead. Default is (716, 626, 261).
- 'arrowhead_props_min' (tuple): The minimum properties of the arrowhead. Default is (478, 418, 200).
- 'linewidth' (float): The width of the line in the network diagram. Default is 1.08.
- 'circle_points' (bool): Whether to include circle points in the network diagram. Default is False.
- 'circle_points_text' (bool): Whether to include text for the circle points in the network diagram. Default is False.
- 'circle_points_text_size' (bool): The size of the text for the circle points in the network diagram. Default is False.
- 'circle_points_size' (float): The size of the circle points in the network diagram. Default is 0.1 times the 'length'.
- 'circle_colors' (dict): The colors for the circles in the network diagram. Default is a dictionary with keys 'Alt', 'Gal', 'Glc', 'Gul', 'Ido', 'Man', 'Tal', 'All' and all values '19'.
- 'sugar_list' (list): The list of sugars for the network diagram. Default is ['Man','Alt','Tal','Ido','Glc','All','Gal','Gul'].



In [5]:
def draw_network(root, rates_df, settings, filename):
    """
    Draws a network diagram based on the given parameters.

    Parameters:
    - root (Element): The root element of the XML tree.
    - rates_df (DataFrame): The rates data frame.
    - settings (dict): A dictionary containing various settings for the network diagram.
    - filename (str): The name of the file to save the network diagram.

    Returns:
    None
    """

    # setup xml tree and page element
    root_cpy = copy.deepcopy(root)
    page = root_cpy.find("page")

    # parse settings or set defaults if not provided
    if "coordinates" in settings:
        x, y = settings["coordinates"]
    else:
        x, y = 700, 700

    if "length" in settings:
        l = settings["length"]
    else:
        l = 100

    if "max_length_part" in settings:
        max_length = settings["max_length_part"] * l
    else:
        max_length = l

    if "min_length_part" in settings:
        if settings["min_length_part"] is None:
            minlength = None
        else:
            minlength = l * settings["min_length_part"]
    else:
        minlength = None

    if "spacing" in settings:
        spacing = settings["spacing"]
    else:
        spacing = 5.5

    if "margin" in settings:
        margin = settings["margin"]
    else:
        margin = 5
    if "white_circle_size" in settings:
        circle_size = settings["white_circle_size"] * l
    else:
        circle_size = l * 0.2

    if "var_arrowhead" in settings:
        var_arrowhead = settings["var_arrowhead"]

        if var_arrowhead:
            if "arrowhead_props_max" in settings:
                arrowhead_props_max = settings["arrowhead_props_max"]
            else:
                arrowhead_props_max = (716, 626, 261)

            if "arrowhead_props_min" in settings:
                arrowhead_props_min = settings["arrowhead_props_min"]
            else:
                arrowhead_props_min = (478, 418, 200)
    else:
        var_arrowhead = False

    if "linewidth" in settings:
        linewidth = settings["linewidth"]
    else:
        linewidth = 1.08

    if "circle_points" in settings and settings["circle_points"]:
        circle_points = True

        if "circle_points_text" in settings:
            circle_points_text = settings["circle_points_text"]
        else:
            circle_points_text = False

        if "circle_points_text_size" in settings:
            circle_points_text_size = settings["circle_points_text_size"]
        else:
            circle_points_text_size = False

        if "circle_points_size" in settings:
            circle_points_size = settings["circle_points_size"] * l
        else:
            circle_points_size = l * 0.1

        if "circle_colors" in settings:
            circle_colors = settings["circle_colors"]
        else:
            circle_colors = {
                "Alt": "19",
                "Gal": "19",
                "Glc": "19",
                "Gul": "19",
                "Ido": "23",
                "Man": "19",
                "Tal": "19",
                "All": "19",
            }
    else:
        circle_points = False
        circle_points_text = False

    if "sugar_list" in settings:
        sugar_list = settings["sugar_list"]
    else:
        sugar_list = ["Man", "Alt", "Tal", "Ido", "Glc", "All", "Gal", "Gul"]

    # process the rates data frame
    df = rates_df.copy()

    # calculate forward and reverse rates based on selectivity and base rate
    df["rate_f"] = np.where(
        df["sel"] >= 0,
        df["rate"] * np.sqrt(df["sel"].abs()),
        df["rate"] / np.sqrt(df["sel"].abs()),
    )
    df["rate_b"] = np.where(
        df["sel"] < 0,
        df["rate"] * np.sqrt(df["sel"].abs()),
        df["rate"] / np.sqrt(df["sel"].abs()),
    )

    # calculate the maximum rate for scaling
    maxrate = max(df.rate_f.max(), df.rate_b.max())

    sqrt2 = np.sqrt(2)

    # calculate the coordinates of the sugar circles based on a cube
    coords = []
    coords.append((x + l / sqrt2, y - l / sqrt2))
    coords.append((x + l / sqrt2 + l, y - l / sqrt2))
    coords.append((x, y))
    coords.append((x + l, y))
    coords.append((x + 2 * l / sqrt2, y))
    coords.append((x + 2 * l / sqrt2 + l, y))
    coords.append((x + l / sqrt2, y + l / sqrt2))
    coords.append((x + l / sqrt2 + l, y + l / sqrt2))

    # create a dictionary of sugars and their coordinates
    points = {}
    for i, sugar in enumerate(sugar_list):
        points[sugar] = coords[i]

    # setup z-depths
    num_back_boxes = 3
    num_intersect = 2
    num_white_boxes_back = num_back_boxes
    num_forward_boxes = len(df) - num_back_boxes
    num_white_boxes_forward = num_forward_boxes
    num_sugar_circles = len(points)
    z_back_boxes = 1
    z_intersect = num_back_boxes + z_back_boxes
    z_white_boxes_back = num_intersect + z_intersect
    z_forward_boxes = num_white_boxes_back + z_white_boxes_back
    z_white_boxes_forward = num_forward_boxes + z_forward_boxes
    z_sugar_circles = num_white_boxes_forward + z_white_boxes_forward
    z_connections = num_sugar_circles + z_sugar_circles

    # iterate over the data frame and draw the network
    for index, row in df.iterrows():

        # setup colors if not provided
        if "arrow_color" in row:
            arrowcolor = row["arrow_color"]
        else:
            arrowcolor = "23"

        if "box_color" in row:
            boxcolor = row["box_color"]
        else:
            boxcolor = "15"

        if "dashed_arrow" in row:
            dashed_arrow = row["dashed_arrow"]
        else:
            dashed_arrow = False

        # check if the arrow is in the background or foreground
        back = row["s1"] == sugar_list[3] or row["s2"] == sugar_list[3]

        # get the coordinates of the sugars
        s1 = points[row["s1"]]
        s2 = points[row["s2"]]

        # draw the network box
        if back:
            page.append(networkbox_cdxml(s1, s2, z_back_boxes, boxcolor))
            z_back_boxes += 1
        else:
            page.append(networkbox_cdxml(s1, s2, z_forward_boxes, boxcolor))
            z_forward_boxes += 1

        # calculate the length of the arrows based on the rates
        len_rate_f = row["rate_f"] / maxrate * max_length
        len_rate_b = row["rate_b"] / maxrate * max_length

        # calculate the slope of the arrow
        arrow_slope = slope(s1[0], s1[1], s2[0], s2[1])
        arrow_inv_slope = -1 * slope(s1[1], s1[0], s2[1], s2[0])

        # enable displaying only one arrow direction
        # set length to zero to hide the arrow
        temp_spacing = spacing
        if "direction" in row:
            direction = row["direction"]
            if direction == "forward":
                len_rate_b = 0
                temp_spacing = 0
            elif direction == "backward":
                len_rate_f = 0
                temp_spacing = 0

            elif direction == "both":
                pass

            else:
                raise ValueError("direction must be forward, backward, or both")

        # set minimum length of arrows if provided
        if minlength is not None:
            if len_rate_f < minlength:
                len_rate_f = minlength
            if len_rate_b < minlength:
                len_rate_b = minlength

        # calculate the coordinates of the arrowheads for the forward and backward arrows
        midpoint_f = shift(*midpoint(s1, s2), arrow_inv_slope, temp_spacing / 2)
        f_f = shift(*midpoint_f, arrow_slope, len_rate_f / 2)
        f_b = shift(*midpoint_f, arrow_slope, -len_rate_f / 2)
        midpoint_b = shift(*midpoint(s1, s2), arrow_inv_slope, -temp_spacing / 2)
        b_f = shift(*midpoint_b, arrow_slope, -len_rate_b / 2)
        b_b = shift(*midpoint_b, arrow_slope, len_rate_b / 2)

        # calculate the coordinates of the white boxes (arrow length + margin on each side)
        if len_rate_f > len_rate_b:
            wb_f = shift(*midpoint_f, arrow_slope, len_rate_f / 2 + margin)
            wb_b = shift(*midpoint_f, arrow_slope, -len_rate_f / 2 - margin)
        else:
            wb_f = shift(*midpoint_b, arrow_slope, -len_rate_b / 2 - margin)
            wb_b = shift(*midpoint_b, arrow_slope, len_rate_b / 2 + margin)

        # draw the white boxes
        if back:
            page.append(whitebox_cdxml(wb_f, wb_b, z_white_boxes_back, color=1))
            z_white_boxes_back += 1
        else:
            page.append(whitebox_cdxml(wb_f, wb_b, z_white_boxes_forward, color=1))
            z_white_boxes_forward += 1

        # draw the arrows (if not zero)
        if len_rate_f > 0:

            if var_arrowhead:
                # calculate the arrowhead properties based on the length of the arrow
                len_ratio = len_rate_f / max_length
                arrowheadprops = (
                    int(
                        arrowhead_props_min[0]
                        + len_ratio * (arrowhead_props_max[0] - arrowhead_props_min[0])
                    ),
                    int(
                        arrowhead_props_min[1]
                        + len_ratio * (arrowhead_props_max[1] - arrowhead_props_min[1])
                    ),
                    int(
                        arrowhead_props_min[2]
                        + len_ratio * (arrowhead_props_max[2] - arrowhead_props_min[2])
                    ),
                )
            else:
                # default arrowhead properties
                arrowheadprops = (716, 626, 261)

            # draw the arrow in the right direction
            if distance(*s2, *f_f) < distance(*s2, *f_b):
                page.append(
                    arrow_cdxml(
                        f_f,
                        f_b,
                        z_connections,
                        color=arrowcolor,
                        arrowhead_head="Full",
                        arrowhead_tail=False,
                        dashed=dashed_arrow,
                        arrowhead_props=arrowheadprops,
                        linewidth=linewidth,
                    )
                )
            else:
                page.append(
                    arrow_cdxml(
                        f_f,
                        f_b,
                        z_connections,
                        color=arrowcolor,
                        arrowhead_head=False,
                        arrowhead_tail="Full",
                        dashed=dashed_arrow,
                        arrowhead_props=arrowheadprops,
                        linewidth=linewidth,
                    )
                )
            z_connections += 1

        # repeat for the backward arrow
        if len_rate_b > 0:
            if var_arrowhead:
                len_ratio = len_rate_b / max_length
                arrowheadprops = (
                    int(
                        arrowhead_props_min[0]
                        + len_ratio * (arrowhead_props_max[0] - arrowhead_props_min[0])
                    ),
                    int(
                        arrowhead_props_min[1]
                        + len_ratio * (arrowhead_props_max[1] - arrowhead_props_min[1])
                    ),
                    int(
                        arrowhead_props_min[2]
                        + len_ratio * (arrowhead_props_max[2] - arrowhead_props_min[2])
                    ),
                )
            else:
                arrowheadprops = (716, 626, 261)

            if distance(*s1, *b_f) < distance(*s1, *b_b):
                page.append(
                    arrow_cdxml(
                        b_f,
                        b_b,
                        z_connections,
                        color=arrowcolor,
                        arrowhead_head="Full",
                        arrowhead_tail=False,
                        dashed=dashed_arrow,
                        arrowhead_props=arrowheadprops,
                        linewidth=linewidth,
                    )
                )
            else:
                page.append(
                    arrow_cdxml(
                        b_f,
                        b_b,
                        z_connections,
                        color=arrowcolor,
                        arrowhead_head=False,
                        arrowhead_tail="Full",
                        dashed=dashed_arrow,
                        arrowhead_props=arrowheadprops,
                        linewidth=linewidth,
                    )
                )
            z_connections += 1

    # deal with the nodes
    for sugar, coord in points.items():

        # draw white cirlces at each node
        page.append(circle_cdxml(coord, circle_size, 1, z_sugar_circles))
        z_sugar_circles += 1

        # if circles should be drawn, draw them with/wihout labels
        if circle_points:
            page.append(
                circle_cdxml(
                    coord, circle_points_size, circle_colors[sugar], z_connections
                )
            )
            z_connections += 1

            if circle_points_text:
                if circle_points_text_size:
                    page.append(
                        text_cdxml(
                            sugar,
                            coord,
                            z_connections,
                            color=1,
                            text_lowercase=True,
                            size=circle_points_text_size,
                        )
                    )
                else:
                    page.append(
                        text_cdxml(
                            sugar, coord, z_connections, color=1, text_lowercase=True
                        )
                    )
                z_connections += 1

    # draw a white circle at the two intersections of the boxes
    intersect_circle_size = l * 0.02
    coord = intersection(
        *points[sugar_list[0]],
        *points[sugar_list[4]],
        *points[sugar_list[1]],
        *points[sugar_list[3]]
    )
    page.append(circle_cdxml(coord, intersect_circle_size, 1, z_intersect))
    z_intersect += 1

    coord = intersection(
        *points[sugar_list[4]],
        *points[sugar_list[6]],
        *points[sugar_list[3]],
        *points[sugar_list[7]]
    )
    page.append(circle_cdxml(coord, intersect_circle_size, 1, z_intersect))
    z_intersect += 1

    # save the file
    tree = ET.ElementTree(root_cpy)
    tree.write(filename, encoding="utf-8", xml_declaration=True)

## Draw networks

In [6]:
df = pd.read_csv('alpha.csv')

settings = {
    'coordinates': (600,600),
    'length': 130*0.95,
    'max_length_part': 0.6,
    'min_length_part': 0.04/0.95,
    'spacing': 6,
    'margin': 4,
    'white_circle_size': 0.2/0.95,
    'sugar_list' : ['Man','Alt','Tal','Ido','Glc','All','Gal','Gul'],
    'var_arrowhead': True,
    'arrowhead_props_max': (716, 626, 261),
    'arrowhead_props_min': (478, 418, 200),
}

draw_network(root=root,rates_df=df,settings=settings,filename='alpha.cdxml')

In [7]:
df = pd.read_csv('beta.csv')

settings = {
    'coordinates': (600,600),
    'length': 70,
    'max_length_part': 0.6,
    'min_length_part': 0.075,
    'spacing': 6,
    'margin': 4,
    'white_circle_size': 0.12,
    'circle_points': True,
    'circle_points_text': True,
    'circle_points_text_size': 6.2,
    'circle_points_size': 0.1,
    'circle_colors' : {
                    "Alt" : "12",
                    "Gal" : "12",
                    "Glc" : "12",
                    "Gul" : "12",
                    "Ido" : "13",
                    "Man" : "12",
                    "Tal" : "4",
                    "All" : "19",
                },
    'sugar_list' : ['Man','Alt','Tal','Ido','Glc','All','Gal','Gul'],
    'var_arrowhead': True,
    'arrowhead_props_max': (716, 626, 261),
    'arrowhead_props_min': (478, 418, 200)
}

draw_network(root=root,rates_df=df,settings=settings,filename='beta.cdxml')