# Data Exploration

In [None]:
import os
os.chdir("..")

In [None]:
from utils.data_utils import ImageDataset
from torchvision import transforms
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from PIL import Image
import cv2

In [None]:
def remove_duplicates(dataset):
    scan_ids = list(dataset["scan.id"].unique())
    for i in scan_ids:
        subset = dataset[dataset["scan.id"] == i]
        idxs = list(subset.index)
        if len(subset) > 1:
            dataset = dataset.drop(idxs[1:], axis=0)
    return dataset.reset_index(drop=True)

def visualize_duplicates(dataset):
    scan_ids = list(dataset["scan.id"].unique())
    for i in scan_ids:
        subset = dataset[dataset["scan.id"] == i]
        idxs = list(subset.index)
        if len(subset) > 1:
            print("Scan ID : {}".format(i))
            plt.figure(figsize=(15, 6))
            for j, k in enumerate(idxs):
                plt.subplot(1, len(subset), j+1)
                plt.title("Scan #: {}, Date: {}".format(dataset.iloc[k]["scan.number"],
                                                        dataset.iloc[k]["date"]))
                img = np.array(Image.open(dataset.iloc[k]["file.path"]))
                plt.imshow(img, plt.cm.gray)
                plt.axis("off")
            plt.show()

def visualize_sizes(df):
    
    genes, sizes = np.unique(df.gene, return_counts=True)
    idxs = np.argsort(sizes)
    genes = genes[idxs]
    sizes = sizes[idxs]
    
    plt.figure(figsize=(12, 4))
    plt.bar(list(reversed(genes)), list(reversed(sizes)))
    plt.xticks(rotation=45)
    plt.show()

# Inspect Dataset Size and Balancing

In [None]:
test = pd.read_csv("datasets/eye2gene/all_baf_valid_50deg_filtered_val_0_edited.csv")
len(test)

In [None]:
real = pd.read_csv("datasets/eye2gene/all_baf_valid_50deg_filtered_train_0.csv")
real_plus_synthetic1800 = pd.read_csv("datasets/syntheye/real+stylegan2_1800.csv")
real_plus_synthetic3600 = pd.read_csv("datasets/syntheye/real+stylegan2_3600.csv")
real_plus_syntheticRebalanced = pd.read_csv("datasets/syntheye/real+stylegan2_rebalanced.csv")
synthetic1800 = pd.read_csv("synthetic_datasets/stylegan2_synthetic_50perclass/generated_examples.csv")
synthetic3600 = pd.read_csv("synthetic_datasets/stylegan2_synthetic_100perclass/generated_examples.csv")
syntheticRebalanced = pd.read_csv("synthetic_datasets/stylegan2_synthetic_-1perclass/generated_examples.csv")

In [None]:
syntheticRebalanced = pd.read_csv("/home/zchayav/projects/stylegan2-ada-pytorch/stylegan2_synthetic_-1perclass/generated_examples.csv")
syntheticRebalanced["file.path"] = list(map(os.path.abspath, "synthetic_datasets/"+syntheticRebalanced["file.path"]))
# synthetic1800.to_csv("synthetic_datasets/stylegan2_synthetic_50perclass/generated_examples.csv")
syntheticRebalanced.to_csv("synthetic_datasets/stylegan2_synthetic_-1perclass/generated_examples.csv", index=False)

In [None]:
with open("classes.txt", "r") as f:
    classes = f.read().splitlines()
    
real = real[real.gene.isin(classes)]
real_plus_synthetic1800 = real_plus_synthetic1800[real_plus_synthetic1800.gene.isin(classes)]
real_plus_synthetic3600 = real_plus_synthetic3600[real_plus_synthetic3600.gene.isin(classes)]
real_plus_syntheticRebalanced = real_plus_syntheticRebalanced[real_plus_syntheticRebalanced.gene.isin(classes)]
synthetic1800 = synthetic1800[synthetic1800.gene.isin(classes)]
synthetic3600 = synthetic3600[synthetic3600.gene.isin(classes)]
syntheticRebalanced = syntheticRebalanced[syntheticRebalanced.isin(classes)]

visualize_sizes(real_df)
visualize_sizes(real_plus_synthetic1800)
visualize_sizes(real_plus_synthetic3600)
visualize_sizes(real_plus_syntheticRebalanced)
visualize_sizes(synthetic1800)
visualize_sizes(synthetic3600)
visualize_sizes(syntheticRebalanced)

# dataset_df = pd.read_csv("datasets/syntheye/")
# dataset_df = dataset_df[dataset_df.gene.isin(classes)]
# # dataset_df = dataset_df.drop(["Unnamed: 0", "Unnamed: 0.1", "Unnamed: 0.1.1"], axis=1)
# dataset_df = dataset_df.reset_index(drop=True)
# dataset_df

In [None]:
print("Dataset size = {}".format(len(dataset_df)))
print("Training set size = {}".format(len(dataset_df[dataset_df.fold != -1])))
print("Test set size = {}".format(len(dataset_df[dataset_df.fold == -1])))
print("Dataframe cols : {}".format(list(dataset_df.columns)))
print("Number of unique patients : {}".format(len(dataset_df["patient.number"].unique())))

## Visualize duplicate images in dataset (images with same scan ID)

In [None]:
visualize_duplicates(dataset_df)

## Omit duplicates

In [None]:
# There are duplicates of images (multiple images collected for a patient) in dataframe - get rid of those
cleaned_df = remove_duplicates(dataset_df)

In [None]:
print("Dataset size = {}".format(len(cleaned_df)))
print("Training set size = {}".format(len(cleaned_df[cleaned_df.fold != -1])))
print("Test set size = {}".format(len(cleaned_df[cleaned_df.fold == -1])))
print("Dataframe cols : {}".format(list(cleaned_df.columns)))
print("Number of unique patients : {}".format(len(cleaned_df["patient.number"].unique())))

In [None]:
classes, counts = np.unique(dataset_df["gene"], return_counts=True)
for i, c in enumerate(classes):
    print("Class: {}, Total : {} images\n".format(c, counts[i]))
    c_df = dataset_df.loc[dataset_df.gene == c]
    patient_to_image_ratio = c_df["patient.number"].value_counts()
    print(len(patient_to_image_ratio))
    print(patient_to_image_ratio)
    print("\n")

## Inspect Images

In [None]:
data_file= "datasets/faf_dataset_cleaned.csv"
filenames_col= "file.path"
labels_col= "gene"
train_classes= "classes.txt"

In [None]:
image_transforms = []
resize_dim = 512
image_transforms.append(transforms.Resize((resize_dim, resize_dim)))
image_transforms.append(transforms.Grayscale())
image_transforms.append(transforms.ToTensor())
image_transforms.append(transforms.Normalize((0.5,), (0.5,)))
image_transforms = transforms.Compose(image_transforms)

In [None]:
dataset = ImageDataset(data_file, filenames_col, labels_col, ["ABCA4"])

In [None]:
filter_type = None
thresholding = "adaptive"

img = [np.array(dataset[i][2]) for i in range(10)]

# apply smoothing filter followed by thresholding 
for im in img:
    
    # apply filter
    if filter_type == "averaging":
        im_filtered = cv2.blur(im, (7, 7))
    elif filter_type == "gaussian":
        im_filtered = cv2.GaussianBlur(im, (5,5), 0)
    else:
        im_filtered = im
    
    # apply thresholding
    if thresholding == "global":
        ret, thresholding_map = cv2.threshold(im_filtered, 50, 1, cv2.THRESH_BINARY_INV)
    else:
        thresholding_map = cv2.adaptiveThreshold(im, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
    
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 4, 1)
    plt.imshow(im, plt.cm.gray)
    plt.axis("off"), plt.title("Original")
    
    plt.subplot(1, 4, 2)
    plt.imshow(im_filtered, plt.cm.gray)
    plt.axis("off"), plt.title("Filtered")
    
    plt.subplot(1, 4, 3)
    plt.imshow(thresholding_map, plt.cm.gray)
    plt.axis("off"), plt.title("Threshold map")
    
    plt.subplot(1, 4, 4)
    plt.imshow(thresholding_map*im, plt.cm.gray)
    plt.axis("off"), plt.title("Original * Threshold")
    
    plt.show()
    
#     plt.figure(figsize=(12, 6))
#     x, y = np.indices((768, 768))
#     plt.subplot(1, 2, 1)
#     plt.scatter(x.ravel(), im.ravel())
#     plt.subplot(1, 2, 2)
#     plt.scatter(y.ravel(), im.ravel())
#     plt.show()