# Image Generation via Generative Adversarial Networks

## import libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import math
from tqdm import tqdm
import random
import os
from torchvision.utils import make_grid

## load data

In [2]:
from google.colab import drive 
drive.mount('/content/drive/')

directory_data  = './drive/MyDrive/Machine_Learning/'
filename_data   = 'assignment_12_data.npz'
data            = np.load(os.path.join(directory_data, filename_data))

real            = torch.from_numpy(data['real_images']).float()

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


## hyper-parameters

In [3]:
device          = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

number_epoch    = 150
size_minibatch  = 50
dim_latent      = 50
dim_channel     = 1
learning_rate_discriminator = 0.001
learning_rate_generator     = 0.001

In [4]:
import random
def affine(image, shear=0, scale=1, rate=[10, 10]):

    func_plt = transforms.functional.to_pil_image
    func_affine = transforms.functional.affine
    func_tensor = transforms.functional.to_tensor

    for i in range(len(image)):

        # random movement
        if rate[0] != 0:
            rate_1 = np.random.randint(-rate[0], rate[0]+1)
            rate_2 = np.random.randint(-rate[1], rate[1]+1)
            movement = [rate_1, rate_2]
        else:
            movement = rate

        if isinstance(scale, list):
            rescale = np.random.randint(scale[0], scale[1]+1) / 10
        else:
            rescale = scale


        trans_image = func_plt(image[i])
        trans_image = func_affine(trans_image, angle=0, shear=shear, scale=rescale, translate=movement)
        trans_image = func_tensor(trans_image)
        trans_image = trans_image.numpy()

        if i == 0:
            image_list = trans_image
        else:
            image_list = np.concatenate([image_list, trans_image], axis=0)

    return image_list

In [5]:
real_image = real[::2]
affine_12 = affine(real_image, scale=[10, 12], rate=[0, 0])
affine_random = affine(real_image, scale=1, rate=[3, 3])

## custom data loader for the PyTorch framework

In [6]:
class dataset (Dataset):
    def  __init__(self, data):

        self.data = data

    def __getitem__(self, index):

        data = self.data[index]
        data = torch.FloatTensor(data).unsqueeze(dim=0)

        return data
  
    def __len__(self):
        
        return self.data.shape[0]

## construct datasets and dataloaders for training and testing

In [7]:
# image_train = np.concatenate([real[1::2], affine_12, affine_random], axis=0)
dataset_real    = dataset(real)
dataloader_real = DataLoader(dataset_real, batch_size=size_minibatch, shuffle=True, drop_last=True)

In [8]:
# image_train.shape

## shape of the data when using the data loader

In [9]:
image_real = dataset_real[0]
print('*******************************************************************')
print('shape of the image in the training dataset:', image_real.shape)
print('*******************************************************************')

*******************************************************************
shape of the image in the training dataset: torch.Size([1, 32, 32])
*******************************************************************


## class for the neural network 

In [10]:
class Discriminator(nn.Module): 

	def __init__(self, in_channel=1, out_channel=1, dim_feature=128):
        
		super(Discriminator, self).__init__()

		self.in_channel 	= in_channel
		self.out_channel	= out_channel
		self.dim_feature	= dim_feature
		threshold_ReLU 		= 0.2
		
		self.feature = nn.Sequential(
			# ================================================================================
			nn.Conv2d(in_channel, dim_feature * 1, kernel_size=3, stride=2, padding=1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
			nn.Conv2d(dim_feature * 1, dim_feature * 2, kernel_size=3, stride=2, padding=1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
			nn.Conv2d(dim_feature * 2, dim_feature * 4, kernel_size=3, stride=2, padding=1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
			nn.Conv2d(dim_feature * 4, dim_feature * 8, kernel_size=3, stride=2, padding=1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
			nn.Conv2d(dim_feature * 8, dim_feature * 16, kernel_size=3, stride=2, padding=1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
		)	
		
		self.classifier = nn.Sequential(
			# ================================================================================
			nn.Linear(dim_feature * 16, dim_feature * 8, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
			nn.Linear(dim_feature * 8, dim_feature * 4, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
			nn.Linear(dim_feature * 4, dim_feature * 2, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
			nn.Linear(dim_feature * 2, dim_feature * 1, bias=True),
			nn.LeakyReLU(threshold_ReLU, inplace=True),			
			# ================================================================================
			nn.Linear(dim_feature * 1, out_channel, bias=True),
			# ================================================================================
		) 

		self.network = nn.Sequential(
			self.feature,
			nn.Flatten(),
			self.classifier,
		)

		self.initialize_weight()

		# *********************************************************************
		# forward propagation
		# *********************************************************************
	def forward(self, x):

		y = self.network.forward(x)

		return y

	def initialize_weight(self):
	
		print('initialize model parameters :', 'xavier_uniform')

		for m in self.network.modules():
			
			if isinstance(m, nn.Conv2d):
				
				nn.init.constant_(m.weight, 1)
				
				if m.bias is not None:

					nn.init.constant_(m.bias, 1)
					pass
					
			elif isinstance(m, nn.BatchNorm2d):
				
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 1)
				
			elif isinstance(m, nn.Linear):
				
				nn.init.constant_(m.weight, 1)

				if m.bias is not None:
					
					nn.init.constant_(m.bias, 1)
					pass

In [11]:
class Generator(nn.Module): 

	def __init__(self, in_channel=1, out_channel=1, dim_feature=8):
        
		super(Generator, self).__init__()

		self.in_channel 	= in_channel
		self.out_channel	= out_channel
		self.dim_feature	= dim_feature
		threshold_ReLU 		= 0.2

		self.network = nn.Sequential(
			# ================================================================================
			nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
			nn.Conv2d(in_channel, dim_feature * 8, kernel_size=3, stride=1, padding=1, bias=False),
			nn.BatchNorm2d(dim_feature * 8),
			nn.LeakyReLU(threshold_ReLU, inplace=True),
			# ================================================================================
			nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
			nn.Conv2d(dim_feature * 8, dim_feature * 4, kernel_size=3, stride=1, padding=1, bias=False),
			nn.BatchNorm2d(dim_feature * 4),
			nn.LeakyReLU(threshold_ReLU, inplace=True),
			# ================================================================================
			nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
			nn.Conv2d(dim_feature * 4, dim_feature * 2, kernel_size=3, stride=1, padding=1, bias=False),
			nn.BatchNorm2d(dim_feature * 2),
			nn.LeakyReLU(threshold_ReLU, inplace=True),
			# ================================================================================
			nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
			nn.Conv2d(dim_feature * 2, dim_feature * 1, kernel_size=3, stride=1, padding=1, bias=False),
			nn.BatchNorm2d(dim_feature * 1),
			nn.LeakyReLU(threshold_ReLU, inplace=True),
			# ================================================================================
			nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
			nn.Conv2d(dim_feature * 1, out_channel, kernel_size=3, stride=1, padding=1, bias=False),
			nn.BatchNorm2d(out_channel),
			# ================================================================================
			nn.Sigmoid(),
			# ================================================================================
		) 			

		self.initialize_weight()
		
		# *********************************************************************
		# forward propagation
		# *********************************************************************
	def forward(self, x):

		y = self.network.forward(x)

		return y

	def initialize_weight(self):
	
		print('initialize model parameters :', 'xavier_uniform')

		for m in self.network.modules():
			
			if isinstance(m, nn.Conv2d):
				
				nn.init.xavier_uniform_(m.weight)
				
				if m.bias is not None:

					nn.init.constant_(m.bias, 1)
					pass
					
			elif isinstance(m, nn.BatchNorm2d):
				
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 1)
				
			elif isinstance(m, nn.Linear):
				
				nn.init.xavier_uniform_(m.weight)

				if m.bias is not None:
					
					nn.init.constant_(m.bias, 1)
					pass


## build network

In [12]:
generator       = Generator(dim_latent, 1, 128).to(device)
discriminator   = Discriminator(dim_channel, 1, 128).to(device)

optimizer_generator     = torch.optim.Adam(generator.parameters(), lr=learning_rate_generator, betas=(0.5, 0.999))
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=learning_rate_discriminator, betas=(0.5, 0.999))

initialize model parameters : xavier_uniform
initialize model parameters : xavier_uniform


## compute the prediction

In [13]:
def compute_prediction(model, input):

    prediction = model(input)

    return prediction

## compute the loss

In [14]:
def compute_loss_discriminator(generator, discriminator, latent, data_real):

    data_fake       = compute_prediction(generator, latent)
    prediction_real = compute_prediction(discriminator, data_real)
    prediction_fake = compute_prediction(discriminator, data_fake)

    criterion   = nn.BCEWithLogitsLoss()
    
    label_real  = torch.ones_like(prediction_real)
    label_fake  = torch.zeros_like(prediction_fake)

    # ==================================================
    # fill up the blank
    #    
    loss_real = criterion(prediction_real, label_real)
    loss_fake = criterion(prediction_fake, label_fake)
    # 
    # ==================================================

    loss_discriminator = (loss_real + loss_fake) / 2.0

    return loss_discriminator

In [15]:
def compute_loss_generator(generator, discriminator, latent):

    data_fake       = compute_prediction(generator, latent)
    prediction_fake = compute_prediction(discriminator, data_fake)

    criterion       = nn.BCEWithLogitsLoss()

    label_real      = torch.ones_like(prediction_fake)

    # ==================================================
    # fill up the blank
    #    
    loss_generator  = criterion(prediction_fake, label_real)
    # 
    # ==================================================

    return loss_generator

## compute the accuracy

In [16]:
def get_center_index(binary_image):
    
    area_square = np.sum(binary_image)

    height = binary_image.shape[0]
    width = binary_image.shape[1]

    x = np.linspace(0, width - 1, width)
    y = np.linspace(0, height - 1, height)
    indices_X, indices_Y = np.meshgrid(x, y)

    x_mean = np.sum(binary_image * indices_X) / area_square
    y_mean = np.sum(binary_image * indices_Y) / area_square

    return (x_mean, y_mean)

In [17]:
# create ideal square image which has the same area to the input image
def create_label(binary_images):
    
    label = np.zeros_like(binary_images)
    
    for i, binary_image in enumerate(binary_images):
        
        image_height = binary_image.shape[0]
        image_width = binary_image.shape[1]

        square_image = np.zeros((image_height, image_width))
        square_length = np.round(np.sqrt(np.sum(binary_image)))

        if square_length == 0:
            # when there is no square
            return square_image

        (square_center_x, square_center_y) = get_center_index(binary_image)

        if square_center_x < 0 or square_center_x > image_width - 1 or square_center_y < 0 or square_center_y > image_height - 1:
            return square_image

        top = np.ceil(square_center_y - square_length / 2)
        bottom = np.floor(square_center_y + square_length / 2)
        left = np.ceil(square_center_x - square_length / 2)
        right = np.floor(square_center_x + square_length / 2)

        top = int(top) if top >= 0 else 0
        bottom = int(bottom) if bottom <= image_height - 1 else image_height - 1
        left = int(left) if left >= 0 else 0
        right = int(right) if right <= image_width - 1 else image_width - 1

        square_image[top : bottom + 1, left : right + 1] = 1
        
        label[i] = square_image
        
    return label

In [18]:
def compute_accuracy(prediction):

    prediction  = prediction.squeeze(axis=1)
    
    prediction_binary   = (prediction >= 0.5).cpu().numpy().astype(int)
    label               = create_label(prediction_binary).astype(int)
    
    region_intersection = prediction_binary & label
    region_union        = prediction_binary | label

    area_intersection   = region_intersection.sum(axis=1).sum(axis=1).astype(float)
    area_union          = region_union.sum(axis=1).sum(axis=1).astype(float)

    eps         = np.finfo(float).eps
    correct     = area_intersection / (area_union + eps)
    accuracy    = correct.mean() * 100.0
    
    return accuracy

## variables for the learning curve

In [19]:
loss_generator_mean     = np.zeros(number_epoch)
loss_generator_std      = np.zeros(number_epoch)
loss_discriminator_mean = np.zeros(number_epoch)
loss_discriminator_std  = np.zeros(number_epoch)

accuracy_mean   = np.zeros(number_epoch)
accuracy_std    = np.zeros(number_epoch)

## train

In [20]:
def train(generator, discriminator, dataloader):

    loss_epoch_generator      = []
    loss_epoch_discriminator  = []
    accuracy_epoch = []
    
    for index_batch, data_real in enumerate(dataloader):

        size_batch  = len(data_real)
        data_real   = data_real.to(device)
        
        latent  = torch.randn(size_batch, dim_latent, device=device)
        latent  = torch.reshape(latent, [size_batch, dim_latent, 1, 1])

        # ---------------------------------------------------------------------------
        #  
        # update the generator
        #  
        # ---------------------------------------------------------------------------
        generator.train()
        discriminator.eval()

        optimizer_generator.zero_grad()
        loss_generator = compute_loss_generator(generator, discriminator, latent)
        loss_generator.backward()
        optimizer_generator.step()

        # ---------------------------------------------------------------------------
        #  
        # update the discriminator
        #  
        # ---------------------------------------------------------------------------
        generator.eval()
        discriminator.train()

        optimizer_discriminator.zero_grad()
        loss_discriminator = compute_loss_discriminator(generator, discriminator, latent, data_real)
        loss_discriminator.backward()
        optimizer_discriminator.step()

        data_fake   = compute_prediction(generator, latent)
        accuracy    = compute_accuracy(data_fake)

        loss_epoch_generator.append(loss_generator.item())
        loss_epoch_discriminator.append(loss_discriminator.item())
        accuracy_epoch.append(accuracy)

    loss_generator_mean_epoch       = np.mean(loss_epoch_generator)
    loss_generator_std_epoch        = np.std(loss_epoch_generator)
    
    loss_discriminator_mean_epoch   = np.mean(loss_epoch_discriminator)
    loss_discriminator_std_epoch    = np.std(loss_epoch_discriminator)

    accuracy_mean_epoch             = np.mean(accuracy_epoch)
    accuracy_std_epoch              = np.std(accuracy_epoch)

    loss_value_generator        = {'mean' : loss_generator_mean_epoch, 'std' : loss_generator_std_epoch}
    loss_value_discriminator    = {'mean' : loss_discriminator_mean_epoch, 'std' : loss_discriminator_std_epoch}
    accuracy_value              = {'mean' : accuracy_mean_epoch, 'std' : accuracy_std_epoch} 

    return loss_value_generator, loss_value_discriminator, accuracy_value


## training epoch

In [None]:
# ================================================================================
# 
# iterations for epochs
#
# ================================================================================
for i in tqdm(range(number_epoch)):
    
    # ================================================================================
    # 
    # training
    #
    # ================================================================================
    (loss_value_generator, loss_value_discriminator, accuracy_value) = train(generator, discriminator, dataloader_real)

    loss_generator_mean[i]      = loss_value_generator['mean']
    loss_generator_std[i]       = loss_value_generator['std']

    loss_discriminator_mean[i]  = loss_value_discriminator['mean']
    loss_discriminator_std[i]   = loss_value_discriminator['std']

    accuracy_mean[i]            = accuracy_value['mean']
    accuracy_std[i]             = accuracy_value['std']

    print(f"epoch : {i}")
    print(f"\tloss_value_discriminator : {loss_value_discriminator['mean']}, acc mean : {accuracy_value['mean']}")

  1%|          | 1/150 [00:17<43:12, 17.40s/it]

epoch : 0
	loss_value_discriminator : 3.787472832741694e+27, acc mean : 79.92073339719686


  1%|▏         | 2/150 [00:34<42:53, 17.39s/it]

epoch : 1
	loss_value_discriminator : 3.8035652094314066e+27, acc mean : 78.47367714126277


  2%|▏         | 3/150 [00:52<42:33, 17.37s/it]

epoch : 2
	loss_value_discriminator : 3.793951161094714e+27, acc mean : 78.14854700893792


  3%|▎         | 4/150 [01:09<42:16, 17.38s/it]

epoch : 3
	loss_value_discriminator : 3.8037876548572454e+27, acc mean : 78.47878571245688


  3%|▎         | 5/150 [01:26<41:58, 17.37s/it]

epoch : 4
	loss_value_discriminator : 3.797497737258254e+27, acc mean : 78.32559169319208


  4%|▍         | 6/150 [01:44<41:41, 17.37s/it]

epoch : 5
	loss_value_discriminator : 3.7971870700643856e+27, acc mean : 78.12280326593363


  5%|▍         | 7/150 [02:01<41:22, 17.36s/it]

epoch : 6
	loss_value_discriminator : 3.802781605470965e+27, acc mean : 78.3079545107816


  5%|▌         | 8/150 [02:18<41:04, 17.36s/it]

epoch : 7
	loss_value_discriminator : 3.798495512207332e+27, acc mean : 78.12529338465868


  6%|▌         | 9/150 [02:36<40:48, 17.36s/it]

epoch : 8
	loss_value_discriminator : 3.7952502065520274e+27, acc mean : 78.03716553343259


  7%|▋         | 10/150 [02:53<40:30, 17.36s/it]

epoch : 9
	loss_value_discriminator : 3.8021780520285403e+27, acc mean : 78.45901970452051


  7%|▋         | 11/150 [03:11<40:12, 17.36s/it]

epoch : 10
	loss_value_discriminator : 3.796645048096286e+27, acc mean : 77.87473927385354


  8%|▊         | 12/150 [03:28<39:54, 17.35s/it]

epoch : 11
	loss_value_discriminator : 3.7995992315265595e+27, acc mean : 78.14675859646671


  9%|▊         | 13/150 [03:45<39:38, 17.36s/it]

epoch : 12
	loss_value_discriminator : 3.802845971737961e+27, acc mean : 78.36370879022097


  9%|▉         | 14/150 [04:03<39:20, 17.35s/it]

epoch : 13
	loss_value_discriminator : 3.799107429317721e+27, acc mean : 78.01411222880671


 10%|█         | 15/150 [04:20<39:02, 17.35s/it]

epoch : 14
	loss_value_discriminator : 3.7980098223056925e+27, acc mean : 78.10894702767304


 11%|█         | 16/150 [04:37<38:45, 17.36s/it]

epoch : 15
	loss_value_discriminator : 3.799341835097624e+27, acc mean : 78.37989938668267


 11%|█▏        | 17/150 [04:55<38:30, 17.38s/it]

epoch : 16
	loss_value_discriminator : 3.799735775184181e+27, acc mean : 78.2075001590864


 12%|█▏        | 18/150 [05:12<38:12, 17.37s/it]

epoch : 17
	loss_value_discriminator : 3.7955166461754993e+27, acc mean : 78.13009507667893


 13%|█▎        | 19/150 [05:30<37:58, 17.40s/it]

epoch : 18
	loss_value_discriminator : 3.7991058849391474e+27, acc mean : 78.12131120359561


 13%|█▎        | 20/150 [05:47<37:40, 17.39s/it]

epoch : 19
	loss_value_discriminator : 3.801358827834286e+27, acc mean : 78.32067218670976


 14%|█▍        | 21/150 [06:04<37:21, 17.37s/it]

epoch : 20
	loss_value_discriminator : 3.796165964702989e+27, acc mean : 78.1284974146758


 15%|█▍        | 22/150 [06:22<37:01, 17.36s/it]

epoch : 21
	loss_value_discriminator : 3.801717559521317e+27, acc mean : 78.1302627588273


 15%|█▌        | 23/150 [06:39<36:44, 17.36s/it]

epoch : 22
	loss_value_discriminator : 3.8016126207132165e+27, acc mean : 78.34027453854871


 16%|█▌        | 24/150 [06:56<36:27, 17.36s/it]

epoch : 23
	loss_value_discriminator : 3.799239326111812e+27, acc mean : 78.25048587588003


 17%|█▋        | 25/150 [07:14<36:07, 17.34s/it]

epoch : 24
	loss_value_discriminator : 3.7992857192441633e+27, acc mean : 78.16683665210479


 17%|█▋        | 26/150 [07:31<35:49, 17.34s/it]

epoch : 25
	loss_value_discriminator : 3.7949379435003e+27, acc mean : 77.72213882664161


 18%|█▊        | 27/150 [07:48<35:31, 17.33s/it]

epoch : 26
	loss_value_discriminator : 3.798214665247744e+27, acc mean : 78.06175989520128


 19%|█▊        | 28/150 [08:06<35:13, 17.33s/it]

epoch : 27
	loss_value_discriminator : 3.7997563119872576e+27, acc mean : 78.26486561961656


 19%|█▉        | 29/150 [08:23<34:56, 17.33s/it]

epoch : 28
	loss_value_discriminator : 3.8016557294671346e+27, acc mean : 78.23909634568427


 20%|██        | 30/150 [08:40<34:39, 17.33s/it]

epoch : 29
	loss_value_discriminator : 3.797225593729916e+27, acc mean : 78.29650239403443


 21%|██        | 31/150 [08:57<34:21, 17.32s/it]

epoch : 30
	loss_value_discriminator : 3.7993989462172763e+27, acc mean : 78.06288571167391


 21%|██▏       | 32/150 [09:15<34:02, 17.31s/it]

epoch : 31
	loss_value_discriminator : 3.799015470153543e+27, acc mean : 78.4214882652731


 22%|██▏       | 33/150 [09:32<33:45, 17.31s/it]

epoch : 32
	loss_value_discriminator : 3.8014197552849913e+27, acc mean : 78.27898707920836


 23%|██▎       | 34/150 [09:49<33:28, 17.32s/it]

epoch : 33
	loss_value_discriminator : 3.8013532028643257e+27, acc mean : 78.38033417126753


 23%|██▎       | 35/150 [10:07<33:11, 17.31s/it]

epoch : 34
	loss_value_discriminator : 3.8002510313556243e+27, acc mean : 78.11809037537824


 24%|██▍       | 36/150 [10:24<32:53, 17.32s/it]

epoch : 35
	loss_value_discriminator : 3.797984823964514e+27, acc mean : 78.37082304467297


 25%|██▍       | 37/150 [10:41<32:36, 17.32s/it]

epoch : 36
	loss_value_discriminator : 3.8045236095907625e+27, acc mean : 78.64069143283503


 25%|██▌       | 38/150 [10:59<32:19, 17.31s/it]

epoch : 37
	loss_value_discriminator : 3.795948138685062e+27, acc mean : 78.07137966324991


 26%|██▌       | 39/150 [11:16<32:01, 17.31s/it]

epoch : 38
	loss_value_discriminator : 3.7993466638546305e+27, acc mean : 78.59627849672037


 27%|██▋       | 40/150 [11:33<31:44, 17.31s/it]

epoch : 39
	loss_value_discriminator : 3.8021052465906275e+27, acc mean : 78.25139371028116


 27%|██▋       | 41/150 [11:51<31:27, 17.32s/it]

epoch : 40
	loss_value_discriminator : 3.8011771780264576e+27, acc mean : 78.29286115105914


 28%|██▊       | 42/150 [12:08<31:09, 17.31s/it]

epoch : 41
	loss_value_discriminator : 3.796972535288797e+27, acc mean : 78.36226347853771


 29%|██▊       | 43/150 [12:25<30:51, 17.31s/it]

epoch : 42
	loss_value_discriminator : 3.8021542514387444e+27, acc mean : 78.38821345472299


 29%|██▉       | 44/150 [12:43<30:34, 17.31s/it]

epoch : 43
	loss_value_discriminator : 3.8032412296942343e+27, acc mean : 78.60501874602392


 30%|███       | 45/150 [13:00<30:16, 17.30s/it]

epoch : 44
	loss_value_discriminator : 3.797978011539028e+27, acc mean : 78.44166162489304


 31%|███       | 46/150 [13:17<29:59, 17.30s/it]

epoch : 45
	loss_value_discriminator : 3.797427155725487e+27, acc mean : 78.04098902411246


 31%|███▏      | 47/150 [13:34<29:42, 17.30s/it]

epoch : 46
	loss_value_discriminator : 3.799447559822821e+27, acc mean : 78.2341110424681


 32%|███▏      | 48/150 [13:52<29:24, 17.30s/it]

epoch : 47
	loss_value_discriminator : 3.801929836072237e+27, acc mean : 78.31670825789337


 33%|███▎      | 49/150 [14:09<29:07, 17.30s/it]

epoch : 48
	loss_value_discriminator : 3.799567177091276e+27, acc mean : 78.45399578854145


 33%|███▎      | 50/150 [14:26<28:50, 17.31s/it]

epoch : 49
	loss_value_discriminator : 3.8007939799508684e+27, acc mean : 78.55137944434945


 34%|███▍      | 51/150 [14:44<28:33, 17.31s/it]

epoch : 50
	loss_value_discriminator : 3.7956033681803364e+27, acc mean : 78.1878268166173


 35%|███▍      | 52/150 [15:01<28:17, 17.32s/it]

epoch : 51
	loss_value_discriminator : 3.8031309439043166e+27, acc mean : 78.36833217465976


 35%|███▌      | 53/150 [15:18<28:00, 17.33s/it]

epoch : 52
	loss_value_discriminator : 3.8007230311991964e+27, acc mean : 78.383281317684


 36%|███▌      | 54/150 [15:36<27:42, 17.32s/it]

epoch : 53
	loss_value_discriminator : 3.800720954868003e+27, acc mean : 78.22365299963593


 37%|███▋      | 55/150 [15:53<27:25, 17.32s/it]

epoch : 54
	loss_value_discriminator : 3.800070301311035e+27, acc mean : 78.24881781399691


 37%|███▋      | 56/150 [16:10<27:08, 17.32s/it]

epoch : 55
	loss_value_discriminator : 3.795889562121741e+27, acc mean : 78.08033964469504


 38%|███▊      | 57/150 [16:28<26:50, 17.32s/it]

epoch : 56
	loss_value_discriminator : 3.8019528129934594e+27, acc mean : 78.32524997279167


 39%|███▊      | 58/150 [16:45<26:33, 17.32s/it]

epoch : 57
	loss_value_discriminator : 3.7983394407406346e+27, acc mean : 77.85147760443259


 39%|███▉      | 59/150 [17:02<26:16, 17.32s/it]

epoch : 58
	loss_value_discriminator : 3.7982968124600504e+27, acc mean : 78.31947817483797


 40%|████      | 60/150 [17:20<25:59, 17.32s/it]

epoch : 59
	loss_value_discriminator : 3.799323607998503e+27, acc mean : 78.32046447415296


 41%|████      | 61/150 [17:37<25:42, 17.33s/it]

epoch : 60
	loss_value_discriminator : 3.797314529344042e+27, acc mean : 78.14087540497425


 41%|████▏     | 62/150 [17:54<25:24, 17.32s/it]

epoch : 61
	loss_value_discriminator : 3.800891941599769e+27, acc mean : 78.45353706852649


 42%|████▏     | 63/150 [18:12<25:07, 17.33s/it]

epoch : 62
	loss_value_discriminator : 3.8017161215332675e+27, acc mean : 78.39660798006953


 43%|████▎     | 64/150 [18:29<24:50, 17.33s/it]

epoch : 63
	loss_value_discriminator : 3.79790004787668e+27, acc mean : 78.10313420779033


 43%|████▎     | 65/150 [18:46<24:32, 17.32s/it]

epoch : 64
	loss_value_discriminator : 3.7993544200670227e+27, acc mean : 78.35368444580325


 44%|████▍     | 66/150 [19:04<24:14, 17.32s/it]

epoch : 65
	loss_value_discriminator : 3.801544005689167e+27, acc mean : 78.174462464402


 45%|████▍     | 67/150 [19:21<23:57, 17.32s/it]

epoch : 66
	loss_value_discriminator : 3.797629819377774e+27, acc mean : 78.24511420007806


 45%|████▌     | 68/150 [19:38<23:40, 17.32s/it]

epoch : 67
	loss_value_discriminator : 3.7989449332361577e+27, acc mean : 78.10207221736135


 46%|████▌     | 69/150 [19:55<23:22, 17.32s/it]

epoch : 68
	loss_value_discriminator : 3.7960775747692926e+27, acc mean : 78.17389470221292


 47%|████▋     | 70/150 [20:13<23:05, 17.32s/it]

epoch : 69
	loss_value_discriminator : 3.8041713814535023e+27, acc mean : 78.6309329263675


 47%|████▋     | 71/150 [20:30<22:47, 17.31s/it]

epoch : 70
	loss_value_discriminator : 3.7987592920677046e+27, acc mean : 78.10271424361306


 48%|████▊     | 72/150 [20:47<22:30, 17.31s/it]

epoch : 71
	loss_value_discriminator : 3.796773705127325e+27, acc mean : 78.02850547315595


 49%|████▊     | 73/150 [21:05<22:12, 17.31s/it]

epoch : 72
	loss_value_discriminator : 3.8012199367212325e+27, acc mean : 78.09452952536601


 49%|████▉     | 74/150 [21:22<21:55, 17.31s/it]

epoch : 73
	loss_value_discriminator : 3.7996899140044493e+27, acc mean : 78.0841446205514


---

# functions for visualizing the results 

---

## plot curve

In [None]:
def plot_image_grid(data, nRow, nCol, filename=None):

    size_col = 1.5
    size_row = 1.5

    fig, axes = plt.subplots(nRow, nCol, constrained_layout=True, figsize=(nCol * size_col, nRow * size_row))
    
    data = data.detach().cpu()

    for i in range(nRow):
        for j in range(nCol):

            k       = i * nCol + j
            image   = np.squeeze(data[k], axis=0)

            axes[i, j].imshow(image, cmap='gray', vmin=0, vmax=1)
            axes[i, j].xaxis.set_visible(False)
            axes[i, j].yaxis.set_visible(False)

    plt.show()

    if filename is not None:

        fig.savefig(filename)
        pass

In [None]:
def plot_data_grid(data, index_data, nRow, nCol):
    
    size_col = 1.5
    size_row = 1.5

    fig, axes = plt.subplots(nRow, nCol, constrained_layout=True, figsize=(nCol * size_col, nRow * size_row))

    for i in range(nRow):
        for j in range(nCol):

            k       = i * nCol + j
            index   = index_data[k]

            axes[i, j].imshow(data[index], cmap='gray', vmin=0, vmax=1)
            axes[i, j].xaxis.set_visible(False)
            axes[i, j].yaxis.set_visible(False)

    plt.show()

In [None]:
def plot_data_tensor_grid(data, index_data, nRow, nCol):
    
    size_col = 1.5
    size_row = 1.5

    fig, axes = plt.subplots(nRow, nCol, constrained_layout=True, figsize=(nCol * size_col, nRow * size_row))

    data = data.detach().cpu().squeeze(axis=1)

    for i in range(nRow):
        for j in range(nCol):

            k       = i * nCol + j
            index   = index_data[k]

            axes[i, j].imshow(data[index], cmap='gray', vmin=0, vmax=1)
            axes[i, j].xaxis.set_visible(False)
            axes[i, j].yaxis.set_visible(False)

    plt.show()

In [None]:
def plot_curve_error(data_mean, data_std, x_label, y_label, title, filename=None):

    fig = plt.figure(figsize=(8, 6))
    plt.title(title)

    alpha = 0.3
    
    plt.plot(range(len(data_mean)), data_mean, '-', color = 'red')
    plt.fill_between(range(len(data_mean)), data_mean - data_std, data_mean + data_std, facecolor = 'blue', alpha = alpha) 
    
    plt.xlabel(x_label)
    plt.ylabel(y_label)

    plt.tight_layout()
    plt.show()

    if filename is not None:

        fig.savefig(filename)
        pass

In [None]:
def print_curve(data, index):
    
    for i in range(len(index)):

        idx = index[i]
        val = data[idx]

        print('index = %2d, value = %12.10f' % (idx, val))

In [None]:
def get_data_last(data, index_start):

    data_last = data[index_start:]

    return data_last

In [None]:
def get_max_last_range(data, index_start):

    data_range = get_data_last(data, index_start)
    value = data_range.max()

    return value

In [None]:
def get_min_last_range(data, index_start):

    data_range = get_data_last(data, index_start)
    value = data_range.min()

    return value

---

# functions for presenting the results

---

In [None]:
def function_result_01():

    print('[plot examples of the real images]')
    print('') 

    nRow = 8
    nCol = 6

    number_data = len(dataset_real)
    step        = int(np.floor(number_data / (nRow * nCol)))
    index_data  = np.arange(0, number_data, step)
    index_plot  = np.arange(0, nRow * nCol)

    data = dataset_real[index_data]
    data = data[0]
    
    plot_data_grid(data, index_plot, nRow, nCol)

In [None]:
def function_result_02():

    print('[plot examples of the fake images]')
    print('') 

    nRow = 8
    nCol = 6
    number_latent = nRow * nCol

    latent  = torch.randn(number_latent, dim_latent, device=device)
    latent  = torch.reshape(latent, [number_latent, dim_latent, 1, 1])

    generator.eval()

    data_fake   = generator(latent)
    filename    = 'fake_image.png'

    plot_image_grid(data_fake, nRow, nCol, filename)

In [None]:
def function_result_03():

    print('[plot the generator loss]')
    print('') 

    plot_curve_error(loss_generator_mean, loss_generator_std, 'epoch', 'loss', 'generator loss', 'loss_generator.png')

In [None]:
def function_result_04():
    
    print('[plot the discriminator loss]')
    print('') 
    
    plot_curve_error(loss_discriminator_mean, loss_discriminator_std, 'epoch', 'loss', 'discriminator loss', 'loss_discriminator.png')

In [None]:
def function_result_05():
    
    print('[plot the accuracy]')
    print('') 
    
    plot_curve_error(accuracy_mean, accuracy_std, 'epoch', 'accuracy', 'training accuracy', 'training_accuracy.png')

In [None]:
def function_result_06():
    
    print('[print the generator loss at the last 10 epochs]')
    print('') 

    data_last = get_data_last(loss_generator_mean, -10)
    index = np.arange(0, 10)
    print_curve(data_last, index)

In [None]:
def function_result_07():
    
    print('[print the discriminator loss at the last 10 epochs]')
    print('') 

    data_last = get_data_last(loss_discriminator_mean, -10)
    index = np.arange(0, 10)
    print_curve(data_last, index)

In [None]:
def function_result_08():
    
    print('[print the accuracy at the last 10 epochs]')
    print('') 

    data_last = get_data_last(accuracy_mean, -10)
    index = np.arange(0, 10)
    print_curve(data_last, index)

In [None]:
def function_result_09():
    
    print('[print the best accuracy within the last 10 epochs]')
    print('') 
    
    value = get_max_last_range(accuracy_mean, -10)
    print('best accuracy = %12.10f' % (value))

---

# RESULTS

---

In [None]:
number_result = 9

for i in range(number_result):

    title           = '# RESULT # {:02d}'.format(i+1) 
    name_function   = 'function_result_{:02d}()'.format(i+1)

    print('') 
    print('################################################################################')
    print('#') 
    print(title)
    print('#') 
    print('################################################################################')
    print('') 

    eval(name_function)