In [1]:
%load_ext autoreload
%autoreload 2
import rootutils
import os

root = rootutils.setup_root(
    os.path.abspath(""), indicator=".project-root", pythonpath=True
)

import torch
from src.modules.geometry_encoder import (
    load_geometry_encoder_pretrained,
)
from src.data.tiles_datamodule import TilesDataModule

model = load_geometry_encoder_pretrained(
    root / "src/models/pretrained/polygnn-ckpt-oct-01", torch.device("cpu")
)

  Referenced from: <0F9D4B2E-DD75-3BAC-BD55-6FA98E65FDBD> /opt/homebrew/anaconda3/envs/geojepa-ipynb-2.4/lib/python3.10/site-packages/libpyg.so
  Reason: tried: '/Library/Frameworks/Python.framework/Versions/3.10/Python' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/Python.framework/Versions/3.10/Python' (no such file), '/Library/Frameworks/Python.framework/Versions/3.10/Python' (no such file), '/opt/homebrew/anaconda3/envs/geojepa/lib/python3.10/Python' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/anaconda3/envs/geojepa/lib/python3.10/Python' (no such file), '/opt/homebrew/anaconda3/envs/geojepa/lib/python3.10/Python' (no such file)
  saved_dict = torch.load(


In [9]:
from tqdm import tqdm
from src.modules.tokenizer import split_feat_embs_to_batch

dataset = TilesDataModule(tile_dir="../data/tiles/sthlm/processed", num_workers=0)
dataset.setup()
loader = dataset.val_dataloader()
from src.data.components.tiles import Feature


def ftostr(feat: Feature):
    if feat.is_point:
        return "point"
    elif feat.is_line:
        return "line"
    elif feat.is_polygon:
        return "polygon"
    elif feat.is_relation_part:
        return "relation"
    else:
        return "UNK"


features = []
labels = []
model.eval()
max_features = 10000
with torch.no_grad():
    for batch in tqdm(loader):
        if len(features) >= max_features:
            break
        emb = model(
            batch.nodes, batch.intra_edges, batch.inter_edges, batch.node_to_feature
        )
        feature_embs = split_feat_embs_to_batch(emb, batch)
        offset = 0
        for tid, tile in enumerate(feature_embs):
            tile_feats = batch.tiles[tid].features
            for fid, f in enumerate(tile):
                if fid >= len(tile_feats):
                    continue
                features.append(f)
                labels.append(tile_feats[fid])
            offset += batch.feature_counts[tid]
        print(len(features))

Found 27595 tile_groups
Found 27595 tile_groups
Found 27595 tile_groups


  0%|          | 1/863 [00:00<02:52,  5.00it/s]

622


  0%|          | 2/863 [00:00<05:27,  2.63it/s]

3139


  0%|          | 3/863 [00:01<07:02,  2.04it/s]

5493


  1%|          | 5/863 [00:02<05:46,  2.47it/s]

6871
7377
len(tiles) > 16, (20)


  1%|          | 6/863 [00:02<05:51,  2.44it/s]

9032


  1%|          | 7/863 [00:02<05:55,  2.41it/s]

10545





In [11]:
# Apply UMAP to reduce to 2D
import umap.umap_ as umap

umap_model = umap.UMAP(n_components=2, random_state=42, n_neighbors=15)
umap_results = umap_model.fit_transform(features)

# Add UMAP results to DataFrame

  warn(
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [13]:
import base64
import pandas as pd
import matplotlib.pyplot as plt
from io import BytesIO

from dash import Dash
from dash import dcc, html, Input, Output, no_update
import plotly.graph_objects as go

from tqdm import tqdm


df = pd.DataFrame()
df["umap-2d-one"] = umap_results[:, 0]
df["umap-2d-two"] = umap_results[:, 1]


def ftocolor(feat):
    tc = 0
    if feat.is_line:
        tc = 10
    elif feat.is_polygon:
        tc = 20
    elif feat.is_relation_part:
        tc = 30
    return len(feat.points) + tc


df["colors"] = [ftocolor(label) for label in labels]


# Create the scatter plot
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=df["umap-2d-one"],
        y=df["umap-2d-two"],
        mode="markers",
        marker=dict(size=10, color=df.colors),
    )
)


def create_geometry_image(feature):
    fig, ax = plt.subplots()
    coords = feature.points

    if feature.is_polygon:
        polygon = plt.Polygon(coords, closed=True, fill=True, color="lightblue")
        ax.add_patch(polygon)
    elif feature.is_line:
        x, y = zip(*coords)
        ax.plot(x, y, color="blue", linewidth=2)
    elif feature.is_relation_part:
        x, y = zip(*coords)
        ax.plot(x, y, color="blue")

    ax.set_aspect("equal")
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)
    ax.axis("off")  # Hide axes

    # Save to a bytes buffer
    buf = BytesIO()
    plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
    plt.close(fig)
    buf.seek(0)
    img_str = base64.b64encode(buf.read()).decode("utf-8")
    return f"data:image/png;base64,{img_str}"


# Initialize the Dash app
app = Dash(__name__)

# Define the layout of the app
app.layout = html.Div(
    className="container",
    children=[
        dcc.Graph(id="graph-5", figure=fig, clear_on_unhover=True),
        dcc.Tooltip(id="graph-tooltip-5", direction="bottom"),
    ],
)


# Update the hover display function to use the new images
@app.callback(
    Output("graph-tooltip-5", "show"),
    Output("graph-tooltip-5", "bbox"),
    Output("graph-tooltip-5", "children"),
    Input("graph-5", "hoverData"),
)
def display_hover(hoverData):
    if hoverData is None:
        return False, no_update, no_update

    hover_data = hoverData["points"][0]
    bbox = hover_data["bbox"]
    num = hover_data["pointNumber"]

    children = [
        html.Div(
            [
                html.P(
                    "Label: "
                    + ftostr(labels[num])
                    + f", {len(labels[num].points)} points",
                    style={"font-weight": "bold"},
                ),
                html.Img(
                    src=create_geometry_image(
                        labels[num]
                    ),  # Use the image generated for hover
                    style={
                        "width": "100px",
                        "height": "100px",
                        "display": "block",
                        "margin": "0 auto",
                    },
                ),
            ]
        )
    ]

    return True, bbox, children


# Run the app
if __name__ == "__main__":
    app.run_server(mode="inline", debug=False)