In [1]:
from GWMT import *
import readMergeTree as rmt
import os
import networkx as nx
from matplotlib import pyplot as plt

# Load data
* The simplified merge tree
* The scalar field data

## loading the merge tree and the scalar field data

In [2]:
datasets = ["CPPin20230808_0percent"]
# datasets = ["20180501_juelich", "20180623_juelich", "20190512_juelich"]

In [3]:
maxima_only = True
# value_thres = 2.0

# thres_dict_by_time = {
#     "morning": 9,
#     "afternoon": 10,
#     "late-afternoon": 9
# }

# time_period = {
#     "morning": [600, 900],  # 0, 36
#     "afternoon": [901, 1500],  # 37, 108
#     "late-afternoon": [1501, 1800],  # 109, 
# }

# def get_time_str(hrtime):
#     int_hrtime = int(hrtime)
#     for key in time_period:
#         if (int_hrtime >= time_period[key][0]) and (int_hrtime <= time_period[key][1]):
#             return key

# def get_hrtime_by_filename(filename):
#     fn1 = filename.replace(".txt", "").replace(".npy", "")
#     datetime = fn1.split("_")[-1]
#     date, hrtime = datetime.split("t")
#     return hrtime

gwmt_list = []

mt_list = []
root_list = []
region_list = []
value_list = []

time_list = dict()
period_list = dict()

for dataset in datasets:
    dataset_path = os.path.join("data", dataset)
    for froot, di, files in os.walk(dataset_path):

        def key(s):
            try:
                int(s)
                return int(s)
            except ValueError:
                return len(files) + 1

        def isSegmentation(s: str):
            return "segmentation" in s

        def endsWithTxt(s: str):
            return s.endswith("txt")

        def endsWithNpy(s: str):
            return s.endswith("npy")

        txt_files = list(filter(endsWithTxt, files))
        txt_files.sort(key=lambda x: key(x.split(".")[0].split("_")[-1]))

        # You need to specify the root node type. Choices: ["minimum", "maximum"]
        # (Avoid specifying merge tree type to avoid confusion between split tree and join tree in different contexts)
        for file in txt_files:
            trees, roots = rmt.get_trees(os.path.join(dataset_path, file), root_type="minimum", threshold=value_thres)
            mt_list.extend(trees)
            root_list.extend(roots)

        for file in txt_files:
            regions, values = rmt.get_regions(os.path.join(dataset_path, file))
            region_list.extend(regions)
            value_list.extend(values)

    assert (len(root_list) == len(mt_list))
    assert (len(region_list) == len(mt_list))

Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree
Adding volume to the merge tree


## Apply area-based intra-cloud anchor point simplification

In [4]:
# Let's not oversimplify the merge tree, because now we need many nodes as anchor points

# This serves for removing the very-small cloud system from the results entirely
# This should be very small
disappear_volume_threshold = 10

# This is to reduce the number of anchor points (but not remove) for cloud systems
# This can be a bit large
volume_threshold = 30

simplified_mt_list = [None for i in range(len(mt_list))]
simplified_root_list = [None for i in range(len(mt_list))]

idx = 0

for mt, region in zip(mt_list, region_list):
    if volume_threshold > 0:
        _, simp_mt = volume_simplify_mt(mt, 
                                        vol_thres=volume_threshold, 
                                        disappear_vol_thres=disappear_volume_threshold, 
                                        vol_name="volume", 
                                        stop_saddle_val=value_thres)
        simplified_mt_list[idx] = simp_mt
        simplified_root_list[idx] = simp_mt.root
    else:
        simplified_mt_list[idx] = mt
        simplified_root_list[idx] = mt.root
    print(mt.number_of_nodes(), simplified_mt_list[idx].number_of_nodes())
    idx += 1

Initially removing 29699 leaves.
59184 3656
Initially removing 29371 leaves.
58576 3629
Initially removing 28604 leaves.
57042 3598
Initially removing 28060 leaves.
55993 3640
Initially removing 27419 leaves.
54724 3620
Initially removing 26814 leaves.
53514 3588
Initially removing 26235 leaves.
52405 3638
Initially removing 25765 leaves.
51469 3584
Initially removing 25368 leaves.
50692 3614
Initially removing 24939 leaves.
49826 3590
Initially removing 24502 leaves.
48981 3568
Initially removing 24021 leaves.
48025 3520
Initially removing 23759 leaves.
47499 3517
Initially removing 23533 leaves.
47029 3398
Initially removing 23239 leaves.
46466 3372
Initially removing 22844 leaves.
45654 3299
Initially removing 22508 leaves.
45035 3249
Initially removing 22244 leaves.
44500 3240
Initially removing 22177 leaves.
44405 3252
Initially removing 22328 leaves.
44649 3234
Initially removing 22101 leaves.
44170 3161
Initially removing 21994 leaves.
43998 3064
Initially removing 22117 leaves.

## We save all the remaining critical point information as a list

- Key information: "x", "y", "z", "CriticalType"

In [5]:
cp_info_root = "./simplified-merge-trees/"

for dataset in datasets:
    cp_info_dir = os.path.join(cp_info_root, dataset)
    os.makedirs(cp_info_dir, exist_ok=True)

    for em, mt in enumerate(simplified_mt_list):
        crit_pts = []
        for node in mt.nodes():
            crit_pts.append([mt.nodes[node]["x"], mt.nodes[node]["y"], mt.nodes[node]["z"], mt.nodes[node]["height"], mt.nodes[node]["type"]])
        cp_file = os.path.join(cp_info_dir, "treeNodes_{}.txt".format(str(em).zfill(3)))
        with open(cp_file, "w") as outf:
            for x, y, z, height, tp in crit_pts:
                print(int(x), int(y), int(z), height, int(tp), file=outf)
    #     crit_pts = np.asarray(crit_pts, dtype=int)
    #     np.savetxt(cp_file, crit_pts, fmt="%d")

        edges = []
        for node in mt.nodes():
            for neighbor in mt.neighbors(node):
                assert node < mt.number_of_nodes()
                if neighbor > node:
                    edges.append([node, neighbor])
        edge_file = os.path.join(cp_info_dir, "treeEdges_{}.txt".format(str(em).zfill(3)))
        edges = np.asarray(edges, dtype=int)

        assert len(edges) == len(crit_pts) - 1
        np.savetxt(edge_file, edges, fmt="%d")


## We put all anchor points at a time step into a list for future use

In [6]:
# anchor_points_list = []

# for mt in simplified_mt_list:
#     anchor_pts = []
#     for node in mt.nodes():
#         if mt.nodes[node]["type"] == 2:
#             anchor_pts.append({"id":node, "x":mt.nodes[node]["x"], "y":mt.nodes[node]["y"]})
#     anchor_points_list.append(anchor_pts)

# Watershed Segmentation

In [7]:
# from skimage.segmentation import watershed

First, we visualize all the cloud areas above the superlevel set threshold

In [8]:
# # value_thres = 2.0
# time_i = 10

In [9]:
# scalar_field = value_list[time_i]
# anchor_points = anchor_points_list[time_i]

In [10]:
# binary_map = np.zeros(scalar_field.shape)
# binary_map[scalar_field >= value_thres] = 1

In [11]:
# plt.imshow(binary_map, interpolation="none")
# plt.colorbar()

In [12]:
# ## Watershed tests

# # markers are the coordinates from anchor points, but must be INT
# markers = np.asarray([[int(each["x"]), int(each["y"])] for each in anchor_points], dtype=int)
# print(markers.shape, np.max(markers[:, 0]), np.max(markers[:, 1]))
# print(scalar_field.shape)
                    
# plt.imshow(binary_map, interpolation="none")
# plt.scatter(x=markers[:, 1], 
#             y=markers[:, 0],
#             s=2)
# plt.colorbar()

In [13]:
# from scipy.ndimage import label
# markers_in_field = np.zeros(scalar_field.shape, dtype=bool)
# markers_in_field[tuple(markers.T)] = True
# markers_with_label, _ = label(markers_in_field, structure=np.asarray([[0, 0, 0],[0, 1, 0],[0, 0, 0]]))

# labels = watershed(-scalar_field, markers=markers_with_label, connectivity=2, mask=binary_map)

# plt.figure(figsize=(16, 10))
# plt.imshow(labels, interpolation="none")
# plt.colorbar()
# plt.scatter(x=markers[:, 1], 
#             y=markers[:, 0],
#             s=2,
#             c="r")

In [14]:
# plt.figure(figsize=(16, 10))
# plt.imshow(markers_with_label, interpolation="none")
# plt.colorbar()
# # plt.scatter(x=markers[:, 1], 
# #             y=markers[:, 0],
# #             s=2,
# #             c="r")

We briefly report some statistics to this segmentation

In [15]:
# label_index, label_counts = np.unique(labels, return_counts=True)

# print(len(markers))
# print(np.sum(markers_in_field))
# print(np.max(markers_with_label))

In [16]:
# mlabel_index, mlabel_counts = np.unique(markers_with_label, return_counts=True)
# print(len(mlabel_counts))

### The histogram for the segmentation area. 

A majority of cloud segmentations are small in area, but there are also regions with >6000 pixels in area. 
It casts doubt on whether using the area as the probability is stable, because the range of the probability of nodes can be very large. (e.g., 6000:10)

In [17]:
# plt.hist(label_counts[1:], bins=200)
# plt.show()

### Unit-based probability distribution

We distribute a total probability of 1 to all anchor points based on the area of the segmentation. 
Therefore, the probability assigned to a node is (seg_area) * 1/(total_cloud_area), 

The 1/(total_cloud_area) is the unit for the probability in this case.

Notes for total cloud area: 
* time_0: 281231
* time_1: 250313
* time_2: 245292

In [18]:
# print("Total cloud area:", np.sum(label_counts[1:]))

## Pairwise Region distance

In [19]:
# time_j = time_i + 1
# print(time_i, time_j)

In [20]:
# def get_segmentation(t):
#     scalar_field = value_list[t]
#     anchor_points = anchor_points_list[t]
    
#     binary_map = np.zeros(scalar_field.shape)
#     binary_map[scalar_field >= value_thres] = 1
    
#     markers = np.asarray([[int(each["x"]), int(each["y"])] for each in anchor_points], dtype=int)
#     markers_in_field = np.zeros(scalar_field.shape, dtype=bool)
#     markers_in_field[tuple(markers.T)] = True
#     markers_with_label, _ = label(markers_in_field, structure=np.asarray([[0, 0, 0],[0, 1, 0],[0, 0, 0]]))

#     labels = watershed(-scalar_field, markers=markers_with_label, connectivity=2, mask=binary_map)
#     return labels

In [21]:
# seg_i = get_segmentation(time_i)
# seg_j = get_segmentation(time_j)

In [22]:
# dist_ij = segmentation_distance(seg_i, seg_j, max_dist=18, normalize_factor=932, max_workers=7)

In [23]:
# plt.imshow(dist_ij)

In [24]:
# ## Let's try some example filtering
# feature_i = np.random.randint(len(label_counts[1:])) + 1 # np.argmax(label_counts[1:]) + 1
# feature_i_map = np.zeros(seg_i.shape)
# feature_i_map[seg_i == feature_i] = 1

# feature_j_map = np.zeros(seg_j.shape)

# idx_i = feature_i - 1
# for idx_j in range(dist_ij.shape[1]):
#     feature_j = idx_j + 1
#     if dist_ij[idx_i, idx_j] < 932:
#         feature_j_map[seg_j == feature_j] = 10 + dist_ij[idx_i, idx_j]

# plt.figure(figsize=(18, 12))
# plt.subplot(211)
# plt.imshow(feature_i_map)

# plt.subplot(212)
# plt.imshow(feature_j_map)