In [1]:
from time import time
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.collections import PathCollection
from matplotlib.legend import Legend
from ripser import ripser
#from gudhi import SimplexTree
from persim import plot_diagrams
from tqdm import tqdm

# from yao_utils import *
from atlas_general import atlas_general, load_atlas

# data_dir = "data/klein_synthetic"

# # Create atlas-graph representation
# patches_pos = np.load(data_dir+"/patches_pos_razor_big.npy")
# patches_neg = np.load(data_dir+"/patches_neg_razor_big.npy")

# vecs_pos = []
# vecs_neg = []
# for j in range(patches_pos.shape[0]):
#     vecs_pos.append(patches_pos[j, :, :].reshape(9))
#     vecs_neg.append(patches_neg[j, :, :].reshape(9))
# Vecs_pos = np.vstack(vecs_pos)
# Vecs_neg = np.vstack(vecs_neg)

# ka = klein_atlas(Vecs_pos, Vecs_neg, dist=1.0,
#             kernel_fun=lambda x: 1.0)

# ka = load_atlas("random_klein", 2, 9)
ka = load_atlas("random_atlas", 2, 3)

In [2]:
for j in range(ka.n_charts):
    x_0, L, M, h_mat, A, b, c = ka.chart_dict[j]
    s, _ = np.linalg.eigh(A)
    print(s)

[ 1.55471372  1.56116406 15.03919145]
[ 1.10272986  2.296672   13.43069609]
[ 1.24348864  2.09108694 15.02368829]
[ 1.21312252  1.58730242 10.92521273]
[ 1.10715831  2.27448013 13.34603433]
[ 1.09301692  1.73056237 10.07821365]


In [3]:
# Sample points from charts
# THE FOLLOWING IS GOOD FOR GETTING SALIENT H1 features
pts_per_chart = 60

# X_pre = []
# start = time()
# for j_theta in tqdm(range(ka.n_flags)):
#     for j_phi in range(ka.n_flags):
#         chart = (j_theta, j_phi)
#         rad = ka.rad_dict[chart]
#         boundary_fun = ka.boundary_fun_dict[chart]
#         for j_per in range(pts_per_chart):
#             cond = True
#             while cond:
#                 xi = rad * np.random.randn(2)
#                 f = boundary_fun(xi)
#                 #if np.linalg.norm(xi) <= rad:
#                 if f < 0:
#                     cond = False
#             x = ka.xi_chart_to_ambient(xi, chart)
#             X_pre.append(x)
# print("Time elapsed: "+str(time() - start))

k = 6
zeros_xi = np.zeros(2)
for j in range(k):
    boundary_fun = ka.boundary_fun_dict[j]
    b_val = boundary_fun(zeros_xi)
    print(b_val)

X_pre = []
for j in tqdm(range(k)):
    X_big = ka.sample_uniformly_from_chart_by_ind(j)
    n_big, _ = X_big.shape
    inds_smol = np.random.choice(n_big, pts_per_chart, replace=False)
    X_smol = X_big[inds_oi, :]
    if np.prod(X_smol.shape) != 0:
        X_pre.append(X_smol)

X = np.vstack(X_pre)

0.823903212982465
0.8260651654021185
0.8526590218937926
0.8125621341536409
0.8559533904036698
0.8164342179750967


  0%|                                                                                          | 0/6 [00:00<?, ?it/s]


ValueError: Cannot take a larger sample than population when 'replace=False'

# Get polar coordinates for points
thetas = []
phis = []

for j in tqdm(range(X.shape[0])):
    x = X[j, :]
    theta, phi = find_closest_theta_phi_brute(x)
    #theta, phi = find_closest_theta_phi(x)
    thetas.append(theta)
    phis.append(phi)

In [None]:
#rips_dict = ripser(X, maxdim=2)
rips_dict = ripser(X, maxdim=1, do_cocycles=True)

In [None]:
#plot_diagrams(rips_dict["dgms"])

#plt.show()

for key in rips_dict.keys():
    print(key)
    print(type(rips_dict[key]))
    print("\n")

In [None]:
dgms = rips_dict["dgms"]
dgms_0 = dgms[0]
dgms_1 = dgms[1]

ccs = rips_dict["cocycles"]
ccs_1 = ccs[1]

lts_1 = []
for birth, death in dgms_1:
    lts_1.append(death - birth)

sorted_inds = list(range(len(lts_1)))
sorted_inds.sort(key=lambda x: lts_1[x],
                reverse=True)

ccs_1_sorted = [ccs_1[ind] for ind in sorted_inds]
dgms_0_sorted = np.array([dgms_0[ind] for ind in sorted_inds])
dgms_1_sorted = np.array([dgms_1[ind] for ind in sorted_inds])
lts_1_sorted = [lts_1[ind] for ind in sorted_inds]

dgms_0_top = dgms_0_sorted[:5]
dgms_1_top = dgms_1_sorted[:5]
ccs_1_top = ccs_1_sorted[:5]

In [None]:
fig = plt.figure(figsize=(3, 3))
gs_grid_len = 10
gs = GridSpec(gs_grid_len, gs_grid_len)
ax = fig.add_subplot(gs[1:-1, 1:-1])
plot_diagrams(dgms, legend=False)

"""
# Change legend fontsize
ax = plt.gca()
legend = [c for c in ax.get_children() if isinstance(c, Legend)][0]
texts = legend.get_texts()
for text in texts:
    label = text.get_text()
    text.set_text(label)
    text.set_fontsize(12)
"""

# Set title
ax.set_title("Atlas graph of\nKlein bottle")

# Change axis fontsizes
letters = ["x", "y"]
for let in letters:
    label = getattr(ax, "get_"+let+"label")()
    getattr(ax, "set_"+let+"label")(label, fontsize=12)
    ticks = getattr(ax, "get_"+let+"ticks")()
    getattr(ax, "set_"+let+"ticklabels")([str(tick) for tick in ticks], fontsize=12)

# Draw border of dots
path_collections = [item for item in ax.get_children() if isinstance(item, PathCollection)]
for pc in path_collections:
    pc.set_edgecolor("k")
    
fig = plt.gcf()
fig.savefig("graphics/klein/homology/h1_atlas.jpg")

plt.show()

In [None]:
plot_diagrams([dgms_0_top, dgms_1_top])

plt.show()

In [None]:
print(ccs_1_top)
print(type(ccs_1_top))
print(len(ccs_1_top))
print(type(ccs_1_top[0]))
print(ccs_1_top[0].shape)

In [None]:
def draw_feature(arr):
    theta_list = []
    phi_list = []
    for row in arr:
        for ind in row:
            theta = thetas[ind]
            phi = phis[ind]
            theta_list.append(theta)
            phi_list.append(phi)
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot()
    ax.scatter(theta_list, phi_list)
    
    plt.show()

In [None]:
draw_feature(ccs_1_top[0])

In [None]:
# Sample points from charts
# THE FOLLOWING IS GOOD FOR GETTING SALIENT H2 feature
pts_per_chart = 10

X_pre = []
start = time()
for j_theta in range(ka.n_flags):
    for j_phi in range(ka.n_flags):
        chart = (j_theta, j_phi)
        rad = ka.rad_dict[chart]
        boundary_fun = ka.boundary_fun_dict[chart]
        for j_per in range(pts_per_chart):
            cond = True
            while cond:
                xi = rad * np.random.randn(2)
                f = boundary_fun(xi)
                #if np.linalg.norm(xi) <= rad:
                if f < 0:
                    cond = False
            x = ka.xi_chart_to_ambient(xi, chart)
            X_pre.append(x)
print("Time elapsed: "+str(time() - start))

X = np.vstack(X_pre)