In [None]:
import yaml
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(style="dark", 
              palette="pastel")
sns.set(rc={"figure.dpi":300, 
            'savefig.dpi':300})

from utils.plot import draw_graph
from data.dataset import PolygonDataset

with open('cfg/gae.yaml', 'r') as f:
    cfg = yaml.safe_load(f)  

import pandas as pd
glyph_df = pd.read_pickle(cfg['test'])
osm_df = pd.read_pickle(cfg['osm'])
melb_df = pd.read_pickle(cfg['melb'])
melb_sim_df = pd.read_pickle(cfg['melb_sim'])

glyh_df_o = glyph_df[glyph_df.trans == 'o'].reset_index(drop=True).sort_values('name')
osm_df_r = osm_df[osm_df.trans == 'r'].reset_index(drop=True)   
osm_df_o = osm_df[osm_df.trans == 'o'].reset_index(drop=True)  

glyph_set = PolygonDataset(glyh_df_o)

osm_set_r = PolygonDataset(osm_df_r)
osm_set_o = PolygonDataset(osm_df_o)

melb_set = PolygonDataset(melb_df)
melb_sim_set = PolygonDataset(melb_sim_df)

In [None]:
from gtda.homology import VietorisRipsPersistence

# query shapes 
query_id = [26258, 9012, 5811, 28789, 30772,
            32358, 18329, 9968, 9787, 11576] 
pcd = [melb_set[qid].pos.numpy() for qid in query_id]

# Track connected components and loops
dims = [0, 1]
persistence = VietorisRipsPersistence(
    metric="euclidean",
    homology_dimensions=dims,
    n_jobs=6,
    collapse_edges=True)

dig = persistence.fit_transform(pcd)

In [None]:
from plotly.subplots import make_subplots
from gtda.plotting import plot_diagram

pds = [plot_diagram(d).data for d in dig]
for pd in pds: pd[1]['marker'] = dict(color="lightseagreen")
for pd in pds: pd[2]['marker'] = dict(color="lightsalmon")


fig = make_subplots(rows=5, cols=2)

fig.add_traces(
    pds[0],
    rows=1, cols=1
)

fig.add_traces(
    pds[1],
    rows=1, cols=2
)

fig.add_traces(
    pds[2],
    rows=2, cols=1
)

fig.add_traces(
    pds[3],
    rows=2, cols=2
)

fig.add_traces(
    pds[4],
    rows=3, cols=1
)

fig.add_traces(
    pds[5],
    rows=3, cols=2
)

fig.add_traces(
    pds[6],
    rows=4, cols=1
)

fig.add_traces(
    pds[7],
    rows=4, cols=2
)

fig.add_traces(
    pds[8],
    rows=5, cols=1
)

fig.add_traces(
    pds[9],
    rows=5, cols=2
)

fig.update_layout(height=1500, width=1500, title_text="Persistent Diagrams")
fig.show()