# pFGW Framework for Feature Matching & Tracking

In [1]:
import sys
import os
import networkx as nx
import gc
import time

sys.path.append("../GWMT/")

In [2]:
from GWMT import *
import readMergeTree as rmt

## Data Input 

This is an example to read our preset datasets. Our merge tree data is split into two files: ``treeEdges_monoMesh_*.txt`` and ``treeNodes_monoMesh_*.txt``, representing the data for tree edges and tree nodes, respectively.

To read your own dataset, your tree data should be saved in a ``nx.Graph`` object, in which each node has properties for its spacial coordinates (e.g., "x", "y"), its scalar value (e.g., "height"), and its critical type (0: minimum, 1: saddle, 2: maximum. This setting cannot be changed). 
Besides, you also need to provide the node id for the root node of the tree.

The tree data should be stored in a ``GWMergeTree`` object.

==========================================================

Our tree edge data format:

a_0, b_0  
a_1, b_1  
...  
a_{|E|-1}, b_{|E|-1}

Each row describes the indices of two nodes that the edge connecting in between. Edges are undirected.

==========================================================

Our tree node data format:

x_0, y_0, z_0, scalar_0, type_0  
x_1, y_1, z_1, scalar_1, type_1  
...  
x_{|V|-1}, y_{|V|-1}, z_{|V|-1}, scalar_{|V|-1}, type_{|V|-1}

Each row has five components: the "x", "y", "z" coordinates, the scalar value, and the critical point type for the node.


In [3]:
# dataset choices: ["HeatedCylinder", "IonizationFront", "UnsteadyCylinderFlow", "VortexStreet", "CloudDensity"]
maxima_only = True
# value_thres = 2.0

thres_dict_by_time = {
    "morning": 9,
    "afternoon": 10,
    "late-afternoon": 9
}

time_period = {
    "morning": [600, 900],  # 0, 36
    "afternoon": [901, 1500],  # 37, 108
    "late-afternoon": [1501, 1800],  # 109, 
}

def get_time_str(hrtime):
    int_hrtime = int(hrtime)
    for key in time_period:
        if (int_hrtime >= time_period[key][0]) and (int_hrtime <= time_period[key][1]):
            return key

def get_hrtime_by_filename(filename):
    fn1 = filename.replace(".txt", "").replace(".npy", "")
    datetime = fn1.split("_")[-1]
    date, hrtime = datetime.split("t")
    return hrtime

gwmt_list = dict()

mt_list = dict()
root_list = dict()

time_list = dict()
period_list = dict()

datasets = ["20180501_juelich"] # , "20180623_juelich", "20190512_juelich"]
for dataset in datasets:
    print("Working on", dataset)
    mt_list[dataset] = []
    root_list[dataset] = []
    
    time_list[dataset] = []
    period_list[dataset] = []
    
    dataset_path = os.path.join("data", dataset)
    for froot, di, files in os.walk(dataset_path):

        def key(s):
            try:
                int(s)
                return int(s)
            except ValueError:
                return len(files) + 1

        def isSegmentation(s: str):
            return "segmentation" in s

        def endsWithTxt(s: str):
            return s.endswith("txt")

        def endsWithNpy(s: str):
            return s.endswith("npy")

        txt_files = list(filter(endsWithTxt, files))
        txt_files.sort(key=lambda x: key(x.split(".")[0].split("_")[-1]))

        # You need to specify the root node type. Choices: ["minimum", "maximum"]
        # (Avoid specifying merge tree type to avoid confusion between split tree and join tree in different contexts)
        for file in txt_files:
            hrtime = get_hrtime_by_filename(file)
            value_thres = thres_dict_by_time[get_time_str(hrtime)]
            trees, roots = rmt.get_trees(os.path.join(dataset_path, file), root_type="minimum", threshold=value_thres)
            if len(trees) > 0:
                time_list[dataset].append(hrtime)
                period_list[dataset].append(get_time_str(hrtime))
            mt_list[dataset].extend(trees)
            root_list[dataset].extend(roots)

    assert (len(root_list[dataset]) == len(mt_list[dataset]))
    assert (len(time_list[dataset]) == len(mt_list[dataset]))
    assert (len(period_list[dataset]) == len(mt_list[dataset]))
    
# ================================================================== #
    
# Let's not oversimplify the merge tree, because now we need many nodes as anchor points

# This serves for removing the very-small cloud system from the results entirely
# This should be very small
disappear_volume_threshold = 1

# This is to reduce the number of anchor points (but not remove) for cloud systems
# This can be a bit large

def get_volume_thres(time_str):
    volume_thres_by_time_dict = {
        "morning": 5,
        "afternoon": 5,
        "late-afternoon": 5,
    }
    return volume_thres_by_time_dict[time_str]

simplified_mt_list = dict()
simplified_root_list = dict()

max_raw_nodes = 0
max_simplified_nodes = 0
total_raw_nodes = 0
total_simplified_nodes = 0
n_runs = 0
mt_simplification_runtime = 0

for dataset in datasets:
    print("Working on", dataset)
    simplified_mt_list[dataset] = [None for i in range(len(mt_list[dataset]))]
    simplified_root_list[dataset] = [None for i in range(len(mt_list[dataset]))]

    idx = 0

    for i in range(len(mt_list[dataset])):
        mt = mt_list[dataset][i]
        
        tstr = period_list[dataset][i]
        tt = time_list[dataset][i]
        value_thres = thres_dict_by_time[get_time_str(tt)]
        if True:
            # For benchmark purpose
            simplify_t0 = time.perf_counter()
            
            _, simp_mt = volume_simplify_mt(mt, 
                                            vol_thres=get_volume_thres(tstr), 
                                            disappear_vol_thres=disappear_volume_threshold, 
                                            vol_name="volume", 
                                            stop_saddle_val=value_thres)
            # For benchmark purpose
            simplify_dt = time.perf_counter() - simplify_t0
            n_runs += 1
            mt_simplification_runtime += simplify_dt
            
            simplified_mt_list[dataset][idx] = simp_mt
            simplified_root_list[dataset][idx] = simp_mt.root
            
            max_raw_nodes = max(max_raw_nodes, mt.number_of_nodes())
            max_simplified_nodes = max(max_simplified_nodes, simp_mt.number_of_nodes())
            total_raw_nodes += mt.number_of_nodes()
            total_simplified_nodes += simp_mt.number_of_nodes()
        else:
            simplified_mt_list[dataset][idx] = mt
            simplified_root_list[dataset][idx] = mt.root

        nn_smaller_than_thres = 0
        for node in simplified_mt_list[dataset][idx].nodes():
            if simplified_mt_list[dataset][idx].nodes[node]["height"] < value_thres:
                if simplified_mt_list[dataset][idx].nodes[node]["type"] == 2:
                    nn_smaller_than_thres += 1
#         print(tt, mt.number_of_nodes(), simplified_mt_list[dataset][idx].number_of_nodes(), nn_smaller_than_thres)
        idx += 1

print("Benchmark information - Simplify Merge Tree:", max_raw_nodes, max_simplified_nodes, n_runs, mt_simplification_runtime)
print("Benchmark information - Simplify Merge Tree:", total_raw_nodes, total_simplified_nodes)
        
for dataset in datasets:
    lmt = len(mt_list[dataset])
    mt_list[dataset] = [None] * lmt
    
gwmt_list = dict()

for dataset in datasets:
    print("Working on", dataset)
    gwmt_list[dataset] = [GWMergeTree(simplified_mt_list[dataset][i], simplified_root_list[dataset][i]) for i in range(len(mt_list[dataset]))]
    print(len(gwmt_list[dataset]))

for dataset in datasets:
    lmt = len(simplified_mt_list[dataset])
    simplified_mt_list[dataset] = [None] * lmt
    simplified_root_list[dataset] = [None] * lmt

del mt_list
    
gc.collect()

Working on 20180501_juelich
Adding volume to the merge tree
Adding volume to the merge tree
Working on 20180501_juelich
Initially removing 6350 leaves.
Initially removing 6415 leaves.
Benchmark information - Simplify Merge Tree: 15346 4032 2 1.8365094000000006
Benchmark information - Simplify Merge Tree: 30667 7799
Working on 20180501_juelich
Working on 20180501_juelich
2


23167

## Parameter Initialization

We now specify the parameters to be passed to the pFGW framework, including the following:

*spacial_scale*: a float number denoting the spacial length of the scalar field. (e.g., $max(height, width)$ for a 2D rectangular scalar field).  
         This is needed for normalizing the spacial coordinates for nodes.  
         This number can be flexible. However, it affects the optimal $\alpha$ value. Please make sure that the normalization is reasonable. A recommended range for the maximum Euclidean distance between any pair of critical points after normalization is $[0.5, 2.0]$.
         

*labels*: a list of strings denoting the spacial coordinate labels. (e.g., ["x", "y"] for a 2D scalar field)

*scalar_name*: the name of the scalar field in GWMergeTree objects.

*edge_weight_name*: the name of the weight of edges in GWMergeTree objects.

*weight_mode*: the strategy to encode $W$. Choices: ["shortestpath", "lca"].

*prob_distribution*: the strategy to encode $p$. Choices: ["uniform", "ancestor"]

In [8]:
if "CloudDensity" in dataset:
    spacial_scale = 255 
elif "CPPin202308" in dataset:
    spacial_scale = 932
elif "juelich" in dataset:
    spacial_scale = 427
    
max_value = 150
labels = ["x", "y", "z"]
scalar_name = "height"
edge_weight_name = "weight"
weight_mode = "shortestpath"
prob_distribution = "uniform"
fully_initialized = False

# lca-threshold
intrinsic_rescale = 1
if weight_mode == "lca":
    intrinsic_rescale = max_value
elif weight_mode == "shortestpath":
    intrinsic_rescale = max_value * 2
    
all_scales = [spacial_scale, spacial_scale, spacial_scale, intrinsic_rescale]

In [9]:
# Validify the GWMergeTree object. This is not mandatory, but recommended to check whether your data input format is correct
for dataset in datasets:
    print("Working on", dataset)
    for tree in gwmt_list[dataset]:
        tree.label_validation(labels, scalar_name, edge_weight_name)

Working on 20180501_juelich


## Framework Initialization

Initializing the pFGW framework with given parameters and data input. There is an output for the normalization process consisting $L+1$ elements, where the first $L$ elements are $\frac{1}{rescaling\_factor}$ for coordinate labels ($L$ coordinate labels in total), and the last element is $\frac{1}{rescaling\_factor}$ for scalar function values.

In [10]:
pfgws = dict()

for dataset in datasets:
    print("Working on", dataset)
    pfgws[dataset] = GWTracking(
        gwmt_list[dataset],
        spacial_scale,
        labels,
        scalar_name=scalar_name,
        edge_weight_name=edge_weight_name,
        weight_mode=weight_mode,
        prob_distribution=prob_distribution,
        tracking_maxima_only=maxima_only,
        maxima_only=maxima_only,
        fully_initialized=fully_initialized,
        scalar_threshold=value_thres
    )
    
    pfgws[dataset].set_all_scales(all_scales)

Working on 20180501_juelich
Please provide the scale for normalization when intialize steps.


## Parameter Tuning

We present a demo for our parameter tuning process in this section.

This is not the only way for parameter tuning. We aim to optimize the one-to-one matching result in the first place.

In [11]:
import pandas as pd

# The optional parameter "amijo" determines the way for OT solver to converge, which may change the result
# Either picking False or True follows our framework pipeline. They just may reach different local optimal solutions.

amijo = False
max_dist_tuning = False

### Actual tracking with tuning the m parameter

In [12]:
finished_runs = dict()

In [13]:
# New!!! We are using auto parameter tuning strategies
# For cloud tracking project, there is a forced search range for other cloud tracking tools
# Therefore, we can safely assume the max matched distance without tuning
max_dist = 0.014
alpha_list = [0.2]
m_range = [0.60, 0.90]

max_workers = 7

for dataset in datasets:
    print("Working on", dataset)
    timesteps = list(range(len(pfgws[dataset].trees)))[:109]

    dataset_str = dataset
    parameter_tuning_path = os.path.join("binary-parameter-tuning", "{}".format(dataset_str))
    os.makedirs(parameter_tuning_path, exist_ok=True)
    oc_path = os.path.join(parameter_tuning_path, "ocs")
    os.makedirs(oc_path, exist_ok=True)

    for alpha in alpha_list:
        if (dataset in finished_runs) and (alpha in finished_runs[dataset]):
            continue
        print("alpha =", alpha)
        best_ms, best_ocs, m_searchspace, dist_values, init_rt, pfgw_rt = pfgws[dataset].adaptive_m_tuning_binary_search(timesteps,
                                                                                             alpha, 
                                                                                             max_dist, 
                                                                                             amijo, 
                                                                                             m_range, 
                                                                                             max_workers, 
                                                                                             metric="l2", 
                                                                                             prob_rubric="nonzero",
                                                                                             benchmark=True)
        
        print("Initialization average runtime:", init_rt)
        print("pFGW average runtime:", pfgw_rt)
        save_binary_parameter_tuning(parameter_tuning_path, oc_path, alpha, best_ms, best_ocs, m_searchspace, dist_values)
        
        if dataset not in finished_runs:
            finished_runs[dataset] = [alpha]
        else:
            finished_runs[dataset].append(alpha)

Working on 20180501_juelich
alpha = 0.2
CList shape: (2017,)
Finished intializing timestep 0
CList shape: (1884,)
Finished intializing timestep 1
Init Runtime #0 106.5332406
Init Runtime #1 94.2633916
(ms, dists): [(0.805, 0.015710079467211653)]
Cannot find the possible m with the current max_dist!
pFGW runtime #0 380.1335666
The optimal m for timestep 0 vs. 1 is 0.805.
Finished releasing timetep 0
INIT runtime: [106.5332406, 94.2633916]
pFGW runtime: [380.1335666]
Initialization average runtime: 100.3983161
pFGW average runtime: 380.1335666


In [14]:
for dataset in datasets:
    print("Working on", dataset)
    timesteps = list(range(len(pfgws[dataset].trees)))[:109]

    # Step 4. We save the best-tuned results into the results folder
    dataset_str = dataset
    parameter_tuning_path = os.path.join("binary-parameter-tuning", "{}".format(dataset_str))

    oc_path = os.path.join(parameter_tuning_path, "ocs")

    output_root = os.path.join("./initial-output/", dataset_str)
    os.makedirs(output_root, exist_ok=True)
    for e, alpha in enumerate(alpha_list):
        if alpha not in finished_runs[dataset]:
            continue
            
        print("alpha = {}".format(str(alpha)))
        best_ms, best_ocs, _, _ = load_binary_parameter_tuning(parameter_tuning_path, oc_path, alpha)

        output_path = os.path.join(output_root, str(round(alpha, 1)))
        os.makedirs(output_path, exist_ok=True)

        for em, m in enumerate(best_ms):
            print(str('%.3f' % m))
            # We filter out coupling information unrelated to maxima
            # Note: the filtered coupling matrix does not sum to m
            id1 = timesteps[em]
            id2 = timesteps[em+1]

            oc = best_ocs[em]
            oc_maxima = oc # pfgw.filtered_oc(id1, id2, oc)
            np.savetxt(os.path.join(output_path, "oc_{}_{}.txt".format(str(id1), str(id2))), oc_maxima)


Working on 20180501_juelich
alpha = 0.2


TypeError: len() of unsized object