# Lower star filtration optimization

**Goal:** "degrade" the persistence of an $H_0$ feature by changing pixel values


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import H1_optimizer
from numpy.linalg import lstsq
import importlib
from tqdm import tqdm
import cv2
import networkx as nx
import matplotlib.pyplot as plt
importlib.reload(H1_optimizer)
import gudhi as gd

In [None]:
def get_image_graph(width, height, queen=True):
    '''
    Turn image into grid graph
    '''
    G = nx.grid_2d_graph(width, height)
    if queen:
        #add diagonal edges
        for i in range(width):
            for j in range(height):
                if i < width - 1 and j < height - 1:
                    G.add_edge((i, j), (i + 1, j + 1))
                if i > 0 and j < height - 1:
                    G.add_edge((i, j), (i - 1, j + 1))
    return G

In [None]:
def find_generators(image, PD, idx=None, image_graph=None, queen=True):
    '''
    Find the generators for a given persistence point
    '''
    if image_graph is None:
        image_graph = get_image_graph(*image.shape, queen=queen)
    birth_pixels = []
    death_pixels = [None]
    merge_pixels = []
    for (b,d) in PD:
        birth_pixel = np.unravel_index(np.argmin(np.abs(image - b)), image.shape)
        birth_pixels.append(birth_pixel)
        if d != float('inf'):
            death_pixel = np.unravel_index(np.argmin(np.abs(image - d)), image.shape)
            death_pixels.append(death_pixel)

    if idx is None:
        idxs = list(range(len(PD)))
    else:
        idxs = [idx]
    for i in idxs:
        b1, d1 = PD[i]
        thresholded_graph = nx.subgraph(image_graph, [n for n in image_graph.nodes if image[n] <= d1+1e-5])
        mergeable_pixels = [n for n in birth_pixels if image[n] <= b1 and nx.has_path(thresholded_graph, birth_pixels[i], n)]
        if mergeable_pixels:
            merge_pixel = min(mergeable_pixels, key = lambda i: image[i])
            merge_pixels.append(merge_pixel)
        else:
            merge_pixels.append(None)

    return [birth_pixels[i] for i in idxs], merge_pixels, [death_pixels[i] for i in idxs]

In [None]:
def get_image_persistence(image, dim=0, queen=True):
    '''
    Find the persistence intervals for an image
    '''
    if queen:
        scx = gd.CubicalComplex(top_dimensional_cells=image)
    else:
        graph = nx.grid_graph(image.shape)
        grid_to_idx = {n: i for i, n in enumerate(graph.nodes)}
        idx_to_grid = {i: n for i, n in enumerate(graph.nodes)}
        graph = nx.relabel_nodes(graph, grid_to_idx)
        scx = gd.SimplexTree()
        for n in graph.nodes:
            scx.insert([n])
        for (i, j) in graph.edges:
            scx.insert([i, j])
        zero_skeleton = scx.get_skeleton(0)
        for j in zero_skeleton:
            scx.assign_filtration(
                j[0], filtration=image[idx_to_grid[j[0][0]][0], idx_to_grid[j[0][0]][1]]
            )
        scx.make_filtration_non_decreasing()
    
    scx.persistence()
    intervals = scx.persistence_intervals_in_dimension(dim)
    return sorted(intervals, key=lambda x: -x[1] + x[0])

In [None]:
def find_birth_cochain(b, d, birth_pixel, image, epsilon=0.1, boundary=None, image_graph=None, queen=True):
    '''
    Find the birth cochain for a given persistence point
    '''
    if image_graph is None:
        image_graph = get_image_graph(*image.shape, queen=queen)                    
    bplus = b + epsilon
    subgraph_at_bplus = nx.subgraph(image_graph, [n for n in image_graph.nodes if image[n] <= bplus])
    cc_of_birth_pixel_at_bplus = nx.node_connected_component(subgraph_at_bplus, birth_pixel)
    cochain = np.zeros(image.shape)
    for (i, j) in cc_of_birth_pixel_at_bplus:
        cochain[i, j] = 1
    return cochain/np.sum(cochain)

In [None]:
def find_death_cochain(b, d, birth_pixel, image, epsilon=0.1, boundary=None, image_graph=None, queen=True):
    '''
    Find the death cochain for a given persistence point
    '''
    pixels = [(i,j) for i in range(image.shape[0]) for j in range(image.shape[1])]
    binarized_image = np.array(image  < d, dtype=np.uint8)
    before_death_nodes = [n for n in pixels if image[n] < d-epsilon]
    after_death_nodes = [n for n in pixels if image[n] <= d + epsilon]
    _, labels = cv2.connectedComponents(binarized_image, connectivity=8)   
    pixel_label = labels[birth_pixel]
    cc1 = [n for n in after_death_nodes if labels[n] == pixel_label]

    if image_graph is None:
        image_graph = get_image_graph(*image.shape, queen=queen)
    if boundary is None:
        boundary = np.array(nx.incidence_matrix(image_graph, oriented=True).todense())
        boundary = boundary.T

    boundary_operator = boundary.copy()
    after_death_edges_idx_complement = [i for i, (u, v) in enumerate(image_graph.edges) if u not in after_death_nodes or v not in after_death_nodes]
    after_death_nodes_idx_complement = [i for i, n in enumerate(image_graph.nodes) if n not in after_death_nodes]
    boundary_operator = np.delete(boundary_operator, after_death_edges_idx_complement, axis=0)
    boundary_operator = np.delete(boundary_operator, after_death_nodes_idx_complement, axis=1)

    inactive_cols = [i for i, n in enumerate(after_death_nodes) if n in list(before_death_nodes)]
    active_cols = [i for i, n in enumerate(after_death_nodes) if n not in list(before_death_nodes)]
    restricted_boundary_operator = np.delete(boundary_operator, inactive_cols, axis=1)
    f = np.array([int(n in cc1) for n in after_death_nodes])

    x = lstsq(restricted_boundary_operator, -boundary_operator@f)[0]
    extended_x = np.zeros(boundary_operator.shape[1])
    extended_x[active_cols] = x
    y = boundary_operator@(extended_x+f)
    y = np.abs(y)/np.linalg.norm(y, ord=1)

    abs_incidence = np.abs(boundary_operator).T
    unrolled_partial_image = np.array([image[n] for n in after_death_nodes])
    abs_incidence_with_values = unrolled_partial_image[:, None] * abs_incidence
    abs_incidence_with_values[np.where(abs_incidence == 0)] = -np.inf

    
    only_maxes = np.zeros_like(abs_incidence_with_values)
    min_indices = np.argmax(abs_incidence_with_values, axis=0)
    rows = min_indices
    cols = np.arange(abs_incidence_with_values.shape[1])
    only_maxes[rows, cols] = 1
    A = only_maxes @ y
    gradient_image = np.zeros_like(image)
    after_death_nodes = np.array(after_death_nodes)
    gradient_image[after_death_nodes[:,0], after_death_nodes[:,1]] = A

    return y, gradient_image, boundary


## Experiment with corrupted MNIST images

In [None]:
def split_and_fuse(image, name, method='cochains', seed=2025):
    if seed is not None:
        np.random.seed(seed)
    X = 1-image.copy()
    plt.imshow(1-X, cmap='gray', vmin=0, vmax=1)
    plt.axis('off')
    plt.savefig(f'figs/MNIST_{name}_raw.png', dpi=150, bbox_inches='tight')
    plt.close()
    queen = True
    X[12,:] = 0.8
    X[14,:] = 0.8
    X[13,:] = 0.8
    X += np.random.random(X.shape)*0.01
    X = np.clip(X, 0, None)
    X_old = X.copy()
    plt.imshow(1-X_old, cmap='gray', vmin=0, vmax=1)
    plt.axis('off')
    plt.savefig(f'figs/MNIST_{name}_initial.png', dpi=150, bbox_inches='tight')
    plt.close()

    image_graph = get_image_graph(*image.shape, queen=queen)
    persistences = []

    boundary = None
    epsilons = [0.1]
    gamma = 1e-1

    if method == 'simplices':
        for step in tqdm(range(1000)):
            Xnew = X.copy()

            PD = get_image_persistence(X, dim=0, queen=queen)
            PD = sorted(PD, key=lambda i: -i[1] + i[0])
            if len(PD) > 1:
                if PD[1][1] - PD[1][0] > 0.01:
                    second_idx = sorted(range(len(PD)), key=lambda i: -PD[i][1] + PD[i][0])[1]
                    birth_pixels, merge_pixels, death_pixels = find_generators(X, PD, second_idx, image_graph=image_graph, queen=queen)
                    Xnew[death_pixels[0]] -= gamma
                    X = np.clip(Xnew, 0, None)
                    persistences.append(PD[second_idx][1] - PD[second_idx][0])


    if method == 'cochains':
        for step in tqdm(range(1000)):
            Xnew = X.copy()

            PD = get_image_persistence(X, dim=0, queen=queen)
            PD = sorted(PD, key=lambda i: -i[1] + i[0])
            if len(PD) > 1:
                if PD[1][1] - PD[1][0] > 0.01:
                    second_idx = sorted(range(len(PD)), key=lambda i: -PD[i][1] + PD[i][0])[1]
                    for epsilon in epsilons:
                        birth_pixels, merge_pixels, death_pixels = find_generators(X, PD, second_idx, image_graph=image_graph, queen=queen)
                        thisepsilon = epsilon
                        y, A, boundary = find_death_cochain(
                            PD[second_idx][0], PD[second_idx][1], birth_pixels[0], X,
                            epsilon=thisepsilon, boundary=boundary, image_graph=image_graph, queen=queen
                        )

                        Xnew -= gamma*A/len(epsilons)


                    X = np.clip(Xnew, 0, None)
                    persistences.append(PD[second_idx][1] - PD[second_idx][0])

    plt.imshow(1-X, cmap='gray', vmin=0, vmax=1)
    plt.axis('off')
    plt.savefig(f'figs/MNIST_{method}_{name}_final.png', dpi=150, bbox_inches='tight')
    plt.close()

In [None]:
from mnist_loader import load_data
train, validation, test = load_data()
L = [0, 1, 2, 3, 4, 5, 7, 13, 15, 17]
L = sorted(L, key = lambda x: train[1][x])

In [None]:
for i, sample in enumerate(train[0][L]):
    np.random.seed(2025)
    for trial in range(5):
        split_and_fuse(sample.reshape(28,28), f'sample{i+1}_trial{trial+1}', method='simplices', seed=None)
    for trial in range(5):
        split_and_fuse(sample.reshape(28,28), f'sample{i+1}_trial{trial+1}', method='cochains', seed=None)

In [None]:
for i, sample in enumerate(train[0][L]):
    split_and_fuse(sample.reshape(28,28), f'sample{i+1}', method='simplices', seed=2025)
for i, sample in enumerate(train[0][L]):
    split_and_fuse(sample.reshape(28,28), f'sample{i+1}', method='cochains', seed=2025)
