In [1]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import euclidean_distances

from atlas_yao import *

In [2]:
data_dir = "data/klein"

Vecs = np.load(data_dir+"/klein_uniform_1000.npy")
#Vecs = np.load(data_dir+"/klein_uniform_100.npy")
patches_pos = np.load(data_dir+"/patches_pos_razor_big.npy")
patches_neg = np.load(data_dir+"/patches_neg_razor_big.npy")

vecs_pos = []
vecs_neg = []
for j in range(patches_pos.shape[0]):
    vecs_pos.append(patches_pos[j, :, :].reshape(9))
    vecs_neg.append(patches_neg[j, :, :].reshape(9))
Vecs_pos = np.vstack(vecs_pos)
Vecs_neg = np.vstack(vecs_neg)

print(Vecs_pos.shape)
print(Vecs_neg.shape)

(1000, 9)
(1000, 9)


In [3]:
# Number of charts inspired by previous notebook
n_charts = 64
km_max_iter = 1000
grid_len = 30

ka = atlas_yao(Vecs, Vecs_pos, Vecs_neg, n_charts,
              km_max_iter=km_max_iter,
              grid_len=grid_len,
              load_dist_mat=True)

Getting graph as sparse matrix...
Done
Getting graph from sparse matrix...
Done
Done
Constructing enormous graph for brute-force geodesic approximation.
Trying eps = 0.6


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:01<00:00, 58.50it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41360/41360 [06:36<00:00, 104.41it/s]


In [4]:
# Get initial points for RPB
### Get random positive point
np.random.seed(600)
n_pos = Vecs_pos.shape[0]
ind_pos = np.random.randint(n_pos)
x_pos = Vecs_pos[ind_pos, :]
### Get closest negative point
X_pos = x_pos.reshape((1, 9))
dist_vec = euclidean_distances(X_pos, Vecs_neg)[0, :]
n_neg = Vecs_neg.shape[0]
ind_neg = np.random.randint(n_neg)
x_neg = Vecs_neg[ind_neg, :]

# Ingest x_pos, x_neg into atlas graph coordinates
xi_pos, chart_pos = ka.ingest_ambient_point(x_pos)
xi_neg, chart_neg = ka.ingest_ambient_point(x_neg)

In [5]:
n_iters = 200
dist_max = 2.0
"""
output = ka.riemannian_principal_boundary_mod(chart_init_pos=(4, 0),
                                        chart_init_neg=(4, 4),
                                        n_iters=n_iters,
                                        dist_max=dist_max)
"""
output =  ka.riemannian_principal_boundary(xi_pos,
                                           chart_pos,
                                           xi_neg,
                                           chart_neg,
                                           stepsize=0.1,
                                           n_iters=n_iters,
                                           dist_max=dist_max)

  0%|                                                                                                                                                                                                                | 0/2000 [00:00<?, ?it/s]

Neg time: 2.0485284328460693
Pos time: 2.3479063510894775
Pos transport time: 4.75810432434082


  0%|                                                                                                                                                                                                      | 1/2000 [00:13<7:44:24, 13.94s/it]

Neg transport time: 4.783691167831421
Boundary fun time: 6.365776062011719e-05
Iterate storage time: 0.0001552104949951172
Neg time: 81.04838061332703
Pos time: 84.48388695716858
Pos transport time: 4.863436698913574


  0%|▏                                                                                                                                                                                                   | 2/2000 [03:09<60:20:47, 108.73s/it]

Neg transport time: 4.690964937210083
Boundary fun time: 0.0003197193145751953
Iterate storage time: 5.9604644775390625e-06
Neg time: 80.29518818855286
Pos time: 78.38463997840881
Pos transport time: 5.293054819107056


  0%|▎                                                                                                                                                                                                   | 3/2000 [05:57<75:32:13, 136.17s/it]

Neg transport time: 4.848647594451904
Boundary fun time: 6.508827209472656e-05
Iterate storage time: 6.198883056640625e-06
Neg time: 86.29346776008606
Pos time: 73.60803508758545
Pos transport time: 5.017688512802124


  0%|▍                                                                                                                                                                                                   | 4/2000 [08:47<82:51:09, 149.43s/it]

Neg transport time: 4.844358682632446
Boundary fun time: 0.0002155303955078125
Iterate storage time: 5.0067901611328125e-06
Neg time: 92.43582725524902
Pos time: 68.75523042678833
Pos transport time: 4.952892780303955


  0%|▍                                                                                                                                                                                                   | 5/2000 [11:38<87:07:06, 157.21s/it]

Neg transport time: 4.843424081802368
Boundary fun time: 6.103515625e-05
Iterate storage time: 7.3909759521484375e-06
Neg time: 95.9594783782959
Pos time: 69.50079894065857
Pos transport time: 5.520999908447266


  0%|▌                                                                                                                                                                                                   | 6/2000 [14:34<90:35:02, 163.54s/it]

Neg transport time: 4.8573408126831055
Boundary fun time: 0.00019431114196777344
Iterate storage time: 5.9604644775390625e-06


  0%|▌                                                                                                                                                                                                   | 6/2000 [15:13<84:19:09, 152.23s/it]


KeyboardInterrupt: 

In [None]:
xi_pos_list, chart_pos_list, xi_prime_pos_list, xi_neg_list, chart_neg_list, xi_prime_neg_list, xi_bou_list, chart_bou_list, xi_prime_bou_list, xi_prime_bou_pos_list, xi_prime_bou_neg_list = output

names = ["xi_pos_list", "chart_pos_list", "xi_prime_pos_list", "xi_neg_list", "chart_neg_list", "xi_prime_neg_list", "xi_bou_list", "chart_bou_list", "xi_prime_bou_list", "xi_prime_bou_pos_list", "xi_prime_bou_neg_list"]

temp_dir = "temp_save_"+str(n_iters)
for item, name in zip(output, names):
    np.save(temp_dir+"/"+name+".npy", item)