In [1]:
import networkx as nx
import sys
import numpy as np
from abc import ABC, abstractmethod
from collections import namedtuple
from copy import deepcopy
from functools import reduce
from typing import Callable, Any
from itertools import accumulate

In [2]:
from systemflow.node import *

In [3]:
from systemflow.auxtypes import is_proportion

In [4]:
class PositionSample(Mutate):
    def __init__(self, name: str = "PositionSample", relevancy_f: Callable = lambda x: 1.0): 
        #"Secret" sample function which determines which locations are of interest
        self.relevancy_f = relevancy_f

        #Originator node, no input fields
        #Input message fields
        msg_fields = VarCollection()
    
        #Input message properties
        msg_properties = VarCollection()

        #Input host parameters
        host_parameters = VarCollection(position = "position (mm,mm)",
                                        last_position = "last position (mm,mm)",
                                        move_rate = "move rate (mm/s)",
                                        dwell_time = "dwell time (s)",)
        
        inputs = MutationInputs(msg_fields, msg_properties, host_parameters)

        #Output message fields
        msg_fields = VarCollection(relevancy = "relevancy (%)",
                                   position = "position (mm,mm)",
                                   move_latency = "movement latency (s)",)

        #Output message properties
        msg_properties = VarCollection()

        #Output host properties
        host_properties = VarCollection(last_positioin = "last_position (mm,mm)",)
        outputs = MutationOutputs(msg_fields, msg_properties, host_properties)

        super().__init__(name, inputs, outputs)

    def transform(self, message: Message, component: Component) -> tuple[dict, dict, dict]:
        #predict the relevancy of the collected data based on position
        position = component.parameters[self.inputs.host_parameters.position]
        last_position = component.parameters[self.inputs.host_parameters.last_position]
        relevancy = self.relevancy_f(position)

        x_vec = position[0] - last_position[0]
        y_vec = position[1] - last_position[1]
        distance = np.linalg.norm([x_vec, y_vec])
        movement_time = distance / component.parameters[self.inputs.host_parameters.move_rate]
        latency = movement_time + component.parameters[self.inputs.host_parameters.dwell_time]

        msg_fields = {self.outputs.msg_fields.relevancy: relevancy,
                    self.outputs.msg_fields.position: position,
                    self.outputs.msg_fields.move_latency: latency,}
        
        msg_props = {}
        
        host_props = {self.inputs.host_parameters.last_position: position,}

        
        return msg_fields, msg_props, host_props


In [5]:
ps = PositionSample()

In [6]:
msg0 = Message({}, {})

In [7]:
sample_stage = Component("Sample Stage",
                    [PositionSample(),],
                    {"position (mm,mm)": [0.0, 0.0],
                     "last position (mm,mm)": [0.0, 0.0],
                     "move rate (mm/s)": 100,
                     "dwell time (s)": 1e-3,},
                     {})

In [8]:
msg1, _ = ps(msg0, sample_stage)

In [9]:
msg1

Message(fields={'relevancy (%)': 1.0, 'position (mm,mm)': [0.0, 0.0], 'movement latency (s)': np.float64(0.001)}, properties={})

In [10]:
ci = CollectImage()

In [11]:
vc = collect_parameters([CollectImage()])

ci_host = Component("Image sensor",
                    [CollectImage(),],
                    parameters = {vc.resolution: (4000, 6000),
                     vc.bitdepth: 16,
                     vc.readout: 1e-3,
                     vc.pixelenergy: 1e-3,})

In [12]:
msg2, _ = ci(msg1, ci_host)

In [13]:
msg2

Message(fields={'relevancy (%)': 1.0, 'position (mm,mm)': [0.0, 0.0], 'movement latency (s)': np.float64(0.001), 'image data (B)': np.float64(48000000.0), 'readout latency (s)': 0.001}, properties={'resolution (n,n,n)': (4000, 6000, 16)})

In [14]:
from systemflow.metrics import serial_parallel_ops

In [15]:
class FlatFieldCorrection(Mutate):
    """
    Apply flat-field correction
    """
    def __init__(self, name: str = "Flat-field Correction"):
        #Input message fields
        msg_fields = VarCollection()
    
        #Input message properties
        msg_properties = VarCollection(resolution = "resolution (n,n,n)",)

        #Input host parameters
        host_parameters = VarCollection(op_latency = "op latency (s)",
                                        parallelism = "parallelism (%)",)
        
        inputs = MutationInputs(msg_fields, msg_properties, host_parameters)

        #Output message fields
        msg_fields = VarCollection(ff_latency = "flatfield latency (s)",)

        #Output message properties
        msg_properties = VarCollection()

        #Output host properties
        host_properties = VarCollection(ops = "flatfield ops (n,n)",)

        outputs = MutationOutputs(msg_fields, msg_properties, host_properties)

        super().__init__(name, inputs, outputs)

    def transform(self, message: Message, component: 'Component') -> tuple[dict, dict, dict]:
        #access the required fields/properties/parameters
        resolution = message.properties[self.inputs.msg_properties.resolution]
        ops = np.prod(resolution)
        parallelism = component.parameters[self.inputs.host_parameters.parallelism]
        serial_ops, parallel_ops = serial_parallel_ops(ops, parallelism)
        op_latency = component.parameters[self.inputs.host_parameters.op_latency]
        latency = serial_ops * op_latency
       
        #create the new fields in the message
        msg_fields = {self.outputs.msg_fields.ff_latency: latency}
        msg_props = {}

        #create the new properties in the host
        resolution = message.properties[self.inputs.msg_properties.resolution]
        host_props = {self.outputs.host_properties.ops: (serial_ops, parallel_ops),}

        return msg_fields, msg_props, host_props

In [16]:
ffc = FlatFieldCorrection()

In [17]:
vc = collect_parameters([ffc])

In [18]:
print(list(vc.__dict__.keys()))

['op_latency', 'parallelism']


In [19]:
ffc_host = Component("Preprocessor 1",
                    [FlatFieldCorrection(),],
                    parameters = {vc.op_latency: 1e-6,
                                  vc.parallelism: 0.75,})

In [20]:
msg3, _ = ffc(msg2, ffc_host)

In [21]:
class MaskCorrection(Mutate):
    """
    Correct pixel values in masked areas
    """
    def __init__(self, name: str = "Mask Correction"):
        #Input message fields
        msg_fields = VarCollection(image_data = "image data (B)")
    
        #Input message properties
        msg_properties = VarCollection(resolution = "resolution (n,n,n)",)

        #Input host parameters
        host_parameters = VarCollection(mask_proportion = "masking proportion (%)",
                                        op_latency = "op latency (s)",
                                        parallelism = "parallelism (%)",
                                        kernel_size = "kernel size (%,%)",)
        
        inputs = MutationInputs(msg_fields, msg_properties, host_parameters)

        #Output message fields
        msg_fields = VarCollection(mask_correction = "masking corrections (B)",
                                   mask_latency = "masking latency (s)",)

        #Output message properties
        msg_properties = VarCollection()

        #Output host properties
        host_properties = VarCollection(ops = "masking operations (n,n)",)
        outputs = MutationOutputs(msg_fields, msg_properties, host_properties)

        super().__init__(name, inputs, outputs)

    def transform(self, message: Message, component: 'Component') -> tuple[dict, dict, dict]:
        #access the required fields/properties/parameters
        image_data = message.fields[self.inputs.msg_fields.image_data]
        resolution = message.properties[self.inputs.msg_properties.resolution]
        resolution_x, resolution_y, _ = resolution
        
        masking = component.parameters[self.inputs.host_parameters.mask_proportion]
        op_latency = component.parameters[self.inputs.host_parameters.op_latency]
        parallelism = component.parameters[self.inputs.host_parameters.parallelism]
        kernel_x, kernel_y = component.parameters[self.inputs.host_parameters.kernel_size]

        is_proportion(masking)
        is_proportion(kernel_x)
        is_proportion(kernel_y)
        #calculate new fields and properties
        #use an X,Y kernel around every masked section (assumption)
        kernel_x = int(resolution_x * kernel_x)
        kernel_y = int(resolution_y * kernel_y)
        center_pixels = int(masking * np.prod(resolution))
        ops = kernel_x * kernel_y * center_pixels
        serial_ops, parallel_ops = serial_parallel_ops(ops, parallelism)
        latency = serial_ops * op_latency

        mask_data = image_data * masking
       
        #create the new fields in the message
        msg_fields = {self.outputs.msg_fields.mask_correction: mask_data,
                      self.outputs.msg_fields.mask_latency: latency,}
        msg_props = {}

        #create the new properties in the host
        host_props = {self.outputs.host_properties.ops: (serial_ops, parallel_ops),}

        return msg_fields, msg_props, host_props

In [22]:
mc = MaskCorrection()

In [23]:
vc = collect_parameters([mc])

In [24]:
print(list(vc.__dict__.keys()))

['mask_proportion', 'op_latency', 'parallelism', 'kernel_size']


In [25]:
mc_host = Component("Preprocessor 2",
                    [MaskCorrection(),],
                    parameters = {vc.op_latency: 1e-6,
                                  vc.parallelism: 0.75,
                                  vc.kernel_size: (0.02, 0.02),
                                  vc.mask_proportion: 0.05,})

In [26]:
msg3.properties

{'resolution (n,n,n)': (4000, 6000, 16)}

In [27]:
msg4, _ = mc(msg3, mc_host)

In [28]:
msg4.fields

{'relevancy (%)': 1.0,
 'position (mm,mm)': [0.0, 0.0],
 'movement latency (s)': np.float64(0.001),
 'image data (B)': np.float64(48000000.0),
 'readout latency (s)': 0.001,
 'flatfield latency (s)': np.float64(0.00013998542046322332),
 'masking corrections (B)': np.float64(2400000.0),
 'masking latency (s)': np.float64(0.000655229007050176)}

In [29]:
fft = FourierTransform()

In [30]:
vc = collect_parameters([fft])

In [31]:
print(list(vc.__dict__.keys()))

['parallelism', 'op_latency']


In [32]:
fft_host = Component("Processor",
                    [FourierTransform(),],
                    parameters = {vc.parallelism: 0.75,
                                  vc.op_latency: 1e-7,})

In [33]:
msg4.properties

{'resolution (n,n,n)': (4000, 6000, 16)}

In [50]:
msg5, f2 = fft(msg4, fft_host)

In [53]:
f2

{'fft ops (n,n)': (np.float64(115.36392103498072),
  np.float64(1535359.3065493205))}

In [37]:
msg5.fields

{'relevancy (%)': 1.0,
 'position (mm,mm)': [0.0, 0.0],
 'movement latency (s)': np.float64(0.001),
 'image data (B)': np.float64(48000000.0),
 'readout latency (s)': 0.001,
 'flatfield latency (s)': np.float64(0.00013998542046322332),
 'masking corrections (B)': np.float64(2400000.0),
 'masking latency (s)': np.float64(0.000655229007050176),
 'frequency data (B)': 384000000,
 'fft latency (s)': np.float64(1.1536392103498072e-05)}

In [38]:
msg5.properties

{'resolution (n,n,n)': (4000, 6000, 16)}

In [40]:
nodes = [sample_stage, ci_host, ffc_host, mc_host, fft_host]

In [41]:
links = [DefaultLink("Sample Stage -> Image sensor",
                     "Sample Stage",
                     "Image sensor"),
        DefaultLink("Image sensor -> Preprocessor 1",
                     "Image sensor",
                     "Preprocessor 1"),
        DefaultLink("Preprocessor 1 -> Preprocessor 2", 
                    "Preprocessor 1", 
                    "Preprocessor 2"),
        DefaultLink("Preprocessor 2 -> Processor",
                    "Preprocessor 2",
                    "Processor"),]

In [60]:
class FFTPower(Metric):
    def __init__(self):
        super().__init__("FFT Power", 
                         {},
                         {},
                         {"Processor": FourierTransform().outputs.host_properties.ops},)
        
    def metric(self, graph: ExecutionGraph):
        host_props = graph.get_all_node_properties()
        ops = host_props["Processor"][FourierTransform().outputs.host_properties.ops]
        power = np.prod(ops) * 1e-12

        metrics = {"power": power,}
        return metrics
    

In [61]:
bcdi_graph = ExecutionGraph("BCDI Experiment", nodes, links, [FFTPower(),])

In [62]:
g2 = bcdi_graph(True)

Executing on node  Processor
Executing on node  Preprocessor 2
Executing on node  Preprocessor 1
Executing on node  Image sensor
Executing on node  Sample Stage


In [63]:
g2

<systemflow.node.ExecutionGraph at 0x125fe7af0>

In [66]:
g2.metric_values[0]["power"]

np.float64(0.00017712506980107855)

In [46]:
g2.root_node.output_msg.fields

{'masking corrections (B)': np.float64(2400000.0),
 'position (mm,mm)': [0.0, 0.0],
 'relevancy (%)': 1.0,
 'masking latency (s)': np.float64(0.000655229007050176),
 'movement latency (s)': np.float64(0.001),
 'flatfield latency (s)': np.float64(0.00013998542046322332),
 'readout latency (s)': 0.001,
 'image data (B)': np.float64(48000000.0),
 'frequency data (B)': 384000000,
 'fft latency (s)': np.float64(1.1536392103498072e-05)}

In [48]:
all_params = g2.get_all_node_parameters()

In [49]:
all_params

{'Processor': {'parallelism (%)': 0.75, 'op latency (s)': 1e-07},
 'Preprocessor 2': {'op latency (s)': 1e-06,
  'parallelism (%)': 0.75,
  'kernel size (%,%)': (0.02, 0.02),
  'masking proportion (%)': 0.05},
 'Preprocessor 1': {'op latency (s)': 1e-06, 'parallelism (%)': 0.75},
 'Image sensor': {'resolution (n,n)': (4000, 6000),
  'bit depth (n)': 16,
  'readout latency (s)': 0.001,
  'pixel energy (J)': 0.001},
 'Sample Stage': {'position (mm,mm)': [0.0, 0.0],
  'last position (mm,mm)': [0.0, 0.0],
  'move rate (mm/s)': 100,
  'dwell time (s)': 0.001}}