# Image Generation via Generative Adversarial Networks

## import libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import math
from tqdm import tqdm
import random
import os
from torchvision.utils import make_grid

## load data

In [2]:
from google.colab import drive 
drive.mount('/content/drive/')

directory_data  = './drive/MyDrive/Machine_Learning/'
filename_data   = 'assignment_12_data.npz'
data            = np.load(os.path.join(directory_data, filename_data))

real            = torch.from_numpy(data['real_images']).float()

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


## hyper-parameters

In [3]:
device          = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

number_epoch    = 120
size_minibatch  = 50
dim_latent      = 20
dim_channel     = 1
learning_rate_discriminator = 0.001
learning_rate_generator     = 0.001

In [4]:
import random
def affine(image, shear=0, scale=1, rate=[10, 10]):

    func_plt = transforms.functional.to_pil_image
    func_affine = transforms.functional.affine
    func_tensor = transforms.functional.to_tensor

    for i in range(len(image)):

        # random movement
        if rate[0] != 0:
            rate_1 = np.random.randint(-rate[0], rate[0]+1)
            rate_2 = np.random.randint(-rate[1], rate[1]+1)
            movement = [rate_1, rate_2]
        else:
            movement = rate

        if isinstance(scale, list):
            rescale = np.random.randint(scale[0], scale[1]+1) / 10
        else:
            rescale = scale


        trans_image = func_plt(image[i])
        trans_image = func_affine(trans_image, angle=0, shear=shear, scale=rescale, translate=movement)
        trans_image = func_tensor(trans_image)
        trans_image = trans_image.numpy()

        if i == 0:
            image_list = trans_image
        else:
            image_list = np.concatenate([image_list, trans_image], axis=0)

    return image_list

In [5]:
real_image = real[::2]
affine_12 = affine(real_image, scale=[5, 12], rate=[0, 0])
affine_random = affine(real_image, scale=1, rate=[3, 3])

## custom data loader for the PyTorch framework

In [6]:
class dataset (Dataset):
    def  __init__(self, data):

        self.data = data

    def __getitem__(self, index):

        data = self.data[index]
        data = torch.FloatTensor(data).unsqueeze(dim=0)

        return data
  
    def __len__(self):
        
        return self.data.shape[0]

## construct datasets and dataloaders for training and testing

In [7]:
# image_train = np.concatenate([real[1::2], affine_12, affine_random], axis=0)
dataset_real    = dataset(real)
dataloader_real = DataLoader(dataset_real, batch_size=size_minibatch, shuffle=True, drop_last=True)

In [8]:
# image_train.shape

## shape of the data when using the data loader

In [9]:
image_real = dataset_real[0]
print('*******************************************************************')
print('shape of the image in the training dataset:', image_real.shape)
print('*******************************************************************')

*******************************************************************
shape of the image in the training dataset: torch.Size([1, 32, 32])
*******************************************************************


## class for the neural network 

In [10]:
class Discriminator(nn.Module): 

	def __init__(self, in_channel=1, out_channel=1, dim_feature=128):
        
		super(Discriminator, self).__init__()

		self.in_channel 	= in_channel
		self.out_channel	= out_channel
		self.dim_feature	= dim_feature
		threshold_ReLU 		= 0.2
		
		self.feature = nn.Sequential(
			# ================================================================================
			nn.Conv2d(in_channel, dim_feature * 1, kernel_size=3, stride=2, padding=1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
			nn.Conv2d(dim_feature * 1, dim_feature * 2, kernel_size=3, stride=2, padding=1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
			nn.Conv2d(dim_feature * 2, dim_feature * 4, kernel_size=3, stride=2, padding=1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
			nn.Conv2d(dim_feature * 4, dim_feature * 8, kernel_size=3, stride=2, padding=1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
			nn.Conv2d(dim_feature * 8, dim_feature * 16, kernel_size=3, stride=2, padding=1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
		)	
		
		self.classifier = nn.Sequential(
			# ================================================================================
			nn.Linear(dim_feature * 16, dim_feature * 8, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
			nn.Linear(dim_feature * 8, dim_feature * 4, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
			nn.Linear(dim_feature * 4, dim_feature * 2, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
			nn.Linear(dim_feature * 2, dim_feature * 1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
			nn.Linear(dim_feature * 1, out_channel, bias=True),
			# ================================================================================
		) 

		self.network = nn.Sequential(
			self.feature,
			nn.Flatten(),
			self.classifier,
		)

		self.initialize_weight()

		# *********************************************************************
		# forward propagation
		# *********************************************************************
	def forward(self, x):

		y = self.network.forward(x)

		return y

	def initialize_weight(self):
	
		print('initialize model parameters :', 'xavier_uniform')

		for m in self.network.modules():
			
			if isinstance(m, nn.Conv2d):
				
				nn.init.xavier_uniform_(m.weight)
				
				if m.bias is not None:

					nn.init.constant_(m.bias, 1)
					pass
					
			elif isinstance(m, nn.BatchNorm2d):
				
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 1)
				
			elif isinstance(m, nn.Linear):
				
				nn.init.xavier_uniform_(m.weight)

				if m.bias is not None:
					
					nn.init.constant_(m.bias, 1)
					pass

In [11]:
class Generator(nn.Module): 

	def __init__(self, in_channel=1, out_channel=1, dim_feature=8):
        
		super(Generator, self).__init__()

		self.in_channel 	= in_channel
		self.out_channel	= out_channel
		self.dim_feature	= dim_feature
		threshold_ReLU 		= 0.2

		self.network = nn.Sequential(
			# ================================================================================
			nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
			nn.Conv2d(in_channel, dim_feature * 8, kernel_size=3, stride=1, padding=1, bias=True),
			nn.BatchNorm2d(dim_feature * 8),
			nn.LeakyReLU(threshold_ReLU, inplace=True),
			# ================================================================================
			nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
			nn.Conv2d(dim_feature * 8, dim_feature * 4, kernel_size=3, stride=1, padding=1, bias=True),
			nn.BatchNorm2d(dim_feature * 4),
			nn.LeakyReLU(threshold_ReLU, inplace=True),
			# ================================================================================
			nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
			nn.Conv2d(dim_feature * 4, dim_feature * 2, kernel_size=3, stride=1, padding=1, bias=True),
			nn.BatchNorm2d(dim_feature * 2),
			nn.LeakyReLU(threshold_ReLU, inplace=True),
			# ================================================================================
			nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
			nn.Conv2d(dim_feature * 2, dim_feature * 1, kernel_size=3, stride=1, padding=1, bias=True),
			nn.BatchNorm2d(dim_feature * 1),
			nn.LeakyReLU(threshold_ReLU, inplace=True),
			# ================================================================================
			nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
			nn.Conv2d(dim_feature * 1, out_channel, kernel_size=3, stride=1, padding=1, bias=True),
			nn.BatchNorm2d(out_channel),
			# ================================================================================
			nn.Sigmoid(),
			# ================================================================================
		) 			

		self.initialize_weight()
		
		# *********************************************************************
		# forward propagation
		# *********************************************************************
	def forward(self, x):

		y = self.network.forward(x)

		return y

	def initialize_weight(self):
	
		print('initialize model parameters :', 'xavier_uniform')

		for m in self.network.modules():
			
			if isinstance(m, nn.Conv2d):
				
				nn.init.xavier_uniform_(m.weight)
				
				if m.bias is not None:

					nn.init.constant_(m.bias, 1)
					pass
					
			elif isinstance(m, nn.BatchNorm2d):
				
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 1)
				
			elif isinstance(m, nn.Linear):
				
				nn.init.xavier_uniform_(m.weight)

				if m.bias is not None:
					
					nn.init.constant_(m.bias, 1)
					pass

## build network

In [12]:
generator       = Generator(dim_latent, 1, 128).to(device)
discriminator   = Discriminator(dim_channel, 1, 128).to(device)

optimizer_generator     = torch.optim.Adam(generator.parameters(), lr=learning_rate_generator, betas=(0.5, 0.999))
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=learning_rate_discriminator, betas=(0.5, 0.999))

initialize model parameters : xavier_uniform
initialize model parameters : xavier_uniform


## compute the prediction

In [13]:
def compute_prediction(model, input):

    prediction = model(input)

    return prediction

## compute the loss

In [14]:
def compute_loss_discriminator(generator, discriminator, latent, data_real):

    data_fake       = compute_prediction(generator, latent)
    prediction_real = compute_prediction(discriminator, data_real)
    prediction_fake = compute_prediction(discriminator, data_fake)

    criterion   = nn.BCEWithLogitsLoss()
    
    label_real  = torch.ones_like(prediction_real)
    label_fake  = torch.zeros_like(prediction_fake)

    # ==================================================
    # fill up the blank
    #    
    loss_real = criterion(prediction_real, label_real)
    loss_fake = criterion(prediction_fake, label_fake)
    # 
    # ==================================================

    loss_discriminator = (loss_real + loss_fake) / 2.0

    return loss_discriminator

In [15]:
def compute_loss_generator(generator, discriminator, latent):

    data_fake       = compute_prediction(generator, latent)
    prediction_fake = compute_prediction(discriminator, data_fake)

    criterion       = nn.BCEWithLogitsLoss()

    label_real      = torch.ones_like(prediction_fake)

    # ==================================================
    # fill up the blank
    #    
    loss_generator  = criterion(prediction_fake, label_real)
    # 
    # ==================================================

    return loss_generator

## compute the accuracy

In [16]:
def get_center_index(binary_image):
    
    area_square = np.sum(binary_image)

    height = binary_image.shape[0]
    width = binary_image.shape[1]

    x = np.linspace(0, width - 1, width)
    y = np.linspace(0, height - 1, height)
    indices_X, indices_Y = np.meshgrid(x, y)

    x_mean = np.sum(binary_image * indices_X) / area_square
    y_mean = np.sum(binary_image * indices_Y) / area_square

    return (x_mean, y_mean)

In [17]:
# create ideal square image which has the same area to the input image
def create_label(binary_images):
    
    label = np.zeros_like(binary_images)
    
    for i, binary_image in enumerate(binary_images):
        
        image_height = binary_image.shape[0]
        image_width = binary_image.shape[1]

        square_image = np.zeros((image_height, image_width))
        square_length = np.round(np.sqrt(np.sum(binary_image)))

        if square_length == 0:
            # when there is no square
            return square_image

        (square_center_x, square_center_y) = get_center_index(binary_image)

        if square_center_x < 0 or square_center_x > image_width - 1 or square_center_y < 0 or square_center_y > image_height - 1:
            return square_image

        top = np.ceil(square_center_y - square_length / 2)
        bottom = np.floor(square_center_y + square_length / 2)
        left = np.ceil(square_center_x - square_length / 2)
        right = np.floor(square_center_x + square_length / 2)

        top = int(top) if top >= 0 else 0
        bottom = int(bottom) if bottom <= image_height - 1 else image_height - 1
        left = int(left) if left >= 0 else 0
        right = int(right) if right <= image_width - 1 else image_width - 1

        square_image[top : bottom + 1, left : right + 1] = 1
        
        label[i] = square_image
        
    return label

In [18]:
def compute_accuracy(prediction):

    prediction  = prediction.squeeze(axis=1)
    
    prediction_binary   = (prediction >= 0.5).cpu().numpy().astype(int)
    label               = create_label(prediction_binary).astype(int)
    
    region_intersection = prediction_binary & label
    region_union        = prediction_binary | label

    area_intersection   = region_intersection.sum(axis=1).sum(axis=1).astype(float)
    area_union          = region_union.sum(axis=1).sum(axis=1).astype(float)

    eps         = np.finfo(float).eps
    correct     = area_intersection / (area_union + eps)
    accuracy    = correct.mean() * 100.0
    
    return accuracy

## variables for the learning curve

In [19]:
loss_generator_mean     = np.zeros(number_epoch)
loss_generator_std      = np.zeros(number_epoch)
loss_discriminator_mean = np.zeros(number_epoch)
loss_discriminator_std  = np.zeros(number_epoch)

accuracy_mean   = np.zeros(number_epoch)
accuracy_std    = np.zeros(number_epoch)

## train

In [20]:
def train(generator, discriminator, dataloader):

    loss_epoch_generator      = []
    loss_epoch_discriminator  = []
    accuracy_epoch = []
    
    for index_batch, data_real in enumerate(dataloader):

        size_batch  = len(data_real)
        data_real   = data_real.to(device)
        
        latent  = torch.randn(size_batch, dim_latent, device=device)
        latent  = torch.reshape(latent, [size_batch, dim_latent, 1, 1])

        # ---------------------------------------------------------------------------
        #  
        # update the generator
        #  
        # ---------------------------------------------------------------------------
        generator.train()
        discriminator.eval()

        optimizer_generator.zero_grad()
        loss_generator = compute_loss_generator(generator, discriminator, latent)
        loss_generator.backward()
        optimizer_generator.step()

        # ---------------------------------------------------------------------------
        #  
        # update the discriminator
        #  
        # ---------------------------------------------------------------------------
        generator.eval()
        discriminator.train()

        optimizer_discriminator.zero_grad()
        loss_discriminator = compute_loss_discriminator(generator, discriminator, latent, data_real)
        loss_discriminator.backward()
        optimizer_discriminator.step()

        data_fake   = compute_prediction(generator, latent)
        accuracy    = compute_accuracy(data_fake)

        loss_epoch_generator.append(loss_generator.item())
        loss_epoch_discriminator.append(loss_discriminator.item())
        accuracy_epoch.append(accuracy)

    loss_generator_mean_epoch       = np.mean(loss_epoch_generator)
    loss_generator_std_epoch        = np.std(loss_epoch_generator)
    
    loss_discriminator_mean_epoch   = np.mean(loss_epoch_discriminator)
    loss_discriminator_std_epoch    = np.std(loss_epoch_discriminator)

    accuracy_mean_epoch             = np.mean(accuracy_epoch)
    accuracy_std_epoch              = np.std(accuracy_epoch)

    loss_value_generator        = {'mean' : loss_generator_mean_epoch, 'std' : loss_generator_std_epoch}
    loss_value_discriminator    = {'mean' : loss_discriminator_mean_epoch, 'std' : loss_discriminator_std_epoch}
    accuracy_value              = {'mean' : accuracy_mean_epoch, 'std' : accuracy_std_epoch} 

    return loss_value_generator, loss_value_discriminator, accuracy_value


## training epoch

In [None]:
# ================================================================================
# 
# iterations for epochs
#
# ================================================================================
for i in tqdm(range(number_epoch)):
    
    # ================================================================================
    # 
    # training
    #
    # ================================================================================
    (loss_value_generator, loss_value_discriminator, accuracy_value) = train(generator, discriminator, dataloader_real)

    loss_generator_mean[i]      = loss_value_generator['mean']
    loss_generator_std[i]       = loss_value_generator['std']

    loss_discriminator_mean[i]  = loss_value_discriminator['mean']
    loss_discriminator_std[i]   = loss_value_discriminator['std']

    accuracy_mean[i]            = accuracy_value['mean']
    accuracy_std[i]             = accuracy_value['std']

    print(f"epoch : {i}")
    print(f"\tloss_value_discriminator : {loss_value_discriminator['mean']}, acc mean : {accuracy_value['mean']}")

  1%|          | 1/120 [00:17<34:46, 17.54s/it]

epoch : 0
	loss_value_discriminator : 3886.9562423672787, acc mean : 66.13629949757905


  2%|▏         | 2/120 [00:34<34:14, 17.41s/it]

epoch : 1
	loss_value_discriminator : 1.626686996325504, acc mean : 80.3515734325044


  2%|▎         | 3/120 [00:52<33:55, 17.40s/it]

epoch : 2
	loss_value_discriminator : 0.2327246782478205, acc mean : 76.94044807551619


  3%|▎         | 4/120 [01:09<33:36, 17.38s/it]

epoch : 3
	loss_value_discriminator : 1.6890033477007649, acc mean : 74.69580036417833


  4%|▍         | 5/120 [01:27<33:20, 17.39s/it]

epoch : 4
	loss_value_discriminator : 0.09095498574031301, acc mean : 71.88487143012553


  5%|▌         | 6/120 [01:44<33:03, 17.40s/it]

epoch : 5
	loss_value_discriminator : 0.14938022363064593, acc mean : 71.24624894949122


  6%|▌         | 7/120 [02:01<32:44, 17.39s/it]

epoch : 6
	loss_value_discriminator : 0.11451147992683705, acc mean : 72.13486538990009


  7%|▋         | 8/120 [02:19<32:28, 17.40s/it]

epoch : 7
	loss_value_discriminator : 0.19899309153646924, acc mean : 73.66086525806232


  8%|▊         | 9/120 [02:36<32:10, 17.39s/it]

epoch : 8
	loss_value_discriminator : 0.250618357817794, acc mean : 77.51929499146658


  8%|▊         | 10/120 [02:53<31:53, 17.39s/it]

epoch : 9
	loss_value_discriminator : 0.11901200081893178, acc mean : 81.04918254593669


  9%|▉         | 11/120 [03:11<31:36, 17.40s/it]

epoch : 10
	loss_value_discriminator : 0.14225253733524748, acc mean : 83.70778266132025


 10%|█         | 12/120 [03:28<31:17, 17.39s/it]

epoch : 11
	loss_value_discriminator : 0.24424392604377382, acc mean : 79.86062506585556


 11%|█         | 13/120 [03:46<30:59, 17.38s/it]

epoch : 12
	loss_value_discriminator : 0.9606133965385515, acc mean : 76.81299949863829


 12%|█▏        | 14/120 [04:03<30:41, 17.37s/it]

epoch : 13
	loss_value_discriminator : 0.076957902502875, acc mean : 79.22468813962303


 12%|█▎        | 15/120 [04:20<30:26, 17.39s/it]

epoch : 14
	loss_value_discriminator : 0.12402580669799516, acc mean : 77.60569153923375


 13%|█▎        | 16/120 [04:38<30:11, 17.42s/it]

epoch : 15
	loss_value_discriminator : 0.2138391560766586, acc mean : 79.32940329241309


 14%|█▍        | 17/120 [04:55<29:54, 17.43s/it]

epoch : 16
	loss_value_discriminator : 0.15970399245879677, acc mean : 74.99522526554406


 15%|█▌        | 18/120 [05:13<29:35, 17.41s/it]

epoch : 17
	loss_value_discriminator : 0.06297099971494009, acc mean : 73.77085603203646


 16%|█▌        | 19/120 [05:30<29:17, 17.41s/it]

epoch : 18
	loss_value_discriminator : 0.16279557347297668, acc mean : 71.72771795707244


 17%|█▋        | 20/120 [05:47<28:59, 17.40s/it]

epoch : 19
	loss_value_discriminator : 0.1180404844412277, acc mean : 74.41355918351333


 18%|█▊        | 21/120 [06:05<28:41, 17.39s/it]

epoch : 20
	loss_value_discriminator : 0.07436837706454964, acc mean : 70.08774125151739


 18%|█▊        | 22/120 [06:22<28:24, 17.39s/it]

epoch : 21
	loss_value_discriminator : 0.07521872118461964, acc mean : 72.97589524536558


 19%|█▉        | 23/120 [06:40<28:05, 17.38s/it]

epoch : 22
	loss_value_discriminator : 20533021.67226394, acc mean : 82.81287669134291


 20%|██        | 24/120 [06:57<27:49, 17.39s/it]

epoch : 23
	loss_value_discriminator : 1260.1574138375215, acc mean : 69.50452137795112


 21%|██        | 25/120 [07:14<27:31, 17.38s/it]

epoch : 24
	loss_value_discriminator : 1042.8662750443748, acc mean : 65.85203912207864


 22%|██▏       | 26/120 [07:32<27:14, 17.39s/it]

epoch : 25
	loss_value_discriminator : 962.5017516779345, acc mean : 60.64236694830584


 22%|██▎       | 27/120 [07:49<26:56, 17.39s/it]

epoch : 26
	loss_value_discriminator : 1181.0118583413057, acc mean : 56.908200158141746


 23%|██▎       | 28/120 [08:06<26:38, 17.37s/it]

epoch : 27
	loss_value_discriminator : 907.2323825747468, acc mean : 57.6614279085163


 24%|██▍       | 29/120 [08:24<26:20, 17.37s/it]

epoch : 28
	loss_value_discriminator : 572.0425834877547, acc mean : 57.81576626898641


 25%|██▌       | 30/120 [08:41<26:02, 17.36s/it]

epoch : 29
	loss_value_discriminator : 317.94625707282574, acc mean : 58.399508078350685


 26%|██▌       | 31/120 [08:59<25:45, 17.37s/it]

epoch : 30
	loss_value_discriminator : 201.87828682744225, acc mean : 59.714406212846114


 27%|██▋       | 32/120 [09:16<25:28, 17.37s/it]

epoch : 31
	loss_value_discriminator : 127.29202716849571, acc mean : 61.42712101224879


 28%|██▊       | 33/120 [09:33<25:12, 17.39s/it]

epoch : 32
	loss_value_discriminator : 219.4166196113409, acc mean : 58.48478923455934


 28%|██▊       | 34/120 [09:51<24:55, 17.39s/it]

epoch : 33
	loss_value_discriminator : 236.34960106638974, acc mean : 56.594786420542206


 29%|██▉       | 35/120 [10:08<24:37, 17.38s/it]

epoch : 34
	loss_value_discriminator : 181.6439048634019, acc mean : 55.84064904567939


 30%|███       | 36/120 [10:26<24:21, 17.40s/it]

epoch : 35
	loss_value_discriminator : 154.8509325315786, acc mean : 55.87137355808781


 31%|███       | 37/120 [10:43<24:06, 17.43s/it]

epoch : 36
	loss_value_discriminator : 386.9482779308807, acc mean : 55.3044494267388


 32%|███▏      | 38/120 [11:00<23:48, 17.41s/it]

epoch : 37
	loss_value_discriminator : 69.47824632289797, acc mean : 57.92934786432869


 32%|███▎      | 39/120 [11:18<23:29, 17.40s/it]

epoch : 38
	loss_value_discriminator : 9000.797371198965, acc mean : 62.395547479113354


 33%|███▎      | 40/120 [11:35<23:13, 17.42s/it]

epoch : 39
	loss_value_discriminator : 2384.7579463581706, acc mean : 67.82935947438447


 34%|███▍      | 41/120 [11:53<22:56, 17.42s/it]

epoch : 40
	loss_value_discriminator : 3716.1726974803346, acc mean : 66.7460740532268


 35%|███▌      | 42/120 [12:10<22:37, 17.40s/it]

epoch : 41
	loss_value_discriminator : 312.5090971547504, acc mean : 66.29901683446857


 36%|███▌      | 43/120 [12:27<22:19, 17.39s/it]

epoch : 42
	loss_value_discriminator : 331.3027938576632, acc mean : 65.98152529373066


 37%|███▋      | 44/120 [12:45<22:00, 17.38s/it]

epoch : 43
	loss_value_discriminator : 705.71518379034, acc mean : 66.00882430184797


 38%|███▊      | 45/120 [13:02<21:44, 17.39s/it]

epoch : 44
	loss_value_discriminator : 680.6507372301678, acc mean : 66.62290220452178


 38%|███▊      | 46/120 [13:20<21:25, 17.37s/it]

epoch : 45
	loss_value_discriminator : 411.0822413466698, acc mean : 65.36441874630096


 39%|███▉      | 47/120 [13:37<21:09, 17.39s/it]

epoch : 46
	loss_value_discriminator : 320.33658373633097, acc mean : 59.42999139608662


 40%|████      | 48/120 [13:54<20:51, 17.38s/it]

epoch : 47
	loss_value_discriminator : 3168.7330289718716, acc mean : 57.284696071671995


 41%|████      | 49/120 [14:12<20:33, 17.38s/it]

epoch : 48
	loss_value_discriminator : 1329.241632239763, acc mean : 56.30434326279051


 42%|████▏     | 50/120 [14:29<20:17, 17.39s/it]

epoch : 49
	loss_value_discriminator : 657.2752918722325, acc mean : 54.458515590061296


 42%|████▎     | 51/120 [14:46<19:59, 17.38s/it]

epoch : 50
	loss_value_discriminator : 283.3150239101676, acc mean : 53.57793301318051


 43%|████▎     | 52/120 [15:04<19:43, 17.41s/it]

epoch : 51
	loss_value_discriminator : 78.8895913057549, acc mean : 52.933863102608


 44%|████▍     | 53/120 [15:21<19:26, 17.41s/it]

epoch : 52
	loss_value_discriminator : 125.0821383276651, acc mean : 60.44021776800636


 45%|████▌     | 54/120 [15:39<19:08, 17.40s/it]

epoch : 53
	loss_value_discriminator : 500.81420831902085, acc mean : 60.27600046012376


 46%|████▌     | 55/120 [15:56<18:51, 17.40s/it]

epoch : 54
	loss_value_discriminator : 370.4803660858509, acc mean : 57.46964802593039


 47%|████▋     | 56/120 [16:14<18:32, 17.38s/it]

epoch : 55
	loss_value_discriminator : 205.72836462287015, acc mean : 57.1387964159601


 48%|████▊     | 57/120 [16:31<18:15, 17.39s/it]

epoch : 56
	loss_value_discriminator : 130.94772268450538, acc mean : 57.17405196321582


 48%|████▊     | 58/120 [16:48<17:59, 17.41s/it]

epoch : 57
	loss_value_discriminator : 487.57063616153806, acc mean : 57.58185810698476


 49%|████▉     | 59/120 [17:06<17:41, 17.40s/it]

epoch : 58
	loss_value_discriminator : 1355.8804021436115, acc mean : 58.112893959783015


 50%|█████     | 60/120 [17:23<17:23, 17.40s/it]

epoch : 59
	loss_value_discriminator : 765.5404089994208, acc mean : 49.1742183668564


 51%|█████     | 61/120 [17:41<17:06, 17.41s/it]

epoch : 60
	loss_value_discriminator : 118018.18057541514, acc mean : 41.46549412822758


 52%|█████▏    | 62/120 [17:58<16:48, 17.39s/it]

epoch : 61
	loss_value_discriminator : 805.2762497261513, acc mean : 52.02254788779534


 52%|█████▎    | 63/120 [18:15<16:31, 17.40s/it]

epoch : 62
	loss_value_discriminator : 2303.470531829568, acc mean : 56.627760597745564


 53%|█████▎    | 64/120 [18:33<16:13, 17.39s/it]

epoch : 63
	loss_value_discriminator : 1791.592176348664, acc mean : 57.15590323892439


 54%|█████▍    | 65/120 [18:50<15:56, 17.39s/it]

epoch : 64
	loss_value_discriminator : 656.4566430768301, acc mean : 57.17399305170972


 55%|█████▌    | 66/120 [19:07<15:38, 17.38s/it]

epoch : 65
	loss_value_discriminator : 804.0072633554769, acc mean : 57.20804135201774


 56%|█████▌    | 67/120 [19:25<15:21, 17.38s/it]

epoch : 66
	loss_value_discriminator : 383.2977557692844, acc mean : 57.15783893064913


 57%|█████▋    | 68/120 [19:42<15:03, 17.38s/it]

epoch : 67
	loss_value_discriminator : 340.0190785749014, acc mean : 58.91712425309841


 57%|█████▊    | 69/120 [20:00<14:45, 17.37s/it]

epoch : 68
	loss_value_discriminator : 301.154246543729, acc mean : 64.33012341838023


 58%|█████▊    | 70/120 [20:17<14:28, 17.37s/it]

epoch : 69
	loss_value_discriminator : 201.82171744801278, acc mean : 65.45120910078619


 59%|█████▉    | 71/120 [20:34<14:11, 17.37s/it]

epoch : 70
	loss_value_discriminator : 121.75305361803188, acc mean : 70.03159561270117


 60%|██████    | 72/120 [20:52<13:53, 17.36s/it]

epoch : 71
	loss_value_discriminator : 701.5015442426815, acc mean : 69.05846984089516


 61%|██████    | 73/120 [21:09<13:36, 17.38s/it]

epoch : 72
	loss_value_discriminator : 278.4674811709759, acc mean : 68.12877028422797


 62%|██████▏   | 74/120 [21:26<13:19, 17.38s/it]

epoch : 73
	loss_value_discriminator : 93.79197249578876, acc mean : 67.8798236520004


 62%|██████▎   | 75/120 [21:44<13:01, 17.37s/it]

epoch : 74
	loss_value_discriminator : 1024.5363771083744, acc mean : 70.75741975167938


 63%|██████▎   | 76/120 [22:01<12:45, 17.40s/it]

epoch : 75
	loss_value_discriminator : 88.1182742895082, acc mean : 69.51734853997728


 64%|██████▍   | 77/120 [22:19<12:27, 17.39s/it]

epoch : 76
	loss_value_discriminator : 102.45260678335677, acc mean : 67.83626941202618


 65%|██████▌   | 78/120 [22:36<12:10, 17.38s/it]

epoch : 77
	loss_value_discriminator : 263.06307403154153, acc mean : 66.13618252903387


 66%|██████▌   | 79/120 [22:53<11:52, 17.37s/it]

epoch : 78
	loss_value_discriminator : 113.17486210479292, acc mean : 64.57051246289564


 67%|██████▋   | 80/120 [23:11<11:34, 17.37s/it]

epoch : 79
	loss_value_discriminator : 168.99068004031514, acc mean : 66.7420356992087


 68%|██████▊   | 81/120 [23:28<11:17, 17.37s/it]

epoch : 80
	loss_value_discriminator : 46.9141335525485, acc mean : 71.48747409415007


 68%|██████▊   | 82/120 [23:45<11:00, 17.38s/it]

epoch : 81
	loss_value_discriminator : 295.40107798021893, acc mean : 72.78655917757114


 69%|██████▉   | 83/120 [24:03<10:42, 17.37s/it]

epoch : 82
	loss_value_discriminator : 10974.487085112305, acc mean : 66.74198982155562


 70%|███████   | 84/120 [24:20<10:25, 17.39s/it]

epoch : 83
	loss_value_discriminator : 28.61305664583694, acc mean : 48.632257923818585


 71%|███████   | 85/120 [24:38<10:09, 17.40s/it]

epoch : 84
	loss_value_discriminator : 15.975981317287268, acc mean : 47.206990934870426


 72%|███████▏  | 86/120 [24:55<09:51, 17.39s/it]

epoch : 85
	loss_value_discriminator : 16.70514684846235, acc mean : 47.830554562710475


 72%|███████▎  | 87/120 [25:12<09:33, 17.39s/it]

epoch : 86
	loss_value_discriminator : 10.077673008101016, acc mean : 50.804059833335664


 73%|███████▎  | 88/120 [25:30<09:16, 17.39s/it]

epoch : 87
	loss_value_discriminator : 9.3249207693477, acc mean : 58.656415280634455


 74%|███████▍  | 89/120 [25:47<08:58, 17.37s/it]

epoch : 88
	loss_value_discriminator : 284.5543991086658, acc mean : 69.2344050536732


 75%|███████▌  | 90/120 [26:05<08:41, 17.37s/it]

epoch : 89
	loss_value_discriminator : 92.33795133876333, acc mean : 72.16176218745254


 76%|███████▌  | 91/120 [26:22<08:23, 17.38s/it]

epoch : 90
	loss_value_discriminator : 61.04395771026611, acc mean : 73.75099278600989


 77%|███████▋  | 92/120 [26:39<08:06, 17.39s/it]

epoch : 91
	loss_value_discriminator : 72.73896316525548, acc mean : 74.59343432412601


 78%|███████▊  | 93/120 [26:57<07:49, 17.37s/it]

epoch : 92
	loss_value_discriminator : 1706.7728833861129, acc mean : 74.71022977460183


 78%|███████▊  | 94/120 [27:14<07:31, 17.38s/it]

epoch : 93
	loss_value_discriminator : 486.9416712827461, acc mean : 54.5598540631018


 79%|███████▉  | 95/120 [27:31<07:14, 17.36s/it]

epoch : 94
	loss_value_discriminator : 2886.800564144933, acc mean : 45.81022782600939


 80%|████████  | 96/120 [27:49<06:56, 17.37s/it]

epoch : 95
	loss_value_discriminator : 561.9468511647956, acc mean : 28.959072297707984


 81%|████████  | 97/120 [28:06<06:39, 17.37s/it]

epoch : 96
	loss_value_discriminator : 328.2048782803291, acc mean : 36.8904742077564


 82%|████████▏ | 98/120 [28:24<06:22, 17.38s/it]

epoch : 97
	loss_value_discriminator : 191.79112675607334, acc mean : 52.32297949982677


 82%|████████▎ | 99/120 [28:41<06:04, 17.37s/it]

epoch : 98
	loss_value_discriminator : 116.7495331708775, acc mean : 69.45827523515814


 83%|████████▎ | 100/120 [28:58<05:47, 17.38s/it]

epoch : 99
	loss_value_discriminator : 59.93946657722253, acc mean : 71.28416935870874


 84%|████████▍ | 101/120 [29:16<05:30, 17.37s/it]

epoch : 100
	loss_value_discriminator : 54.15423074364662, acc mean : 73.5705200094567


 85%|████████▌ | 102/120 [29:33<05:12, 17.37s/it]

epoch : 101
	loss_value_discriminator : 136.42490994791652, acc mean : 76.16276739817049


 86%|████████▌ | 103/120 [29:50<04:55, 17.36s/it]

epoch : 102
	loss_value_discriminator : 263.8513992442641, acc mean : 75.39412621561874


 87%|████████▋ | 104/120 [30:08<04:38, 17.38s/it]

epoch : 103
	loss_value_discriminator : 54.48567278142517, acc mean : 76.4648632363888


 88%|████████▊ | 105/120 [30:25<04:20, 17.39s/it]

epoch : 104
	loss_value_discriminator : 367.6731976759295, acc mean : 76.76591206572267


 88%|████████▊ | 106/120 [30:43<04:03, 17.37s/it]

epoch : 105
	loss_value_discriminator : 96.2542292417222, acc mean : 76.33668693501943


 89%|████████▉ | 107/120 [31:00<03:45, 17.38s/it]

epoch : 106
	loss_value_discriminator : 99.54994020905606, acc mean : 74.9792247208708


 90%|█████████ | 108/120 [31:17<03:28, 17.37s/it]

epoch : 107
	loss_value_discriminator : 87.8971877264422, acc mean : 75.57123608904524


 91%|█████████ | 109/120 [31:35<03:11, 17.38s/it]

epoch : 108
	loss_value_discriminator : 10.113550892662863, acc mean : 75.8006656957219


 92%|█████████▏| 110/120 [31:52<02:53, 17.36s/it]

epoch : 109
	loss_value_discriminator : 14.282144139214678, acc mean : 75.98963453060567


 92%|█████████▎| 111/120 [32:09<02:36, 17.37s/it]

epoch : 110
	loss_value_discriminator : 9.795620093165441, acc mean : 86.77139955433171


 93%|█████████▎| 112/120 [32:27<02:18, 17.35s/it]

epoch : 111
	loss_value_discriminator : 2.8200442541477293, acc mean : 90.87010452047055


 94%|█████████▍| 113/120 [32:44<02:01, 17.38s/it]

epoch : 112
	loss_value_discriminator : 52.37600050088278, acc mean : 84.572840291541


 95%|█████████▌| 114/120 [33:02<01:44, 17.38s/it]

epoch : 113
	loss_value_discriminator : 156.24064328937337, acc mean : 86.24797840586074


---

# functions for visualizing the results 

---

## plot curve

In [None]:
def plot_image_grid(data, nRow, nCol, filename=None):

    size_col = 1.5
    size_row = 1.5

    fig, axes = plt.subplots(nRow, nCol, constrained_layout=True, figsize=(nCol * size_col, nRow * size_row))
    
    data = data.detach().cpu()

    for i in range(nRow):
        for j in range(nCol):

            k       = i * nCol + j
            image   = np.squeeze(data[k], axis=0)

            axes[i, j].imshow(image, cmap='gray', vmin=0, vmax=1)
            axes[i, j].xaxis.set_visible(False)
            axes[i, j].yaxis.set_visible(False)

    plt.show()

    if filename is not None:

        fig.savefig(filename)
        pass

In [None]:
def plot_data_grid(data, index_data, nRow, nCol):
    
    size_col = 1.5
    size_row = 1.5

    fig, axes = plt.subplots(nRow, nCol, constrained_layout=True, figsize=(nCol * size_col, nRow * size_row))

    for i in range(nRow):
        for j in range(nCol):

            k       = i * nCol + j
            index   = index_data[k]

            axes[i, j].imshow(data[index], cmap='gray', vmin=0, vmax=1)
            axes[i, j].xaxis.set_visible(False)
            axes[i, j].yaxis.set_visible(False)

    plt.show()

In [None]:
def plot_data_tensor_grid(data, index_data, nRow, nCol):
    
    size_col = 1.5
    size_row = 1.5

    fig, axes = plt.subplots(nRow, nCol, constrained_layout=True, figsize=(nCol * size_col, nRow * size_row))

    data = data.detach().cpu().squeeze(axis=1)

    for i in range(nRow):
        for j in range(nCol):

            k       = i * nCol + j
            index   = index_data[k]

            axes[i, j].imshow(data[index], cmap='gray', vmin=0, vmax=1)
            axes[i, j].xaxis.set_visible(False)
            axes[i, j].yaxis.set_visible(False)

    plt.show()

In [None]:
def plot_curve_error(data_mean, data_std, x_label, y_label, title, filename=None):

    fig = plt.figure(figsize=(8, 6))
    plt.title(title)

    alpha = 0.3
    
    plt.plot(range(len(data_mean)), data_mean, '-', color = 'red')
    plt.fill_between(range(len(data_mean)), data_mean - data_std, data_mean + data_std, facecolor = 'blue', alpha = alpha) 
    
    plt.xlabel(x_label)
    plt.ylabel(y_label)

    plt.tight_layout()
    plt.show()

    if filename is not None:

        fig.savefig(filename)
        pass

In [None]:
def print_curve(data, index):
    
    for i in range(len(index)):

        idx = index[i]
        val = data[idx]

        print('index = %2d, value = %12.10f' % (idx, val))

In [None]:
def get_data_last(data, index_start):

    data_last = data[index_start:]

    return data_last

In [None]:
def get_max_last_range(data, index_start):

    data_range = get_data_last(data, index_start)
    value = data_range.max()

    return value

In [None]:
def get_min_last_range(data, index_start):

    data_range = get_data_last(data, index_start)
    value = data_range.min()

    return value

---

# functions for presenting the results

---

In [None]:
def function_result_01():

    print('[plot examples of the real images]')
    print('') 

    nRow = 8
    nCol = 6

    number_data = len(dataset_real)
    step        = int(np.floor(number_data / (nRow * nCol)))
    index_data  = np.arange(0, number_data, step)
    index_plot  = np.arange(0, nRow * nCol)

    data = dataset_real[index_data]
    data = data[0]
    
    plot_data_grid(data, index_plot, nRow, nCol)

In [None]:
def function_result_02():

    print('[plot examples of the fake images]')
    print('') 

    nRow = 8
    nCol = 6
    number_latent = nRow * nCol

    latent  = torch.randn(number_latent, dim_latent, device=device)
    latent  = torch.reshape(latent, [number_latent, dim_latent, 1, 1])

    generator.eval()

    data_fake   = generator(latent)
    filename    = 'fake_image.png'

    plot_image_grid(data_fake, nRow, nCol, filename)

In [None]:
def function_result_03():

    print('[plot the generator loss]')
    print('') 

    plot_curve_error(loss_generator_mean, loss_generator_std, 'epoch', 'loss', 'generator loss', 'loss_generator.png')

In [None]:
def function_result_04():
    
    print('[plot the discriminator loss]')
    print('') 
    
    plot_curve_error(loss_discriminator_mean, loss_discriminator_std, 'epoch', 'loss', 'discriminator loss', 'loss_discriminator.png')

In [None]:
def function_result_05():
    
    print('[plot the accuracy]')
    print('') 
    
    plot_curve_error(accuracy_mean, accuracy_std, 'epoch', 'accuracy', 'training accuracy', 'training_accuracy.png')

In [None]:
def function_result_06():
    
    print('[print the generator loss at the last 10 epochs]')
    print('') 

    data_last = get_data_last(loss_generator_mean, -10)
    index = np.arange(0, 10)
    print_curve(data_last, index)

In [None]:
def function_result_07():
    
    print('[print the discriminator loss at the last 10 epochs]')
    print('') 

    data_last = get_data_last(loss_discriminator_mean, -10)
    index = np.arange(0, 10)
    print_curve(data_last, index)

In [None]:
def function_result_08():
    
    print('[print the accuracy at the last 10 epochs]')
    print('') 

    data_last = get_data_last(accuracy_mean, -10)
    index = np.arange(0, 10)
    print_curve(data_last, index)

In [None]:
def function_result_09():
    
    print('[print the best accuracy within the last 10 epochs]')
    print('') 
    
    value = get_max_last_range(accuracy_mean, -10)
    print('best accuracy = %12.10f' % (value))

---

# RESULTS

---

In [None]:
number_result = 9

for i in range(number_result):

    title           = '# RESULT # {:02d}'.format(i+1) 
    name_function   = 'function_result_{:02d}()'.format(i+1)

    print('') 
    print('################################################################################')
    print('#') 
    print(title)
    print('#') 
    print('################################################################################')
    print('') 

    eval(name_function)