In [None]:
import os
import torch
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from concurrent.futures import ProcessPoolExecutor
import cv2
from PIL import Image

In [None]:
results_path = "../results/512_D_CMSGGAN/MSE_with_image_standardized/"
distance_matrix_files = [path for path in os.listdir(results_path) if path.endswith("distance_matrix.csv")]

## Similarity Scores

In [None]:
metrics = pd.DataFrame(columns=["gene", "gen_image_path", "real_image_path", "MSE"])
for path in distance_matrix_files:
    print(path)
    df = pd.read_csv(os.path.join(results_path, path), index_col="Unnamed: 0")
    df = df[["gen_image_path", "real_image_path", "MSE"]]
    df["gene"] = [path.split("_")[0]]*len(df)
    metrics = metrics.append(df)

In [None]:
plt.figure(figsize=(12, 6))
sns.violinplot(data=metrics, x="gene", y="MSE")
plt.show()

## Quality Evaluation

In [None]:
results_path = "../results/512_D_CMSGGAN/BRISQUE/"
brisque_scores_path = os.listdir(results_path)

brisque_scores_synthetic = pd.DataFrame(columns=["gene", "Synthetic image path", "Quality Score"])
for path in brisque_scores_path:
    print(path)
    df = pd.read_csv(os.path.join(results_path, path), index_col="Unnamed: 0")
    df["gene"] = path.split("_")[0]
    brisque_scores_synthetic = brisque_scores_synthetic.append(df)
brisque_scores_synthetic["image_type"] = ["synthetic"]*len(brisque_scores_synthetic)
brisque_scores_synthetic.rename(columns={"gene":"gene", "Synthetic image path": "file.path", "Quality Score": "brisq.score"}, inplace=True)
brisque_scores_synthetic.reset_index(inplace=True, drop=True)

In [None]:
brisque_scores_synthetic

In [None]:
real_results_path = "../datasets/syntheye/faf_dataset_cleaned.csv"
real_df = pd.read_csv(real_results_path)
real_df = real_df[real_df.fold != -1]
real_df = real_df[["gene", "file.path", "brisq.score"]]
real_df["image_type"] = ["real"]*len(real_df)
real_df.reset_index(drop=True, inplace=True)

In [None]:
from brisque import BRISQUE
from tqdm import tqdm
brisq = BRISQUE()

# modify brisq values
new_brisq = np.zeros(len(real_df))
for i, row in tqdm(real_df.iterrows()):
    img = Image.open(row["file.path"])
    img = img.resize((512, 512))
    new_brisq[i] = brisq.get_score(np.array(img))
    
real_df["brisq.score"] = np.array(new_brisq)

In [None]:
real_df

## Plot distribution of scores

In [None]:
combined_df = pd.concat([brisque_scores_synthetic, real_df], axis=0)
combined_df

In [None]:
with open("../classes.txt") as f:
    classes = f.read().splitlines()

plt.figure(figsize=(15, 6))
sns.boxplot(data=combined_df, y="brisq.score", x="gene", hue="image_type", order=classes)
plt.xticks(rotation=45)
plt.show()