# Image Generation via Generative Adversarial Networks

## import libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import math
from tqdm import tqdm
import random
import os
from torchvision.utils import make_grid

## load data

In [2]:
from google.colab import drive 
drive.mount('/content/drive/')

directory_data  = './drive/MyDrive/Machine_Learning/'
filename_data   = 'assignment_12_data.npz'
data            = np.load(os.path.join(directory_data, filename_data))

real            = torch.from_numpy(data['real_images']).float()

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


## hyper-parameters

In [3]:
device          = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

number_epoch    = 150
size_minibatch  = 50
dim_latent      = 70
dim_channel     = 1
learning_rate_discriminator = 0.001
learning_rate_generator     = 0.001

In [4]:
import random
def affine(image, shear=0, scale=1, rate=[10, 10]):

    func_plt = transforms.functional.to_pil_image
    func_affine = transforms.functional.affine
    func_tensor = transforms.functional.to_tensor

    for i in range(len(image)):

        # random movement
        if rate[0] != 0:
            rate_1 = np.random.randint(-rate[0], rate[0]+1)
            rate_2 = np.random.randint(-rate[1], rate[1]+1)
            movement = [rate_1, rate_2]
        else:
            movement = rate

        if isinstance(scale, list):
            rescale = np.random.randint(scale[0], scale[1]+1) / 10
        else:
            rescale = scale


        trans_image = func_plt(image[i])
        trans_image = func_affine(trans_image, angle=0, shear=shear, scale=rescale, translate=movement)
        trans_image = func_tensor(trans_image)
        trans_image = trans_image.numpy()

        if i == 0:
            image_list = trans_image
        else:
            image_list = np.concatenate([image_list, trans_image], axis=0)

    return image_list

In [5]:
real_image = real[::2]
affine_12 = affine(real_image, scale=[5, 12], rate=[0, 0])
affine_random = affine(real_image, scale=1, rate=[3, 3])

## custom data loader for the PyTorch framework

In [6]:
class dataset (Dataset):
    def  __init__(self, data):

        self.data = data

    def __getitem__(self, index):

        data = self.data[index]
        data = torch.FloatTensor(data).unsqueeze(dim=0)

        return data
  
    def __len__(self):
        
        return self.data.shape[0]

## construct datasets and dataloaders for training and testing

In [7]:
# image_train = np.concatenate([real[1::2], affine_12, affine_random], axis=0)
dataset_real    = dataset(real)
dataloader_real = DataLoader(dataset_real, batch_size=size_minibatch, shuffle=True, drop_last=True)

In [8]:
# image_train.shape

## shape of the data when using the data loader

In [9]:
image_real = dataset_real[0]
print('*******************************************************************')
print('shape of the image in the training dataset:', image_real.shape)
print('*******************************************************************')

*******************************************************************
shape of the image in the training dataset: torch.Size([1, 32, 32])
*******************************************************************


## class for the neural network 

In [10]:
class Discriminator(nn.Module): 

	def __init__(self, in_channel=1, out_channel=1, dim_feature=128):
        
		super(Discriminator, self).__init__()

		self.in_channel 	= in_channel
		self.out_channel	= out_channel
		self.dim_feature	= dim_feature
		threshold_ReLU 		= 0.2
		
		self.feature = nn.Sequential(
			# ================================================================================
			nn.Conv2d(in_channel, dim_feature * 1, kernel_size=3, stride=2, padding=1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),
			nn.BatchNorm2d(dim_feature * 1),						
			# ================================================================================
			nn.Conv2d(dim_feature * 1, dim_feature * 2, kernel_size=3, stride=2, padding=1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			nn.BatchNorm2d(dim_feature * 2),
			# ================================================================================
			nn.Conv2d(dim_feature * 2, dim_feature * 4, kernel_size=3, stride=2, padding=1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),	
			nn.BatchNorm2d(dim_feature * 4),		
			# ================================================================================
			nn.Conv2d(dim_feature * 4, dim_feature * 8, kernel_size=3, stride=2, padding=1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			nn.BatchNorm2d(dim_feature * 8),
			# ================================================================================
			nn.Conv2d(dim_feature * 8, dim_feature * 16, kernel_size=3, stride=2, padding=1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),	
			nn.BatchNorm2d(dim_feature * 16),		
			# ================================================================================
		)	
		
		self.classifier = nn.Sequential(
			# ================================================================================
			nn.Linear(dim_feature * 16, dim_feature * 8, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),
			nn.BatchNorm2d(dim_feature * 8),			
			# ================================================================================
			nn.Linear(dim_feature * 8, dim_feature * 4, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			nn.BatchNorm2d(dim_feature * 4),
			# ===============================================================================
			nn.Linear(dim_feature * 4, dim_feature * 2, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			nn.BatchNorm2d(dim_feature * 2),
			# ================================================================================
			nn.Linear(dim_feature * 2, dim_feature * 1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),	
			nn.BatchNorm2d(dim_feature * 1),		
			# ================================================================================
			nn.Linear(dim_feature * 1, out_channel, bias=True),
			# ================================================================================
		) 

		self.network = nn.Sequential(
			self.feature,
			nn.Flatten(),
			self.classifier,
		)

		self.initialize_weight()

		# *********************************************************************
		# forward propagation
		# *********************************************************************
	def forward(self, x):

		y = self.network.forward(x)

		return y

	def initialize_weight(self):
	
		print('initialize model parameters :', 'xavier_uniform')

		for m in self.network.modules():
			
			if isinstance(m, nn.Conv2d):
				
				nn.init.xavier_uniform_(m.weight)
				
				if m.bias is not None:

					nn.init.constant_(m.bias, 1)
					pass
					
			elif isinstance(m, nn.BatchNorm2d):
				
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 1)
				
			elif isinstance(m, nn.Linear):
				
				nn.init.xavier_uniform_(m.weight)

				if m.bias is not None:
					
					nn.init.constant_(m.bias, 1)
					pass

In [11]:
class Generator(nn.Module): 

	def __init__(self, in_channel=1, out_channel=1, dim_feature=8):
        
		super(Generator, self).__init__()

		self.in_channel 	= in_channel
		self.out_channel	= out_channel
		self.dim_feature	= dim_feature
		threshold_ReLU 		= 0.2

		self.network = nn.Sequential(
			# ================================================================================
			nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
			nn.Conv2d(in_channel, dim_feature * 8, kernel_size=3, stride=1, padding=1, bias=True),
			nn.BatchNorm2d(dim_feature * 8),
			nn.LeakyReLU(threshold_ReLU, inplace=True),
			# ================================================================================
			nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
			nn.Conv2d(dim_feature * 8, dim_feature * 4, kernel_size=3, stride=1, padding=1, bias=True),
			nn.BatchNorm2d(dim_feature * 4),
			nn.LeakyReLU(threshold_ReLU, inplace=True),
			# ================================================================================
			nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
			nn.Conv2d(dim_feature * 4, dim_feature * 2, kernel_size=3, stride=1, padding=1, bias=True),
			nn.BatchNorm2d(dim_feature * 2),
			nn.LeakyReLU(threshold_ReLU, inplace=True),
			# ================================================================================
			nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
			nn.Conv2d(dim_feature * 2, dim_feature * 1, kernel_size=3, stride=1, padding=1, bias=True),
			nn.BatchNorm2d(dim_feature * 1),
			nn.LeakyReLU(threshold_ReLU, inplace=True),
			# ================================================================================
			nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
			nn.Conv2d(dim_feature * 1, out_channel, kernel_size=3, stride=1, padding=1, bias=True),
			nn.BatchNorm2d(out_channel),
			# ================================================================================
			nn.Sigmoid(),
			# ================================================================================
		) 			

		self.initialize_weight()
		
		# *********************************************************************
		# forward propagation
		# *********************************************************************
	def forward(self, x):

		y = self.network.forward(x)

		return y

	def initialize_weight(self):
	
		print('initialize model parameters :', 'xavier_uniform')

		for m in self.network.modules():
			
			if isinstance(m, nn.Conv2d):
				
				nn.init.xavier_uniform_(m.weight)
				
				if m.bias is not None:

					nn.init.constant_(m.bias, 1)
					pass
					
			elif isinstance(m, nn.BatchNorm2d):
				
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 1)
				
			elif isinstance(m, nn.Linear):
				
				nn.init.xavier_uniform_(m.weight)

				if m.bias is not None:
					
					nn.init.constant_(m.bias, 1)
					pass


## build network

In [12]:
generator       = Generator(dim_latent, 1, 128).to(device)
discriminator   = Discriminator(dim_channel, 1, 128).to(device)

optimizer_generator     = torch.optim.Adam(generator.parameters(), lr=learning_rate_generator, betas=(0.5, 0.999))
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=learning_rate_discriminator, betas=(0.5, 0.999))

initialize model parameters : xavier_uniform
initialize model parameters : xavier_uniform


## compute the prediction

In [13]:
def compute_prediction(model, input):

    prediction = model(input)

    return prediction

## compute the loss

In [14]:
def compute_loss_discriminator(generator, discriminator, latent, data_real):

    data_fake       = compute_prediction(generator, latent)
    prediction_real = compute_prediction(discriminator, data_real)
    prediction_fake = compute_prediction(discriminator, data_fake)

    criterion   = nn.BCEWithLogitsLoss()
    
    label_real  = torch.ones_like(prediction_real)
    label_fake  = torch.zeros_like(prediction_fake)

    # ==================================================
    # fill up the blank
    #    
    loss_real = criterion(prediction_real, label_real)
    loss_fake = criterion(prediction_fake, label_fake)
    # 
    # ==================================================

    loss_discriminator = (loss_real + loss_fake) / 2.0

    return loss_discriminator

In [15]:
def compute_loss_generator(generator, discriminator, latent):

    data_fake       = compute_prediction(generator, latent)
    prediction_fake = compute_prediction(discriminator, data_fake)

    criterion       = nn.BCEWithLogitsLoss()

    label_real      = torch.ones_like(prediction_fake)

    # ==================================================
    # fill up the blank
    #    
    loss_generator  = criterion(prediction_fake, label_real)
    # 
    # ==================================================

    return loss_generator

## compute the accuracy

In [16]:
def get_center_index(binary_image):
    
    area_square = np.sum(binary_image)

    height = binary_image.shape[0]
    width = binary_image.shape[1]

    x = np.linspace(0, width - 1, width)
    y = np.linspace(0, height - 1, height)
    indices_X, indices_Y = np.meshgrid(x, y)

    x_mean = np.sum(binary_image * indices_X) / area_square
    y_mean = np.sum(binary_image * indices_Y) / area_square

    return (x_mean, y_mean)

In [17]:
# create ideal square image which has the same area to the input image
def create_label(binary_images):
    
    label = np.zeros_like(binary_images)
    
    for i, binary_image in enumerate(binary_images):
        
        image_height = binary_image.shape[0]
        image_width = binary_image.shape[1]

        square_image = np.zeros((image_height, image_width))
        square_length = np.round(np.sqrt(np.sum(binary_image)))

        if square_length == 0:
            # when there is no square
            return square_image

        (square_center_x, square_center_y) = get_center_index(binary_image)

        if square_center_x < 0 or square_center_x > image_width - 1 or square_center_y < 0 or square_center_y > image_height - 1:
            return square_image

        top = np.ceil(square_center_y - square_length / 2)
        bottom = np.floor(square_center_y + square_length / 2)
        left = np.ceil(square_center_x - square_length / 2)
        right = np.floor(square_center_x + square_length / 2)

        top = int(top) if top >= 0 else 0
        bottom = int(bottom) if bottom <= image_height - 1 else image_height - 1
        left = int(left) if left >= 0 else 0
        right = int(right) if right <= image_width - 1 else image_width - 1

        square_image[top : bottom + 1, left : right + 1] = 1
        
        label[i] = square_image
        
    return label

In [18]:
def compute_accuracy(prediction):

    prediction  = prediction.squeeze(axis=1)
    
    prediction_binary   = (prediction >= 0.5).cpu().numpy().astype(int)
    label               = create_label(prediction_binary).astype(int)
    
    region_intersection = prediction_binary & label
    region_union        = prediction_binary | label

    area_intersection   = region_intersection.sum(axis=1).sum(axis=1).astype(float)
    area_union          = region_union.sum(axis=1).sum(axis=1).astype(float)

    eps         = np.finfo(float).eps
    correct     = area_intersection / (area_union + eps)
    accuracy    = correct.mean() * 100.0
    
    return accuracy

## variables for the learning curve

In [19]:
loss_generator_mean     = np.zeros(number_epoch)
loss_generator_std      = np.zeros(number_epoch)
loss_discriminator_mean = np.zeros(number_epoch)
loss_discriminator_std  = np.zeros(number_epoch)

accuracy_mean   = np.zeros(number_epoch)
accuracy_std    = np.zeros(number_epoch)

## train

In [20]:
def train(generator, discriminator, dataloader):

    loss_epoch_generator      = []
    loss_epoch_discriminator  = []
    accuracy_epoch = []
    
    for index_batch, data_real in enumerate(dataloader):

        size_batch  = len(data_real)
        data_real   = data_real.to(device)
        
        latent  = torch.randn(size_batch, dim_latent, device=device)
        latent  = torch.reshape(latent, [size_batch, dim_latent, 1, 1])

        # ---------------------------------------------------------------------------
        #  
        # update the generator
        #  
        # ---------------------------------------------------------------------------
        generator.train()
        discriminator.eval()

        optimizer_generator.zero_grad()
        loss_generator = compute_loss_generator(generator, discriminator, latent)
        loss_generator.backward()
        optimizer_generator.step()

        # ---------------------------------------------------------------------------
        #  
        # update the discriminator
        #  
        # ---------------------------------------------------------------------------
        generator.eval()
        discriminator.train()

        optimizer_discriminator.zero_grad()
        loss_discriminator = compute_loss_discriminator(generator, discriminator, latent, data_real)
        loss_discriminator.backward()
        optimizer_discriminator.step()

        data_fake   = compute_prediction(generator, latent)
        accuracy    = compute_accuracy(data_fake)

        loss_epoch_generator.append(loss_generator.item())
        loss_epoch_discriminator.append(loss_discriminator.item())
        accuracy_epoch.append(accuracy)

    loss_generator_mean_epoch       = np.mean(loss_epoch_generator)
    loss_generator_std_epoch        = np.std(loss_epoch_generator)
    
    loss_discriminator_mean_epoch   = np.mean(loss_epoch_discriminator)
    loss_discriminator_std_epoch    = np.std(loss_epoch_discriminator)

    accuracy_mean_epoch             = np.mean(accuracy_epoch)
    accuracy_std_epoch              = np.std(accuracy_epoch)

    loss_value_generator        = {'mean' : loss_generator_mean_epoch, 'std' : loss_generator_std_epoch}
    loss_value_discriminator    = {'mean' : loss_discriminator_mean_epoch, 'std' : loss_discriminator_std_epoch}
    accuracy_value              = {'mean' : accuracy_mean_epoch, 'std' : accuracy_std_epoch} 

    return loss_value_generator, loss_value_discriminator, accuracy_value


## training epoch

In [None]:
# ================================================================================
# 
# iterations for epochs
#
# ================================================================================
for i in tqdm(range(number_epoch)):
    
    # ================================================================================
    # 
    # training
    #
    # ================================================================================
    (loss_value_generator, loss_value_discriminator, accuracy_value) = train(generator, discriminator, dataloader_real)

    loss_generator_mean[i]      = loss_value_generator['mean']
    loss_generator_std[i]       = loss_value_generator['std']

    loss_discriminator_mean[i]  = loss_value_discriminator['mean']
    loss_discriminator_std[i]   = loss_value_discriminator['std']

    accuracy_mean[i]            = accuracy_value['mean']
    accuracy_std[i]             = accuracy_value['std']

    print(f"epoch : {i}")
    print(f"\tloss_value_discriminator : {loss_value_discriminator['mean']}, acc mean : {accuracy_value['mean']}")

  1%|          | 1/150 [00:17<43:47, 17.64s/it]

epoch : 0
	loss_value_discriminator : 2.471686920454336, acc mean : 81.598036429561


  1%|▏         | 2/150 [00:35<43:27, 17.62s/it]

epoch : 1
	loss_value_discriminator : 0.8575951172861942, acc mean : 66.72604922594132


  2%|▏         | 3/150 [00:52<43:08, 17.61s/it]

epoch : 2
	loss_value_discriminator : 296104.0453656833, acc mean : 79.17949025999268


  3%|▎         | 4/150 [01:10<42:48, 17.59s/it]

epoch : 3
	loss_value_discriminator : 0.9088526649597618, acc mean : 88.61978507461203


  3%|▎         | 5/150 [01:27<42:28, 17.58s/it]

epoch : 4
	loss_value_discriminator : 100.00922195414124, acc mean : 85.70422093279035


  4%|▍         | 6/150 [01:45<42:11, 17.58s/it]

epoch : 5
	loss_value_discriminator : 20.21682912874175, acc mean : 72.03009098297524


  5%|▍         | 7/150 [02:03<41:54, 17.58s/it]

epoch : 6
	loss_value_discriminator : 50.96546620199847, acc mean : 68.98114646908347


  5%|▌         | 8/150 [02:20<41:36, 17.58s/it]

epoch : 7
	loss_value_discriminator : 35.38349077735789, acc mean : 63.76439383284413


  6%|▌         | 9/150 [02:38<41:19, 17.58s/it]

epoch : 8
	loss_value_discriminator : 11.285574621440366, acc mean : 57.75013027492389


  7%|▋         | 10/150 [02:55<41:01, 17.58s/it]

epoch : 9
	loss_value_discriminator : 4.274460797377208, acc mean : 55.87541204898138


  7%|▋         | 11/150 [03:13<40:43, 17.58s/it]

epoch : 10
	loss_value_discriminator : 60.51805138148652, acc mean : 47.75998211505892


  8%|▊         | 12/150 [03:31<40:25, 17.58s/it]

epoch : 11
	loss_value_discriminator : 21.556372406126627, acc mean : 57.00071987191002


  9%|▊         | 13/150 [03:48<40:08, 17.58s/it]

epoch : 12
	loss_value_discriminator : 11.294437943889545, acc mean : 62.815189044551474


  9%|▉         | 14/150 [04:06<39:50, 17.57s/it]

epoch : 13
	loss_value_discriminator : 5.857189278399013, acc mean : 50.71846436691359


 10%|█         | 15/150 [04:23<39:33, 17.58s/it]

epoch : 14
	loss_value_discriminator : 50.70414279842327, acc mean : 66.43569719130583


 11%|█         | 16/150 [04:41<39:15, 17.58s/it]

epoch : 15
	loss_value_discriminator : 4.684664128100013, acc mean : 63.89969482463467


 11%|█▏        | 17/150 [04:58<38:57, 17.57s/it]

epoch : 16
	loss_value_discriminator : 1.8598990230873051, acc mean : 66.76623784818416


 12%|█▏        | 18/150 [05:16<38:39, 17.57s/it]

epoch : 17
	loss_value_discriminator : 4.476341111206136, acc mean : 59.6504252936958


 13%|█▎        | 19/150 [05:34<38:22, 17.58s/it]

epoch : 18
	loss_value_discriminator : 83.39936551864888, acc mean : 64.59812205568053


 13%|█▎        | 20/150 [05:51<38:05, 17.58s/it]

epoch : 19
	loss_value_discriminator : 1.5533809989468885, acc mean : 67.90289342813604


 14%|█▍        | 21/150 [06:09<37:49, 17.59s/it]

epoch : 20
	loss_value_discriminator : 3.3851868624638675, acc mean : 70.87761254808673


 15%|█▍        | 22/150 [06:26<37:31, 17.59s/it]

epoch : 21
	loss_value_discriminator : 10.376336887289682, acc mean : 68.99188729370593


 15%|█▌        | 23/150 [06:44<37:14, 17.59s/it]

epoch : 22
	loss_value_discriminator : 1.4888877659477049, acc mean : 64.14983740339777


 16%|█▌        | 24/150 [07:02<36:55, 17.59s/it]

epoch : 23
	loss_value_discriminator : 2.228878461737235, acc mean : 66.53437409923444


 17%|█▋        | 25/150 [07:19<36:38, 17.59s/it]

epoch : 24
	loss_value_discriminator : 3.2129406413599995, acc mean : 65.19485615207572


 17%|█▋        | 26/150 [07:37<36:20, 17.58s/it]

epoch : 25
	loss_value_discriminator : 1.724350424284202, acc mean : 69.96657270945308


 18%|█▊        | 27/150 [07:54<36:02, 17.58s/it]

epoch : 26
	loss_value_discriminator : 25.063867902712676, acc mean : 68.31851728930957


 19%|█▊        | 28/150 [08:12<35:44, 17.58s/it]

epoch : 27
	loss_value_discriminator : 12.619341033329418, acc mean : 67.25742637173121


 19%|█▉        | 29/150 [08:29<35:26, 17.57s/it]

epoch : 28
	loss_value_discriminator : 2.0743610638450827, acc mean : 67.54373149265386


 20%|██        | 30/150 [08:47<35:09, 17.58s/it]

epoch : 29
	loss_value_discriminator : 2.039872380444901, acc mean : 67.09114426791521


 21%|██        | 31/150 [09:05<34:52, 17.58s/it]

epoch : 30
	loss_value_discriminator : 1.186914848265965, acc mean : 70.13712815161928


 21%|██▏       | 32/150 [09:22<34:34, 17.58s/it]

epoch : 31
	loss_value_discriminator : 1.5643174470053482, acc mean : 67.03079818976131


 22%|██▏       | 33/150 [09:40<34:16, 17.58s/it]

epoch : 32
	loss_value_discriminator : 0.6188778926019626, acc mean : 65.68100506633299


 23%|██▎       | 34/150 [09:57<33:59, 17.58s/it]

epoch : 33
	loss_value_discriminator : 1.27387720672273, acc mean : 70.22039259365977


 23%|██▎       | 35/150 [10:15<33:41, 17.58s/it]

epoch : 34
	loss_value_discriminator : 0.41058841853262856, acc mean : 72.3274597994621


 24%|██▍       | 36/150 [10:32<33:23, 17.57s/it]

epoch : 35
	loss_value_discriminator : 11.055376091096702, acc mean : 67.78353829785334


 25%|██▍       | 37/150 [10:50<33:05, 17.57s/it]

epoch : 36
	loss_value_discriminator : 1.833852757658081, acc mean : 77.10376704676234


 25%|██▌       | 38/150 [11:08<32:47, 17.57s/it]

epoch : 37
	loss_value_discriminator : 0.9403700211771084, acc mean : 67.21693405314582


 26%|██▌       | 39/150 [11:25<32:30, 17.57s/it]

epoch : 38
	loss_value_discriminator : 2.4632599814293124, acc mean : 63.17155916422004


 27%|██▋       | 40/150 [11:43<32:13, 17.57s/it]

epoch : 39
	loss_value_discriminator : 0.7580538201380703, acc mean : 62.934933545249635


 27%|██▋       | 41/150 [12:00<31:56, 17.59s/it]

epoch : 40
	loss_value_discriminator : 0.5040524015786487, acc mean : 70.4776507131845


 28%|██▊       | 42/150 [12:18<31:38, 17.58s/it]

epoch : 41
	loss_value_discriminator : 3353.1975807607423, acc mean : 68.1491368245209


 29%|██▊       | 43/150 [12:36<31:21, 17.59s/it]

epoch : 42
	loss_value_discriminator : 0.6223215284713143, acc mean : 67.41897180015522


 29%|██▉       | 44/150 [12:53<31:03, 17.58s/it]

epoch : 43
	loss_value_discriminator : 0.11675001859400852, acc mean : 67.03883416070836


 30%|███       | 45/150 [13:11<30:46, 17.58s/it]

epoch : 44
	loss_value_discriminator : 0.026908693875109864, acc mean : 65.62419212969607


 31%|███       | 46/150 [13:28<30:28, 17.58s/it]

epoch : 45
	loss_value_discriminator : 0.032899886714408466, acc mean : 64.10985938372228


 31%|███▏      | 47/150 [13:46<30:11, 17.59s/it]

epoch : 46
	loss_value_discriminator : 0.03686013643041517, acc mean : 62.131680278317226


 32%|███▏      | 48/150 [14:03<29:53, 17.58s/it]

epoch : 47
	loss_value_discriminator : 0.8178046400140337, acc mean : 57.76138089839138


 33%|███▎      | 49/150 [14:21<29:35, 17.58s/it]

epoch : 48
	loss_value_discriminator : 29.449246796584408, acc mean : 56.62376568335657


 33%|███▎      | 50/150 [14:39<29:17, 17.58s/it]

epoch : 49
	loss_value_discriminator : 5.9457812711881095, acc mean : 57.80144085196889


 34%|███▍      | 51/150 [14:56<29:00, 17.58s/it]

epoch : 50
	loss_value_discriminator : 1.5549455116290416, acc mean : 60.10418430457342


 35%|███▍      | 52/150 [15:14<28:43, 17.58s/it]

epoch : 51
	loss_value_discriminator : 0.05252918687789176, acc mean : 60.084120635479614


 35%|███▌      | 53/150 [15:31<28:24, 17.58s/it]

epoch : 52
	loss_value_discriminator : 1.939446044653545, acc mean : 60.35305364097102


 36%|███▌      | 54/150 [15:49<28:07, 17.58s/it]

epoch : 53
	loss_value_discriminator : 0.054309672132529706, acc mean : 61.49184181411835


 37%|███▋      | 55/150 [16:06<27:50, 17.58s/it]

epoch : 54
	loss_value_discriminator : 0.4310927072541279, acc mean : 61.68476798896202


 37%|███▋      | 56/150 [16:24<27:33, 17.59s/it]

epoch : 55
	loss_value_discriminator : 0.19327620265828446, acc mean : 61.79752954803316


 38%|███▊      | 57/150 [16:42<27:16, 17.60s/it]

epoch : 56
	loss_value_discriminator : 0.07017630412545174, acc mean : 62.14921328734661


 39%|███▊      | 58/150 [16:59<26:59, 17.60s/it]

epoch : 57
	loss_value_discriminator : 0.04346079291282346, acc mean : 62.91653356347738


 39%|███▉      | 59/150 [17:17<26:40, 17.59s/it]

epoch : 58
	loss_value_discriminator : 0.06396535115247791, acc mean : 63.426106642957166


 40%|████      | 60/150 [17:34<26:23, 17.59s/it]

epoch : 59
	loss_value_discriminator : 0.21116501792021036, acc mean : 64.81862385332632


 41%|████      | 61/150 [17:52<26:04, 17.58s/it]

epoch : 60
	loss_value_discriminator : 0.022072237409950735, acc mean : 65.92807400367661


 41%|████▏     | 62/150 [18:10<25:47, 17.59s/it]

epoch : 61
	loss_value_discriminator : 0.009219238507612741, acc mean : 67.48575976179859


 42%|████▏     | 63/150 [18:27<25:29, 17.58s/it]

epoch : 62
	loss_value_discriminator : 0.014047100258224863, acc mean : 68.6834829919157


 43%|████▎     | 64/150 [18:45<25:12, 17.59s/it]

epoch : 63
	loss_value_discriminator : 9.608138789794518, acc mean : 70.43086408749643


 43%|████▎     | 65/150 [19:02<24:54, 17.58s/it]

epoch : 64
	loss_value_discriminator : 0.6381921102050868, acc mean : 70.56532787652442


 44%|████▍     | 66/150 [19:20<24:37, 17.58s/it]

epoch : 65
	loss_value_discriminator : 0.5984185165379617, acc mean : 71.51121156131836


 45%|████▍     | 67/150 [19:38<24:18, 17.58s/it]

epoch : 66
	loss_value_discriminator : 0.005554994944325931, acc mean : 76.74061779125825


 45%|████▌     | 68/150 [19:55<24:01, 17.58s/it]

epoch : 67
	loss_value_discriminator : 5.400076185172936, acc mean : 78.1133375141892


 46%|████▌     | 69/150 [20:13<23:43, 17.57s/it]

epoch : 68
	loss_value_discriminator : 0.2837892278340553, acc mean : 67.20041570085026


 47%|████▋     | 70/150 [20:30<23:25, 17.56s/it]

epoch : 69
	loss_value_discriminator : 0.2863372043852969, acc mean : 61.28780648504036


 47%|████▋     | 71/150 [20:48<23:07, 17.56s/it]

epoch : 70
	loss_value_discriminator : 0.04290531105359975, acc mean : 68.87431435989666


 48%|████▊     | 72/150 [21:05<22:49, 17.56s/it]

epoch : 71
	loss_value_discriminator : 0.14936267625341507, acc mean : 79.7804967215199


 49%|████▊     | 73/150 [21:23<22:31, 17.55s/it]

epoch : 72
	loss_value_discriminator : 0.20204789297383477, acc mean : 81.330836772201


 49%|████▉     | 74/150 [21:40<22:13, 17.55s/it]

epoch : 73
	loss_value_discriminator : 0.056835300693809016, acc mean : 78.27643749173818


 50%|█████     | 75/150 [21:58<21:57, 17.56s/it]

epoch : 74
	loss_value_discriminator : 0.2920170573330684, acc mean : 84.81371385589664


 51%|█████     | 76/150 [22:16<21:39, 17.56s/it]

epoch : 75
	loss_value_discriminator : 0.069156943645942, acc mean : 79.53575590155108


 51%|█████▏    | 77/150 [22:33<21:21, 17.55s/it]

epoch : 76
	loss_value_discriminator : 0.05964259977209889, acc mean : 78.07734965341452


 52%|█████▏    | 78/150 [22:51<21:04, 17.56s/it]

epoch : 77
	loss_value_discriminator : 0.0875275296701909, acc mean : 76.35812019964568


 53%|█████▎    | 79/150 [23:08<20:46, 17.56s/it]

epoch : 78
	loss_value_discriminator : 0.043255911120330746, acc mean : 81.94638187741216


 53%|█████▎    | 80/150 [23:26<20:28, 17.55s/it]

epoch : 79
	loss_value_discriminator : 1.552993739674326, acc mean : 80.50156451508695


 54%|█████▍    | 81/150 [23:43<20:11, 17.56s/it]

epoch : 80
	loss_value_discriminator : 1.105815460335059, acc mean : 86.28956630083185


 55%|█████▍    | 82/150 [24:01<19:53, 17.55s/it]

epoch : 81
	loss_value_discriminator : 0.13079811486284937, acc mean : 73.79721248869885


 55%|█████▌    | 83/150 [24:18<19:35, 17.55s/it]

epoch : 82
	loss_value_discriminator : 0.12381213627025339, acc mean : 81.03521402410676


 56%|█████▌    | 84/150 [24:36<19:18, 17.55s/it]

epoch : 83
	loss_value_discriminator : 0.2025290408339568, acc mean : 74.69755703232164


 57%|█████▋    | 85/150 [24:54<19:00, 17.55s/it]

epoch : 84
	loss_value_discriminator : 0.021102888322825177, acc mean : 73.45311572401916


 57%|█████▋    | 86/150 [25:11<18:42, 17.55s/it]

epoch : 85
	loss_value_discriminator : 0.007191603424617761, acc mean : 76.8309394899925


 58%|█████▊    | 87/150 [25:29<18:25, 17.55s/it]

epoch : 86
	loss_value_discriminator : 0.019471527344268356, acc mean : 79.33041857581993


 59%|█████▊    | 88/150 [25:46<18:07, 17.55s/it]

epoch : 87
	loss_value_discriminator : 0.012030047241518502, acc mean : 82.01183752230965


 59%|█████▉    | 89/150 [26:04<17:50, 17.55s/it]

epoch : 88
	loss_value_discriminator : 0.008767409314777767, acc mean : 81.30957861673079


---

# functions for visualizing the results 

---

## plot curve

In [None]:
def plot_image_grid(data, nRow, nCol, filename=None):

    size_col = 1.5
    size_row = 1.5

    fig, axes = plt.subplots(nRow, nCol, constrained_layout=True, figsize=(nCol * size_col, nRow * size_row))
    
    data = data.detach().cpu()

    for i in range(nRow):
        for j in range(nCol):

            k       = i * nCol + j
            image   = np.squeeze(data[k], axis=0)

            axes[i, j].imshow(image, cmap='gray', vmin=0, vmax=1)
            axes[i, j].xaxis.set_visible(False)
            axes[i, j].yaxis.set_visible(False)

    plt.show()

    if filename is not None:

        fig.savefig(filename)
        pass

In [None]:
def plot_data_grid(data, index_data, nRow, nCol):
    
    size_col = 1.5
    size_row = 1.5

    fig, axes = plt.subplots(nRow, nCol, constrained_layout=True, figsize=(nCol * size_col, nRow * size_row))

    for i in range(nRow):
        for j in range(nCol):

            k       = i * nCol + j
            index   = index_data[k]

            axes[i, j].imshow(data[index], cmap='gray', vmin=0, vmax=1)
            axes[i, j].xaxis.set_visible(False)
            axes[i, j].yaxis.set_visible(False)

    plt.show()

In [None]:
def plot_data_tensor_grid(data, index_data, nRow, nCol):
    
    size_col = 1.5
    size_row = 1.5

    fig, axes = plt.subplots(nRow, nCol, constrained_layout=True, figsize=(nCol * size_col, nRow * size_row))

    data = data.detach().cpu().squeeze(axis=1)

    for i in range(nRow):
        for j in range(nCol):

            k       = i * nCol + j
            index   = index_data[k]

            axes[i, j].imshow(data[index], cmap='gray', vmin=0, vmax=1)
            axes[i, j].xaxis.set_visible(False)
            axes[i, j].yaxis.set_visible(False)

    plt.show()

In [None]:
def plot_curve_error(data_mean, data_std, x_label, y_label, title, filename=None):

    fig = plt.figure(figsize=(8, 6))
    plt.title(title)

    alpha = 0.3
    
    plt.plot(range(len(data_mean)), data_mean, '-', color = 'red')
    plt.fill_between(range(len(data_mean)), data_mean - data_std, data_mean + data_std, facecolor = 'blue', alpha = alpha) 
    
    plt.xlabel(x_label)
    plt.ylabel(y_label)

    plt.tight_layout()
    plt.show()

    if filename is not None:

        fig.savefig(filename)
        pass

In [None]:
def print_curve(data, index):
    
    for i in range(len(index)):

        idx = index[i]
        val = data[idx]

        print('index = %2d, value = %12.10f' % (idx, val))

In [None]:
def get_data_last(data, index_start):

    data_last = data[index_start:]

    return data_last

In [None]:
def get_max_last_range(data, index_start):

    data_range = get_data_last(data, index_start)
    value = data_range.max()

    return value

In [None]:
def get_min_last_range(data, index_start):

    data_range = get_data_last(data, index_start)
    value = data_range.min()

    return value

---

# functions for presenting the results

---

In [None]:
def function_result_01():

    print('[plot examples of the real images]')
    print('') 

    nRow = 8
    nCol = 6

    number_data = len(dataset_real)
    step        = int(np.floor(number_data / (nRow * nCol)))
    index_data  = np.arange(0, number_data, step)
    index_plot  = np.arange(0, nRow * nCol)

    data = dataset_real[index_data]
    data = data[0]
    
    plot_data_grid(data, index_plot, nRow, nCol)

In [None]:
def function_result_02():

    print('[plot examples of the fake images]')
    print('') 

    nRow = 8
    nCol = 6
    number_latent = nRow * nCol

    latent  = torch.randn(number_latent, dim_latent, device=device)
    latent  = torch.reshape(latent, [number_latent, dim_latent, 1, 1])

    generator.eval()

    data_fake   = generator(latent)
    filename    = 'fake_image.png'

    plot_image_grid(data_fake, nRow, nCol, filename)

In [None]:
def function_result_03():

    print('[plot the generator loss]')
    print('') 

    plot_curve_error(loss_generator_mean, loss_generator_std, 'epoch', 'loss', 'generator loss', 'loss_generator.png')

In [None]:
def function_result_04():
    
    print('[plot the discriminator loss]')
    print('') 
    
    plot_curve_error(loss_discriminator_mean, loss_discriminator_std, 'epoch', 'loss', 'discriminator loss', 'loss_discriminator.png')

In [None]:
def function_result_05():
    
    print('[plot the accuracy]')
    print('') 
    
    plot_curve_error(accuracy_mean, accuracy_std, 'epoch', 'accuracy', 'training accuracy', 'training_accuracy.png')

In [None]:
def function_result_06():
    
    print('[print the generator loss at the last 10 epochs]')
    print('') 

    data_last = get_data_last(loss_generator_mean, -10)
    index = np.arange(0, 10)
    print_curve(data_last, index)

In [None]:
def function_result_07():
    
    print('[print the discriminator loss at the last 10 epochs]')
    print('') 

    data_last = get_data_last(loss_discriminator_mean, -10)
    index = np.arange(0, 10)
    print_curve(data_last, index)

In [None]:
def function_result_08():
    
    print('[print the accuracy at the last 10 epochs]')
    print('') 

    data_last = get_data_last(accuracy_mean, -10)
    index = np.arange(0, 10)
    print_curve(data_last, index)

In [None]:
def function_result_09():
    
    print('[print the best accuracy within the last 10 epochs]')
    print('') 
    
    value = get_max_last_range(accuracy_mean, -10)
    print('best accuracy = %12.10f' % (value))

---

# RESULTS

---

In [None]:
number_result = 9

for i in range(number_result):

    title           = '# RESULT # {:02d}'.format(i+1) 
    name_function   = 'function_result_{:02d}()'.format(i+1)

    print('') 
    print('################################################################################')
    print('#') 
    print(title)
    print('#') 
    print('################################################################################')
    print('') 

    eval(name_function)