## Example - Ptychography

In this example, we'll modify the BCDI setup so that instead of collecting and reconstructing one image, many diffraction patterns are collected over the sample area. Each sample image overlaps with the previous one, providing redundant information which aids the phase reconstruction and allows the imaging field to be scaled arbitrarily large - limited by stage movement, sample durability, and other physical factors rather than the available size of an optical lens. 

The process of collecting overlapping images can be done in several ways - the simplest is to step-and-repeat the collection of images over a pre-defined area, then do the post-processing offline. However, with large samples and/or small regions of interest, this may not be the most effective way to image. More advanced setups can be done in a "streaming" way, where the reconstructed real-space image is updated each step of the way and the process can be controlled with human feedback. To increase speed further, human feedback can be replaced with automated feature detection and search strategies. 

We'll start by looking at the simplest setup: offline processing.

In [1]:
import networkx as nx
import sys
import numpy as np
from abc import ABC, abstractmethod
from collections import namedtuple
from copy import deepcopy
from functools import reduce
from typing import Callable, Any
from itertools import accumulate

In [2]:
from systemflow.node import *
from systemflow.mutations import *
from systemflow.metrics import *
from systemflow.auxtypes import is_proportion
from systemflow.xrs import *

The basic setup for ptychography is the same as BCDI - a sample stage, far-field diffraction pattern sensor, and analysis computer.

In [3]:
sample_stage_mutations = [PositionSample(),]
vc_sample = collect_parameters(sample_stage_mutations)

sample_stage = Component("Sample Stage",
                    [PositionSample(),],
                    {vc_sample.last_position: [0.0, 0.0],
                     vc_sample.position: [0.0, 0.0],
                     vc_sample.move_rate: 100,
                     vc_sample.settle_time: 1e-3,},
                     {})

In [4]:
detector_mutations = [CollectImage(),]
vc_detector = collect_parameters(detector_mutations)

detector_host = Component("Image sensor",
                    [CollectImage(),],
                    parameters = {vc_detector.resolution: (2000, 2000),
                     vc_detector.bitdepth: 16,
                     vc_detector.readout: 1e-3,
                     vc_detector.pixelenergy: 1e-3,
                     vc_detector.sample_rate: 10e3,})

In [5]:
sampling_nodes = [sample_stage, detector_host]
sampling_links = [DefaultLink("Sample Stage -> Image sensor",
                              sample_stage.name,
                              detector_host.name),]
sampling_metrics = [TotalLatency(),]

sampling_exg = ExecutionGraph(name="Sampling Process",
                              nodes=sampling_nodes,
                              links=sampling_links,
                              metrics=sampling_metrics)

In [6]:
sampling_exg_2 = sampling_exg()

In [7]:
sampling_exg_2.metric_values

{'total latency (s)': np.float64(0.002)}

In [8]:
msg1 = sampling_exg_2.get_output_msg()

In [9]:
msg1.fields

{'movement latency (s)': np.float64(0.001),
 'position (mm,mm)': [0.0, 0.0],
 'relevancy (%)': 1.0,
 'image data (B)': np.float64(8000000.0),
 'readout latency (s)': 0.001}

In [10]:
msg1.properties

{'resolution (n,n)': (2000, 2000),
 'bitdepth (n)': 16,
 'sample rate (Hz)': 10000.0,
 'images (n)': 1}

In [11]:
#emulate the receipt of an image
network_mutations = [InputMessage(),]
vc_network = collect_parameters(network_mutations)

network_host = Component("Recieve images", 
                         network_mutations,
                         {vc_network.input_message: msg1,})



In [12]:
storage_mutations = [StoreImage(),]
vc_store = collect_parameters(storage_mutations)
storage_host = Component("Store images", 
                         storage_mutations,
                         {vc_store.stored_images: 0,
                         vc_store.stored_data: 0,
                         vc_store.storage_rate: 1e9})

In [13]:
storage_nodes = [network_host, storage_host]
storage_links = [DefaultLink("Network -> Storage",
                              network_host.name,
                              storage_host.name),]
storage_metrics = [TotalLatency(),]

storage_exg = ExecutionGraph(name="Storage Process",
                              nodes=storage_nodes,
                              links=storage_links,
                              metrics=storage_metrics)

In [14]:
vc_store.storage_rate

'disk storage rate (B/s)'

In [15]:
storage_host.parameters

{'stored images (n)': 0,
 'stored data (B)': 0,
 'disk storage rate (B/s)': 1000000000.0}

In [16]:
storage_exg2 = storage_exg()

In [17]:
msg2 = storage_exg2.get_output_msg()

In [20]:
msg2.fields

{'movement latency (s)': np.float64(0.001),
 'position (mm,mm)': [0.0, 0.0],
 'relevancy (%)': 1.0,
 'image data (B)': np.float64(8000000.0),
 'readout latency (s)': 0.001,
 'stored data (B)': np.float64(8000000.0)}

In [21]:
msg2.properties

{'images (n)': 1,
 'sample rate (Hz)': 10000.0,
 'resolution (n,n)': (2000, 2000),
 'bitdepth (n)': 16,
 'stored images (n)': 1}

In [19]:
storage_exg2.root_node.properties

{'storage latency (s)': np.float64(0.008)}

In [23]:
analysis_mutations = [InputMessage(), PhaseReconstruction3D(),]
vc_analysis = collect_parameters(analysis_mutations)
analysis_host = Component("offline analysis",
                          analysis_mutations,
                          {vc_analysis.op_latency: 1e-5,
                           vc_analysis.parallelism: 0.70,
                           vc_analysis.overlap: 0.40,
                           vc_analysis.iterations: 20,
                           vc_analysis.input_message: msg2})


In [24]:
analysis_exg = ExecutionGraph("Offline reconstruction",
                              [analysis_host,],
                              links = [],
                              metrics = [])

In [25]:
analysis_exg2 = analysis_exg()

AssertionError: Phase reconstruction transform's properties not found in incoming message: xy_images (n,n)

We've now setup the components necessary to collect an image which will become part of the reconstruction. However, this collection becomes a multi-step process which we must model. We do this by incorporating a new element, a System which controls the flow of different execution graphs:

In [None]:
class OfflinePtychography(System):
    def __init__(self, name, x_steps: int = 20, y_steps: int = 20, iter = 0, execution_history = ...):
        exec_graphs = {"sampling": sampling_exg,
                        "storage": storage_exg,
                        "analysis": analysis_exg,}
        super().__init__(name, exec_graphs, iter, execution_history)
        self.x_steps = x_steps
        self.y_steps = y_steps

    def flow_control(self):

        

In [23]:
# pickup - get control flow through System

In [None]:
cpu_mutations = [FlatFieldCorrection(),
                MaskCorrection(),
                PhaseReconstruction3D(),]

vc_cpu = collect_parameters(cpu_mutations)

In [None]:
cpu_host = Component("Processor",
                    cpu_mutations,
                    parameters = {vc_cpu.op_latency: 1e-6,
                                  vc_cpu.parallelism: 0.75,
                                  vc_cpu.kernel_size: (0.02, 0.02),
                                  vc_cpu.mask_proportion: 0.05,
                                  vc_cpu.op_latency: 1e-7,
                                  vc_cpu.iterations: 20,
                                  vc_cpu.overlap: 0.40,})

In [20]:
cpu_host2 = cpu_host()

TypeError: Component.__call__() missing 1 required positional argument: 'exg'

In [11]:
ci_host = Component("Image sensor",
                    [CollectImage(),],
                    parameters = {vc_img.resolution: (2000, 2000),
                     vc_img.bitdepth: 16,
                     vc_img.readout: 1e-3,
                     vc_img.pixelenergy: 1e-3,
                     vc_img.sample_rate: 10e3,})

In [None]:
nodes = [sample_stage, ci_host, cpu_host]

In [None]:
links = [DefaultLink("Sample Stage -> Image sensor",
                     "Sample Stage",
                     "Image sensor"),
        DefaultLink("Image sensor -> Processor",
                     "Image sensor",
                     "Processor"),]

We'll also define metrics that look at the main host properties predicted by the ExecutionGraph which we want to measure:

In [37]:
class ReconstructionPower(Metric):
    def __init__(self):
        super().__init__("Phase reconstruction power", 
                         [],
                         [PhaseReconstruction2D().outputs.host_properties.ops],)
        
    def metric(self, message: Message, properties: dict):
        matches = self.graph_matches(properties)
        power = np.prod(matches[0]) * 1e-8
        metrics = {"reconstruction power (W)": power,}
        
        return metrics
    

In [38]:
class TotalOps(Metric):
    def __init__(self):
        super().__init__("Total operations", 
                         [],
                         [Regex(r"ops \(n,n\)"),],)
        
    def metric(self, message: Message, properties: dict):
        matches = self.graph_matches(properties)
        ops = np.sum([np.prod(op) for op in matches])
        metrics = {"total ops (n)": ops,}
        
        return metrics
    

In [39]:
class TotalLatency(Metric):
    def __init__(self):
        super().__init__("Total latency", 
                         [Regex(r"latency \(s\)"),],
                         [],)
        
    def metric(self, message: Message, properties: dict):
        matches = self.message_matches(message)
        ops = np.sum(matches)
        metrics = {"total latency (s)": ops,}
        
        return metrics
    

In [40]:
bcdi_graph = ExecutionGraph("BCDI Experiment", nodes, links, [ReconstructionPower(), TotalOps(), TotalLatency()])

In [41]:
g2 = bcdi_graph()

In [42]:
g2

<systemflow.node.ExecutionGraph at 0x159ad2110>

In [43]:
g2.metric_values

{'reconstruction power (W)': np.float64(6.08164799306237),
 'total ops (n)': np.float64(612164799.306237),
 'total latency (s)': np.float64(0.0020335507624599796)}

In [44]:
g2.get_all_node_parameters()

{'Processor': {'op latency (s)': 1e-07,
  'parallelism (%)': 0.75,
  'kernel size (%,%)': (0.02, 0.02),
  'masking proportion (%)': 0.05,
  'iterations (n)': 20},
 'Image sensor': {'resolution (n,n)': (2000, 2000),
  'bit depth (n)': 16,
  'readout latency (s)': 0.001,
  'pixel energy (J)': 0.001,
  'sample rate (Hz)': 10000.0},
 'Sample Stage': {'position (mm,mm)': [0.0, 0.0],
  'last position (mm,mm)': [0.0, 0.0],
  'move rate (mm/s)': 100,
  'settle time (s)': 0.001}}

And we'll setup one experiment to sweep over the sensor resolution:

In [45]:
def sweep_resolution(resolution: tuple, exg: ExecutionGraph):
    # an empirical relationship we assume between the classifier skill and number of filters
    new_params = {"Image sensor": {vc_img.resolution: resolution,}}
    # we can simply call an existing graph with new parameters
    new_exg = exg.with_updated_parameters(new_params)()

    power = new_exg.metric_values["reconstruction power (W)"]
    ops = new_exg.metric_values["total ops (n)"]
    latency = new_exg.metric_values["total latency (s)"]
    return power, ops, latency

In [46]:
# assume aspect ratio stays the same and scales up and down
resolutions = [np.astype(s * np.array((2000, 1400)), 'int') for s in np.linspace(start=0.8, stop=6.0, num=101)]
megapixels = [np.prod(r)/1e6 for r in resolutions]

In [47]:
metrics = [sweep_resolution(r, g2) for r in resolutions]

In [48]:
import plotly.graph_objects as go

Now, we can predict the overall power and latency for a single image by number of megapixels: power goes up linearly, and with a parallel implementation of the algorithms, latency goes up sub-linearly:

In [49]:
fig = go.Figure()

fig.add_trace(go.Scatter(
    x = megapixels,
    y = [m[0] for m in metrics],))

fig.update_layout(
    title_text = "Total Processing Power by Sensor Pixels",
    xaxis_title="Sensor Pixels (MP)",
    yaxis_title="Total Power (W)",
)
fig.show()

In [50]:
fig = go.Figure()

fig.add_trace(go.Scatter(
    x = megapixels,
    y = [m[2] for m in metrics],
    name = "latency (s)",))

fig.update_layout(
    title_text = "Total Processing Latency by Sensor Pixels",
    xaxis_title="Sensor Pixels (MP)",
    yaxis_title="Total Latency (s)",
)
fig.show()

We'll expand on this functionality in the case where many images are used to do the reconstruction, as is the case in ptychography, tomography, and laminography.