Step 1: Mount Google Drive (Optional)

In [None]:
 from google.colab import drive
 drive.mount('/content/drive',force_remount=True)

Mounted at /content/drive


Step 2: Import Necessary Libraries
python
Copy code
*italicized text*

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from PIL import Image
import numpy as np
from tqdm import tqdm


Step 3: Define the Custom Dataset
For a DCGAN, you typically don't need paired datasets, so only a raw image dataset is needed.

In [None]:
class ImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_names = sorted(os.listdir(image_dir))
        self.transform = transform

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_names[idx])
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image


Step 4: Define Data Transformations and Create DataLoader

In [None]:
# Define the transformations
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # For DCGAN, standard size is 64x64
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1] for tanh activation
])

# Directory path
image_dir = '/content/drive/MyDrive/DIP_Project/raw-890/'

# Create the dataset and dataloader
dataset = ImageDataset(image_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)


Step 5: Define the Generator and Discriminator for DCGAN

In [None]:
class DCGANGenerator(nn.Module):
    def __init__(self, z_dim=100, ngf=64, output_nc=3):
        super(DCGANGenerator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(z_dim, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, output_nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)


In [None]:
class DCGANDiscriminator(nn.Module):
    def __init__(self, input_nc=3, ndf=64):
        super(DCGANDiscriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(input_nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)


Step 6: Initialize Models, Loss Functions, and Optimizers

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Initialize models
netG = DCGANGenerator(z_dim=100).to(device)
netD = DCGANDiscriminator().to(device)

# Loss function
criterion = nn.BCELoss()

# Optimizers
optimizer_G = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))


Using device: cpu


Step 7: Define Checkpointing Mechanism

In [None]:
checkpoint_dir = '/content/drive/MyDrive/DIP_Project/checkpoints_DCGAN/'
os.makedirs(checkpoint_dir, exist_ok=True)

def save_checkpoint(state, filename):
    torch.save(state, filename)

def load_checkpoint(modelG, modelD, optimizer_G, optimizer_D, checkpoint_dir):
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]
    if not checkpoints:
        print("No checkpoints found, starting from scratch.")
        return 0
    latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('_')[1].split('.pth')[0]))
    checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
    print(f"Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    modelG.load_state_dict(checkpoint['modelG_state_dict'])
    modelD.load_state_dict(checkpoint['modelD_state_dict'])
    optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
    optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming from epoch {start_epoch}")
    return start_epoch


Step 8: Load Checkpoint if Available

In [None]:
start_epoch = load_checkpoint(netG, netD, optimizer_G, optimizer_D, checkpoint_dir)

Loading checkpoint: /content/drive/MyDrive/DIP_Project/checkpoints_DCGAN/checkpoint_200.pth


  checkpoint = torch.load(checkpoint_path, map_location=device)


Resuming from epoch 200


Step 9: Define the Training Loop with Periodic Checkpointing

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import os
import matplotlib.pyplot as plt

# Ensure that these components are already defined:
# - netG: The generator model
# - netD: The discriminator model
# - dataloader: The DataLoader for your training data
# - device: The device (e.g., 'cuda' or 'cpu')
# - checkpoint_dir: The directory where checkpoints are stored

# Define the GAN loss criterion (Binary Cross-Entropy loss)
criterion_GAN = nn.BCELoss()

# Define the L1 loss criterion for pixel-level comparison
criterion_L1 = nn.L1Loss()

# Define a function to save the model checkpoint
def save_checkpoint(state, filename='checkpoint.pth'):
    torch.save(state, filename)
    print(f"Checkpoint saved at {filename}")

# Function to visualize generated images
def show_generated_images(epoch, generator, num_images=5):
    generator.eval()
    with torch.no_grad():
        z = torch.randn(num_images, 100, 1, 1).to(device)  # Random latent vectors
        generated_images = generator(z)
        fig, axes = plt.subplots(1, num_images, figsize=(15, 5))
        for i in range(num_images):
            img = generated_images[i].cpu().squeeze().permute(1, 2, 0) * 0.5 + 0.5  # Rescale to [0, 1]
            axes[i].imshow(img)
            axes[i].axis('off')
        plt.suptitle(f"Generated Images at Epoch {epoch}")
        plt.show()

# Total number of epochs and checkpoint saving interval
total_epochs = 200
save_every = 10  # Save a checkpoint every 10 epochs
start_epoch   # Adjust this if resuming from a saved checkpoint

# Optimizers for the generator and discriminator
optimizer_G = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training loop
for epoch in range(start_epoch, total_epochs):
    netG.train()
    netD.train()

    epoch_loss_D = 0.0
    epoch_loss_G = 0.0
    num_batches = 0

    # Progress bar for the current epoch
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{total_epochs}", leave=False)

    for i, real_images in enumerate(progress_bar):
        real_images = real_images.to(device)  # Move images to the device (GPU/CPU)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Real images
        pred_real = netD(real_images)
        target_real = torch.ones_like(pred_real).to(device)
        loss_D_real = criterion_GAN(pred_real, target_real)

        # Fake images generated by the generator
        batch_size = real_images.size(0)
        z = torch.randn(batch_size, 100, 1, 1).to(device)  # Random latent vector
        fake = netG(z)
        pred_fake = netD(fake.detach())  # Detach to avoid updating the generator
        target_fake = torch.zeros_like(pred_fake).to(device)
        loss_D_fake = criterion_GAN(pred_fake, target_fake)

        # Total discriminator loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()

        pred_fake = netD(fake)
        target_G = torch.ones_like(pred_fake).to(device)
        loss_G_GAN = criterion_GAN(pred_fake, target_G)
        loss_G_L1 = criterion_L1(fake, real_images) * 100
        loss_G = loss_G_GAN + loss_G_L1
        loss_G.backward()
        optimizer_G.step()

        # Accumulate losses
        epoch_loss_D += loss_D.item()
        epoch_loss_G += loss_G.item()
        num_batches += 1

        # Update progress bar
        progress_bar.set_postfix({'Loss_D': loss_D.item(), 'Loss_G': loss_G.item()})

    avg_loss_D = epoch_loss_D / num_batches
    avg_loss_G = epoch_loss_G / num_batches
    print(f"Epoch [{epoch+1}/{total_epochs}] Loss D: {avg_loss_D:.4f}, Loss G: {avg_loss_G:.4f}")

    # Save checkpoint
    if (epoch + 1) % save_every == 0 or (epoch + 1) == total_epochs:
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_{epoch+1}.pth')
        save_checkpoint({
            'epoch': epoch,
            'modelG_state_dict': netG.state_dict(),
            'modelD_state_dict': netD.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
        }, checkpoint_path)

    # Display generated images periodically
    if (epoch + 1) % save_every == 0 or (epoch + 1) == total_epochs:
        show_generated_images(epoch + 1, netG)

print("Training complete.")


Training complete.


In [None]:
# Save final models
final_modelG_path = '/content/drive/MyDrive/DIP_Project/DC_GAN_netG_final.pth'
final_modelD_path = '/content/drive/MyDrive/DIP_Project/DC_GAN_netD_final.pth'

torch.save(netG.state_dict(), final_modelG_path)
torch.save(netD.state_dict(), final_modelD_path)

print("Final models saved.")


Final models saved.


In [None]:
!pip install torch torchvision numpy scikit-image piqa



In [None]:
!pip install torch-fidelity
!pip install torchmetrics

Collecting torch-fidelity
  Downloading torch_fidelity-0.3.0-py3-none-any.whl.metadata (2.0 kB)
Downloading torch_fidelity-0.3.0-py3-none-any.whl (37 kB)
Installing collected packages: torch-fidelity
Successfully installed torch-fidelity-0.3.0
Collecting torchmetrics
  Downloading torchmetrics-1.6.0-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.8-py3-none-any.whl.metadata (5.2 kB)
Downloading torchmetrics-1.6.0-py3-none-any.whl (926 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m926.4/926.4 kB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.11.8-py3-none-any.whl (26 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.11.8 torchmetrics-1.6.0


In [None]:
!pip install -U torchmetrics[image]



In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from torchmetrics.image.fid import FrechetInceptionDistance
from torch.nn.functional import mse_loss

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
from torchvision import transforms
import numpy as np
import cv2
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from torchmetrics.image.fid import FrechetInceptionDistance

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Custom dataset class for loading raw and reference images
class UIEB_Dataset(Dataset):
    def __init__(self, raw_dir, reference_dir, transform=None):
        self.raw_dir = raw_dir
        self.reference_dir = reference_dir
        self.image_names = sorted(os.listdir(raw_dir))  # Ensure consistent ordering
        self.transform = transform

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        raw_image_path = os.path.join(self.raw_dir, self.image_names[idx])
        reference_image_path = os.path.join(self.reference_dir, self.image_names[idx])

        raw_image = Image.open(raw_image_path).convert('RGB')
        reference_image = Image.open(reference_image_path).convert('RGB')

        if self.transform:
            raw_image = self.transform(raw_image)
            reference_image = self.transform(reference_image)

        return raw_image, reference_image

# Define the transformations
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Adjust size as needed
    transforms.ToTensor(),
])

# Directory paths
raw_dir = '/content/drive/MyDrive/DIP_Project/raw-890/'
reference_dir = '/content/drive/MyDrive/DIP_Project/reference-890/'

# Create the dataset and dataloader
dataset = UIEB_Dataset(raw_dir, reference_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=2)

# Load the model (ensure this matches your DCGAN generator structure)
netG = DCGANGenerator(z_dim=100).to(device)
netG.load_state_dict(torch.load('/content/drive/MyDrive/DIP_Project/DC_GAN_netG_final.pth', map_location=device))
netG.eval()  # Set the model to evaluation mode

# Initialize metrics
mse = MeanSquaredError().to(device)
psnr = PeakSignalNoiseRatio().to(device)
ssim = StructuralSimilarityIndexMeasure().to(device)
fid = FrechetInceptionDistance().to(device)

# Function to convert image tensor for FID metric
def convert_to_uint8_for_fid(image_tensor):
    image_tensor = (image_tensor * 255).clamp(0, 255)  # Rescale from [0, 1] to [0, 255]
    return image_tensor.to(torch.uint8)

# Function to calculate UICM (Colorfulness)
def calculate_uicm(image):
    lab_image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
    L, A, B = cv2.split(lab_image)
    colorfulness = np.sqrt(np.mean(A**2) + np.mean(B**2))
    return colorfulness

# Function to calculate UISM (Sharpness)
def calculate_uism(image):
    gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    sharpness = cv2.Laplacian(gray_image, cv2.CV_64F).var()
    return sharpness

# Function to calculate UIConM (Contrast)
def calculate_uiconm(image):
    gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    contrast = gray_image.std()
    return contrast

# Function to calculate UIQM by combining UICM, UISM, and UIConM
def calculate_uiqm(image):
    uicm = calculate_uicm(image)
    uism = calculate_uism(image)
    uiconm = calculate_uiconm(image)
    uiqm = 0.5 * uicm + 0.3 * uism + 0.2 * uiconm
    return uiqm

# Function to process the image tensor and convert it to a numpy array
def tensor_to_numpy(image_tensor):
    image = image_tensor.permute(1, 2, 0).cpu().numpy()
    image = np.clip(image * 255, 0, 255).astype(np.uint8)
    return image

# Iterate through the dataloader
uiqm_scores = []

for batch in dataloader:
    raw_images, reference_images = batch
    raw_images, reference_images = raw_images.to(device), reference_images.to(device)

    # Generate images using the model
    with torch.no_grad():
        z = torch.randn(raw_images.size(0), 100, 1, 1).to(device)  # Latent vector for DCGAN
        generated_images = netG(z)

    # Scale from [-1, 1] to [0, 1] for MSE, PSNR, SSIM
    generated_images_scaled = (generated_images * 0.5 + 0.5)

    # Update metrics (MSE, PSNR, SSIM)
    mse.update(generated_images_scaled, reference_images)
    psnr.update(generated_images_scaled, reference_images)
    ssim.update(generated_images_scaled, reference_images)

    # Convert images for FID metric
    generated_images_uint8 = convert_to_uint8_for_fid(generated_images_scaled)
    reference_images_uint8 = convert_to_uint8_for_fid(reference_images)

    # Update FID metric
    fid.update(generated_images_uint8, real=False)
    fid.update(reference_images_uint8, real=True)

    # Calculate UIQM for each image in the batch
    for img in generated_images_scaled:
        img_numpy = tensor_to_numpy(img)
        uiqm_value = calculate_uiqm(img_numpy)
        uiqm_scores.append(uiqm_value)

# Compute final scores
mse_score = mse.compute().item()
psnr_score = psnr.compute().item()
ssim_score = ssim.compute().item()
fid_score = fid.compute().item()
uiqm_score = np.mean(uiqm_scores)

# Print results
print(f"MSE: {mse_score:.4f}")
print(f"PSNR: {psnr_score:.4f}")
print(f"SSIM: {ssim_score:.4f}")
print(f"FID: {fid_score:.4f}")
print(f"UIQM: {uiqm_score:.4f}")


  netG.load_state_dict(torch.load('/content/drive/MyDrive/DIP_Project/DC_GAN_netG_final.pth', map_location=device))
Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:00<00:00, 108MB/s]


MSE: 0.0647
PSNR: 11.8931
SSIM: 0.1725
FID: 238.7355
UIQM: 85.5271
