In [1]:
import numpy as np
import math

# ============================== Torch Imports =====================================

import torch
import torch.utils.data as data_utils
import torch.optim as optim
from torch.autograd import Variable
import torch.nn as nn

# ================================ Dataset =========================================

from WSI_dataloader import collate, BreastEmbeddingDataset

#================================ Visualization ====================================

import plotly.express as px
import pandas as pd
from sklearn.manifold import TSNE



1024 30


In [2]:
# =============================================== Initializations ===============================================



In [15]:
DATASET_HDF5 = "/media/mdastorage/breast_5x_aug_1.h5"

# ===================================== Dataset to Pandas ===============================

dataset = BreastEmbeddingDataset(DATASET_HDF5)

dataloader = data_utils.DataLoader(dataset, batch_size=1, num_workers=0, pin_memory=True, shuffle=False, collate_fn=collate)

embedding_df = pd.DataFrame(columns=["embedding", "coords", "bag", "label"])

tsne = TSNE(n_components=3, random_state=0)

embeddings = []
for batch_idx, (data, _, _, _) in enumerate(dataloader):
    if batch_idx < 350:
        continue
    if batch_idx > 400:
        break
    data = np.array(data[0,:,:])
    for inst in range(data.shape[0]):
        embeddings.append(data[inst,:])

print(np.array(embeddings).shape)
projections = tsne.fit_transform(np.array(embeddings))
print(projections.shape)


instance_index = 0
for batch_idx, (data, coords, label, path) in enumerate(dataloader):
    if batch_idx < 350:
        continue
    if batch_idx > 400:
        break
    num_instances = data.shape[1]
    for index in range(instance_index, num_instances+instance_index):
        projection = projections[index]
        new_row = pd.DataFrame({"embedding_x": projection[0], "embedding_y": projection[1], "embedding_z": projection[2], "coords": coords, "bag": path, "label": label,})
        embedding_df = pd.concat([new_row, embedding_df.loc[:]]).reset_index(drop=True)
        
    instance_index += num_instances



fig = px.scatter_3d(embedding_df, x="embedding_x", y="embedding_y", z="embedding_z", color="bag", symbol="label", symbol_sequence=["circle","x"])

fig.update_traces(marker_size=3)
fig.update_layout(dict(updatemenus=[
                        dict(
                            type = "buttons",
                            direction = "left",
                            buttons=list([
                                dict(
                                    args=["visible", "legendonly"],
                                    label="Deselect All",
                                    method="restyle"
                                ),
                                dict(
                                    args=["visible", True],
                                    label="Select All",
                                    method="restyle"
                                )
                            ]),
                            pad={"r": 10, "t": 10},
                            showactive=False,
                            x=1,
                            xanchor="right",
                            y=1.1,
                            yanchor="top"
                        ),
                    ]
              ))
fig.show()


740 740 740
<class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'>
Dataset Fetched!
(2341, 1024)
(2341, 3)
