In [1]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Additional information about the CUDA device
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))


Using device: cuda
NVIDIA GeForce RTX 2070 SUPER


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def weights_init(m):
    classname = m.__class__.__name__
    
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        m_batchsize, C, width, height = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
        energy = torch.bmm(proj_query, proj_key)
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, width, height)

        out = self.gamma * out + x
        return out

In [4]:
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        avg_pool = self.avg_pool(x)
        max_pool = self.max_pool(x)
        avg_out = self.fc(avg_pool)
        max_out = self.fc(max_pool)
        out = avg_out + max_out
        return out * x


In [5]:
class G(nn.Module):
    def __init__(self):
        super(G, self).__init__()
        # Encoder
        self.encoder1 = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1),  # 64x64
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            ChannelAttention(64)  # Add channel attention
        )
        self.encoder2 = nn.Sequential(
            nn.Conv2d(64, 128, 4, 2, 1),         # 32x32
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            ChannelAttention(128)  # Add channel attention
        )
        self.encoder3 = nn.Sequential(
            nn.Conv2d(128, 256, 4, 2, 1),        # 16x16
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            ChannelAttention(256)  # Add channel attention
        )
        self.encoder4 = nn.Sequential(
            nn.Conv2d(256, 512, 4, 2, 1),        # 8x8
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            ChannelAttention(512)  # Add channel attention
        )
        self.encoder5 = nn.Sequential(
            nn.Conv2d(512, 512, 4, 2, 1),        # 4x4
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            ChannelAttention(512)  # Add channel attention
        )
        self.encoder6 = nn.Sequential(
            nn.Conv2d(512, 512, 4, 2, 1),        # 2x2
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            ChannelAttention(512)  # Add channel attention
        )

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 1024, 4, 2, 1),       # 1x1
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            nn.ConvTranspose2d(1024, 512, 4, 2, 1),  # 2x2
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )

        # Decoder
        self.decoder1 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 4, 2, 1),   # 4x4
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )
        self.decoder2 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 4, 2, 1),  # 8x8
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )
        self.decoder3 = nn.Sequential(
            nn.ConvTranspose2d(1024, 256, 4, 2, 1),   # 16x16
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            SelfAttention(256)  # Add self-attention
        )
        self.decoder4 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, 4, 2, 1),   # 32x32
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            SelfAttention(128)  # Add self-attention
        )
        self.decoder5 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, 4, 2, 1),    # 64x64
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )
        self.decoder6 = nn.Sequential(
            nn.ConvTranspose2d(128, 1, 4, 2, 1),  # 128x128
            nn.Tanh()
        )

    def forward(self, x):
        # Encoding
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(enc1)
        enc3 = self.encoder3(enc2)
        enc4 = self.encoder4(enc3)
        enc5 = self.encoder5(enc4)
        enc6 = self.encoder6(enc5)

        # Bottleneck
        bottleneck = self.bottleneck(enc6)

        # Decoding and adding skip connection
        dec1 = self.decoder1(torch.cat([bottleneck, enc6], dim=1))
        dec2 = self.decoder2(torch.cat([dec1, enc5], dim=1))
        dec3 = self.decoder3(torch.cat([dec2, enc4], dim=1))
        dec4 = self.decoder4(torch.cat([dec3, enc3], dim=1))
        dec5 = self.decoder5(torch.cat([dec4, enc2], dim=1))
        decoded = self.decoder6(torch.cat([dec5, enc1], dim=1))

        return decoded

# Example usage:
# generator = G()
# generator.apply(weights_init)
# print(generator)


In [6]:
import torch
from torch import nn

class RelativeAvgDiscriminator(nn.Module):
  def __init__(self):
    super(RelativeAvgDiscriminator, self).__init__()

    # Separate feature extraction for real and generated data
    self.conv_real = nn.Sequential(
        nn.Conv2d(1, 16, 4, 2, 1),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(16, 32, 4, 2, 1),
        nn.BatchNorm2d(32),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(32, 64, 4, 2, 1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2, inplace=True),
    )
    self.conv_generated = nn.Sequential(
        nn.Conv2d(1, 16, 4, 2, 1),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(16, 32, 4, 2, 1),
        nn.BatchNorm2d(32),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(32, 64, 4, 2, 1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2, inplace=True),
    )

    # Relative Average Pooling
    self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)

    # Remaining convolutional layers (modified for combined features)
    self.post_pool = nn.Sequential(
        nn.Conv2d(128, 128, 4, 2, 1),  # Input channels changed to 128
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(128, 256, 4, 2, 1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2, inplace=True),
       
    )

    # Output layer with sigmoid activation
    self.output = nn.Sigmoid()

  def forward(self, real, fake):
    # Extract features from real and generated data
    real_features = self.conv_real(real)
    generated_features = self.conv_generated(fake)

    # Concatenate features before pooling
    combined_features = torch.cat([real_features, generated_features], dim=1)

    # Relative Average Pooling
    features = self.avgpool(combined_features)

    # Process features with remaining layers
    output = self.post_pool(features)

    # Probability score
    #probability = self.output(logits)

    return output


In [7]:
import torch
import torch.nn.functional as F
import torchvision.utils as vutils
from tqdm import tqdm
from torch.autograd import Variable
from skimage.metrics import structural_similarity as ssim

# Define a function to calculate PSNR
def calculate_psnr(img1, img2):
    mse = F.mse_loss(img1, img2)
    psnr = 20 * torch.log10(1.0 / torch.sqrt(mse))
    return psnr.item()

# Define a function to calculate SSIM
# Define a function to calculate SSIM
def calculate_ssim(img1, img2):
    # Ensure tensors are on the same device
    if img1.device != img2.device:
        raise ValueError("Input tensors must be on the same device")

    # Calculate SSIM directly on GPU tensors
    img1 = img1.detach().squeeze().clamp(0, 1).cpu().numpy()  # Ensure pixel values are in [0, 1] range
    img2 = img2.detach().squeeze().clamp(0, 1).cpu().numpy()  # Ensure pixel values are in [0, 1] range
    return ssim(img1.transpose(1, 2, 0), img2.transpose(1, 2, 0), multichannel=True, data_range=1)


# Define lists to store PSNR and SSIM values for each epoch
psnr_values = []
ssim_values = []

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils as vutils
from tqdm import tqdm
import os
import time

#gpu_id = 0
#device = torch.device("cuda", gpu_id)

netG = G().to(device)
netG.apply(weights_init)

netD = RelativeAvgDiscriminator().to(device)
netD.apply(weights_init)

L1_factor = 1
L2_factor = 1
GAN_factor = 0.005

#criterion = nn.BCEWithLogitsLoss()

optimizerD = optim.Adam(netD.parameters(), lr=0.002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.002, betas=(0.5, 0.999), eps=1e-8)




In [9]:
# Lists to store the losses
losses_L1 = []
losses_L2 = []
losses_gan = []


In [10]:
import torch
import torch.nn.functional as F

def multi_scale_pixelwise_loss(fake_images, real_images, num_scales=3):
    loss = 0.0
    for scale in range(num_scales):
        fake_scaled = F.interpolate(fake_images, scale_factor=1 / (2 ** scale), mode='bilinear', align_corners=False)
        real_scaled = F.interpolate(real_images, scale_factor=1 / (2 ** scale), mode='bilinear', align_corners=False)
        pixel_loss = F.l1_loss(fake_scaled, real_scaled)
        loss += pixel_loss / (2 ** scale)
    return loss


In [11]:
# Initialize lists to store losses
generator_losses = []
discriminator_losses = []
multi_scale_losses = []

avg_generator_losses = []
avg_discriminator_losses = []
avg_multi_scale_losses = []

In [12]:
start_time = time.time()

In [13]:
import os
import numpy as np
from PIL import Image
import logging

# Set up logging
logging.basicConfig(filename='missing_files.log', level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s')

def is_jpeg(filename):
    return any(filename.endswith(extension) for extension in [".jpg", ".jpeg", ".png"])

def get_subdirs(directory):
    subdirs = sorted([os.path.join(directory, name) for name in sorted(os.listdir(directory)) if os.path.isdir(os.path.join(directory, name))])
    return subdirs

class ExternalInputIterator:
    def __init__(self, imageset_dir, batch_size, random_shuffle=False):
        self.imageset_dir = imageset_dir
        self.batch_size = batch_size

        # Get subdirectories (assuming "pose" and "frontal" folders exist)
        self.pose_dirs = os.path.join(imageset_dir, "pose")
        self.frontal_dir = os.path.join(imageset_dir, "frontal")
        print(self.frontal_dir)
        print(self.pose_dirs)

        # Collect profile image paths
        self.profile_files = [os.path.join(self.pose_dirs, file) for file in sorted(os.listdir(self.pose_dirs)) if is_jpeg(file)]
        print(len(self.profile_files))

        # Collect frontal image paths
        self.frontal_files = [os.path.join(self.frontal_dir, file) for file in sorted(os.listdir(self.frontal_dir)) if is_jpeg(file)]
        print(len(self.frontal_files))

        # Shuffle if necessary
        if random_shuffle:
            np.random.shuffle(self.profile_files)
            np.random.shuffle(self.frontal_files)

        self.i = 0
        self.n = len(self.profile_files)

    def __iter__(self):
        return self

    def __next__(self):
        profiles = []
        frontals = []

        for _ in range(self.batch_size):
            profile_filename = self.profile_files[self.i]
            frontal_filename = self.match_frontal_image(profile_filename)

            try:
                with Image.open(profile_filename) as profile_img:
                    profiles.append(np.array(profile_img))
            except FileNotFoundError:
                logging.error(f'Profile image not found: {profile_filename}')
                raise
            except Exception as e:
                logging.error(f'Error opening profile image {profile_filename}: {e}')
                raise

            if frontal_filename is None:
                logging.error(f'Matching frontal image not found for: {profile_filename}')
                raise FileNotFoundError(f'Matching frontal image not found for: {profile_filename}')
            try:
                with Image.open(frontal_filename) as frontal_img:
                    frontals.append(np.array(frontal_img))
            except FileNotFoundError:
                logging.error(f'Frontal image not found: {frontal_filename}')
                raise
            except Exception as e:
                logging.error(f'Error opening frontal image {frontal_filename}: {e}')
                raise

            self.i = (self.i + 1) % self.n

        return (profiles, frontals)

    def match_frontal_image(self, profile_filename):
        profile_name = os.path.basename(profile_filename).split("_")[1]
        for frontal_file in self.frontal_files:
            if profile_name in frontal_file:
                return frontal_file
        return None

class ImagePipeline:
    def __init__(self, imageset_dir, image_size=128, random_shuffle=False, batch_size=64, device=device):
        self.eii = ExternalInputIterator(imageset_dir, batch_size, random_shuffle)
        self.iterator = iter(self.eii)
        self.num_inputs = len(self.eii.profile_files)
        self.image_size = image_size

    def epoch_size(self, name=None):
        return self.num_inputs

    def __len__(self):
        return self.num_inputs

    def __iter__(self):
        return self

    def __next__(self):
        (images, targets) = next(self.iterator)

        # Perform resizing and normalization using NumPy
        resized_images = np.array([np.array(Image.fromarray(img).resize((self.image_size, self.image_size))) for img in images])
        resized_targets = np.array([np.array(Image.fromarray(target).resize((self.image_size, self.image_size))) for target in targets])

        # Normalize using mean and standard deviation
        normalized_images = (resized_images - 128.0) / 128.0
        normalized_targets = (resized_targets - 128.0) / 128.0

        return (normalized_images, normalized_targets)

    def __getitem__(self, index):
        # Advance the iterator to the desired index
        for _ in range(index):
            next(self.iterator)

        # Return the next batch
        return next(self)


In [14]:
from __future__ import print_function
import time
import math
import random
import os
from os import listdir
from os.path import join
from PIL import Image

import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torchvision.utils as vutils
from torch.autograd import Variable
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

#from nvidia.dali.plugin.pytorch import DALIGenericIterator

#from data import ImagePipeline
#import network

np.random.seed(42)
random.seed(10)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(999)
torch.cuda.manual_seed(999)
# Where is your training dataset at?
datapath =r"C:\Users\zed\Dataset\CAS_5000"

# You can also choose which GPU you want your model to be trained on below:
#gpu_id = 0
#device = torch.device("cuda", gpu_id)

#checkpoint_dir = "checkpoints"

"""train_pipe = ImagePipeline(datapath, image_size=128, random_shuffle=True, batch_size=30, device_id=gpu_id)
train_pipe.build()
m_train = train_pipe.epoch_size()
print("Size of the training set: ", m_train)
train_pipe_loader = DALIGenericIterator(train_pipe, ["profiles", "frontals"], m_train)"""
# Assuming you have the modified ImagePipeline class from the previous responses
train_pipe = ImagePipeline(datapath, image_size=128, random_shuffle=True, batch_size=1, device=device)
# No need to call build() without DALI

# Use a standard PyTorch DataLoader instead of DALIGenericIterator
#train_pipe_loader = DataLoader(train_pipe, batch_size=train_pipe.batch_size)
m_train = train_pipe.epoch_size()
#train_pipe_loader = DataLoader(train_pipe, batch_size=32,)
train_pipe_loader = DataLoader(train_pipe,batch_size=128,drop_last=True)

criterion = nn.BCEWithLogitsLoss().to(device)

C:\Users\zed\Dataset\CAS_5000\frontal
C:\Users\zed\Dataset\CAS_5000\pose
5248
323


In [15]:
checkpoint_dir1 = "FFRAD_CAS_67_Checkpoint"

In [None]:

for epoch in range(40):  # Assuming 3 epochs for demonstration
    
    # Track loss values for each epoch
    loss_L1 = 0
    loss_L2 = 0
    loss_gan = 0
    total_psnr = 0
    total_ssim = 0
    
   
    with tqdm(total=len(train_pipe_loader), desc=f"Epoch {epoch}") as pbar:
        for i, data in enumerate(train_pipe_loader, 0):
            profile = data[0].view(128, 1, 128, 128).to(device)  # Reshape and move to device
            frontal = data[1].view(128, 1, 128, 128).to(device)  # Reshape and move to device

            # TRAINING THE DISCRIMINATOR
            netD.zero_grad()
            optimizerD.zero_grad()

            real = Variable(frontal).type('torch.FloatTensor').to(device)
            target = Variable(torch.ones(real.size()[0])).to(device)
            profile = Variable(profile).type('torch.FloatTensor').to(device)
            
            real_output = netD(real,real)  # Discriminator output for real images
            generated = netG(profile)  # Generate images from profile
            fake_output = netD(profile, generated.detach())  # Discriminator output for fake images

            # Concatenate real and fake outputs along a new dimension
            concatenated = torch.cat((real_output, fake_output), dim=0)

            # Create labels for real and fake images
            target_real = torch.ones_like(real_output)
            target_fake = torch.zeros_like(fake_output)
            targets = torch.cat((target_real, target_fake), dim=0)

            # Calculate BCE loss for the concatenated outputs
            #errD = F.binary_cross_entropy_with_logits(concatenated, targets)

            errD = criterion(concatenated, targets.float())
            errD.backward()
            optimizerD.step()
             # Accumulate discriminator loss
            discriminator_losses.append(errD.item())

            # TRAINING THE GENERATOR
            netG.zero_grad()
            optimizerG.zero_grad()
            generated = netG(profile)
            output = netD(profile, generated)

            # G wants to have the synthetic images be accepted by D
            errG_GAN = criterion(output, torch.ones_like(output).float())

            # Calculate L1 and L2 loss between generated and real images
            #errG_L1 = F.l1_loss(generated, frontal.float())
            #errG_L2 = F.mse_loss(generated, frontal.float())
            errG_L1 = multi_scale_pixelwise_loss(generated, real)  # Multi-scale pixel-wise loss
           
            #errG_L1 = torch.mean(torch.abs(real - generated))
            errG_L2 = torch.mean(torch.pow((real - generated), 2))
            
            # Total generator loss
            errG = GAN_factor * errG_GAN + L1_factor * errG_L1 + L2_factor * errG_L2
            errG.backward()
            optimizerG.step()

             #Accumulate generator loss
            generator_losses.append(errG.item())

            #Accumulate multi-scale pixel-wise loss
            multi_scale_losses.append(errG_L1.item())

            # Update loss values
            loss_L1 += errG_L1.item()
            loss_L2 += errG_L2.item()
            loss_gan += errG_GAN.item()

            # Calculate PSNR for each generated image and accumulate
            psnr = calculate_psnr(generated, frontal)
            total_psnr += psnr

            # Calculate SSIM for each generated image and accumulate
            ssim_val = calculate_ssim(generated, frontal)
            total_ssim += ssim_val

            pbar.update(1)

   # Append the average losses to the respective lists


    avg_gen_loss = sum(generator_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)
    avg_disc_loss = sum(discriminator_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)
    avg_multi_loss = sum(multi_scale_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)

    avg_generator_losses.append(avg_gen_loss)
    avg_discriminator_losses.append(avg_disc_loss)
    avg_multi_scale_losses.append(avg_multi_loss)
    
    
    # Calculate average PSNR and SSIM for this epoch
    avg_psnr = total_psnr / len(train_pipe_loader)
    avg_ssim = total_ssim / len(train_pipe_loader)

    # Append the average PSNR and SSIM for this epoch to the respective lists
    psnr_values.append(avg_psnr)
    ssim_values.append(avg_ssim)
    
    if epoch == 0:
        print('First training epoch completed in ',(time.time() - start_time),' seconds')
    #if epoch > 0:
        #print(f"Epoch: {epoch} is starting..")
    # reset the DALI iterator
    #train_pipe_loader.reset()

    losses_L1.append(loss_L1 / m_train)
    losses_L2.append(loss_L2 / m_train)
    losses_gan.append(loss_gan / m_train)

     # Save checkpoint after each epoch
    checkpoint_state = {
      'epoch': epoch,
      'netG_state_dict': netG.state_dict(),
      'netD_state_dict': netD.state_dict(),
      'optimizerG_state_dict': optimizerG.state_dict(),
      'optimizerD_state_dict': optimizerD.state_dict(),
      'loss_L1': loss_L1,
      'loss_L2': loss_L2,
      'loss_gan': loss_gan,
      'psnr_values': psnr_values,
      'ssim_values': ssim_values,
      'losses_L1': losses_L1,
      'losses_L2': losses_L2,
      'losses_gan': losses_gan,
      'discriminator_losses': discriminator_losses,
      'generator_losses': generator_losses,
      'multi_scale_losses': multi_scale_losses,
      'avg_generator_losses': avg_generator_losses,
      'avg_discriminator_losses': avg_discriminator_losses,
      'avg_multi_scale_losses': avg_multi_scale_losses,
    }
    torch.save(checkpoint_state, os.path.join(checkpoint_dir1, f"checkpoint_{epoch}.pth"))

    

    # Print the absolute values of three losses to screen:
    print('[%d/40] Training absolute losses: L1 %.7f ; L2 %.7f BCE %.7f; Average PSNR: %.2f; Average SSIM: %.4f' % ((epoch + 1), loss_L1/m_train, loss_L2/m_train, loss_gan/m_train, avg_psnr, avg_ssim, ))

    # Print the PSNR and SSIM on each epoch
    #print('[%d/30] Average PSNR: %.2f, Average SSIM: %.4f' % (epoch + 1, avg_psnr, avg_ssim))

    # Save the inputs, outputs, and ground truth frontals to files:
    vutils.save_image(profile.data, 'FFRAD_CAS_67_Output/%03d_input.jpg' % epoch, normalize=True)
    vutils.save_image(real.data, 'FFRAD_CAS_67_Output/%03d_real.jpg' % epoch, normalize=True)
    vutils.save_image(generated.data, 'FFRAD_CAS_67_Output/%03d_generated.jpg' % epoch, normalize=True)

    

    # Save the pre-trained Generator as well
    torch.save(netG,'FFRAD_CAS_67_Output/netG_%d.pt' % epoch)

Epoch 0: 100%|██████████████████████████████████████████████████████████████████████| 41/41 [3:18:27<00:00, 290.42s/it]


First training epoch completed in  12311.57608294487  seconds
[1/40] Training absolute losses: L1 0.0033723 ; L2 0.0009225 BCE 0.0049726; Average PSNR: 9.52; Average SSIM: 0.2775


Epoch 1: 100%|██████████████████████████████████████████████████████████████████████| 41/41 [3:11:25<00:00, 280.14s/it]


[2/40] Training absolute losses: L1 0.0026301 ; L2 0.0005976 BCE 0.0053839; Average PSNR: 11.18; Average SSIM: 0.4047


Epoch 2: 100%|██████████████████████████████████████████████████████████████████████| 41/41 [3:12:02<00:00, 281.03s/it]


[3/40] Training absolute losses: L1 0.0024495 ; L2 0.0005318 BCE 0.0054906; Average PSNR: 11.68; Average SSIM: 0.4565


Epoch 3:  24%|█████████████████▌                                                      | 10/41 [12:09<54:41, 105.86s/it]

In [16]:
latest_epoch = 44
checkpoint_path = os.path.join(checkpoint_dir1, f"checkpoint_{latest_epoch}.pth")
checkpoint = torch.load(checkpoint_path)

# Load model and optimizer states
netG.load_state_dict(checkpoint['netG_state_dict'])
netD.load_state_dict(checkpoint['netD_state_dict'])
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])

# Load training progress
loss_L1 = checkpoint['loss_L1']
loss_L2 = checkpoint['loss_L2']
loss_gan = checkpoint['loss_gan']
psnr_values = checkpoint['psnr_values']
ssim_values = checkpoint['ssim_values']
losses_L1 = checkpoint['losses_L1']
losses_L2 = checkpoint['losses_L2']
losses_gan = checkpoint['losses_gan']
discriminator_losses = checkpoint['discriminator_losses']
generator_losses = checkpoint['generator_losses']
multi_scale_losses = checkpoint['multi_scale_losses']
avg_generator_losses = checkpoint['avg_generator_losses']
avg_discriminator_losses = checkpoint['avg_discriminator_losses']
avg_multi_scale_losses = checkpoint['avg_multi_scale_losses']

# Start training from the loaded epoch
start_epoch = checkpoint['epoch'] + 1

In [None]:

for epoch in range(start_epoch,40):  # Assuming 3 epochs for demonstration
    
    # Track loss values for each epoch
    loss_L1 = 0
    loss_L2 = 0
    loss_gan = 0
    total_psnr = 0
    total_ssim = 0
    
   
    with tqdm(total=len(train_pipe_loader), desc=f"Epoch {epoch}") as pbar:
        for i, data in enumerate(train_pipe_loader, 0):
            profile = data[0].view(128, 1, 128, 128).to(device)  # Reshape and move to device
            frontal = data[1].view(128, 1, 128, 128).to(device)  # Reshape and move to device

            # TRAINING THE DISCRIMINATOR
            netD.zero_grad()
            optimizerD.zero_grad()

            real = Variable(frontal).type('torch.FloatTensor').to(device)
            target = Variable(torch.ones(real.size()[0])).to(device)
            profile = Variable(profile).type('torch.FloatTensor').to(device)
            
            real_output = netD(real,real)  # Discriminator output for real images
            generated = netG(profile)  # Generate images from profile
            fake_output = netD(profile, generated.detach())  # Discriminator output for fake images

            # Concatenate real and fake outputs along a new dimension
            concatenated = torch.cat((real_output, fake_output), dim=0)

            # Create labels for real and fake images
            target_real = torch.ones_like(real_output)
            target_fake = torch.zeros_like(fake_output)
            targets = torch.cat((target_real, target_fake), dim=0)

            # Calculate BCE loss for the concatenated outputs
            #errD = F.binary_cross_entropy_with_logits(concatenated, targets)

            errD = criterion(concatenated, targets.float())
            errD.backward()
            optimizerD.step()
             # Accumulate discriminator loss
            discriminator_losses.append(errD.item())

            # TRAINING THE GENERATOR
            netG.zero_grad()
            optimizerG.zero_grad()
            generated = netG(profile)
            output = netD(profile, generated)

            # G wants to have the synthetic images be accepted by D
            errG_GAN = criterion(output, torch.ones_like(output).float())

            # Calculate L1 and L2 loss between generated and real images
            #errG_L1 = F.l1_loss(generated, frontal.float())
            #errG_L2 = F.mse_loss(generated, frontal.float())
            errG_L1 = multi_scale_pixelwise_loss(generated, real)  # Multi-scale pixel-wise loss
           
            #errG_L1 = torch.mean(torch.abs(real - generated))
            errG_L2 = torch.mean(torch.pow((real - generated), 2))
            
            # Total generator loss
            errG = GAN_factor * errG_GAN + L1_factor * errG_L1 + L2_factor * errG_L2
            errG.backward()
            optimizerG.step()

             #Accumulate generator loss
            generator_losses.append(errG.item())

            #Accumulate multi-scale pixel-wise loss
            multi_scale_losses.append(errG_L1.item())

            # Update loss values
            loss_L1 += errG_L1.item()
            loss_L2 += errG_L2.item()
            loss_gan += errG_GAN.item()

            # Calculate PSNR for each generated image and accumulate
            psnr = calculate_psnr(generated, frontal)
            total_psnr += psnr

            # Calculate SSIM for each generated image and accumulate
            ssim_val = calculate_ssim(generated, frontal)
            total_ssim += ssim_val

            pbar.update(1)

   # Append the average losses to the respective lists


    avg_gen_loss = sum(generator_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)
    avg_disc_loss = sum(discriminator_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)
    avg_multi_loss = sum(multi_scale_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)

    avg_generator_losses.append(avg_gen_loss)
    avg_discriminator_losses.append(avg_disc_loss)
    avg_multi_scale_losses.append(avg_multi_loss)
    
    
    # Calculate average PSNR and SSIM for this epoch
    avg_psnr = total_psnr / len(train_pipe_loader)
    avg_ssim = total_ssim / len(train_pipe_loader)

    # Append the average PSNR and SSIM for this epoch to the respective lists
    psnr_values.append(avg_psnr)
    ssim_values.append(avg_ssim)
    
    if epoch == 0:
        print('First training epoch completed in ',(time.time() - start_time),' seconds')
    #if epoch > 0:
        #print(f"Epoch: {epoch} is starting..")
    # reset the DALI iterator
    #train_pipe_loader.reset()

    losses_L1.append(loss_L1 / m_train)
    losses_L2.append(loss_L2 / m_train)
    losses_gan.append(loss_gan / m_train)

     # Save checkpoint after each epoch
    checkpoint_state = {
      'epoch': epoch,
      'netG_state_dict': netG.state_dict(),
      'netD_state_dict': netD.state_dict(),
      'optimizerG_state_dict': optimizerG.state_dict(),
      'optimizerD_state_dict': optimizerD.state_dict(),
      'loss_L1': loss_L1,
      'loss_L2': loss_L2,
      'loss_gan': loss_gan,
      'psnr_values': psnr_values,
      'ssim_values': ssim_values,
      'losses_L1': losses_L1,
      'losses_L2': losses_L2,
      'losses_gan': losses_gan,
      'discriminator_losses': discriminator_losses,
      'generator_losses': generator_losses,
      'multi_scale_losses': multi_scale_losses,
      'avg_generator_losses': avg_generator_losses,
      'avg_discriminator_losses': avg_discriminator_losses,
      'avg_multi_scale_losses': avg_multi_scale_losses,
    }
    torch.save(checkpoint_state, os.path.join(checkpoint_dir1, f"checkpoint_{epoch}.pth"))

    

    # Print the absolute values of three losses to screen:
    print('[%d/40] Training absolute losses: L1 %.7f ; L2 %.7f BCE %.7f; Average PSNR: %.2f; Average SSIM: %.4f' % ((epoch + 1), loss_L1/m_train, loss_L2/m_train, loss_gan/m_train, avg_psnr, avg_ssim, ))

    # Print the PSNR and SSIM on each epoch
    #print('[%d/30] Average PSNR: %.2f, Average SSIM: %.4f' % (epoch + 1, avg_psnr, avg_ssim))

    # Save the inputs, outputs, and ground truth frontals to files:
    vutils.save_image(profile.data, 'FFRAD_CAS_67_Output/%03d_input.jpg' % epoch, normalize=True)
    vutils.save_image(real.data, 'FFRAD_CAS_67_Output/%03d_real.jpg' % epoch, normalize=True)
    vutils.save_image(generated.data, 'FFRAD_CAS_67_Output/%03d_generated.jpg' % epoch, normalize=True)

    

    # Save the pre-trained Generator as well
    torch.save(netG,'FFRAD_CAS_67_Output/netG_%d.pt' % epoch)

Epoch 3: 100%|██████████████████████████████████████████████████████████████████████| 41/41 [3:12:25<00:00, 281.61s/it]


[4/40] Training absolute losses: L1 0.0023149 ; L2 0.0004862 BCE 0.0055652; Average PSNR: 12.07; Average SSIM: 0.4877


Epoch 4: 100%|██████████████████████████████████████████████████████████████████████| 41/41 [3:09:26<00:00, 277.23s/it]


[5/40] Training absolute losses: L1 0.0021452 ; L2 0.0004263 BCE 0.0055383; Average PSNR: 12.64; Average SSIM: 0.5357


Epoch 5: 100%|██████████████████████████████████████████████████████████████████████| 41/41 [3:15:15<00:00, 285.74s/it]


[6/40] Training absolute losses: L1 0.0019310 ; L2 0.0003568 BCE 0.0055936; Average PSNR: 13.41; Average SSIM: 0.5890


Epoch 6:  88%|█████████████████████████████████████████████████████████████▍        | 36/41 [2:25:57<37:12, 446.51s/it]

In [None]:

for epoch in range(start_epoch,40):  # Assuming 3 epochs for demonstration
    
    # Track loss values for each epoch
    loss_L1 = 0
    loss_L2 = 0
    loss_gan = 0
    total_psnr = 0
    total_ssim = 0
    
   
    with tqdm(total=len(train_pipe_loader), desc=f"Epoch {epoch}") as pbar:
        for i, data in enumerate(train_pipe_loader, 0):
            profile = data[0].view(128, 1, 128, 128).to(device)  # Reshape and move to device
            frontal = data[1].view(128, 1, 128, 128).to(device)  # Reshape and move to device

            # TRAINING THE DISCRIMINATOR
            netD.zero_grad()
            optimizerD.zero_grad()

            real = Variable(frontal).type('torch.FloatTensor').to(device)
            target = Variable(torch.ones(real.size()[0])).to(device)
            profile = Variable(profile).type('torch.FloatTensor').to(device)
            
            real_output = netD(real,real)  # Discriminator output for real images
            generated = netG(profile)  # Generate images from profile
            fake_output = netD(profile, generated.detach())  # Discriminator output for fake images

            # Concatenate real and fake outputs along a new dimension
            concatenated = torch.cat((real_output, fake_output), dim=0)

            # Create labels for real and fake images
            target_real = torch.ones_like(real_output)
            target_fake = torch.zeros_like(fake_output)
            targets = torch.cat((target_real, target_fake), dim=0)

            # Calculate BCE loss for the concatenated outputs
            #errD = F.binary_cross_entropy_with_logits(concatenated, targets)

            errD = criterion(concatenated, targets.float())
            errD.backward()
            optimizerD.step()
             # Accumulate discriminator loss
            discriminator_losses.append(errD.item())

            # TRAINING THE GENERATOR
            netG.zero_grad()
            optimizerG.zero_grad()
            generated = netG(profile)
            output = netD(profile, generated)

            # G wants to have the synthetic images be accepted by D
            errG_GAN = criterion(output, torch.ones_like(output).float())

            # Calculate L1 and L2 loss between generated and real images
            #errG_L1 = F.l1_loss(generated, frontal.float())
            #errG_L2 = F.mse_loss(generated, frontal.float())
            errG_L1 = multi_scale_pixelwise_loss(generated, real)  # Multi-scale pixel-wise loss
           
            #errG_L1 = torch.mean(torch.abs(real - generated))
            errG_L2 = torch.mean(torch.pow((real - generated), 2))
            
            # Total generator loss
            errG = GAN_factor * errG_GAN + L1_factor * errG_L1 + L2_factor * errG_L2
            errG.backward()
            optimizerG.step()

             #Accumulate generator loss
            generator_losses.append(errG.item())

            #Accumulate multi-scale pixel-wise loss
            multi_scale_losses.append(errG_L1.item())

            # Update loss values
            loss_L1 += errG_L1.item()
            loss_L2 += errG_L2.item()
            loss_gan += errG_GAN.item()

            # Calculate PSNR for each generated image and accumulate
            psnr = calculate_psnr(generated, frontal)
            total_psnr += psnr

            # Calculate SSIM for each generated image and accumulate
            ssim_val = calculate_ssim(generated, frontal)
            total_ssim += ssim_val

            pbar.update(1)

   # Append the average losses to the respective lists


    avg_gen_loss = sum(generator_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)
    avg_disc_loss = sum(discriminator_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)
    avg_multi_loss = sum(multi_scale_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)

    avg_generator_losses.append(avg_gen_loss)
    avg_discriminator_losses.append(avg_disc_loss)
    avg_multi_scale_losses.append(avg_multi_loss)
    
    
    # Calculate average PSNR and SSIM for this epoch
    avg_psnr = total_psnr / len(train_pipe_loader)
    avg_ssim = total_ssim / len(train_pipe_loader)

    # Append the average PSNR and SSIM for this epoch to the respective lists
    psnr_values.append(avg_psnr)
    ssim_values.append(avg_ssim)
    
    if epoch == 0:
        print('First training epoch completed in ',(time.time() - start_time),' seconds')
    #if epoch > 0:
        #print(f"Epoch: {epoch} is starting..")
    # reset the DALI iterator
    #train_pipe_loader.reset()

    losses_L1.append(loss_L1 / m_train)
    losses_L2.append(loss_L2 / m_train)
    losses_gan.append(loss_gan / m_train)

     # Save checkpoint after each epoch
    checkpoint_state = {
      'epoch': epoch,
      'netG_state_dict': netG.state_dict(),
      'netD_state_dict': netD.state_dict(),
      'optimizerG_state_dict': optimizerG.state_dict(),
      'optimizerD_state_dict': optimizerD.state_dict(),
      'loss_L1': loss_L1,
      'loss_L2': loss_L2,
      'loss_gan': loss_gan,
      'psnr_values': psnr_values,
      'ssim_values': ssim_values,
      'losses_L1': losses_L1,
      'losses_L2': losses_L2,
      'losses_gan': losses_gan,
      'discriminator_losses': discriminator_losses,
      'generator_losses': generator_losses,
      'multi_scale_losses': multi_scale_losses,
      'avg_generator_losses': avg_generator_losses,
      'avg_discriminator_losses': avg_discriminator_losses,
      'avg_multi_scale_losses': avg_multi_scale_losses,
    }
    torch.save(checkpoint_state, os.path.join(checkpoint_dir1, f"checkpoint_{epoch}.pth"))

    

    # Print the absolute values of three losses to screen:
    print('[%d/40] Training absolute losses: L1 %.7f ; L2 %.7f BCE %.7f; Average PSNR: %.2f; Average SSIM: %.4f' % ((epoch + 1), loss_L1/m_train, loss_L2/m_train, loss_gan/m_train, avg_psnr, avg_ssim, ))

    # Print the PSNR and SSIM on each epoch
    #print('[%d/30] Average PSNR: %.2f, Average SSIM: %.4f' % (epoch + 1, avg_psnr, avg_ssim))

    # Save the inputs, outputs, and ground truth frontals to files:
    vutils.save_image(profile.data, 'FFRAD_CAS_67_Output/%03d_input.jpg' % epoch, normalize=True)
    vutils.save_image(real.data, 'FFRAD_CAS_67_Output/%03d_real.jpg' % epoch, normalize=True)
    vutils.save_image(generated.data, 'FFRAD_CAS_67_Output/%03d_generated.jpg' % epoch, normalize=True)

    

    # Save the pre-trained Generator as well
    torch.save(netG,'FFRAD_CAS_67_Output/netG_%d.pt' % epoch)

Epoch 6: 100%|██████████████████████████████████████████████████████████████████████| 41/41 [3:12:02<00:00, 281.05s/it]


[7/40] Training absolute losses: L1 0.0018586 ; L2 0.0003356 BCE 0.0055806; Average PSNR: 13.68; Average SSIM: 0.6132


Epoch 7:  78%|██████████████████████████████████████████████████████▋               | 32/41 [1:55:41<59:10, 394.48s/it]

In [None]:

for epoch in range(start_epoch,40):  # Assuming 3 epochs for demonstration
    
    # Track loss values for each epoch
    loss_L1 = 0
    loss_L2 = 0
    loss_gan = 0
    total_psnr = 0
    total_ssim = 0
    
   
    with tqdm(total=len(train_pipe_loader), desc=f"Epoch {epoch}") as pbar:
        for i, data in enumerate(train_pipe_loader, 0):
            profile = data[0].view(128, 1, 128, 128).to(device)  # Reshape and move to device
            frontal = data[1].view(128, 1, 128, 128).to(device)  # Reshape and move to device

            # TRAINING THE DISCRIMINATOR
            netD.zero_grad()
            optimizerD.zero_grad()

            real = Variable(frontal).type('torch.FloatTensor').to(device)
            target = Variable(torch.ones(real.size()[0])).to(device)
            profile = Variable(profile).type('torch.FloatTensor').to(device)
            
            real_output = netD(real,real)  # Discriminator output for real images
            generated = netG(profile)  # Generate images from profile
            fake_output = netD(profile, generated.detach())  # Discriminator output for fake images

            # Concatenate real and fake outputs along a new dimension
            concatenated = torch.cat((real_output, fake_output), dim=0)

            # Create labels for real and fake images
            target_real = torch.ones_like(real_output)
            target_fake = torch.zeros_like(fake_output)
            targets = torch.cat((target_real, target_fake), dim=0)

            # Calculate BCE loss for the concatenated outputs
            #errD = F.binary_cross_entropy_with_logits(concatenated, targets)

            errD = criterion(concatenated, targets.float())
            errD.backward()
            optimizerD.step()
             # Accumulate discriminator loss
            discriminator_losses.append(errD.item())

            # TRAINING THE GENERATOR
            netG.zero_grad()
            optimizerG.zero_grad()
            generated = netG(profile)
            output = netD(profile, generated)

            # G wants to have the synthetic images be accepted by D
            errG_GAN = criterion(output, torch.ones_like(output).float())

            # Calculate L1 and L2 loss between generated and real images
            #errG_L1 = F.l1_loss(generated, frontal.float())
            #errG_L2 = F.mse_loss(generated, frontal.float())
            errG_L1 = multi_scale_pixelwise_loss(generated, real)  # Multi-scale pixel-wise loss
           
            #errG_L1 = torch.mean(torch.abs(real - generated))
            errG_L2 = torch.mean(torch.pow((real - generated), 2))
            
            # Total generator loss
            errG = GAN_factor * errG_GAN + L1_factor * errG_L1 + L2_factor * errG_L2
            errG.backward()
            optimizerG.step()

             #Accumulate generator loss
            generator_losses.append(errG.item())

            #Accumulate multi-scale pixel-wise loss
            multi_scale_losses.append(errG_L1.item())

            # Update loss values
            loss_L1 += errG_L1.item()
            loss_L2 += errG_L2.item()
            loss_gan += errG_GAN.item()

            # Calculate PSNR for each generated image and accumulate
            psnr = calculate_psnr(generated, frontal)
            total_psnr += psnr

            # Calculate SSIM for each generated image and accumulate
            ssim_val = calculate_ssim(generated, frontal)
            total_ssim += ssim_val

            pbar.update(1)

   # Append the average losses to the respective lists


    avg_gen_loss = sum(generator_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)
    avg_disc_loss = sum(discriminator_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)
    avg_multi_loss = sum(multi_scale_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)

    avg_generator_losses.append(avg_gen_loss)
    avg_discriminator_losses.append(avg_disc_loss)
    avg_multi_scale_losses.append(avg_multi_loss)
    
    
    # Calculate average PSNR and SSIM for this epoch
    avg_psnr = total_psnr / len(train_pipe_loader)
    avg_ssim = total_ssim / len(train_pipe_loader)

    # Append the average PSNR and SSIM for this epoch to the respective lists
    psnr_values.append(avg_psnr)
    ssim_values.append(avg_ssim)
    
    if epoch == 0:
        print('First training epoch completed in ',(time.time() - start_time),' seconds')
    #if epoch > 0:
        #print(f"Epoch: {epoch} is starting..")
    # reset the DALI iterator
    #train_pipe_loader.reset()

    losses_L1.append(loss_L1 / m_train)
    losses_L2.append(loss_L2 / m_train)
    losses_gan.append(loss_gan / m_train)

     # Save checkpoint after each epoch
    checkpoint_state = {
      'epoch': epoch,
      'netG_state_dict': netG.state_dict(),
      'netD_state_dict': netD.state_dict(),
      'optimizerG_state_dict': optimizerG.state_dict(),
      'optimizerD_state_dict': optimizerD.state_dict(),
      'loss_L1': loss_L1,
      'loss_L2': loss_L2,
      'loss_gan': loss_gan,
      'psnr_values': psnr_values,
      'ssim_values': ssim_values,
      'losses_L1': losses_L1,
      'losses_L2': losses_L2,
      'losses_gan': losses_gan,
      'discriminator_losses': discriminator_losses,
      'generator_losses': generator_losses,
      'multi_scale_losses': multi_scale_losses,
      'avg_generator_losses': avg_generator_losses,
      'avg_discriminator_losses': avg_discriminator_losses,
      'avg_multi_scale_losses': avg_multi_scale_losses,
    }
    torch.save(checkpoint_state, os.path.join(checkpoint_dir1, f"checkpoint_{epoch}.pth"))

    

    # Print the absolute values of three losses to screen:
    print('[%d/40] Training absolute losses: L1 %.7f ; L2 %.7f BCE %.7f; Average PSNR: %.2f; Average SSIM: %.4f' % ((epoch + 1), loss_L1/m_train, loss_L2/m_train, loss_gan/m_train, avg_psnr, avg_ssim, ))

    # Print the PSNR and SSIM on each epoch
    #print('[%d/30] Average PSNR: %.2f, Average SSIM: %.4f' % (epoch + 1, avg_psnr, avg_ssim))

    # Save the inputs, outputs, and ground truth frontals to files:
    vutils.save_image(profile.data, 'FFRAD_CAS_67_Output/%03d_input.jpg' % epoch, normalize=True)
    vutils.save_image(real.data, 'FFRAD_CAS_67_Output/%03d_real.jpg' % epoch, normalize=True)
    vutils.save_image(generated.data, 'FFRAD_CAS_67_Output/%03d_generated.jpg' % epoch, normalize=True)

    

    # Save the pre-trained Generator as well
    torch.save(netG,'FFRAD_CAS_67_Output/netG_%d.pt' % epoch)

Epoch 7: 100%|██████████████████████████████████████████████████████████████████████| 41/41 [3:11:51<00:00, 280.76s/it]


[8/40] Training absolute losses: L1 0.0016755 ; L2 0.0002746 BCE 0.0055999; Average PSNR: 14.56; Average SSIM: 0.6612


Epoch 8: 100%|██████████████████████████████████████████████████████████████████████| 41/41 [3:09:38<00:00, 277.51s/it]


[9/40] Training absolute losses: L1 0.0015052 ; L2 0.0002263 BCE 0.0055916; Average PSNR: 15.39; Average SSIM: 0.7046


Epoch 9: 100%|██████████████████████████████████████████████████████████████████████| 41/41 [3:09:09<00:00, 276.82s/it]


[10/40] Training absolute losses: L1 0.0013512 ; L2 0.0001849 BCE 0.0056376; Average PSNR: 16.27; Average SSIM: 0.7424


Epoch 10: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:09:33<00:00, 277.40s/it]


[11/40] Training absolute losses: L1 0.0012374 ; L2 0.0001582 BCE 0.0056452; Average PSNR: 16.95; Average SSIM: 0.7732


Epoch 11: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:11:58<00:00, 280.95s/it]


[12/40] Training absolute losses: L1 0.0011309 ; L2 0.0001343 BCE 0.0056330; Average PSNR: 17.67; Average SSIM: 0.7982


Epoch 12: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:11:09<00:00, 279.74s/it]


[13/40] Training absolute losses: L1 0.0010402 ; L2 0.0001155 BCE 0.0056815; Average PSNR: 18.32; Average SSIM: 0.8223


Epoch 13: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:10:16<00:00, 278.45s/it]


[14/40] Training absolute losses: L1 0.0009716 ; L2 0.0001012 BCE 0.0056465; Average PSNR: 18.89; Average SSIM: 0.8372


Epoch 14: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:09:38<00:00, 277.52s/it]


[15/40] Training absolute losses: L1 0.0008892 ; L2 0.0000875 BCE 0.0056663; Average PSNR: 19.52; Average SSIM: 0.8578


Epoch 15: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:09:43<00:00, 277.65s/it]


[16/40] Training absolute losses: L1 0.0008396 ; L2 0.0000779 BCE 0.0056431; Average PSNR: 20.03; Average SSIM: 0.8699


Epoch 16: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:09:20<00:00, 277.08s/it]


[17/40] Training absolute losses: L1 0.0007822 ; L2 0.0000693 BCE 0.0056704; Average PSNR: 20.53; Average SSIM: 0.8837


Epoch 17: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:09:24<00:00, 277.18s/it]


[18/40] Training absolute losses: L1 0.0007507 ; L2 0.0000635 BCE 0.0056558; Average PSNR: 20.92; Average SSIM: 0.8919


Epoch 18: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:09:23<00:00, 277.15s/it]


[19/40] Training absolute losses: L1 0.0006997 ; L2 0.0000568 BCE 0.0055507; Average PSNR: 21.41; Average SSIM: 0.9029


Epoch 19: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:10:51<00:00, 279.30s/it]


[20/40] Training absolute losses: L1 0.0006872 ; L2 0.0000534 BCE 0.0057500; Average PSNR: 21.68; Average SSIM: 0.9076


Epoch 20: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:10:01<00:00, 278.09s/it]


[21/40] Training absolute losses: L1 0.0006487 ; L2 0.0000486 BCE 0.0057599; Average PSNR: 22.08; Average SSIM: 0.9158


Epoch 21: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:10:02<00:00, 278.12s/it]


[22/40] Training absolute losses: L1 0.0006238 ; L2 0.0000450 BCE 0.0057495; Average PSNR: 22.42; Average SSIM: 0.9210


Epoch 22: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:10:03<00:00, 278.14s/it]


[23/40] Training absolute losses: L1 0.0006109 ; L2 0.0000428 BCE 0.0057326; Average PSNR: 22.64; Average SSIM: 0.9252


Epoch 23: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:10:05<00:00, 278.19s/it]


[24/40] Training absolute losses: L1 0.0005837 ; L2 0.0000393 BCE 0.0057399; Average PSNR: 23.01; Average SSIM: 0.9296


Epoch 24: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:09:27<00:00, 277.25s/it]


[25/40] Training absolute losses: L1 0.0005659 ; L2 0.0000370 BCE 0.0057262; Average PSNR: 23.27; Average SSIM: 0.9349


Epoch 25: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:09:57<00:00, 277.98s/it]


[26/40] Training absolute losses: L1 0.0005649 ; L2 0.0000359 BCE 0.0057331; Average PSNR: 23.42; Average SSIM: 0.9345


Epoch 26: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:10:06<00:00, 278.22s/it]


[27/40] Training absolute losses: L1 0.0005305 ; L2 0.0000327 BCE 0.0057221; Average PSNR: 23.81; Average SSIM: 0.9421


Epoch 27: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:17:06<00:00, 288.44s/it]


[28/40] Training absolute losses: L1 0.0005356 ; L2 0.0000319 BCE 0.0057093; Average PSNR: 23.93; Average SSIM: 0.9416


Epoch 28:   7%|█████▎                                                                   | 3/41 [01:05<15:48, 24.95s/it]

In [18]:

for epoch in range(start_epoch,40):  # Assuming 3 epochs for demonstration
    
    # Track loss values for each epoch
    loss_L1 = 0
    loss_L2 = 0
    loss_gan = 0
    total_psnr = 0
    total_ssim = 0
    
   
    with tqdm(total=len(train_pipe_loader), desc=f"Epoch {epoch}") as pbar:
        for i, data in enumerate(train_pipe_loader, 0):
            profile = data[0].view(128, 1, 128, 128).to(device)  # Reshape and move to device
            frontal = data[1].view(128, 1, 128, 128).to(device)  # Reshape and move to device

            # TRAINING THE DISCRIMINATOR
            netD.zero_grad()
            optimizerD.zero_grad()

            real = Variable(frontal).type('torch.FloatTensor').to(device)
            target = Variable(torch.ones(real.size()[0])).to(device)
            profile = Variable(profile).type('torch.FloatTensor').to(device)
            
            real_output = netD(real,real)  # Discriminator output for real images
            generated = netG(profile)  # Generate images from profile
            fake_output = netD(profile, generated.detach())  # Discriminator output for fake images

            # Concatenate real and fake outputs along a new dimension
            concatenated = torch.cat((real_output, fake_output), dim=0)

            # Create labels for real and fake images
            target_real = torch.ones_like(real_output)
            target_fake = torch.zeros_like(fake_output)
            targets = torch.cat((target_real, target_fake), dim=0)

            # Calculate BCE loss for the concatenated outputs
            #errD = F.binary_cross_entropy_with_logits(concatenated, targets)

            errD = criterion(concatenated, targets.float())
            errD.backward()
            optimizerD.step()
             # Accumulate discriminator loss
            discriminator_losses.append(errD.item())

            # TRAINING THE GENERATOR
            netG.zero_grad()
            optimizerG.zero_grad()
            generated = netG(profile)
            output = netD(profile, generated)

            # G wants to have the synthetic images be accepted by D
            errG_GAN = criterion(output, torch.ones_like(output).float())

            # Calculate L1 and L2 loss between generated and real images
            #errG_L1 = F.l1_loss(generated, frontal.float())
            #errG_L2 = F.mse_loss(generated, frontal.float())
            errG_L1 = multi_scale_pixelwise_loss(generated, real)  # Multi-scale pixel-wise loss
           
            #errG_L1 = torch.mean(torch.abs(real - generated))
            errG_L2 = torch.mean(torch.pow((real - generated), 2))
            
            # Total generator loss
            errG = GAN_factor * errG_GAN + L1_factor * errG_L1 + L2_factor * errG_L2
            errG.backward()
            optimizerG.step()

             #Accumulate generator loss
            generator_losses.append(errG.item())

            #Accumulate multi-scale pixel-wise loss
            multi_scale_losses.append(errG_L1.item())

            # Update loss values
            loss_L1 += errG_L1.item()
            loss_L2 += errG_L2.item()
            loss_gan += errG_GAN.item()

            # Calculate PSNR for each generated image and accumulate
            psnr = calculate_psnr(generated, frontal)
            total_psnr += psnr

            # Calculate SSIM for each generated image and accumulate
            ssim_val = calculate_ssim(generated, frontal)
            total_ssim += ssim_val

            pbar.update(1)

   # Append the average losses to the respective lists


    avg_gen_loss = sum(generator_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)
    avg_disc_loss = sum(discriminator_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)
    avg_multi_loss = sum(multi_scale_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)

    avg_generator_losses.append(avg_gen_loss)
    avg_discriminator_losses.append(avg_disc_loss)
    avg_multi_scale_losses.append(avg_multi_loss)
    
    
    # Calculate average PSNR and SSIM for this epoch
    avg_psnr = total_psnr / len(train_pipe_loader)
    avg_ssim = total_ssim / len(train_pipe_loader)

    # Append the average PSNR and SSIM for this epoch to the respective lists
    psnr_values.append(avg_psnr)
    ssim_values.append(avg_ssim)
    
    if epoch == 0:
        print('First training epoch completed in ',(time.time() - start_time),' seconds')
    #if epoch > 0:
        #print(f"Epoch: {epoch} is starting..")
    # reset the DALI iterator
    #train_pipe_loader.reset()

    losses_L1.append(loss_L1 / m_train)
    losses_L2.append(loss_L2 / m_train)
    losses_gan.append(loss_gan / m_train)

     # Save checkpoint after each epoch
    checkpoint_state = {
      'epoch': epoch,
      'netG_state_dict': netG.state_dict(),
      'netD_state_dict': netD.state_dict(),
      'optimizerG_state_dict': optimizerG.state_dict(),
      'optimizerD_state_dict': optimizerD.state_dict(),
      'loss_L1': loss_L1,
      'loss_L2': loss_L2,
      'loss_gan': loss_gan,
      'psnr_values': psnr_values,
      'ssim_values': ssim_values,
      'losses_L1': losses_L1,
      'losses_L2': losses_L2,
      'losses_gan': losses_gan,
      'discriminator_losses': discriminator_losses,
      'generator_losses': generator_losses,
      'multi_scale_losses': multi_scale_losses,
      'avg_generator_losses': avg_generator_losses,
      'avg_discriminator_losses': avg_discriminator_losses,
      'avg_multi_scale_losses': avg_multi_scale_losses,
    }
    torch.save(checkpoint_state, os.path.join(checkpoint_dir1, f"checkpoint_{epoch}.pth"))

    

    # Print the absolute values of three losses to screen:
    print('[%d/40] Training absolute losses: L1 %.7f ; L2 %.7f BCE %.7f; Average PSNR: %.2f; Average SSIM: %.4f' % ((epoch + 1), loss_L1/m_train, loss_L2/m_train, loss_gan/m_train, avg_psnr, avg_ssim, ))

    # Print the PSNR and SSIM on each epoch
    #print('[%d/30] Average PSNR: %.2f, Average SSIM: %.4f' % (epoch + 1, avg_psnr, avg_ssim))

    # Save the inputs, outputs, and ground truth frontals to files:
    vutils.save_image(profile.data, 'FFRAD_CAS_67_Output/%03d_input.jpg' % epoch, normalize=True)
    vutils.save_image(real.data, 'FFRAD_CAS_67_Output/%03d_real.jpg' % epoch, normalize=True)
    vutils.save_image(generated.data, 'FFRAD_CAS_67_Output/%03d_generated.jpg' % epoch, normalize=True)

    

    # Save the pre-trained Generator as well
    torch.save(netG,'FFRAD_CAS_67_Output/netG_%d.pt' % epoch)

Epoch 28: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:27:08<00:00, 303.13s/it]


[29/40] Training absolute losses: L1 0.0005017 ; L2 0.0000293 BCE 0.0057040; Average PSNR: 24.31; Average SSIM: 0.9472


Epoch 29: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:13:24<00:00, 283.04s/it]


[30/40] Training absolute losses: L1 0.0004948 ; L2 0.0000281 BCE 0.0057237; Average PSNR: 24.46; Average SSIM: 0.9499


Epoch 30: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:14:50<00:00, 285.13s/it]


[31/40] Training absolute losses: L1 0.0004921 ; L2 0.0000273 BCE 0.0056990; Average PSNR: 24.60; Average SSIM: 0.9502


Epoch 31: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:12:05<00:00, 281.12s/it]


[32/40] Training absolute losses: L1 0.0004690 ; L2 0.0000255 BCE 0.0057066; Average PSNR: 24.90; Average SSIM: 0.9545


Epoch 32: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:12:21<00:00, 281.49s/it]


[33/40] Training absolute losses: L1 0.0004691 ; L2 0.0000249 BCE 0.0057332; Average PSNR: 25.01; Average SSIM: 0.9547


Epoch 33: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:11:44<00:00, 280.60s/it]


[34/40] Training absolute losses: L1 0.0004519 ; L2 0.0000236 BCE 0.0057327; Average PSNR: 25.24; Average SSIM: 0.9580


Epoch 34: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:15:08<00:00, 285.58s/it]


[35/40] Training absolute losses: L1 0.0004536 ; L2 0.0000232 BCE 0.0057067; Average PSNR: 25.32; Average SSIM: 0.9580


Epoch 35: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:14:26<00:00, 284.55s/it]


[36/40] Training absolute losses: L1 0.0004376 ; L2 0.0000221 BCE 0.0057132; Average PSNR: 25.53; Average SSIM: 0.9608


Epoch 36: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:13:28<00:00, 283.13s/it]


[37/40] Training absolute losses: L1 0.0004444 ; L2 0.0000220 BCE 0.0057289; Average PSNR: 25.54; Average SSIM: 0.9608


Epoch 37: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:12:10<00:00, 281.23s/it]


[38/40] Training absolute losses: L1 0.0004252 ; L2 0.0000207 BCE 0.0057273; Average PSNR: 25.81; Average SSIM: 0.9636


Epoch 38: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:12:09<00:00, 281.20s/it]


[39/40] Training absolute losses: L1 0.0004299 ; L2 0.0000206 BCE 0.0057002; Average PSNR: 25.82; Average SSIM: 0.9635


Epoch 39: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:11:57<00:00, 280.92s/it]


[40/40] Training absolute losses: L1 0.0004176 ; L2 0.0000198 BCE 0.0057145; Average PSNR: 26.01; Average SSIM: 0.9658


In [None]:

for epoch in range(start_epoch,50):  # Assuming 3 epochs for demonstration
    
    # Track loss values for each epoch
    loss_L1 = 0
    loss_L2 = 0
    loss_gan = 0
    total_psnr = 0
    total_ssim = 0
    
   
    with tqdm(total=len(train_pipe_loader), desc=f"Epoch {epoch}") as pbar:
        for i, data in enumerate(train_pipe_loader, 0):
            profile = data[0].view(128, 1, 128, 128).to(device)  # Reshape and move to device
            frontal = data[1].view(128, 1, 128, 128).to(device)  # Reshape and move to device

            # TRAINING THE DISCRIMINATOR
            netD.zero_grad()
            optimizerD.zero_grad()

            real = Variable(frontal).type('torch.FloatTensor').to(device)
            target = Variable(torch.ones(real.size()[0])).to(device)
            profile = Variable(profile).type('torch.FloatTensor').to(device)
            
            real_output = netD(real,real)  # Discriminator output for real images
            generated = netG(profile)  # Generate images from profile
            fake_output = netD(profile, generated.detach())  # Discriminator output for fake images

            # Concatenate real and fake outputs along a new dimension
            concatenated = torch.cat((real_output, fake_output), dim=0)

            # Create labels for real and fake images
            target_real = torch.ones_like(real_output)
            target_fake = torch.zeros_like(fake_output)
            targets = torch.cat((target_real, target_fake), dim=0)

            # Calculate BCE loss for the concatenated outputs
            #errD = F.binary_cross_entropy_with_logits(concatenated, targets)

            errD = criterion(concatenated, targets.float())
            errD.backward()
            optimizerD.step()
             # Accumulate discriminator loss
            discriminator_losses.append(errD.item())

            # TRAINING THE GENERATOR
            netG.zero_grad()
            optimizerG.zero_grad()
            generated = netG(profile)
            output = netD(profile, generated)

            # G wants to have the synthetic images be accepted by D
            errG_GAN = criterion(output, torch.ones_like(output).float())

            # Calculate L1 and L2 loss between generated and real images
            #errG_L1 = F.l1_loss(generated, frontal.float())
            #errG_L2 = F.mse_loss(generated, frontal.float())
            errG_L1 = multi_scale_pixelwise_loss(generated, real)  # Multi-scale pixel-wise loss
           
            #errG_L1 = torch.mean(torch.abs(real - generated))
            errG_L2 = torch.mean(torch.pow((real - generated), 2))
            
            # Total generator loss
            errG = GAN_factor * errG_GAN + L1_factor * errG_L1 + L2_factor * errG_L2
            errG.backward()
            optimizerG.step()

             #Accumulate generator loss
            generator_losses.append(errG.item())

            #Accumulate multi-scale pixel-wise loss
            multi_scale_losses.append(errG_L1.item())

            # Update loss values
            loss_L1 += errG_L1.item()
            loss_L2 += errG_L2.item()
            loss_gan += errG_GAN.item()

            # Calculate PSNR for each generated image and accumulate
            psnr = calculate_psnr(generated, frontal)
            total_psnr += psnr

            # Calculate SSIM for each generated image and accumulate
            ssim_val = calculate_ssim(generated, frontal)
            total_ssim += ssim_val

            pbar.update(1)

   # Append the average losses to the respective lists


    avg_gen_loss = sum(generator_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)
    avg_disc_loss = sum(discriminator_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)
    avg_multi_loss = sum(multi_scale_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)

    avg_generator_losses.append(avg_gen_loss)
    avg_discriminator_losses.append(avg_disc_loss)
    avg_multi_scale_losses.append(avg_multi_loss)
    
    
    # Calculate average PSNR and SSIM for this epoch
    avg_psnr = total_psnr / len(train_pipe_loader)
    avg_ssim = total_ssim / len(train_pipe_loader)

    # Append the average PSNR and SSIM for this epoch to the respective lists
    psnr_values.append(avg_psnr)
    ssim_values.append(avg_ssim)
    
    if epoch == 0:
        print('First training epoch completed in ',(time.time() - start_time),' seconds')
    #if epoch > 0:
        #print(f"Epoch: {epoch} is starting..")
    # reset the DALI iterator
    #train_pipe_loader.reset()

    losses_L1.append(loss_L1 / m_train)
    losses_L2.append(loss_L2 / m_train)
    losses_gan.append(loss_gan / m_train)

     # Save checkpoint after each epoch
    checkpoint_state = {
      'epoch': epoch,
      'netG_state_dict': netG.state_dict(),
      'netD_state_dict': netD.state_dict(),
      'optimizerG_state_dict': optimizerG.state_dict(),
      'optimizerD_state_dict': optimizerD.state_dict(),
      'loss_L1': loss_L1,
      'loss_L2': loss_L2,
      'loss_gan': loss_gan,
      'psnr_values': psnr_values,
      'ssim_values': ssim_values,
      'losses_L1': losses_L1,
      'losses_L2': losses_L2,
      'losses_gan': losses_gan,
      'discriminator_losses': discriminator_losses,
      'generator_losses': generator_losses,
      'multi_scale_losses': multi_scale_losses,
      'avg_generator_losses': avg_generator_losses,
      'avg_discriminator_losses': avg_discriminator_losses,
      'avg_multi_scale_losses': avg_multi_scale_losses,
    }
    torch.save(checkpoint_state, os.path.join(checkpoint_dir1, f"checkpoint_{epoch}.pth"))

    

    # Print the absolute values of three losses to screen:
    print('[%d/50] Training absolute losses: L1 %.7f ; L2 %.7f BCE %.7f; Average PSNR: %.2f; Average SSIM: %.4f' % ((epoch + 1), loss_L1/m_train, loss_L2/m_train, loss_gan/m_train, avg_psnr, avg_ssim, ))

    # Print the PSNR and SSIM on each epoch
    #print('[%d/30] Average PSNR: %.2f, Average SSIM: %.4f' % (epoch + 1, avg_psnr, avg_ssim))

    # Save the inputs, outputs, and ground truth frontals to files:
    vutils.save_image(profile.data, 'FFRAD_CAS_67_Output/%03d_input.jpg' % epoch, normalize=True)
    vutils.save_image(real.data, 'FFRAD_CAS_67_Output/%03d_real.jpg' % epoch, normalize=True)
    vutils.save_image(generated.data, 'FFRAD_CAS_67_Output/%03d_generated.jpg' % epoch, normalize=True)

    

    # Save the pre-trained Generator as well
    torch.save(netG,'FFRAD_CAS_67_Output/netG_%d.pt' % epoch)

Epoch 40: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:16:22<00:00, 287.39s/it]


[41/50] Training absolute losses: L1 0.0004146 ; L2 0.0000193 BCE 0.0057004; Average PSNR: 26.11; Average SSIM: 0.9661


Epoch 41:  61%|████████████████████████████████████████▊                          | 25/41 [1:13:22<1:22:21, 308.87s/it]

In [None]:

for epoch in range(start_epoch,50):  # Assuming 3 epochs for demonstration
    
    # Track loss values for each epoch
    loss_L1 = 0
    loss_L2 = 0
    loss_gan = 0
    total_psnr = 0
    total_ssim = 0
    
   
    with tqdm(total=len(train_pipe_loader), desc=f"Epoch {epoch}") as pbar:
        for i, data in enumerate(train_pipe_loader, 0):
            profile = data[0].view(128, 1, 128, 128).to(device)  # Reshape and move to device
            frontal = data[1].view(128, 1, 128, 128).to(device)  # Reshape and move to device

            # TRAINING THE DISCRIMINATOR
            netD.zero_grad()
            optimizerD.zero_grad()

            real = Variable(frontal).type('torch.FloatTensor').to(device)
            target = Variable(torch.ones(real.size()[0])).to(device)
            profile = Variable(profile).type('torch.FloatTensor').to(device)
            
            real_output = netD(real,real)  # Discriminator output for real images
            generated = netG(profile)  # Generate images from profile
            fake_output = netD(profile, generated.detach())  # Discriminator output for fake images

            # Concatenate real and fake outputs along a new dimension
            concatenated = torch.cat((real_output, fake_output), dim=0)

            # Create labels for real and fake images
            target_real = torch.ones_like(real_output)
            target_fake = torch.zeros_like(fake_output)
            targets = torch.cat((target_real, target_fake), dim=0)

            # Calculate BCE loss for the concatenated outputs
            #errD = F.binary_cross_entropy_with_logits(concatenated, targets)

            errD = criterion(concatenated, targets.float())
            errD.backward()
            optimizerD.step()
             # Accumulate discriminator loss
            discriminator_losses.append(errD.item())

            # TRAINING THE GENERATOR
            netG.zero_grad()
            optimizerG.zero_grad()
            generated = netG(profile)
            output = netD(profile, generated)

            # G wants to have the synthetic images be accepted by D
            errG_GAN = criterion(output, torch.ones_like(output).float())

            # Calculate L1 and L2 loss between generated and real images
            #errG_L1 = F.l1_loss(generated, frontal.float())
            #errG_L2 = F.mse_loss(generated, frontal.float())
            errG_L1 = multi_scale_pixelwise_loss(generated, real)  # Multi-scale pixel-wise loss
           
            #errG_L1 = torch.mean(torch.abs(real - generated))
            errG_L2 = torch.mean(torch.pow((real - generated), 2))
            
            # Total generator loss
            errG = GAN_factor * errG_GAN + L1_factor * errG_L1 + L2_factor * errG_L2
            errG.backward()
            optimizerG.step()

             #Accumulate generator loss
            generator_losses.append(errG.item())

            #Accumulate multi-scale pixel-wise loss
            multi_scale_losses.append(errG_L1.item())

            # Update loss values
            loss_L1 += errG_L1.item()
            loss_L2 += errG_L2.item()
            loss_gan += errG_GAN.item()

            # Calculate PSNR for each generated image and accumulate
            psnr = calculate_psnr(generated, frontal)
            total_psnr += psnr

            # Calculate SSIM for each generated image and accumulate
            ssim_val = calculate_ssim(generated, frontal)
            total_ssim += ssim_val

            pbar.update(1)

   # Append the average losses to the respective lists


    avg_gen_loss = sum(generator_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)
    avg_disc_loss = sum(discriminator_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)
    avg_multi_loss = sum(multi_scale_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)

    avg_generator_losses.append(avg_gen_loss)
    avg_discriminator_losses.append(avg_disc_loss)
    avg_multi_scale_losses.append(avg_multi_loss)
    
    
    # Calculate average PSNR and SSIM for this epoch
    avg_psnr = total_psnr / len(train_pipe_loader)
    avg_ssim = total_ssim / len(train_pipe_loader)

    # Append the average PSNR and SSIM for this epoch to the respective lists
    psnr_values.append(avg_psnr)
    ssim_values.append(avg_ssim)
    
    if epoch == 0:
        print('First training epoch completed in ',(time.time() - start_time),' seconds')
    #if epoch > 0:
        #print(f"Epoch: {epoch} is starting..")
    # reset the DALI iterator
    #train_pipe_loader.reset()

    losses_L1.append(loss_L1 / m_train)
    losses_L2.append(loss_L2 / m_train)
    losses_gan.append(loss_gan / m_train)

     # Save checkpoint after each epoch
    checkpoint_state = {
      'epoch': epoch,
      'netG_state_dict': netG.state_dict(),
      'netD_state_dict': netD.state_dict(),
      'optimizerG_state_dict': optimizerG.state_dict(),
      'optimizerD_state_dict': optimizerD.state_dict(),
      'loss_L1': loss_L1,
      'loss_L2': loss_L2,
      'loss_gan': loss_gan,
      'psnr_values': psnr_values,
      'ssim_values': ssim_values,
      'losses_L1': losses_L1,
      'losses_L2': losses_L2,
      'losses_gan': losses_gan,
      'discriminator_losses': discriminator_losses,
      'generator_losses': generator_losses,
      'multi_scale_losses': multi_scale_losses,
      'avg_generator_losses': avg_generator_losses,
      'avg_discriminator_losses': avg_discriminator_losses,
      'avg_multi_scale_losses': avg_multi_scale_losses,
    }
    torch.save(checkpoint_state, os.path.join(checkpoint_dir1, f"checkpoint_{epoch}.pth"))

    

    # Print the absolute values of three losses to screen:
    print('[%d/50] Training absolute losses: L1 %.7f ; L2 %.7f BCE %.7f; Average PSNR: %.2f; Average SSIM: %.4f' % ((epoch + 1), loss_L1/m_train, loss_L2/m_train, loss_gan/m_train, avg_psnr, avg_ssim, ))

    # Print the PSNR and SSIM on each epoch
    #print('[%d/30] Average PSNR: %.2f, Average SSIM: %.4f' % (epoch + 1, avg_psnr, avg_ssim))

    # Save the inputs, outputs, and ground truth frontals to files:
    vutils.save_image(profile.data, 'FFRAD_CAS_67_Output/%03d_input.jpg' % epoch, normalize=True)
    vutils.save_image(real.data, 'FFRAD_CAS_67_Output/%03d_real.jpg' % epoch, normalize=True)
    vutils.save_image(generated.data, 'FFRAD_CAS_67_Output/%03d_generated.jpg' % epoch, normalize=True)

    

    # Save the pre-trained Generator as well
    torch.save(netG,'FFRAD_CAS_67_Output/netG_%d.pt' % epoch)

Epoch 41: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:13:39<00:00, 283.39s/it]


[42/50] Training absolute losses: L1 0.0003961 ; L2 0.0000181 BCE 0.0056868; Average PSNR: 26.39; Average SSIM: 0.9680


Epoch 42: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:09:05<00:00, 276.71s/it]


[43/50] Training absolute losses: L1 0.0003984 ; L2 0.0000180 BCE 0.0057351; Average PSNR: 26.41; Average SSIM: 0.9690


Epoch 43: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:26:30<00:00, 302.21s/it]


[44/50] Training absolute losses: L1 0.0003907 ; L2 0.0000173 BCE 0.0057037; Average PSNR: 26.57; Average SSIM: 0.9696


Epoch 44: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:10:32<00:00, 278.83s/it]


[45/50] Training absolute losses: L1 0.0003866 ; L2 0.0000169 BCE 0.0057082; Average PSNR: 26.67; Average SSIM: 0.9710


Epoch 45:  63%|██████████████████████████████████████████▍                        | 26/41 [1:17:54<1:19:27, 317.82s/it]

In [17]:

for epoch in range(start_epoch,50):  # Assuming 3 epochs for demonstration
    
    # Track loss values for each epoch
    loss_L1 = 0
    loss_L2 = 0
    loss_gan = 0
    total_psnr = 0
    total_ssim = 0
    
   
    with tqdm(total=len(train_pipe_loader), desc=f"Epoch {epoch}") as pbar:
        for i, data in enumerate(train_pipe_loader, 0):
            profile = data[0].view(128, 1, 128, 128).to(device)  # Reshape and move to device
            frontal = data[1].view(128, 1, 128, 128).to(device)  # Reshape and move to device

            # TRAINING THE DISCRIMINATOR
            netD.zero_grad()
            optimizerD.zero_grad()

            real = Variable(frontal).type('torch.FloatTensor').to(device)
            target = Variable(torch.ones(real.size()[0])).to(device)
            profile = Variable(profile).type('torch.FloatTensor').to(device)
            
            real_output = netD(real,real)  # Discriminator output for real images
            generated = netG(profile)  # Generate images from profile
            fake_output = netD(profile, generated.detach())  # Discriminator output for fake images

            # Concatenate real and fake outputs along a new dimension
            concatenated = torch.cat((real_output, fake_output), dim=0)

            # Create labels for real and fake images
            target_real = torch.ones_like(real_output)
            target_fake = torch.zeros_like(fake_output)
            targets = torch.cat((target_real, target_fake), dim=0)

            # Calculate BCE loss for the concatenated outputs
            #errD = F.binary_cross_entropy_with_logits(concatenated, targets)

            errD = criterion(concatenated, targets.float())
            errD.backward()
            optimizerD.step()
             # Accumulate discriminator loss
            discriminator_losses.append(errD.item())

            # TRAINING THE GENERATOR
            netG.zero_grad()
            optimizerG.zero_grad()
            generated = netG(profile)
            output = netD(profile, generated)

            # G wants to have the synthetic images be accepted by D
            errG_GAN = criterion(output, torch.ones_like(output).float())

            # Calculate L1 and L2 loss between generated and real images
            #errG_L1 = F.l1_loss(generated, frontal.float())
            #errG_L2 = F.mse_loss(generated, frontal.float())
            errG_L1 = multi_scale_pixelwise_loss(generated, real)  # Multi-scale pixel-wise loss
           
            #errG_L1 = torch.mean(torch.abs(real - generated))
            errG_L2 = torch.mean(torch.pow((real - generated), 2))
            
            # Total generator loss
            errG = GAN_factor * errG_GAN + L1_factor * errG_L1 + L2_factor * errG_L2
            errG.backward()
            optimizerG.step()

             #Accumulate generator loss
            generator_losses.append(errG.item())

            #Accumulate multi-scale pixel-wise loss
            multi_scale_losses.append(errG_L1.item())

            # Update loss values
            loss_L1 += errG_L1.item()
            loss_L2 += errG_L2.item()
            loss_gan += errG_GAN.item()

            # Calculate PSNR for each generated image and accumulate
            psnr = calculate_psnr(generated, frontal)
            total_psnr += psnr

            # Calculate SSIM for each generated image and accumulate
            ssim_val = calculate_ssim(generated, frontal)
            total_ssim += ssim_val

            pbar.update(1)

   # Append the average losses to the respective lists


    avg_gen_loss = sum(generator_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)
    avg_disc_loss = sum(discriminator_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)
    avg_multi_loss = sum(multi_scale_losses[epoch * len(train_pipe_loader):(epoch + 1) * len(train_pipe_loader)]) / len(train_pipe_loader)

    avg_generator_losses.append(avg_gen_loss)
    avg_discriminator_losses.append(avg_disc_loss)
    avg_multi_scale_losses.append(avg_multi_loss)
    
    
    # Calculate average PSNR and SSIM for this epoch
    avg_psnr = total_psnr / len(train_pipe_loader)
    avg_ssim = total_ssim / len(train_pipe_loader)

    # Append the average PSNR and SSIM for this epoch to the respective lists
    psnr_values.append(avg_psnr)
    ssim_values.append(avg_ssim)
    
    if epoch == 0:
        print('First training epoch completed in ',(time.time() - start_time),' seconds')
    #if epoch > 0:
        #print(f"Epoch: {epoch} is starting..")
    # reset the DALI iterator
    #train_pipe_loader.reset()

    losses_L1.append(loss_L1 / m_train)
    losses_L2.append(loss_L2 / m_train)
    losses_gan.append(loss_gan / m_train)

     # Save checkpoint after each epoch
    checkpoint_state = {
      'epoch': epoch,
      'netG_state_dict': netG.state_dict(),
      'netD_state_dict': netD.state_dict(),
      'optimizerG_state_dict': optimizerG.state_dict(),
      'optimizerD_state_dict': optimizerD.state_dict(),
      'loss_L1': loss_L1,
      'loss_L2': loss_L2,
      'loss_gan': loss_gan,
      'psnr_values': psnr_values,
      'ssim_values': ssim_values,
      'losses_L1': losses_L1,
      'losses_L2': losses_L2,
      'losses_gan': losses_gan,
      'discriminator_losses': discriminator_losses,
      'generator_losses': generator_losses,
      'multi_scale_losses': multi_scale_losses,
      'avg_generator_losses': avg_generator_losses,
      'avg_discriminator_losses': avg_discriminator_losses,
      'avg_multi_scale_losses': avg_multi_scale_losses,
    }
    torch.save(checkpoint_state, os.path.join(checkpoint_dir1, f"checkpoint_{epoch}.pth"))

    

    # Print the absolute values of three losses to screen:
    print('[%d/50] Training absolute losses: L1 %.7f ; L2 %.7f BCE %.7f; Average PSNR: %.2f; Average SSIM: %.4f' % ((epoch + 1), loss_L1/m_train, loss_L2/m_train, loss_gan/m_train, avg_psnr, avg_ssim, ))

    # Print the PSNR and SSIM on each epoch
    #print('[%d/30] Average PSNR: %.2f, Average SSIM: %.4f' % (epoch + 1, avg_psnr, avg_ssim))

    # Save the inputs, outputs, and ground truth frontals to files:
    vutils.save_image(profile.data, 'FFRAD_CAS_67_Output/%03d_input.jpg' % epoch, normalize=True)
    vutils.save_image(real.data, 'FFRAD_CAS_67_Output/%03d_real.jpg' % epoch, normalize=True)
    vutils.save_image(generated.data, 'FFRAD_CAS_67_Output/%03d_generated.jpg' % epoch, normalize=True)

    

    # Save the pre-trained Generator as well
    torch.save(netG,'FFRAD_CAS_67_Output/netG_%d.pt' % epoch)

Epoch 45: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:19:09<00:00, 291.45s/it]


[46/50] Training absolute losses: L1 0.0003797 ; L2 0.0000164 BCE 0.0057218; Average PSNR: 26.80; Average SSIM: 0.9714


Epoch 46: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:08:41<00:00, 276.13s/it]


[47/50] Training absolute losses: L1 0.0003740 ; L2 0.0000159 BCE 0.0057250; Average PSNR: 26.92; Average SSIM: 0.9727


Epoch 47: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:09:11<00:00, 276.86s/it]


[48/50] Training absolute losses: L1 0.0003671 ; L2 0.0000155 BCE 0.0057053; Average PSNR: 27.06; Average SSIM: 0.9732


Epoch 48: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:08:35<00:00, 275.99s/it]


[49/50] Training absolute losses: L1 0.0003629 ; L2 0.0000151 BCE 0.0057157; Average PSNR: 27.16; Average SSIM: 0.9741


Epoch 49: 100%|█████████████████████████████████████████████████████████████████████| 41/41 [3:07:57<00:00, 275.07s/it]


[50/50] Training absolute losses: L1 0.0003583 ; L2 0.0000148 BCE 0.0056949; Average PSNR: 27.27; Average SSIM: 0.9744
