In [1]:
from jupyter_dash import JupyterDash
from dash import dcc, html, Input, Output, State, no_update
import plotly.graph_objects as go
import pandas as pd
import base64
from io import BytesIO
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, Subset
from sklearn.manifold import TSNE
from tqdm import tqdm

gpu_id = 1
device = torch.device("cuda:" + str(gpu_id) if torch.cuda.is_available() else "cpu")

# Initialize the model and dataset

# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
# model.to(device)
# model.eval()

# from r3m import load_r3m
# model = load_r3m("resnet18") # resnet18, resnet34, resnet50
# model.to(device)
# model.eval()

# import clip
# model, preprocess = clip.load('ViT-B/32', device=device)

# def get_features_and_images_clip(model, dataloader):
#     features = []
#     labels = []
#     images = []
#     model.eval()
#     with torch.no_grad():
#         for inputs, batch_labels in dataloader:
#             # Convert tensor to PIL for CLIP preprocessing
#             pil_images = [transforms.ToPILImage()(img_tensor).convert("RGB") for img_tensor in inputs]
#             # Preprocess images and move to the correct device
#             clip_inputs = torch.stack([preprocess(img) for img in pil_images]).to(device)
#             outputs = model.encode_image(clip_inputs)
#             features.append(outputs)
#             labels += batch_labels.tolist()
#             # Use original images for display (already converted to PIL above)
#             for img in pil_images:
#                 buffered = BytesIO()
#                 img.save(buffered, format="PNG")
#                 encoded_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
#                 images.append(f"data:image/png;base64,{encoded_image}")
#     return torch.cat(features).cpu().numpy(), labels, images


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
full_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
subset_indices = list(range(100))  # take only first 100 samples
dataset = Subset(full_dataset, subset_indices)
dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)

# Extract features, labels, and images
def get_features_and_images(model, dataloader):
    features = []
    labels = []
    images = []
    with torch.no_grad():
        for inputs, batch_labels in tqdm(dataloader):
            outputs = model(inputs.to(device))
            features.append(outputs.detach().cpu())
            labels += batch_labels.tolist()
            for img_tensor in inputs:
                img = transforms.ToPILImage()(img_tensor).convert("RGBA")
                buffered = BytesIO()
                img.save(buffered, format="PNG")
                encoded_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
                images.append(f"data:image/png;base64,{encoded_image}")
    return torch.cat(features).cpu().numpy(), labels, images

features, labels, images = get_features_and_images(model, dataloader)

# App layout
app = JupyterDash(__name__)
app.layout = html.Div([
    dcc.Graph(id="tsne-graph"),
    dcc.Slider(
        id='perplexity-slider',
        min=5,
        max=50,
        value=30,  # Default value
        marks={str(i): str(i) for i in range(1, 51, 1)},
        step=1
    ),
    dcc.Tooltip(id="graph-tooltip"),
])

# Callback to update t-SNE plot based on perplexity slider
@app.callback(
    Output("tsne-graph", "figure"),
    Input("perplexity-slider", "value")
)
def update_tsne(perplexity):
    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
    X_reduced = tsne.fit_transform(features)

    df = pd.DataFrame({
        "x": X_reduced[:, 0],
        "y": X_reduced[:, 1],
        "label": labels,
        "image": images
    })

    fig = go.Figure(data=[
        go.Scatter(
            x=df["x"],
            y=df["y"],
            mode="markers",
            marker=dict(colorscale='Viridis', color=df["label"], size=10, opacity=0.8),
            hoverinfo='none'
        )
    ])
    fig.update_layout(plot_bgcolor='rgba(255,255,255,0.1)')
    return fig

# Callback to display hover tooltip
@app.callback(
    Output("graph-tooltip", "show"),
    Output("graph-tooltip", "bbox"),
    Output("graph-tooltip", "children"),
    Input("tsne-graph", "hoverData"),
)
def display_hover(hoverData):
    if hoverData is None:
        return False, no_update, no_update
    pt = hoverData['points'][0]
    bbox = pt["bbox"]
    num = pt["pointIndex"]
    img_src = images[num]

    children = [
        html.Div([
            html.Img(src=img_src, style={"width": "100%"}),
            html.H3(f"Label: {labels[num]}", style={"color": "darkblue"})
        ], style={'width': '200px', 'white-space': 'normal'})
    ]

    return True, bbox, children

# Run the Dash app within a Jupyter notebook
app.run_server(mode='inline')


Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


Files already downloaded and verified
Dash is running on http://127.0.0.1:8050/



Address already in use
Port 8050 is in use by another program. Either identify and stop that program, or start the server with a different port.


OSError: Address 'http://127.0.0.1:8050' already in use.
    Try passing a different port to run_server.