In [1]:
# Import necessary libraries
import io
import base64
import numpy as np
from dash import Dash, dcc, html, Input, Output, no_update, callback
import plotly.graph_objects as go
from PIL import Image
from sklearn.manifold import TSNE
import torch
from tqdm import tqdm
import random
import umap

from anomaly_detection.models.cvae3d_flex import CVAE3D
# from anomaly_detection.models.cvae3d import CVAE3D
from anomaly_detection.data.data_loader import get_data_loader
from anomaly_detection.config.config_handler import get_config
from anomaly_detection.training.train import train_model

# Helper function to convert numpy array to base64 image
def np_image_to_base64(im_matrix):
    im = Image.fromarray(im_matrix)
    buffer = io.BytesIO()
    im.save(buffer, format="png")
    encoded_image = base64.b64encode(buffer.getvalue()).decode()
    im_url = "data:image/png;base64, " + encoded_image
    return im_url

# Load your configuration
config = get_config('/home/ssulta24/Desktop/VCAE_new/anomaly_detection/config/config.yaml')

# Set device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load your pre-trained model
model = CVAE3D(input_shape=(24, 24, 240), latent_dim=config['latent_dim'], hidden_dims=config['hidden_dims']).to("cuda:1")
# model.load_state_dict(torch.load('/home/ssulta24/Desktop/VCAE_new/wandb/run-20240823_154222-zcnsqis6/files/final_model.pth', map_location=device))
model.load_state_dict(torch.load('/home/ssulta24/Desktop/VCAE_new/wandb/run-20240830_155434-okinkke3/files/best_model.pth', map_location=device))

# model = train_model(config)
model.eval()

Using device: cuda:1


  model.load_state_dict(torch.load('/home/ssulta24/Desktop/VCAE_new/wandb/run-20240830_155434-okinkke3/files/best_model.pth', map_location=device))


RuntimeError: Error(s) in loading state_dict for CVAE3D:
	Missing key(s) in state_dict: "encoder.0.0.weight", "encoder.0.0.bias", "encoder.1.0.weight", "encoder.1.0.bias", "encoder.2.0.weight", "encoder.2.0.bias", "fc_mu.weight", "fc_mu.bias", "fc_var.weight", "fc_var.bias", "decoder_input.weight", "decoder_input.bias", "decoder.0.0.weight", "decoder.0.0.bias", "decoder.1.0.weight", "decoder.1.0.bias", "decoder.2.0.weight", "decoder.2.0.bias". 
	Unexpected key(s) in state_dict: "encoder.4.weight", "encoder.4.bias", "encoder.6.weight", "encoder.6.bias", "encoder.9.weight", "encoder.9.bias", "encoder.11.weight", "encoder.11.bias", "encoder.0.weight", "encoder.0.bias", "encoder.2.weight", "encoder.2.bias", "decoder.5.weight", "decoder.5.bias", "decoder.7.weight", "decoder.7.bias", "decoder.9.weight", "decoder.9.bias", "decoder.11.weight", "decoder.11.bias", "decoder.0.weight", "decoder.0.bias", "decoder.2.weight", "decoder.2.bias". 

In [4]:
# Get your data loader
data_loader = get_data_loader(config)

# Calculate total number of samples
total_samples = len(data_loader.dataset)

# Randomly select 5000 indices
max_samples = 64
random_indices = random.sample(range(total_samples), min(max_samples, total_samples))

# Collect embeddings and original data
embeddings = []
original_data = []

with torch.no_grad():
    for idx in tqdm(random_indices, desc="Collecting random samples"):
        # Get the specific sample from the dataset
        x = data_loader.dataset[idx]
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x)
        x = x.unsqueeze(0).to(device)  # Add batch dimension and move to device
        
        mean, _ = model.encode(x)
        embeddings.append(mean.cpu().numpy())
        original_data.append(x.cpu().numpy())
embeddings = np.vstack(embeddings)
original_data = np.vstack(original_data)

print(len(embeddings))

# Randomly select 5000 samples
if len(embeddings) > max_samples:
    indices = random.sample(range(len(embeddings)), max_samples)
    embeddings = embeddings[indices]
    original_data = original_data[indices]

# Perform t-SNE
tsne = TSNE(n_components=3, random_state=42, verbose=1)
embeddings = tsne.fit_transform(embeddings) # Takes exponentially more time to compute

# # Perform UMAP  
# reducer = umap.UMAP(n_components=3, random_state=42, n_neighbors=15, min_dist=0.1)
# embeddings = reducer.fit_transform(embeddings)



Collecting random samples: 100%|██████████| 64/64 [00:00<00:00, 152.49it/s]


64
[t-SNE] Computing 63 nearest neighbors...
[t-SNE] Indexed 64 samples in 0.000s...
[t-SNE] Computed neighbors for 64 samples in 0.063s...
[t-SNE] Computed conditional probabilities for sample 64 / 64
[t-SNE] Mean sigma: 0.000002
[t-SNE] KL divergence after 250 iterations with early exaggeration: 113.118515
[t-SNE] KL divergence after 1000 iterations: 1.010657


In [6]:
# Create 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
    x=embeddings[:, 0],
    y=embeddings[:, 1],
    z=embeddings[:, 2],
    mode='markers',
    marker=dict(
        size=3,
        color=np.sum(original_data, axis=(1, 2, 3, 4)),
        colorscale='Viridis',
        opacity=0.8
    )
)])

fig.update_layout(
    title=f"3D UMAP of Latent Space (Epoch {1})",
    scene=dict(
        xaxis_title="UMAP 1",
        yaxis_title="UMAP 2",
        zaxis_title="UMAP 3",
    ),
)

fig.update_traces(
    hoverinfo="none",
    hovertemplate=None,
)

# Create Dash app
app = Dash(__name__)

app.layout = html.Div([
    dcc.Graph(id="graph-3d-plot", figure=fig, clear_on_unhover=True),
    dcc.Tooltip(id="graph-tooltip", direction='bottom'),
])

@app.callback(
    Output("graph-tooltip", "show"),
    Output("graph-tooltip", "bbox"),
    Output("graph-tooltip", "children"),
    Input("graph-3d-plot", "hoverData"),
)

def display_hover(hoverData):
    if hoverData is None:
        return False, no_update, no_update

    hover_data = hoverData["points"][0]
    bbox = hover_data["bbox"]
    num = hover_data["pointNumber"]

    # Create EELS image
    eels_data = original_data[num].squeeze().sum(axis=-1)
    eels_data = (eels_data - eels_data.min()) / (eels_data.max() - eels_data.min())
    eels_image = (eels_data * 255).astype(np.uint8)
    im_url = np_image_to_base64(eels_image)

    children = [
        html.Div([
            html.Img(
                src=im_url,
                style={"width": "100px", 'display': 'block', 'margin': '0 auto'},
            ),
            html.P(f"Sample {num}", style={'font-weight': 'bold', 'text-align': 'center'})
        ])
    ]

    return True, bbox, children

if __name__ == "__main__":
    app.run(jupyter_mode="external", jupyter_height=2000, jupyter_width="200%")

Dash app running on http://127.0.0.1:8050/
