In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision.models import vgg16
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, ToPILImage
from torchvision.utils import save_image
from PIL import Image
from skimage.metrics import structural_similarity as ssim
from skimage.color import rgb2gray
from sklearn.metrics import mean_squared_error


# Optimized U-Net Generator
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNetGenerator, self).__init__()

        def down_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=4, stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(out_c, affine=True),  
                nn.LeakyReLU(0.2, inplace=True)
            )

        def up_block(in_c, out_c):
            return nn.Sequential(
                nn.ConvTranspose2d(in_c, out_c, kernel_size=4, stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(out_c, affine=True),
                nn.ReLU(inplace=True)
            )

        self.encoder = nn.Sequential(
            down_block(in_channels, 64),
            down_block(64, 128),
            down_block(128, 256),
            down_block(256, 512)
        )

        self.decoder = nn.Sequential(
            up_block(512, 256),
            up_block(256, 128),
            up_block(128, 64),
            nn.ConvTranspose2d(64, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        enc = self.encoder(x)
        return self.decoder(enc)


# PatchGAN Discriminator
class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(PatchGANDiscriminator, self).__init__()

        def disc_block(in_c, out_c, stride):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=4, stride=stride, padding=1, bias=False),
                nn.InstanceNorm2d(out_c, affine=True),
                nn.LeakyReLU(0.2, inplace=True)
            )

        self.model = nn.Sequential(
            disc_block(in_channels, 64, 2),
            disc_block(64, 128, 2),
            disc_block(128, 256, 2),
            nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)


# Perceptual Loss (VGG-16)
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        vgg = vgg16(pretrained=True).features[:16]  
        for param in vgg.parameters():
            param.requires_grad = False  
        self.vgg = vgg

    def forward(self, x, y):
        return nn.MSELoss()(self.vgg(x), self.vgg(y))


# Loss Functions
def adversarial_loss(fake_pred):
    return nn.BCEWithLogitsLoss()(fake_pred, torch.ones_like(fake_pred))


def pattern_preserving_loss(real_images, fake_images):
    return nn.L1Loss()(fake_images, real_images)


def temporal_self_distillation_loss(generator, ema_generator, real_images):
    fake_images = generator(real_images)
    with torch.no_grad():
        ema_fake_images = ema_generator(real_images)
    return nn.L1Loss()(fake_images, ema_fake_images)


# Train Discriminator
def train_discriminator(discriminator, optimizer, real_images, fake_images):
    discriminator.train()
    optimizer.zero_grad()

    real_pred = discriminator(real_images)
    fake_pred = discriminator(fake_images.detach())

    real_loss = nn.BCEWithLogitsLoss()(real_pred, torch.full_like(real_pred, 0.9))  
    fake_loss = nn.BCEWithLogitsLoss()(fake_pred, torch.full_like(fake_pred, 0.1))

    d_loss = (real_loss + fake_loss) / 2
    d_loss.backward()
    optimizer.step()

    return d_loss.item()


# Train Generator
def train_generator(generator, discriminator, optimizer, real_images, perceptual_loss, ema_generator=None):
    generator.train()
    optimizer.zero_grad()

    fake_images = generator(real_images)
    fake_pred = discriminator(fake_images)

    adv_loss = adversarial_loss(fake_pred)
    pp_loss = pattern_preserving_loss(real_images, fake_images)

    if ema_generator is not None:
        with torch.no_grad():
            ema_fake_images = ema_generator(real_images)
        tsd_loss = nn.L1Loss()(fake_images, ema_fake_images)
    else:
        tsd_loss = torch.tensor(0.0, device=real_images.device)

    pl_loss = perceptual_loss(fake_images, real_images)
    g_loss = adv_loss + (50 * pp_loss) + (10 * tsd_loss) + (5 * pl_loss)
    g_loss.backward()
    optimizer.step()
    return g_loss.item(), fake_images

# Compute Metrics (SSIM, FSIM)
def compute_ssim(fake_image, real_image):
    real_image_resized = real_image.resize(fake_image.size)
    
    fake_image_gray = rgb2gray(np.array(fake_image))
    real_image_gray = rgb2gray(np.array(real_image_resized))
    return ssim(fake_image_gray, real_image_gray, data_range=fake_image_gray.max() - fake_image_gray.min())

def compute_fsim(fake_image, real_image):
    real_image_resized = real_image.resize(fake_image.size)
    
    fake_image_array = np.array(fake_image)
    real_image_array = np.array(real_image_resized)
    fake_image_gray = rgb2gray(fake_image_array)
    real_image_gray = rgb2gray(real_image_array)
    
    # Calculate FSIM using normalized mean squared error as a proxy (if FSIM implementation is not available)
    fsim_value = 1 - mean_squared_error(fake_image_gray.flatten(), real_image_gray.flatten()) / np.max(fake_image_gray)
    return fsim_value

# Training Function
def train_gan(generator, discriminator, folder_path, num_epochs=1000, output_dir="generated_images"):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    perceptual_loss = PerceptualLoss()
    transform = Compose([Resize((256, 256)), ToTensor(), Normalize((0.5,), (0.5,))])
    image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('png', 'jpg', 'tif', 'jpeg'))]

    gen_optimizer = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    disc_optimizer = optim.Adam(discriminator.parameters(), lr=5e-5, betas=(0.5, 0.999))

    g_losses, d_losses, ssim_scores, fsim_scores = [], [], [], []

    ema_generator = UNetGenerator()
    ema_generator.load_state_dict(generator.state_dict())
    ema_generator.eval()

    for image_path in image_paths:
        image = Image.open(image_path).convert("RGB")
        real_image = transform(image).unsqueeze(0)

        for epoch in range(1, num_epochs + 1):
            fake_image = generator(real_image)

            d_loss = train_discriminator(discriminator, disc_optimizer, real_image, fake_image)
            g_loss, fake_image = train_generator(generator, discriminator, gen_optimizer, real_image, perceptual_loss, ema_generator)

            if epoch % 10 == 0:
                for ema_param, param in zip(ema_generator.parameters(), generator.parameters()):
                    ema_param.data.mul_(0.99).add_(0.01 * param.data)

            if epoch == num_epochs:
                fake_image_save = ToPILImage()(fake_image.squeeze().cpu())
                save_image(fake_image, os.path.join(output_dir, f"{os.path.basename(image_path)}_epoch_{epoch}.png"), normalize=True)

                ssim_score = compute_ssim(fake_image_save, image)
                fsim_score = compute_fsim(fake_image_save, image)

                g_losses.append(g_loss)
                d_losses.append(d_loss)
                ssim_scores.append(ssim_score)
                fsim_scores.append(fsim_score)

                print(f"\n Image: {os.path.basename(image_path)} |Generator Loss: {g_loss:.4f} |Discriminator Loss: {d_loss:.4f} ")
                print(f"SSIM: {ssim_score:.4f} | FSIM: {fsim_score:.4f}")

 
    plt.figure(figsize=(12, 5))
    plt.plot(g_losses, label="Generator Loss")
    plt.plot(d_losses, label="Discriminator Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

    plt.figure(figsize=(12, 5))
    plt.plot(ssim_scores, label="SSIM Score")
    plt.plot(fsim_scores, label="FSIM Score")
    plt.xlabel("Epochs")
    plt.ylabel("Score")
    plt.legend()
    plt.show()


# Run the Training
if __name__ == "__main__":
    generator, discriminator = UNetGenerator(), PatchGANDiscriminator()
    train_gan(generator, discriminator, r"F:\final year project\deb", num_epochs=1000, output_dir=r"E:\DEB")