In [3]:
%load_ext autoreload
%autoreload 2
import rootutils
import os

root = rootutils.setup_root(
    os.path.abspath(""), indicator=".project-root", pythonpath=True
)

import torch
from src.modules.sat_img_encoder import SatImgEncoder


encoder = SatImgEncoder(freeze=True)



In [9]:
%autoreload 2
from tqdm import tqdm
from time import time
from src.data.tiles_datamodule import TilesDataModule

dataset = TilesDataModule(tile_dir="../data/tiles/sthlm/processed", num_workers=0)
dataset.setup()
loader = dataset.val_dataloader()


images = []
embs = []
i = 0
for batch in tqdm(loader):
    if i > 5:
        break
    i += 1
    print(batch.SAT_imgs.shape)
    start = time()
    cls, patches = encoder(batch.SAT_imgs)
    print(f"Time: {(time() - start) * 1e3 / 32} ms per image")
    for i, token in enumerate(cls):
        embs.append(token)
        images.append(batch.SAT_imgs[i])
    print(cls.shape)
    print(patches.shape)

Found 27977 tile_groups
Found 27977 tile_groups
Found 27977 tile_groups


  0%|          | 0/875 [00:00<?, ?it/s]

torch.Size([32, 3, 224, 224])


  0%|          | 1/875 [00:04<1:06:10,  4.54s/it]

Time: 137.89112865924835 ms per image
torch.Size([32, 1024])
torch.Size([32, 196, 1024])
1 features without min_box





In [10]:
# Apply UMAP to reduce to 2D
import umap.umap_ as umap

umap_model = umap.UMAP(n_components=2, random_state=42, n_neighbors=10)
umap_results = umap_model.fit_transform(embs)

# Add UMAP results to DataFrame

  warn(f"n_jobs value {self.n_jobs} overridden to 1 by setting random_state. Use no seed for parallelism.")


In [13]:
import base64
import pandas as pd
from io import BytesIO

from dash import dcc, html, Input, Output, no_update, Dash
import plotly.graph_objects as go

from tqdm import tqdm
import torchvision.transforms as transforms


df = pd.DataFrame()
df["umap-2d-one"] = umap_results[:, 0]
df["umap-2d-two"] = umap_results[:, 1]


# Create the scatter plot
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=df["umap-2d-one"],
        y=df["umap-2d-two"],
        mode="markers",
        marker=dict(size=10),
    )
)


def create_geometry_image(image: torch.Tensor):
    to_pil = transforms.ToPILImage()
    img = to_pil(image)

    # Save image to a BytesIO object
    buffered = BytesIO()
    img.save(buffered, format="PNG")

    # Encode image to Base64
    img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return f"data:image/png;base64,{img_base64}"


# Initialize the Dash app
app = Dash(__name__)

# Define the layout of the app
app.layout = html.Div(
    className="container",
    children=[
        dcc.Graph(id="graph-5", figure=fig, clear_on_unhover=True),
        dcc.Tooltip(id="graph-tooltip-5", direction="bottom"),
    ],
)


# Update the hover display function to use the new images
@app.callback(
    Output("graph-tooltip-5", "show"),
    Output("graph-tooltip-5", "bbox"),
    Output("graph-tooltip-5", "children"),
    Input("graph-5", "hoverData"),
)
def display_hover(hoverData):
    if hoverData is None:
        return False, no_update, no_update

    hover_data = hoverData["points"][0]
    bbox = hover_data["bbox"]
    num = hover_data["pointNumber"]

    children = [
        html.Div(
            [
                # html.P("Label: " + ftostr(labels[num]) + f", {len(labels[num].points)} points", style={"font-weight": "bold"}),
                html.Img(
                    src=create_geometry_image(
                        images[num]
                    ),  # Use the image generated for hover
                    style={
                        "width": "100px",
                        "height": "100px",
                        "display": "block",
                        "margin": "0 auto",
                    },
                ),
            ]
        )
    ]

    return True, bbox, children


# Run the app
if __name__ == "__main__":
    app.run_server(mode="inline", debug=False)

In [None]:
# create embeddings for all images in dataset.