# Synthetic Data Quality Evaluation

In [1]:
# import libraries
import os
os.chdir("..")
import torch
import pandas as pd
import numpy as np
import PIL
from PIL import Image
from torchvision.utils import make_grid
from torchvision import transforms
from matplotlib import pyplot as plt
from utils.data_utils import ImageDataset
import seaborn as sns

## Load Data

In [2]:
synthetic_dataset_path = "/home/zchayav/projects/syntheye/synthetic_datasets/stylegan2_synthetic_100perclass/generated_examples.csv"
real_dataset_path = "/home/zchayav/projects/syntheye/datasets/eye2gene_new_filepaths/all_baf_valid_50deg_filtered_train_0_edited.csv"

# add image transforms
tr = transforms.Compose([transforms.Resize((512, 512)), transforms.Grayscale(), transforms.ToTensor()])

# load csvs
synthetic_dataset = ImageDataset(synthetic_dataset_path, "file.path", "gene", class_vals="classes.txt", transforms=tr, class_mapping="classes_mapping.json")
real_dataset = ImageDataset(real_dataset_path, "file.path", "gene", class_vals="classes.txt", transforms=tr, class_mapping="classes_mapping.json")
synth_dataloader = torch.utils.data.DataLoader(synthetic_dataset, batch_size=64)
real_dataloader = torch.utils.data.DataLoader(real_dataset, batch_size=64)

In [None]:
sample_images = torch.cat([synthetic_dataset[i][2] for i in range(25)], dim=0)
sample_images = sample_images[:, None, :, :]
assert sample_images.shape == (25, 1, 512, 512)

plt.figure(figsize=(6, 6))
grid = make_grid(sample_images, nrow=5)
plt.imshow(grid.numpy().transpose(1, 2, 0))
plt.axis('off')
plt.show()

In [None]:
sample_images = torch.cat([real_dataset[i][2] for i in range(25)], dim=0)
sample_images = sample_images[:, None, :, :]
assert sample_images.shape == (25, 1, 512, 512)

plt.figure(figsize=(6, 6))
grid = make_grid(sample_images, nrow=5)
plt.imshow(grid.numpy().transpose(1, 2, 0))
plt.axis('off')
plt.show()

## Method 1: UMAP Analysis

In [3]:
import umap
from babyplots import Babyplot
from tqdm import tqdm

In [None]:
# embed the real dataset in 2D
print("Fitting Real Dataset")
real_reducer = umap.UMAP(random_state=1399, n_components=2)
for _, _, x, y in tqdm(real_dataloader):
    x = x.view(len(x), -1).numpy()
    real_reducer.fit(x)

print("Fitting Synthetic Dataset")
synth_reducer = umap.UMAP(random_state=1399, n_components=2)
for _, _, x, y in tqdm(synth_dataloader):
    x = x.view(len(x), -1).numpy()
    synth_reducer.fit(x)

In [None]:
# embed the real and synthetic datasets
real_embed = []
for _, _, x, y in real_dataloader:
    real_embed.append(real_reducer.transform(x.view(len(x), -1)))
real_embed = np.concatenate(real_embed)

synth_embed = []
for _, _, x, y in synth_dataloader:
    synth_embed.append(synth_reducer.transform(x.view(len(x), -1)))
synth_embed = np.concatenate(synth_embed)

In [None]:
df = pd.DataFrame(data=np.concatenate([real_embed, np.array(real_dataset.img_labels)[:, None]], axis=1), columns=["C1", "C2", "Label"])

In [None]:
plt.figure()
plt.scatter(df.values[:, 0], df.values[:, 1])
plt.show()

### 2D feature space plot

In [None]:
# create tSNE style plots of the clustering
bp = Babyplot()
bp.add_plot(real_embed.tolist(), "pointCloud", "categories", real_dataset.img_labels, {"colorScale": "Set1", "showLegend": True, "folded": True, "foldedEmbedding": real_embed.tolist(), "size": 5})
bp

### 3D feature space plot

In [None]:
# create tSNE style plots of the clustering
bp = Babyplot()
bp.add_plot(real_embed.tolist(), "pointCloud", "categories", real_dataset.img_labels, {"colorScale": "Set1", "showLegend": True, "folded": True, "foldedEmbedding": real_embed.tolist(), "showAxes": [True, True, True], "size": 5})
bp

In [None]:
bp = Babyplot()
bp.add_plot(synth_embed.tolist(), "pointCloud", "categories", synthetic_dataset.img_labels, {"colorScale": "Set1", "showLegend": True, "folded": True, "foldedEmbedding": synth_embed.tolist(), "showAxes": [True, True, True], "size": 5})
bp