In [5]:
import os
import numpy as np
from PIL import Image

def is_jpeg(filename):
    return any(filename.endswith(extension) for extension in [".jpg", ".jpeg", ".png"])

def get_subdirs(directory):
    subdirs = sorted([os.path.join(directory, name) for name in sorted(os.listdir(directory)) if os.path.isdir(os.path.join(directory, name))])
    return subdirs

class ExternalInputIterator:
    def __init__(self, imageset_dir, batch_size, random_shuffle=False):
        self.imageset_dir = imageset_dir
        self.batch_size = batch_size

        # Get subdirectories (assuming "pose" and "frontal" folders exist)
        #self.pose_dirs = get_subdirs(os.path.join(imageset_dir, "pose"))
        self.pose_dirs = os.path.join(imageset_dir, "pose")
        self.frontal_dir = os.path.join(imageset_dir, "frontal")
        print(self.frontal_dir)
        print(self.pose_dirs)

        # Collect profile image paths
        self.profile_files = []
        #for pose_dir in self.pose_dirs:
        profile_files = [os.path.join(self.pose_dirs, file) for file in sorted(os.listdir(self.pose_dirs)) if is_jpeg(file)]
        self.profile_files.extend(profile_files)
        print(len(self.profile_files))

        # Collect frontal image paths
        self.frontal_files = [os.path.join(self.frontal_dir, file) for file in sorted(os.listdir(self.frontal_dir)) if is_jpeg(file)]
        print(len(self.frontal_files))

        # Shuffle if necessary
        if random_shuffle:
            np.random.shuffle(self.profile_files)
            np.random.shuffle(self.frontal_files)

        self.i = 0
        self.n = len(self.profile_files)

    def __iter__(self):
        return self

    def __next__(self):
        profiles = []
        frontals = []

        for _ in range(self.batch_size):
            profile_filename = self.profile_files[self.i]
            frontal_filename = self.match_frontal_image(profile_filename)

            with Image.open(profile_filename) as profile_img:
                profiles.append(np.array(profile_img))
            with Image.open(frontal_filename) as frontal_img:
                frontals.append(np.array(frontal_img))

            self.i = (self.i + 1) % self.n

        return (profiles, frontals)

    def match_frontal_image(self, profile_filename):
        profile_name = os.path.basename(profile_filename).split("_")[0]
        for frontal_file in self.frontal_files:
            if profile_name in frontal_file:
                return frontal_file
        return None

class ImagePipeline:
    def __init__(self, imageset_dir, image_size=128, random_shuffle=False, batch_size=64,device_id = 0):
        self.eii = ExternalInputIterator(imageset_dir, batch_size, random_shuffle)
        self.iterator = iter(self.eii)
        self.num_inputs = len(self.eii.profile_files)
        self.image_size = image_size

    def epoch_size(self, name=None):
        return self.num_inputs

    def __len__(self):
        return self.num_inputs

    def __iter__(self):
        return self

    def __next__(self):
        (images, targets) = next(self.iterator)

        # Perform resizing and normalization using NumPy
        resized_images = np.array([np.array(Image.fromarray(img).resize((self.image_size, self.image_size))) for img in images])
        resized_targets = np.array([np.array(Image.fromarray(target).resize((self.image_size, self.image_size))) for target in targets])

        # Calculate mean and standard deviation for each channel separately
        #mean = np.array([0.5, 0.5, 0.5])  # Assuming RGB images have pixel values in [0, 255] range
        #std = np.array([0.5, 0.5, 0.5])   # Assuming RGB images have pixel values in [0, 255] range
        
        # Normalize each channel independently
        #normalized_images = (resized_images / 255.0 - mean) / std
        #normalized_targets = (resized_targets / 255.0 - mean) / std


        # Normalize using mean and standard deviation
        normalized_images = (resized_images - 128.0) / 128.0
        normalized_targets = (resized_targets - 128.0) / 128.0

        return (normalized_images, normalized_targets)

    def __getitem__(self, index):
        # Advance the iterator to the desired index
        for _ in range(index):
            next(self.iterator)

        # Return the next batch
        return next(self)


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def weights_init(m):
    classname = m.__class__.__name__
    
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


In [3]:
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        avg_pool = self.avg_pool(x)
        max_pool = self.max_pool(x)
        avg_out = self.fc(avg_pool)
        max_out = self.fc(max_pool)
        out = avg_out + max_out
        return out * x

class G(nn.Module):
    def __init__(self):
        super(G, self).__init__()
        # Encoder
        self.encoder1 = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1),  # 64x64
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            ChannelAttention(64)  # Add channel attention
        )
        self.encoder2 = nn.Sequential(
            nn.Conv2d(64, 128, 4, 2, 1),         # 32x32
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            ChannelAttention(128)  # Add channel attention
        )
        self.encoder3 = nn.Sequential(
            nn.Conv2d(128, 256, 4, 2, 1),        # 16x16
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            ChannelAttention(256)  # Add channel attention
        )
        self.encoder4 = nn.Sequential(
            nn.Conv2d(256, 512, 4, 2, 1),        # 8x8
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            ChannelAttention(512)  # Add channel attention
        )
        self.encoder5 = nn.Sequential(
            nn.Conv2d(512, 512, 4, 2, 1),        # 4x4
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            ChannelAttention(512)  # Add channel attention
        )
        self.encoder6 = nn.Sequential(
            nn.Conv2d(512, 512, 4, 2, 1),        # 2x2
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            ChannelAttention(512)  # Add channel attention
        )

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 1024, 4, 2, 1),       # 1x1
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            nn.ConvTranspose2d(1024, 512, 4, 2, 1),  # 2x2
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )

        # Decoder
        self.decoder1 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 4, 2, 1),   # 4x4
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )
        self.decoder2 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 4, 2, 1),  # 8x8
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )
        self.decoder3 = nn.Sequential(
            nn.ConvTranspose2d(1024, 256, 4, 2, 1),   # 16x16
            nn.BatchNorm2d(256),
            nn.ReLU(True)
        )
        self.decoder4 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, 4, 2, 1),   # 32x32
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )
        self.decoder5 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, 4, 2, 1),    # 64x64
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )
        self.decoder6 = nn.Sequential(
            nn.ConvTranspose2d(128, 1, 4, 2, 1),  # 128x128
            nn.Tanh()
        )

    def forward(self, x):
        # Encoding
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(enc1)
        enc3 = self.encoder3(enc2)
        enc4 = self.encoder4(enc3)
        enc5 = self.encoder5(enc4)
        enc6 = self.encoder6(enc5)

        # Bottleneck
        bottleneck = self.bottleneck(enc6)

        # Decoding and adding skip connection
        dec1 = self.decoder1(torch.cat([bottleneck, enc6], dim=1))
        dec2 = self.decoder2(torch.cat([dec1, enc5], dim=1))
        dec3 = self.decoder3(torch.cat([dec2, enc4], dim=1))
        dec4 = self.decoder4(torch.cat([dec3, enc3], dim=1))
        dec5 = self.decoder5(torch.cat([dec4, enc2], dim=1))
        decoded = self.decoder6(torch.cat([dec5, enc1], dim=1))

        return decoded

# Example usage:
# generator = G()
# generator.apply(weights_init)
# print(generator)


In [6]:
''' Discriminator network for 128x128 RGB images '''
class D(nn.Module):
    
    def __init__(self):
        super(D, self).__init__()
        self.main = nn.Sequential(
                                  nn.Conv2d(1, 16, 4, 2, 1),
                                  nn.LeakyReLU(0.2, inplace = True),
                                  nn.Conv2d(16, 32, 4, 2, 1),
                                  nn.BatchNorm2d(32),
                                  nn.LeakyReLU(0.2, inplace = True),
                                  nn.Conv2d(32, 64, 4, 2, 1),
                                  nn.BatchNorm2d(64),
                                  nn.LeakyReLU(0.2, inplace = True),
                                  nn.Conv2d(64, 128, 4, 2, 1),
                                  nn.BatchNorm2d(128),
                                  nn.LeakyReLU(0.2, inplace = True),
                                  nn.Conv2d(128, 256, 4, 2, 1),
                                  nn.BatchNorm2d(256),
                                  nn.LeakyReLU(0.2, inplace = True),
                                  nn.Conv2d(256, 512, 4, 2, 1),
                                  nn.BatchNorm2d(512),
                                  nn.LeakyReLU(0.2, inplace = True),
                                  nn.Conv2d(512, 1, 4, 2, 1, bias = False),
                                  #nn.Sigmoid()
                                  )
    
    
    def forward(self, input):
        output = self.main(input)
        return output.view(-1)

In [15]:
import torch
import torch.nn.functional as F
import torchvision.utils as vutils
from tqdm import tqdm
from torch.autograd import Variable
from skimage.metrics import structural_similarity as ssim

# Define a function to calculate PSNR
def calculate_psnr(img1, img2):
    mse = F.mse_loss(img1, img2)
    psnr = 20 * torch.log10(1.0 / torch.sqrt(mse))
    return psnr.item()

# Define a function to calculate SSIM
# Define a function to calculate SSIM
def calculate_ssim(img1, img2):
    # Ensure tensors are on the same device
    if img1.device != img2.device:
        raise ValueError("Input tensors must be on the same device")

    # Calculate SSIM directly on GPU tensors
    img1 = img1.detach().squeeze().clamp(0, 1).cpu().numpy()  # Ensure pixel values are in [0, 1] range
    img2 = img2.detach().squeeze().clamp(0, 1).cpu().numpy()  # Ensure pixel values are in [0, 1] range
    return ssim(img1.transpose(1, 2, 0), img2.transpose(1, 2, 0), multichannel=True, data_range=1)


# Define lists to store PSNR and SSIM values for each epoch
psnr_values = []
ssim_values = []

In [16]:
from __future__ import print_function
import time
import math
import random
import os
from os import listdir
from os.path import join
from PIL import Image

import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torchvision.utils as vutils
from torch.autograd import Variable
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

#from nvidia.dali.plugin.pytorch import DALIGenericIterator

#from data import ImagePipeline
#import network

np.random.seed(42)
random.seed(10)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(999)
torch.cuda.manual_seed(999)
# Where is your training dataset at?
datapath =r"C:\Users\zed\Dataset\\Grayscale_Dataset"

# You can also choose which GPU you want your model to be trained on below:
gpu_id = 0
device = torch.device("cuda", gpu_id)

"""train_pipe = ImagePipeline(datapath, image_size=128, random_shuffle=True, batch_size=30, device_id=gpu_id)
train_pipe.build()
m_train = train_pipe.epoch_size()
print("Size of the training set: ", m_train)
train_pipe_loader = DALIGenericIterator(train_pipe, ["profiles", "frontals"], m_train)"""
# Assuming you have the modified ImagePipeline class from the previous responses
train_pipe = ImagePipeline(datapath, image_size=128, random_shuffle=True, batch_size=32, device_id=gpu_id)
# No need to call build() without DALI

# Use a standard PyTorch DataLoader instead of DALIGenericIterator
#train_pipe_loader = DataLoader(train_pipe, batch_size=train_pipe.batch_size)
m_train = train_pipe.epoch_size()
#train_pipe_loader = DataLoader(train_pipe, batch_size=32,)
train_pipe_loader = DataLoader(train_pipe,)
# Generator:
#netG = network.G().to(device)
#netG.apply(network.weights_init)
netG = G().to(device)
netG.apply(weights_init)

# Discriminator:
#netD = network.D().to(device)
#netD.apply(network.weights_init)
netD = D().to(device)
netD.apply(weights_init)

# Here is where you set how important each component of the loss function is:
L1_factor = 0
L2_factor = 1
GAN_factor = 0.005

#criterion = nn.BCELoss() # Binary cross entropy loss
criterion = nn.BCEWithLogitsLoss()

# Optimizers for the generator and the discriminator (Adam is a fancier version of gradient descent with a few more bells and whistles that is used very often):
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999), eps = 1e-8)

# Create a directory for the output files
try:
    os.mkdir('output')
except OSError:
    pass

start_time = time.time()

# Let's train for 30 epochs (meaning, we go through the entire training set 30 times):
for epoch in range(3):
    
    # Lets keep track of the loss values for each epoch:
    loss_L1 = 0
    loss_L2 = 0
    loss_gan = 0
    total_psnr = 0  # Total PSNR for this epoch
    total_ssim = 0  # Total SSIM for this epoch
    
    # Your train_pipe_loader will load the images one batch at a time
    # The inner loop iterates over those batches:
    print("starting...")
    with tqdm(total=len(train_pipe_loader), desc=f"Epoch {epoch+1}") as pbar:
        for i, data in enumerate(train_pipe_loader, 0):
        
        # These are your images from the current batch:
        #profile = data[0]['profiles']
        #frontal = data[0]['frontals']
            profile = data[0].view(32, 1, 128, 128)
            frontal = data[1].view(32, 1, 128, 128)
        
        
        # TRAINING THE DISCRIMINATOR
            netD.zero_grad()
            optimizerD.zero_grad()
            real = Variable(frontal).type('torch.FloatTensor').to(device)
            target = Variable(torch.ones(real.size()[0])).to(device)
            output = netD(real)
        # D should accept the GT images
            errD_real = criterion(output.squeeze(), target)
        
            profile = Variable(profile).type('torch.FloatTensor').to(device)
            generated = netG(profile)
            target = Variable(torch.zeros(real.size()[0])).to(device)
            output = netD(generated.detach()) # detach() because we are not training G here

            # Calculate PSNR for each generated image and accumulate
            psnr = calculate_psnr(generated, real)
            total_psnr += psnr

            # Calculate SSIM for each generated image and accumulate
            ssim_val = calculate_ssim(generated, real)
            total_ssim += ssim_val

            pbar.update(1)
        
        # D should reject the synthetic images
            errD_fake = criterion(output, target)
        
            errD = errD_real + errD_fake
            errD.backward()
        # Update D
            optimizerD.step()
        
        # TRAINING THE GENERATOR
            netG.zero_grad()
            optimizerG.zero_grad()
            target = Variable(torch.ones(real.size()[0])).to(device)
            output = netD(generated)
        
        # G wants to :
        # (a) have the synthetic images be accepted by D (= look like frontal images of people)
            errG_GAN = criterion(output, target)
        
        # (b) have the synthetic images resemble the ground truth frontal image
            errG_L1 = torch.mean(torch.abs(real - generated))
            errG_L2 = torch.mean(torch.pow((real - generated), 2))
        
            errG = GAN_factor * errG_GAN + L1_factor * errG_L1 + L2_factor * errG_L2
        
            loss_L1 += errG_L1.item()
            loss_L2 += errG_L2.item()
            loss_gan += errG_GAN.item()
        
            errG.backward()
        # Update G
            optimizerG.step()

    # Calculate average PSNR and SSIM for this epoch
    avg_psnr = total_psnr / len(train_pipe_loader)
    avg_ssim = total_ssim / len(train_pipe_loader)

    # Append the average PSNR and SSIM for this epoch to the respective lists
    psnr_values.append(avg_psnr)
    ssim_values.append(avg_ssim)
    
    if epoch == 0:
        print('First training epoch completed in ',(time.time() - start_time),' seconds')
    if epoch > 0:
        print(f"Epoch: {epoch} is starting..")
    # reset the DALI iterator
    #train_pipe_loader.reset()

    # Print the absolute values of three losses to screen:
    print('[%d/30] Training absolute losses: L1 %.7f ; L2 %.7f BCE %.7f' % ((epoch + 1), loss_L1/m_train, loss_L2/m_train, loss_gan/m_train,))

    # Print the PSNR and SSIM on each epoch
    print('[%d/30] Average PSNR: %.2f, Average SSIM: %.4f' % (epoch + 1, avg_psnr, avg_ssim))

    # Save the inputs, outputs, and ground truth frontals to files:
    vutils.save_image(profile.data, 'output/%03d_input.jpg' % epoch, normalize=True)
    vutils.save_image(real.data, 'output/%03d_real.jpg' % epoch, normalize=True)
    vutils.save_image(generated.data, 'output/%03d_generated.jpg' % epoch, normalize=True)

    

    # Save the pre-trained Generator as well
    torch.save(netG,'output/netG_%d.pt' % epoch)

C:\Users\zed\Dataset\\Grayscale_Dataset\frontal
C:\Users\zed\Dataset\\Grayscale_Dataset\pose
200
50
starting...


Epoch 1: 100%|███████████████████████████████████████████████████████████████████████| 200/200 [13:16<00:00,  3.98s/it]


First training epoch completed in  796.5197060108185  seconds
[1/30] Training absolute losses: L1 0.1347296 ; L2 0.0400709 BCE 1.9535212
[1/30] Average PSNR: 16.61, Average SSIM: 0.7150
starting...


Epoch 2: 100%|███████████████████████████████████████████████████████████████████████| 200/200 [10:15<00:00,  3.08s/it]


Epoch: 1 is starting..
[2/30] Training absolute losses: L1 0.0567451 ; L2 0.0053083 BCE 1.4513911
[2/30] Average PSNR: 22.90, Average SSIM: 0.9314
starting...


Epoch 3: 100%|███████████████████████████████████████████████████████████████████████| 200/200 [09:12<00:00,  2.76s/it]


Epoch: 2 is starting..
[3/30] Training absolute losses: L1 0.0450300 ; L2 0.0034333 BCE 1.7253960
[3/30] Average PSNR: 24.79, Average SSIM: 0.9558
