# Import libraries

In [None]:
import os
os.chdir("..")

In [None]:
import os
import torch
import pandas as pd
import numpy as np
from utils.data_utils import get_noise, get_one_hot_labels, combine_vectors, ImageDataset, show_tensor_images
from utils.evaluation_utils import Interpolate
from matplotlib import pyplot as plt

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# Load original dataset 

In [None]:
from torchvision import transforms
image_transforms = transforms.Compose([transforms.Resize((512, 512)), transforms.Grayscale(), transforms.ToTensor()])
data_file = "datasets/syntheye/faf_dataset_cleaned.csv"
classes = "classes.txt"
dataset = ImageDataset(data_file=data_file,
                       fpath_col_name="file.path",
                       lbl_col_name="gene",
                       class_vals=classes,
                       transforms=image_transforms)

In [None]:
classes, class_rep = np.unique(dataset.img_labels, return_counts=True)
class_rep_df = pd.DataFrame({'Class':classes, '%':class_rep})
plt.figure(figsize=(6, 12))
plt.pie(class_rep, labels=classes, autopct="%.1f%%", rotatelabels=True)
plt.show()

In [None]:
class_rep_df = class_rep_df.sort_values(by=["%"], ascending=False)
np.sum(class_rep_df["%"])

# Load model and weights

In [None]:
# configs
model_name = "cmsggan"
n_classes = dataset.n_classes
z_dim = 512
resolution = 512
depth = int(np.log2(resolution) - 1)
weights_dir = "checkpoints/data:faf_dataset_cleaned.csv_classes:classes.txt_trans:512-1-1_mod:cmsggan1-512-512_tr:1000-RAHinge-32-1-0.003-0.003-0.0-0.99/"
weights_file = "model_ema_state_106.pth"

In [None]:
# load architecture
from models.msggan import conditional_msggan
gan_model = conditional_msggan.MSG_GAN(latent_size=z_dim,
                                       mode="grayscale",
                                       depth=depth,
                                       n_classes=n_classes,
                                       device="cpu").gen_shadow

In [None]:
# load weights
gan_model = torch.nn.DataParallel(gan_model, device_ids=[device])
gan_model.load_state_dict(torch.load(weights_dir+weights_file, map_location=device))
# gan_model.eval()

In [None]:
torch.manual_seed(100)
noise = get_noise(1, 512).to(device)
class_enc = get_one_hot_labels(torch.tensor([1]), 36).to(device)
combined = combine_vectors(noise, class_enc)
# combined = (combined / combined.norm()) * (combined.shape[-1] ** 0.5)
images = gan_model(combined)

plt.imshow(images[-1].detach().cpu().numpy().squeeze())
plt.show()

## Image Encoder

In [None]:
from torch.nn import Module

class ImageEncoder_v2(Module):
    def __init__(self):
        super(ImageEncoder_v2).__init__()

In [None]:
from torch.nn import Module

# class NoiseLayer(Module):
#     def __init__(self, latent_size, idx=1, n_classes=36):
#         super(NoiseLayer, self).__init__()
#         self.latent_size = latent_size
#         self.class_encoding = get_one_hot_labels(torch.as_tensor([idx]), n_classes)
#         self.w = torch.nn.Parameter(torch.randn(1, self.latent_size), requires_grad=True)
        
#     def forward(self, x):
#         noise = torch.multiply(x, self.w)
#         latent = combine_vectors(noise, self.class_encoding)
#         return latent

class ImageEncoder():
    def __init__(self, latent_size=512, idx=None, n_classes=36, decoder_model=None, device=None):
        self.device = device
        
        # initialize noise as update-able parameter
        self.noise = torch.nn.Parameter(torch.randn(1, latent_size), requires_grad=True)
        
        # obtain class encoding (a non)
        self.class_encoding = get_one_hot_labels(torch.tensor([idx]), n_classes)
        
        # initialize GAN (decoder) with the parameters frozen
        self.decoder_model = decoder_model
        for param in self.decoder_model.parameters():
            param.requires_grad = False
            
        # for adjusting range of image values
        self.image_adjustor = conditional_msggan.Generator.adjust_dynamic_range
        
    def plot_results(self, losses, expected_image, predicted_image):
        plt.figure(figsize=(20, 6))
        
        plt.subplot(1, 3, 1)
        plt.plot(np.arange(len(losses)), losses)
        plt.title("Loss change over iterations")
        plt.xlabel("Iterations"), plt.ylabel("Reconstruction loss")
        
        plt.subplot(1, 3, 2)
        plt.imshow(expected_image.detach().cpu().numpy(), plt.cm.gray)
        plt.title("Target Image"), plt.axis("off")
        
        plt.subplot(1, 3, 3)
        plt.imshow(predicted_image.detach().cpu().numpy(), plt.cm.gray)
        plt.title("Generated Image"), plt.axis("off")
        plt.show()
        
    def __call__(self, expected_im, lr=0.01, iterations=10, plot_losses=True, steps=1):
        
        # initialize optimizer and reconstruction loss
        self.optim = torch.optim.Adam([self.noise], lr=lr)
        self.criterion = torch.nn.MSELoss()
        
        # save losses here
        losses = []
        
        # begin training
        for i in range(1, iterations+1):
            
            # predict image
            combined_latent = combine_vectors(self.noise, self.class_encoding)
            predicted_im = self.decoder_model(combined_latent)[-1].squeeze()
            predicted_im = self.image_adjustor(predicted_im)
            
            # compute loss 
            loss = self.criterion(predicted_im, expected_im.to(predicted_im.device))
            losses.append(loss.item())
            
            # optimize 
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()
            
            # plot error
            if plot_losses:
                if i % steps == 0:
                    self.plot_results(losses, expected_im, predicted_im)

        return self.noise

In [None]:
torch.manual_seed(1399)
_, _, expect_im, idx = dataset[500]
# expect_im = transforms.RandomHorizontalFlip(p=1)(expect_im)
encoder = ImageEncoder(idx=idx, n_classes=36, decoder_model=gan_model, device="cuda:1")
z1 = encoder(expect_im.squeeze(), lr=0.01, iterations=200, plot_losses=True, steps=10)
assert z1.shape == (1, 512)

# use noise to create reconstructed image
class_encoding = get_one_hot_labels(torch.as_tensor([idx]), 36)
combined_latent = combine_vectors(z1, class_encoding)
im_reconstruction = gan_model(combined_latent)[-1].squeeze()
image_adjustor = conditional_msggan.Generator.adjust_dynamic_range
im_reconstruction = image_adjustor(im_reconstruction)

# plot image vs its gan reconstruction
# plt.figure()
# plt.subplot(1, 2, 1)
# plt.imshow(expect_im.numpy().squeeze(), plt.cm.gray)
# plt.title("Target Image"), plt.axis("off")

# plt.subplot(1, 2, 2)
# plt.imshow(im_reconstruction.detach().cpu().numpy(), plt.cm.gray)
# plt.title("Generated Image"), plt.axis("off")
# plt.show()

In [None]:
torch.manual_seed(1399)
_, _, expect_im, idx = dataset[500]
from torchvision import transforms
expect_im = transforms.RandomHorizontalFlip(p=1)(expect_im)
encoder = ImageEncoder(idx=idx, n_classes=36, decoder_model=gan_model, device="cuda:1")
z2 = encoder(expect_im.squeeze(), lr=0.05, iterations=2000, plot_losses=True, steps=250)
assert z2.shape == (1, 512)

# use noise to create reconstructed image
class_encoding = get_one_hot_labels(torch.as_tensor([idx]), 36)
combined_latent = combine_vectors(z2, class_encoding)
im_reconstruction = gan_model(combined_latent)[-1].squeeze()
image_adjustor = conditional_msggan.Generator.adjust_dynamic_range
im_reconstruction = image_adjustor(im_reconstruction)

## Can the GAN reconstruct its own image?

In [None]:
# torch.manual_seed(1399)

idx, n_classes = 10, 36

# obtain true 
noise_true = torch.randn(1, 512)
class_encoding = get_one_hot_labels(torch.tensor([idx]), n_classes)
synthetic_im = gan_model(combine_vectors(noise_true, class_encoding))[-1]

# reconstruct 
encoder = ImageEncoder(idx=idx, n_classes=n_classes, decoder_model=gan_model, device="cuda:1")
noise_expected = encoder(synthetic_im.squeeze(), lr=0.01, iterations=1000, plot_losses=True, steps=250)

## Differentiate left vs right

## Generate some random images

In [None]:
# sample a set of noise vectors
n_samples = 50
noise = get_noise(n_samples, 512)
labels = torch.tensor(n_samples*[dataset.class2idx["ABCA4"]])
one_hot_labels = get_one_hot_labels(labels, dataset.n_classes)
noise_and_labels = combine_vectors(noise, one_hot_labels)
with torch.no_grad():
    generated_images = gan_model(noise_and_labels)[-1]

image_adjustor = conditional_msggan.Generator.adjust_dynamic_range
generated_images = image_adjustor(generated_images)

In [None]:
# visualize images
plt.figure(figsize=(12, 12))
_ = show_tensor_images(generated_images, normalize=False, show_image=True)

In [None]:
# save_folder = "results/gif_imgs/"
# os.makedirs(save_folder, exist_ok=True)
# for i in range(len(generated_images)):
#     plt.figure(figsize=(12, 6))
#     plt.axis('off')
#     _ = show_tensor_images(generated_images[i], normalize=False, show_image=True, save_path=save_folder+"res_{}".format(i))

## Exploring the latent space

In [None]:
from scipy.stats import ttest_ind
right_eyes = [2, 7, 9, 10, 12, 15, 17, 20, 21, 23, 27, 30, 36, 38, 46, 47]
left_eyes = [0, 1, 3, 4, 8, 11, 13, 16, 18, 26, 29, 32, 33, 37, 40, 41, 42, 45, 48]

In [None]:
# visualize indexes below to confirm correct features are used
_ = show_tensor_images(generated_images[left_eyes], normalize=False, show_image=True)

In [None]:
# find the significantly different elements of right vs left eyes
latents_right = noise_and_labels[right_eyes] 
latents_left = noise_and_labels[left_eyes]
print(latents_right.shape, latents_left.shape)
_, pval = ttest_ind(latents_right[:, :512], latents_left[:, :512], axis=0)
sig_dif_idxs = np.concatenate((np.where(pval == np.min(pval), 1, 0), np.zeros(36)))[None, :]
laterality = (torch.mean(latents_left, dim=0)[None, :] - torch.mean(latents_right, dim=0)[None, :]) * sig_dif_idxs
laterality = laterality.type(torch.FloatTensor)
print("Components that affect laterality = {}".format(np.where(pval == np.min(pval))[0]))

In [None]:
# visualize components
avg_right = torch.mean(latents_right, dim=0)
avg_left = torch.mean(latents_left, dim=0)

plt.figure(figsize=(20, 6))
plt.subplot(1, 3, 1)
plt.bar(x = np.arange(len(avg_right)), height = avg_right)
plt.bar(x = np.arange(len(avg_right)), height = avg_right*sig_dif_idxs.squeeze())
plt.xlim(-10, 548)
plt.title("Right Eye Average Latent")

plt.subplot(1, 3, 2)
plt.bar(x=np.arange(len(avg_left)), height=avg_left)
plt.bar(x = np.arange(len(avg_left)), height = avg_left*sig_dif_idxs.squeeze())
plt.xlim(-10, 548)
plt.title("Left Eye Average Latent")

plt.subplot(1, 3, 3)
plt.bar(x=np.arange(len(avg_right)), height=(avg_left - avg_right), color="black")
plt.bar(x=np.arange(len(avg_right)), height=(avg_left - avg_right)*sig_dif_idxs.squeeze(), color="red")
plt.title("Difference in Right vs Left Latent")

plt.xlim(-10, 548)
plt.show()

In [None]:
# add the laterality difference to a sample left image
sample_left = 0
factor = -100
transformed_latent = noise_and_labels[sample_left][None, :] + factor*laterality
with torch.no_grad():
    tranformed_image = gan_model(transformed_latent)[-1]
    
# visualize difference
plt.subplot(1, 2, 1)
plt.imshow(generated_images[sample_left].cpu().numpy().squeeze(), plt.cm.gray)
plt.title("Original Image")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(tranformed_image.cpu().numpy().squeeze(), plt.cm.gray)
plt.title("Image after adding feature")
plt.axis("off")
plt.show()

## Interpolate between classes

In [None]:
interpolator = Interpolate(gan_model, interp="slerp", component="classes", n_classes=dataset.n_classes, n_interpolation=6, device=device)

In [None]:
noise = get_noise(1, 512, device)
class1 = 'ABCA4'
class2 = 'USH2A'
class1_idx = torch.tensor([dataset.class2idx[class1]])
class2_idx = torch.tensor([dataset.class2idx[class2]])
interpolations = interpolator(noise, class1_idx, class2_idx)
image_adjustor = conditional_msggan.Generator.adjust_dynamic_range
interpolations[-1] = image_adjustor(interpolations[-1])

plt.figure(figsize=(20, 6))
_ = plt.axis('off')
_ = show_tensor_images(interpolations[-1], n_rows=9, normalize=False, show_image=True)

In [None]:
# save_folder = "results/gif_imgs_2/"
# os.makedirs(save_folder, exist_ok=True)
# for i in range(len(interpolations[-1])):
# #     plt.figure(figsize=(6, 6))
#     plt.axis('off')
#     _ = show_tensor_images(interpolations[-1][i], normalize=False, show_image=True, save_path=save_folder+"res_{}".format(i))

## Interpolate between noises

In [None]:
interpolator = Interpolate(gan_model, interp="slerp", component="latents", n_classes=dataset.n_classes, n_interpolation=6, device=device)

In [None]:
label = 'ABCA4'
class_idx = torch.tensor([dataset.class2idx[label]])
latent1 = get_noise(1, z_dim, device)
latent2 = get_noise(1, z_dim, device)
interpolations = interpolator(latent1, latent2, class_idx)
image_adjustor = conditional_msggan.Generator.adjust_dynamic_range
interpolations[-1] = image_adjustor(interpolations[-1])

plt.figure(figsize=(20, 6))
_ = plt.axis('off')
_ = show_tensor_images(interpolations[-1], n_rows=9, normalize=False, show_image=True)