# Point cloud optimization

**Goal:** "promote" the persistence of an $H_1$ feature by moving points.


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from matplotlib import cm
plt.rcParams['text.usetex'] = True
from ripser import ripser
from H1_optimizer import move_points_enhance_H1

## Visualize the birth and death cochains for a small example

In [None]:
def create_small_dataset(npoints=10, noise_level=0.1):
    persistence = 0
    while persistence < 1e-3:
        t = np.random.random(size=npoints)*2*np.pi
        X = np.array([np.cos(t), np.sin(t)]).T 
        X += np.random.randn(*X.shape)*noise_level
        PD = ripser(X, maxdim=1)['dgms'][1]
        persistence = np.max(PD[:,1] - PD[:,0]) if len(PD) > 0 else 0
    return X/np.max(np.linalg.norm(X, axis=1))

In [None]:
np.random.seed(2025)
X_old = create_small_dataset(npoints=10, noise_level=0.1)
os.makedirs('figs', exist_ok=True)

X, log = move_points_enhance_H1(
    X_old, log=True, max_iter=1000, epsilon=0.05, gamma=0.02,
    verbose=False, save_plots=True, penalty=True, return_losses=True
)

In [None]:
plt.subplots(figsize=(15,5))
plt.plot(np.array(log['persistence_content']) - np.array(log['penalty']), c='tab:blue', label='Objective (persistence content - penalty)', linewidth=1)
points = [0, 100, 200, 300, 400]
plt.scatter(points, np.array(log['persistence_content'])[points] - np.array(log['penalty'])[points], marker='o', s=50)
plt.legend(fontsize=16, loc='lower right')
plt.yticks(fontsize=16)
plt.ylim(0.5, 2.2)
plt.xlim(-50, 500)
plt.xticks(points, fontsize=16)
plt.savefig('figs/objective.png', dpi=300, bbox_inches='tight')

## Test different learning rates and epsilonss for one small cloud

In [None]:
np.random.seed(2025)
X_old = create_small_dataset(npoints=10, noise_level=0.1)
pd = ripser(X_old, maxdim=1)['dgms'][1]
initial_persistence = np.max(pd[:,1] - pd[:,0])

gammas = [0.01, 0.02, 0.03, 0.04]
epsilons = [0.03, 0.04, 0.05, 0.06]
iter = 1000 
relative_epsilon = True
compute = False

if compute:
    # data storage for later
    log1_storage = defaultdict(list)
    log2_storage = []
    # get data
    for eps in epsilons:
        if not relative_epsilon:
            epsilon = eps*initial_persistence
        else:
            epsilon = eps
        for j, gamma in enumerate(gammas):
            print('epsilon = {:.3f}, gamma = {:.3f}'.format(epsilon, gamma))
            X, log1 = move_points_enhance_H1(X_old, log=True, max_iter=iter, epsilon=epsilon, gamma=gamma, penalty=True, tol=1e-10, relative_epsilon=relative_epsilon)
            log1_storage[eps].append(log1)

    for j, gamma in enumerate(gammas):
        X, log2 = move_points_enhance_H1(X_old, method='single_edges', log=True, max_iter=iter, gamma=gamma, penalty=True)
        log2_storage.append(log2)


# plot scatter plots - cochains
for epsilon in epsilons:
    for j, gamma in enumerate(gammas):
        log1 = log1_storage[epsilon][j]
        plt.scatter(X_old[:,0], X_old[:,1], c='black', zorder=100, label='Start')
        plt.scatter(log1['X'][-1][:,0], log1['X'][-1][:,1], color=cm.tab20(2*j), label='End'.format(gamma))
        plt.gca().set_aspect('equal', adjustable='box')
        plt.xlim(-2.0, 2.0)
        plt.ylim(-2.0, 2.0)
        plt.xticks([-2, -1, 0, 1, 2], fontsize=30)
        plt.yticks([-2, -1, 0, 1, 2], fontsize=30)
        plt.title(r'$\gamma = {:.2f}$'.format(gamma), fontsize=30)
        plt.savefig('figs/point_movement_harmonic_gamma{:.3f}_epsilon{:.3f}.png'.format(gamma, epsilon), bbox_inches='tight', dpi=150)
        plt.close()

# plot scatter plots - simplices
for j, gamma in enumerate(gammas):
    log2 = log2_storage[j]
    plt.scatter(X_old[:,0], X_old[:,1], c='black', zorder=100, label='Start')
    plt.scatter(log2['X'][-1][:,0], log2['X'][-1][:,1], color=cm.tab20(2*j), label='End'.format(gamma))
    plt.gca().set_aspect('equal', adjustable='box')
    plt.xlim(-2.0, 2.0)
    plt.ylim(-2.0, 2.0)
    plt.xticks([-2, -1, 0, 1, 2], fontsize=30)
    plt.yticks([-2, -1, 0, 1, 2], fontsize=30)
    plt.title(r'$\gamma = {:.2f}$'.format(gamma), fontsize=30)
    #plt.legend(loc='upper right', fontsize=20)
    plt.savefig('figs/point_movement_simplices_gamma{:.3f}.png'.format(gamma), bbox_inches='tight', dpi=150)
    plt.close()


# plot line plots - cochains
for epsilon in epsilons:
    plt.subplots(figsize=(8, 4))
    for j, gamma in enumerate(gammas):
        log1 = log1_storage[epsilon][j]
        plt.plot((log1['bd'][:,1] - log1['bd'][:,0])/np.linalg.norm(log1['X'], axis=(1,2)), label=r'$\gamma = {:.2f}$'.format(gamma), color=cm.tab20(2*j), zorder=100-10*j)
    plt.ylim(0,0.5)
    plt.xticks([0,200,400,600,800,1000], fontsize=20)
    plt.yticks([0,0.1,0.2,0.3,0.4,0.5], fontsize=20)
    plt.ylabel(r'Persistence/$||X||_2$', fontsize=20)
    plt.legend(loc='lower right', fontsize=20)
    plt.savefig('figs/death_minus_birth_harmonic_epsilon{:.3f}.png'.format(epsilon), bbox_inches='tight', dpi=150)
    plt.show()

# plot line plots - simplices
plt.subplots(figsize=(8, 4))
for j, gamma in enumerate(gammas):
    log2 = log2_storage[j]
    plt.plot((log2['bd'][:,1] - log2['bd'][:,0])/np.linalg.norm(log2['X'], axis=(1,2)), label=r'$\gamma = {:.2f}$'.format(gamma), color=cm.tab20(2*j), zorder=100-10*j)
plt.ylim(0,0.5)
plt.xticks([0,200,400,600,800,1000], fontsize=20)
plt.yticks([0,0.1,0.2,0.3,0.4,0.5], fontsize=20)
plt.ylabel(r'Persistence/$||X||_2$', fontsize=20)
plt.legend(loc='lower right', fontsize=20)
plt.savefig('figs/death_minus_birth_simplices.png', bbox_inches='tight', dpi=150)
plt.show()

## Multi-cochains: averaging over multiple epsilons

In [None]:
repeats = 10
gamma = 0.02
iter = 1000
epsilon = [0.01, 0.05, 0.1]
npoints_list = np.arange(10,21)
relative_epsilon = True

np.random.seed(2025)
npoints_data = []
bds_cochains_multi = []
bds_simplices = []
Xs_cochains_multi = []
Xs_simplices = []
Xs_initial = []
for npoints in npoints_list:
    for rep in range(repeats):
        npoints_data.append(npoints)
        X_old = create_small_dataset(npoints=npoints, noise_level=0.1)
        Xs_initial.append(X_old)
        print('Trial {} with {} points'.format(rep, npoints))
        X, log1 = move_points_enhance_H1(X_old, log=True, max_iter=iter, epsilon=epsilon, gamma=gamma, penalty=True, tol=0, relative_epsilon=relative_epsilon)
        bds_cochains_multi.append(log1['bd'][-1][1] - log1['bd'][-1][0])
        Xs_cochains_multi.append(X)
        X, log2 = move_points_enhance_H1(X_old, method='single_edges', log=True, penalty=True, max_iter=iter, gamma=gamma)
        bds_simplices.append(log2['bd'][-1][1] - log2['bd'][-1][0])
        Xs_simplices.append(X)


In [None]:
plt.scatter(
    bds_simplices/np.array([np.linalg.norm(x) for x in Xs_simplices]), 
    bds_cochains_multi/np.array([np.linalg.norm(x) for x in Xs_cochains_multi]), alpha=0.7,
    c = npoints_data, cmap='viridis'
)
plt.plot([0,0.5], [0,0.5], c='gray', linestyle='--')
plt.gca().set_aspect('equal', adjustable='box')
plt.xlabel(r'Simplices', fontsize=20)
plt.xticks([0, 0.1, 0.2, 0.3, 0.4, 0.5], fontsize=20)
plt.yticks([0, 0.1, 0.2, 0.3, 0.4, 0.5], fontsize=20)
plt.ylabel(r'Multi-cochains', fontsize=20)
plt.title(r'Persistence/$||X||_2$', fontsize=20)
cbar = plt.colorbar()
cbar.ax.set_ylabel('\# points', fontsize=20)
cbar.ax.tick_params(labelsize=20)
plt.savefig('figs/final_persistence_scatter_combo.png', bbox_inches='tight', dpi=150)
plt.show()

In [None]:
m = np.argmax(bds_cochains_multi/np.array([np.linalg.norm(x) for x in Xs_cochains_multi]) - bds_simplices/np.array([np.linalg.norm(x) for x in Xs_simplices]))
plt.scatter(Xs_simplices[m][:,0], Xs_simplices[m][:,1], label='Simplices')
plt.scatter(Xs_cochains_multi[m][:,0], Xs_cochains_multi[m][:,1], label='Multi-cochains')
plt.scatter(Xs_initial[m][:,0], Xs_initial[m][:,1], c='black', label='Initial data', zorder=100)
plt.yticks([-1, -0.5, 0, 0.5, 1], fontsize=20)
plt.xticks([-1, -0.5, 0, 0.5, 1], fontsize=20)
plt.title('Largest difference', fontsize=20)
plt.gca().set_aspect('equal', adjustable='box')
plt.legend(fontsize=18, bbox_to_anchor=(1, 1))
plt.savefig('figs/best_improvement_combo.png', bbox_inches='tight', dpi=150)
plt.show()