In [14]:
import os
import numpy as np
from PIL import Image

def is_jpeg(filename):
    return any(filename.endswith(extension) for extension in [".jpg", ".jpeg", ".png"])

def get_subdirs(directory):
    subdirs = sorted([os.path.join(directory, name) for name in sorted(os.listdir(directory)) if os.path.isdir(os.path.join(directory, name))])
    return subdirs

class ExternalInputIterator:
    def __init__(self, imageset_dir, batch_size, random_shuffle=False):
        self.imageset_dir = imageset_dir
        self.batch_size = batch_size

        # Get subdirectories (assuming "pose" and "frontal" folders exist)
        #self.pose_dirs = get_subdirs(os.path.join(imageset_dir, "pose"))
        self.pose_dirs = os.path.join(imageset_dir, "pose")
        self.frontal_dir = os.path.join(imageset_dir, "frontal")
        print(self.frontal_dir)
        print(self.pose_dirs)

        # Collect profile image paths
        self.profile_files = []
        #for pose_dir in self.pose_dirs:
        profile_files = [os.path.join(self.pose_dirs, file) for file in sorted(os.listdir(self.pose_dirs)) if is_jpeg(file)]
        self.profile_files.extend(profile_files)
        print(len(self.profile_files))

        # Collect frontal image paths
        self.frontal_files = [os.path.join(self.frontal_dir, file) for file in sorted(os.listdir(self.frontal_dir)) if is_jpeg(file)]
        print(len(self.frontal_files))

        # Shuffle if necessary
        if random_shuffle:
            np.random.shuffle(self.profile_files)
            np.random.shuffle(self.frontal_files)

        self.i = 0
        self.n = len(self.profile_files)

    def __iter__(self):
        return self

    def __next__(self):
        profiles = []
        frontals = []

        for _ in range(self.batch_size):
            profile_filename = self.profile_files[self.i]
            frontal_filename = self.match_frontal_image(profile_filename)

            with Image.open(profile_filename) as profile_img:
                profiles.append(np.array(profile_img))
            with Image.open(frontal_filename) as frontal_img:
                frontals.append(np.array(frontal_img))

            self.i = (self.i + 1) % self.n

        return (profiles, frontals)

    def match_frontal_image(self, profile_filename):
        profile_name = os.path.basename(profile_filename).split("_")[0]
        for frontal_file in self.frontal_files:
            if profile_name in frontal_file:
                return frontal_file
        return None

class ImagePipeline:
    def __init__(self, imageset_dir, image_size=128, random_shuffle=False, batch_size=64,device_id = 0):
        self.eii = ExternalInputIterator(imageset_dir, batch_size, random_shuffle)
        self.iterator = iter(self.eii)
        self.num_inputs = len(self.eii.profile_files)
        self.image_size = image_size

    def epoch_size(self, name=None):
        return self.num_inputs

    def __len__(self):
        return self.num_inputs

    def __iter__(self):
        return self

    def __next__(self):
        (images, targets) = next(self.iterator)

        # Perform resizing and normalization using NumPy
        resized_images = np.array([np.array(Image.fromarray(img).resize((self.image_size, self.image_size))) for img in images])
        resized_targets = np.array([np.array(Image.fromarray(target).resize((self.image_size, self.image_size))) for target in targets])

        # Calculate mean and standard deviation for each channel separately
        #mean = np.array([0.5, 0.5, 0.5])  # Assuming RGB images have pixel values in [0, 255] range
        #std = np.array([0.5, 0.5, 0.5])   # Assuming RGB images have pixel values in [0, 255] range
        
        # Normalize each channel independently
        #normalized_images = (resized_images / 255.0 - mean) / std
        #normalized_targets = (resized_targets / 255.0 - mean) / std


        # Normalize using mean and standard deviation
        normalized_images = (resized_images - 128.0) / 128.0
        normalized_targets = (resized_targets - 128.0) / 128.0

        return (normalized_images, normalized_targets)

    def __getitem__(self, index):
        # Advance the iterator to the desired index
        for _ in range(index):
            next(self.iterator)

        # Return the next batch
        return next(self)


In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def weights_init(m):
    classname = m.__class__.__name__
    
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


In [16]:

''' Generator network for 128x128 RGB images '''
class G(nn.Module):
    
    def __init__(self):
        super(G, self).__init__()
        
        self.main = nn.Sequential(
            # Input HxW = 128x128
            nn.Conv2d(1, 16, 4, 2, 1), # Output HxW = 64x64
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 4, 2, 1), # Output HxW = 32x32
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1), # Output HxW = 16x16
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 128, 4, 2, 1), # Output HxW = 8x8
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 256, 4, 2, 1), # Output HxW = 4x4
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 512, 4, 2, 1), # Output HxW = 2x2
            nn.MaxPool2d((2,2)),
            # At this point, we arrive at our low D representation vector, which is 512 dimensional.

            nn.ConvTranspose2d(512, 256, 4, 1, 0, bias = False), # Output HxW = 4x4
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False), # Output HxW = 8x8
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False), # Output HxW = 16x16
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias = False), # Output HxW = 32x32
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 4, 2, 1, bias = False), # Output HxW = 64x64
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 1, 4, 2, 1, bias = False), # Output HxW = 128x128
            nn.Tanh()
        )

    
    def forward(self, input):
        output = self.main(input)
        return output


In [17]:
import torch.nn as nn

class RelativeAvgDiscriminator(nn.Module):
    def __init__(self):
        super(RelativeAvgDiscriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 16, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(16, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 2, 1, bias=False),
            # nn.Sigmoid()
        )

    def forward(self, real, fake):
        #print(real.shape)
        #print(fake.shape)
        # Concatenate real and fake data along a new dimension
        #input_concat = torch.cat((real, fake), dim=1)
        input_concat = torch.cat((real.squeeze(0), fake.squeeze(0)), dim=0)
        #print(input_concat.shape)
        output = self.main(input_concat)
        
        return output.view(-1)


**Edited RAD**

In [19]:
import torch
from torch import nn

class RelativeAvgDiscriminator(nn.Module):
  def __init__(self):
    super(RelativeAvgDiscriminator, self).__init__()

    # Separate feature extraction for real and generated data
    self.conv_real = nn.Sequential(
        nn.Conv2d(1, 16, 4, 2, 1),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(16, 32, 4, 2, 1),
        nn.BatchNorm2d(32),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(32, 64, 4, 2, 1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2, inplace=True),
    )
    self.conv_generated = nn.Sequential(
        nn.Conv2d(1, 16, 4, 2, 1),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(16, 32, 4, 2, 1),
        nn.BatchNorm2d(32),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(32, 64, 4, 2, 1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2, inplace=True),
    )

    # Relative Average Pooling
    self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)

    # Remaining convolutional layers (modified for combined features)
    self.post_pool = nn.Sequential(
        nn.Conv2d(128, 128, 4, 2, 1),  # Input channels changed to 128
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(128, 256, 4, 2, 1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2, inplace=True),
        # ... rest of the layers (same as original code)
    )

    # Output layer with sigmoid activation
    self.output = nn.Sigmoid()

  def forward(self, real, fake):
    # Extract features from real and generated data
    real_features = self.conv_real(real)
    generated_features = self.conv_generated(fake)

    # Concatenate features before pooling
    combined_features = torch.cat([real_features, generated_features], dim=1)

    # Relative Average Pooling
    features = self.avgpool(combined_features)

    # Process features with remaining layers
    output = self.post_pool(features)

    # Probability score
    #probability = self.output(logits)

    return output


In [23]:
from __future__ import print_function
import time
import math
import random
import os
from os import listdir
from os.path import join
from PIL import Image

import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torchvision.utils as vutils
from torch.autograd import Variable
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

#from nvidia.dali.plugin.pytorch import DALIGenericIterator

#from data import ImagePipeline
#import network

np.random.seed(42)
random.seed(10)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(999)
torch.cuda.manual_seed(999)
# Where is your training dataset at?
datapath =r"C:\Users\zed\Dataset\\Grayscale_Dataset"

# You can also choose which GPU you want your model to be trained on below:
gpu_id = 0
device = torch.device("cuda", gpu_id)

checkpoint_dir = "checkpoints"

"""train_pipe = ImagePipeline(datapath, image_size=128, random_shuffle=True, batch_size=30, device_id=gpu_id)
train_pipe.build()
m_train = train_pipe.epoch_size()
print("Size of the training set: ", m_train)
train_pipe_loader = DALIGenericIterator(train_pipe, ["profiles", "frontals"], m_train)"""
# Assuming you have the modified ImagePipeline class from the previous responses
train_pipe = ImagePipeline(datapath, image_size=128, random_shuffle=True, batch_size=32, device_id=gpu_id)
# No need to call build() without DALI

# Use a standard PyTorch DataLoader instead of DALIGenericIterator
#train_pipe_loader = DataLoader(train_pipe, batch_size=train_pipe.batch_size)
m_train = train_pipe.epoch_size()
#train_pipe_loader = DataLoader(train_pipe, batch_size=32,)
train_pipe_loader = DataLoader(train_pipe,)

criterion = nn.BCEWithLogitsLoss()

C:\Users\zed\Dataset\\Grayscale_Dataset\frontal
C:\Users\zed\Dataset\\Grayscale_Dataset\pose
200
50


In [24]:
import torch
import torch.nn.functional as F
import torchvision.utils as vutils
from tqdm import tqdm
from torch.autograd import Variable
from skimage.metrics import structural_similarity as ssim

# Define a function to calculate PSNR
def calculate_psnr(img1, img2):
    mse = F.mse_loss(img1, img2)
    psnr = 20 * torch.log10(1.0 / torch.sqrt(mse))
    return psnr.item()

# Define a function to calculate SSIM
# Define a function to calculate SSIM
def calculate_ssim(img1, img2):
    # Ensure tensors are on the same device
    if img1.device != img2.device:
        raise ValueError("Input tensors must be on the same device")

    # Calculate SSIM directly on GPU tensors
    img1 = img1.detach().squeeze().clamp(0, 1).cpu().numpy()  # Ensure pixel values are in [0, 1] range
    img2 = img2.detach().squeeze().clamp(0, 1).cpu().numpy()  # Ensure pixel values are in [0, 1] range
    return ssim(img1.transpose(1, 2, 0), img2.transpose(1, 2, 0), multichannel=True, data_range=1)


# Define lists to store PSNR and SSIM values for each epoch
psnr_values = []
ssim_values = []

In [59]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils as vutils
from tqdm import tqdm
import os
import time

netG = G().to(device)
netG.apply(weights_init)

netD = RelativeAvgDiscriminator().to(device)
netD.apply(weights_init)

L1_factor = 0
L2_factor = 1
GAN_factor = 0.005

#criterion = nn.BCEWithLogitsLoss()

optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999), eps=1e-8)

try:
    os.mkdir('output_RAD')
except OSError:
    pass
    
checkpoint_dir = "checkpoints"

start_time = time.time()

import torch.nn.functional as F

# Training loop
for epoch in range(3):  # Assuming 3 epochs for demonstration
    
    # Track loss values for each epoch
    loss_L1 = 0
    loss_L2 = 0
    loss_gan = 0
    total_psnr = 0
    total_ssim = 0
    
    print("Starting epoch", epoch + 1)
    with tqdm(total=len(train_pipe_loader), desc=f"Epoch {epoch+1}") as pbar:
        for i, data in enumerate(train_pipe_loader, 0):
            profile = data[0].view(32, 1, 128, 128).to(device)  # Reshape and move to device
            frontal = data[1].view(32, 1, 128, 128).to(device)  # Reshape and move to device

            # TRAINING THE DISCRIMINATOR
            netD.zero_grad()
            optimizerD.zero_grad()

            real = Variable(frontal).type('torch.FloatTensor').to(device)
            target = Variable(torch.ones(real.size()[0])).to(device)
            profile = Variable(profile).type('torch.FloatTensor').to(device)
            
            real_output = netD(real,real)  # Discriminator output for real images
            generated = netG(profile)  # Generate images from profile
            fake_output = netD(profile, generated.detach())  # Discriminator output for fake images

            # Concatenate real and fake outputs along a new dimension
            concatenated = torch.cat((real_output, fake_output), dim=0)

            # Create labels for real and fake images
            target_real = torch.ones_like(real_output)
            target_fake = torch.zeros_like(fake_output)
            targets = torch.cat((target_real, target_fake), dim=0)

            # Calculate BCE loss for the concatenated outputs
            #errD = F.binary_cross_entropy_with_logits(concatenated, targets)

            errD = criterion(concatenated, targets.float())
            errD.backward()
            optimizerD.step()

            # TRAINING THE GENERATOR
            netG.zero_grad()
            optimizerG.zero_grad()
            generated = netG(profile)
            output = netD(profile, generated)

            # G wants to have the synthetic images be accepted by D
            errG_GAN = criterion(output, torch.ones_like(output).float())

            # Calculate L1 and L2 loss between generated and real images
            errG_L1 = F.l1_loss(generated, frontal.float())
            errG_L2 = F.mse_loss(generated, frontal.float())
           
            #errG_L1 = torch.mean(torch.abs(real - generated))
            #errG_L2 = torch.mean(torch.pow((real - generated), 2))
            
            # Total generator loss
            errG = GAN_factor * errG_GAN + L1_factor * errG_L1 + L2_factor * errG_L2
            errG.backward()
            optimizerG.step()

            # Update loss values
            loss_L1 += errG_L1.item()
            loss_L2 += errG_L2.item()
            loss_gan += errG_GAN.item()

            # Calculate PSNR for each generated image and accumulate
            psnr = calculate_psnr(generated, frontal)
            total_psnr += psnr

            # Calculate SSIM for each generated image and accumulate
            ssim_val = calculate_ssim(generated, frontal)
            total_ssim += ssim_val

            pbar.update(1)

    # Calculate average PSNR and SSIM for this epoch
    avg_psnr = total_psnr / len(train_pipe_loader)
    avg_ssim = total_ssim / len(train_pipe_loader)

    # Print loss values and metrics
    print(f'[{epoch+1}/3] Training absolute losses: L1 {loss_L1/m_train:.7f}; L2 {loss_L2/m_train:.7f}; BCE {loss_gan/m_train:.7f}')
    print(f'[{epoch+1}/3] Average PSNR: {avg_psnr:.2f}, Average SSIM: {avg_ssim:.4f}')

    # Save images and models
    vutils.save_image(profile.data, f'output_RAD/{epoch:03d}_input.jpg', normalize=True)
    vutils.save_image(frontal.data, f'output_RAD/{epoch:03d}_real.jpg', normalize=True)
    vutils.save_image(generated.data, f'output_RAD/{epoch:03d}_generated.jpg', normalize=True)
    torch.save(netG.state_dict(), f'output_RAD/netG_{epoch}.pt')
    torch.save(netD.state_dict(), f'output_RAD/netD_{epoch}.pt')



Starting epoch 1


Epoch 1: 100%|███████████████████████████████████████████████████████████████████████| 200/200 [09:20<00:00,  2.80s/it]


[1/3] Training absolute losses: L1 0.1610392; L2 0.0519637; BCE 4.7296138
[1/3] Average PSNR: 14.86, Average SSIM: 0.5734
Starting epoch 2


Epoch 2: 100%|███████████████████████████████████████████████████████████████████████| 200/200 [09:22<00:00,  2.81s/it]


[2/3] Training absolute losses: L1 0.0788424; L2 0.0107688; BCE 6.2342537
[2/3] Average PSNR: 19.71, Average SSIM: 0.8646
Starting epoch 3


Epoch 3: 100%|███████████████████████████████████████████████████████████████████████| 200/200 [09:21<00:00,  2.81s/it]

[3/3] Training absolute losses: L1 0.0737215; L2 0.0092869; BCE 7.2429958
[3/3] Average PSNR: 20.34, Average SSIM: 0.8879





**Training with checkpointing**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils as vutils
from tqdm import tqdm
import os
import time

netG = G().to(device)
netG.apply(weights_init)

netD = RelativeAvgDiscriminator().to(device)
netD.apply(weights_init)

L1_factor = 0
L2_factor = 1
GAN_factor = 0.005

#criterion = nn.BCEWithLogitsLoss()

optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999), eps=1e-8)

try:
    os.mkdir('output_RAD')
except OSError:
    pass
    
checkpoint_dir = "checkpoints"

start_time = time.time()

import torch.nn.functional as F

# Training loop
for epoch in range(3):  # Assuming 3 epochs for demonstration
    
    # Track loss values for each epoch
    loss_L1 = 0
    loss_L2 = 0
    loss_gan = 0
    total_psnr = 0
    total_ssim = 0
    
    print("Starting epoch", epoch + 1)
    with tqdm(total=len(train_pipe_loader), desc=f"Epoch {epoch}") as pbar:
        for i, data in enumerate(train_pipe_loader, 0):
            profile = data[0].view(32, 1, 128, 128).to(device)  # Reshape and move to device
            frontal = data[1].view(32, 1, 128, 128).to(device)  # Reshape and move to device

            # TRAINING THE DISCRIMINATOR
            netD.zero_grad()
            optimizerD.zero_grad()

            real = Variable(frontal).type('torch.FloatTensor').to(device)
            target = Variable(torch.ones(real.size()[0])).to(device)
            profile = Variable(profile).type('torch.FloatTensor').to(device)
            
            real_output = netD(real,real)  # Discriminator output for real images
            generated = netG(profile)  # Generate images from profile
            fake_output = netD(profile, generated.detach())  # Discriminator output for fake images

            # Concatenate real and fake outputs along a new dimension
            concatenated = torch.cat((real_output, fake_output), dim=0)

            # Create labels for real and fake images
            target_real = torch.ones_like(real_output)
            target_fake = torch.zeros_like(fake_output)
            targets = torch.cat((target_real, target_fake), dim=0)

            # Calculate BCE loss for the concatenated outputs
            #errD = F.binary_cross_entropy_with_logits(concatenated, targets)

            errD = criterion(concatenated, targets.float())
            errD.backward()
            optimizerD.step()

            # TRAINING THE GENERATOR
            netG.zero_grad()
            optimizerG.zero_grad()
            generated = netG(profile)
            output = netD(profile, generated)

            # G wants to have the synthetic images be accepted by D
            errG_GAN = criterion(output, torch.ones_like(output).float())

            # Calculate L1 and L2 loss between generated and real images
            errG_L1 = F.l1_loss(generated, frontal.float())
            errG_L2 = F.mse_loss(generated, frontal.float())
           
            #errG_L1 = torch.mean(torch.abs(real - generated))
            #errG_L2 = torch.mean(torch.pow((real - generated), 2))
            
            # Total generator loss
            errG = GAN_factor * errG_GAN + L1_factor * errG_L1 + L2_factor * errG_L2
            errG.backward()
            optimizerG.step()

            # Update loss values
            loss_L1 += errG_L1.item()
            loss_L2 += errG_L2.item()
            loss_gan += errG_GAN.item()

            # Calculate PSNR for each generated image and accumulate
            psnr = calculate_psnr(generated, frontal)
            total_psnr += psnr

            # Calculate SSIM for each generated image and accumulate
            ssim_val = calculate_ssim(generated, frontal)
            total_ssim += ssim_val

            pbar.update(1)

    # Calculate average PSNR and SSIM for this epoch
    avg_psnr = total_psnr / len(train_pipe_loader)
    avg_ssim = total_ssim / len(train_pipe_loader)

      # Save checkpoint after each epoch
    checkpoint_state = {
      'epoch': epoch,
      'netG_state_dict': netG.state_dict(),
      'netD_state_dict': netD.state_dict(),
      'optimizerG_state_dict': optimizerG.state_dict(),
      'optimizerD_state_dict': optimizerD.state_dict(),
      'loss_L1': loss_L1,
      'loss_L2': loss_L2,
      'loss_gan': loss_gan,
      'psnr_values': psnr_values,
      'ssim_values': ssim_values,
      #'losses_L1': losses_L1,
      #'losses_L2': losses_L2,
      #'losses_gan': losses_gan,
    }
    torch.save(checkpoint_state, os.path.join(checkpoint_dir, f"checkpoint_{epoch}.pth"))

    # Print loss values and metrics
    print(f'[{epoch+1}/3] Training absolute losses: L1 {loss_L1/m_train:.7f}; L2 {loss_L2/m_train:.7f}; BCE {loss_gan/m_train:.7f}')
    print(f'[{epoch+1}/3] Average PSNR: {avg_psnr:.2f}, Average SSIM: {avg_ssim:.4f}')

    # Save images and models
    vutils.save_image(profile.data, f'output_RAD/{epoch:03d}_input.jpg', normalize=True)
    vutils.save_image(frontal.data, f'output_RAD/{epoch:03d}_real.jpg', normalize=True)
    vutils.save_image(generated.data, f'output_RAD/{epoch:03d}_generated.jpg', normalize=True)
    torch.save(netG.state_dict(), f'output_RAD/netG_{epoch}.pt')
    torch.save(netD.state_dict(), f'output_RAD/netD_{epoch}.pt')



Starting epoch 1


Epoch 0: 100%|███████████████████████████████████████████████████████████████████████| 200/200 [09:18<00:00,  2.79s/it]


[1/3] Training absolute losses: L1 0.1587943; L2 0.0515081; BCE 0.6418217
[1/3] Average PSNR: 15.18, Average SSIM: 0.7265
Starting epoch 2


Epoch 1:  36%|█████████████████████████▉                                              | 72/200 [01:20<04:11,  1.97s/it]

**Loading checkpoint**

In [28]:
latest_epoch = 2
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{latest_epoch}.pth")
checkpoint = torch.load(checkpoint_path)

# Load model and optimizer states
netG.load_state_dict(checkpoint['netG_state_dict'])
netD.load_state_dict(checkpoint['netD_state_dict'])
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])

# Load training progress
loss_L1 = checkpoint['loss_L1']
loss_L2 = checkpoint['loss_L2']
loss_gan = checkpoint['loss_gan']
psnr_values = checkpoint['psnr_values']
ssim_values = checkpoint['ssim_values']
#losses_L1 = checkpoint['losses_L1']
#losses_L2 = checkpoint['losses_L2']
#losses_gan = checkpoint['losses_gan']

# Start training from the loaded epoch
start_epoch = checkpoint['epoch'] + 1
# Let's train for 30 epochs (meaning, we go through the entire training set 30 times):
for epoch in range(start_epoch,5):
   # Track loss values for each epoch
    loss_L1 = 0
    loss_L2 = 0
    loss_gan = 0
    total_psnr = 0
    total_ssim = 0
    
    print("Starting epoch", epoch + 1)
    with tqdm(total=len(train_pipe_loader), desc=f"Epoch {epoch}") as pbar:
        for i, data in enumerate(train_pipe_loader, 0):
            profile = data[0].view(32, 1, 128, 128).to(device)  # Reshape and move to device
            frontal = data[1].view(32, 1, 128, 128).to(device)  # Reshape and move to device

            # TRAINING THE DISCRIMINATOR
            netD.zero_grad()
            optimizerD.zero_grad()

            real = Variable(frontal).type('torch.FloatTensor').to(device)
            target = Variable(torch.ones(real.size()[0])).to(device)
            profile = Variable(profile).type('torch.FloatTensor').to(device)
            
            real_output = netD(real,real)  # Discriminator output for real images
            generated = netG(profile)  # Generate images from profile
            fake_output = netD(profile, generated.detach())  # Discriminator output for fake images

            # Concatenate real and fake outputs along a new dimension
            concatenated = torch.cat((real_output, fake_output), dim=0)

            # Create labels for real and fake images
            target_real = torch.ones_like(real_output)
            target_fake = torch.zeros_like(fake_output)
            targets = torch.cat((target_real, target_fake), dim=0)

            # Calculate BCE loss for the concatenated outputs
            #errD = F.binary_cross_entropy_with_logits(concatenated, targets)

            errD = criterion(concatenated, targets.float())
            errD.backward()
            optimizerD.step()

            # TRAINING THE GENERATOR
            netG.zero_grad()
            optimizerG.zero_grad()
            generated = netG(profile)
            output = netD(profile, generated)

            # G wants to have the synthetic images be accepted by D
            errG_GAN = criterion(output, torch.ones_like(output).float())

            # Calculate L1 and L2 loss between generated and real images
            errG_L1 = F.l1_loss(generated, frontal.float())
            errG_L2 = F.mse_loss(generated, frontal.float())
           
            #errG_L1 = torch.mean(torch.abs(real - generated))
            #errG_L2 = torch.mean(torch.pow((real - generated), 2))
            
            # Total generator loss
            errG = GAN_factor * errG_GAN + L1_factor * errG_L1 + L2_factor * errG_L2
            errG.backward()
            optimizerG.step()

            # Update loss values
            loss_L1 += errG_L1.item()
            loss_L2 += errG_L2.item()
            loss_gan += errG_GAN.item()

            # Calculate PSNR for each generated image and accumulate
            psnr = calculate_psnr(generated, frontal)
            total_psnr += psnr

            # Calculate SSIM for each generated image and accumulate
            ssim_val = calculate_ssim(generated, frontal)
            total_ssim += ssim_val

            pbar.update(1)

    # Calculate average PSNR and SSIM for this epoch
    avg_psnr = total_psnr / len(train_pipe_loader)
    avg_ssim = total_ssim / len(train_pipe_loader)

      # Save checkpoint after each epoch
    checkpoint_state = {
      'epoch': epoch,
      'netG_state_dict': netG.state_dict(),
      'netD_state_dict': netD.state_dict(),
      'optimizerG_state_dict': optimizerG.state_dict(),
      'optimizerD_state_dict': optimizerD.state_dict(),
      'loss_L1': loss_L1,
      'loss_L2': loss_L2,
      'loss_gan': loss_gan,
      'psnr_values': psnr_values,
      'ssim_values': ssim_values,
      #'losses_L1': losses_L1,
      #'losses_L2': losses_L2,
      #'losses_gan': losses_gan,
    }
    torch.save(checkpoint_state, os.path.join(checkpoint_dir, f"checkpoint_{epoch}.pth"))

    # Print loss values and metrics
    print(f'[{epoch+1}/3] Training absolute losses: L1 {loss_L1/m_train:.7f}; L2 {loss_L2/m_train:.7f}; BCE {loss_gan/m_train:.7f}')
    print(f'[{epoch+1}/3] Average PSNR: {avg_psnr:.2f}, Average SSIM: {avg_ssim:.4f}')

    # Save images and models
    vutils.save_image(profile.data, f'output_RAD/{epoch:03d}_input.jpg', normalize=True)
    vutils.save_image(frontal.data, f'output_RAD/{epoch:03d}_real.jpg', normalize=True)
    vutils.save_image(generated.data, f'output_RAD/{epoch:03d}_generated.jpg', normalize=True)
    torch.save(netG.state_dict(), f'output_RAD/netG_{epoch}.pt')
    torch.save(netD.state_dict(), f'output_RAD/netD_{epoch}.pt')


Starting epoch 4


Epoch 3: 100%|███████████████████████████████████████████████████████████████████████| 200/200 [09:20<00:00,  2.80s/it]


[4/3] Training absolute losses: L1 0.0356334; L2 0.0025022; BCE 0.7002112
[4/3] Average PSNR: 26.07, Average SSIM: 0.9635
Starting epoch 5


Epoch 4: 100%|███████████████████████████████████████████████████████████████████████| 200/200 [09:18<00:00,  2.79s/it]


[5/3] Training absolute losses: L1 0.0318828; L2 0.0019760; BCE 0.7046735
[5/3] Average PSNR: 27.08, Average SSIM: 0.9714
