In [1]:
import os
import time
import datetime
import time
import numpy as np
import matplotlib.pyplot as plt
import random

from skimage.metrics import structural_similarity
from skimage.metrics import peak_signal_noise_ratio
from sklearn.metrics import mean_squared_error

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as T
import torchvision
from torch.optim import Adam
from torch.utils.data import DataLoader, Subset
from torchvision.models import vgg16, vgg19
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure

from Create_Dataset import PairedImageDataset
from Pairing_Images import PairFinder

In [2]:
IMG_SHAPE = (256,256,3)
TARGET_SHAPE = (256,256,3)
BATCH_SIZE = 1
# Dataset Hyper Parameters
subset = "agri"
save_dataframe = "True"
s1_image_path = "Dataset/agri/s1/"
s2_image_path = "Dataset/agri/s2/"

In [3]:
image_dataset = PairedImageDataset(s1_dir=s1_image_path, s2_dir=s2_image_path, subset_name=subset, save_dataframe=save_dataframe)
dataloader = DataLoader(image_dataset, batch_size=BATCH_SIZE, shuffle=True)
print(f"Total Instances = {len(image_dataset)}")

subset_indices = list(range(min(1000, len(image_dataset))))
subset_dataset = Subset(image_dataset, subset_indices)
subset_loader = DataLoader(subset_dataset, batch_size=1, shuffle=False) 
print(f"Sample Instances = {len(subset_dataset)}")


plot_indices = list(range(10))
plot_dataset = Subset(image_dataset, plot_indices)
plot_loader = DataLoader(plot_dataset,batch_size = 1, shuffle = False)
print(f"To plot Instances = {len(plot_dataset)}")

Total Instances = 4000
Sample Instances = 1000
To plot Instances = 10


In [4]:
for i,j in subset_dataset:
    print(i.shape,j.shape)
    break

torch.Size([3, 256, 256]) torch.Size([3, 256, 256])


In [5]:
class UNetBlockDown(nn.Module):
    def __init__(self, in_channels, out_channels, apply_batchnorm=True):
        super(UNetBlockDown, self).__init__()
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)]
        if apply_batchnorm:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class UNetBlockUp(nn.Module):
    def __init__(self, in_channels, out_channels, apply_dropout=False):
        super(UNetBlockUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ]
        if apply_dropout:
            layers.append(nn.Dropout(0.5))
        self.block = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.block(x)
        x = torch.cat([x, skip_input], dim=1)  
        return x



class UNetGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=3):
        super(UNetGenerator, self).__init__()
        self.down1 = UNetBlockDown(in_channels, 64, apply_batchnorm=False)
        self.down2 = UNetBlockDown(64, 128)
        self.down3 = UNetBlockDown(128, 256)
        self.down4 = UNetBlockDown(256, 512)
        self.down5 = UNetBlockDown(512, 512)
        self.down6 = UNetBlockDown(512, 512)
        self.down7 = UNetBlockDown(512, 512)
        self.down8 = UNetBlockDown(512, 512, apply_batchnorm=False) 

        self.up1 = UNetBlockUp(512, 512, apply_dropout=True)
        self.up2 = UNetBlockUp(1024, 512, apply_dropout=True)
        self.up3 = UNetBlockUp(1024, 512, apply_dropout=True)
        self.up4 = UNetBlockUp(1024, 512)
        self.up5 = UNetBlockUp(1024, 256)
        self.up6 = UNetBlockUp(512, 128)
        self.up7 = UNetBlockUp(256, 64)

        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)
        return self.final(u7)


In [6]:
class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(PatchDiscriminator, self).__init__()
        self.model = nn.Sequential(
            self._block(in_channels * 2, 64, norm=False),
            self._block(64, 128),
            self._block(128, 256),
            nn.ZeroPad2d(1),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            
            nn.ZeroPad2d(1),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0)
        )

    def _block(self, in_channels, out_channels, norm=True):
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)]
        if norm:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)

    def forward(self, input_image, target_image):
        x = torch.cat([input_image, target_image], dim=1) 
        return self.model(x)


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DenoisingAutoencoder(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(DenoisingAutoencoder, self).__init__()

        # ----------> Encoder <----------
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=5, padding=2),  # (B, 64, H, W)
            nn.ReLU(inplace=False),
            nn.MaxPool2d(kernel_size=2, stride=2),                 # (B, 64, H/2, W/2)

            nn.Conv2d(64, 128, kernel_size=3, padding=1),          # (B, 128, H/2, W/2)
            nn.ReLU(inplace=False),
            nn.MaxPool2d(kernel_size=2, stride=2),                 # (B, 128, H/4, W/4)

            nn.Conv2d(128, 256, kernel_size=3, padding=1),         # (B, 256, H/4, W/4)
            nn.ReLU(inplace=False),
            nn.MaxPool2d(kernel_size=2, stride=2)                  # (B, 256, H/8, W/8)
        )

        # ----------> Decoder <----------
        self.deconv1 = nn.ConvTranspose2d(256, 256, kernel_size=3, padding=1)  # (B, 256, H/8, W/8)
        self.deconv2 = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)  # (B, 128, H/4, W/4)
        self.deconv3 = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)   # (B, 64, H/2, W/2)
        self.deconv4 = nn.ConvTranspose2d(64, out_channels, kernel_size=5, padding=2)  # (B, 3, H, W)

        self.relu = nn.ReLU(inplace=False)
        self.upsample = lambda x: F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        self.final_activation = nn.Tanh()

    def forward(self, x):
        # Encoder
        x = self.encoder(x)

        # Decoder
        x = self.relu(self.deconv1(x))
        x = self.upsample(x)

        x = self.relu(self.deconv2(x))
        x = self.upsample(x)

        x = self.relu(self.deconv3(x))
        x = self.upsample(x)

        x = self.deconv4(x)
        x = self.final_activation(x)

        return x


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = UNetGenerator(in_channels=3, out_channels=3).to(device)
dae = DenoisingAutoencoder(in_channels=3).to(device)
discriminator = PatchDiscriminator(in_channels=3).to(device)

#### Stacking UNet Generaotor + AutoEncoder

In [10]:
class ColorizationWithDenoising(nn.Module):
    def __init__(self, generator, dae):
        super(ColorizationWithDenoising, self).__init__()
        self.generator = generator
        self.dae = dae

    def forward(self, grayscale_input):
        # grayscale_input: [B, 1, 256, 256]
        colorized = self.generator(grayscale_input)         # [B, 3, 256, 256], range [-1, 1]
        denoised = self.dae(colorized)                      # [B, 3, 256, 256], range [-1, 1]
        return denoised, colorized  # returning both for loss comparison


In [11]:
def compute_metrics(generated_image, target_image):
    """
    Compute SSIM, PSNR, MSE, RMSE between generated and target images.
    Both inputs are PyTorch tensors in the range [-1, 1].
    """
    # Move to CPU, detach, and convert to numpy
    generated_image = generated_image.squeeze().detach().cpu().numpy()
    target_image = target_image.squeeze().detach().cpu().numpy()

    # Rescale from [-1, 1] to [0, 1]
    generated_image = (generated_image + 1) / 2.0
    target_image = (target_image + 1) / 2.0

    # Transpose from (C, H, W) to (H, W, C)
    generated_image = np.transpose(generated_image, (1, 2, 0))
    target_image = np.transpose(target_image, (1, 2, 0))

    # Clip values
    generated_image = np.clip(generated_image, 0, 1)
    target_image = np.clip(target_image, 0, 1)

    # Determine win_size
    min_height = min(generated_image.shape[0], target_image.shape[0])
    min_width = min(generated_image.shape[1], target_image.shape[1])
    win_size = min(min_height, min_width)
    win_size = win_size if win_size % 2 == 1 else win_size - 1  # make it odd

    # Ensure win_size is at least 3
    win_size = max(3, win_size)

    # Compute metrics
    ssim = structural_similarity(
        target_image, generated_image,
        channel_axis=-1, data_range=1.0,
        win_size=win_size
    )
    psnr = peak_signal_noise_ratio(target_image, generated_image, data_range=1.0)
    mse = mean_squared_error(target_image.flatten(), generated_image.flatten())
    rmse = np.sqrt(mse)

    return {
        "SSIM": ssim,
        "PSNR": psnr,
        "MSE": mse,
        "RMSE": rmse
    }

def evaluate_model(test_loader, generator, device):
    """
    Evaluate the model on the test set using image quality metrics.
    """
    ssim_scores, psnr_scores, mse_scores, rmse_scores = [], [], [], []

    generator.to(device)
    generator.eval()

    with torch.no_grad():
        for input_image, target_image in test_loader:
            input_image = input_image.to(device)
            target_image = target_image.to(device)

            # Ensure input is 4D (B, C, H, W)
            if input_image.dim() == 3:
                input_image = input_image.unsqueeze(0)
                target_image = target_image.unsqueeze(0)

            generated_image = generator(input_image)

            metrics = compute_metrics(generated_image, target_image)
            ssim_scores.append(metrics["SSIM"])
            psnr_scores.append(metrics["PSNR"])
            mse_scores.append(metrics["MSE"])
            rmse_scores.append(metrics["RMSE"])

    # Print averaged results
    print(f"SSIM: {np.mean(ssim_scores):.4f} | PSNR: {np.mean(psnr_scores):.2f} dB | MSE: {np.mean(mse_scores):.4f} | RMSE: {np.mean(rmse_scores):.4f}")



In [12]:
LAMBDA_L1 = 100
LAMBDA_PERC = 1
loss_object = nn.BCEWithLogitsLoss()
l1_loss_fn = nn.L1Loss()

In [13]:
def visualize_colorization_results(generator, subset_loader, num_images=10):
    """
    Visualizes grayscale input, ground truth color, and generated colorized images.

    Args:
        generator (torch.nn.Module): Trained generator model for colorization.
        subset_loader (DataLoader): DataLoader providing grayscale and color image pairs.
        num_images (int): Number of images to visualize (default is 10).
    """
    # Set generator to evaluation mode
    generator.eval()

    # Automatically get device of the generator
    device = next(generator.parameters()).device

    # Counter for number of images plotted
    count = 0

    # Iterate through the dataset
    for gray, color in subset_loader:
        # Move input to same device as model
        gray = gray.to(device)

        # Generate image
        with torch.no_grad():
            generated = generator(gray)

        # Move tensors to CPU for visualization
        gray_np = gray[0].cpu().permute(1, 2, 0).numpy()
        color_np = color[0].cpu().permute(1, 2, 0).numpy()
        gen_np   = generated[0].cpu().permute(1, 2, 0).numpy()

        # Scale from [-1, 1] to [0, 1] if needed
        gray_np = (gray_np + 1) / 2.0
        color_np = (color_np + 1) / 2.0
        gen_np   = (gen_np + 1) / 2.0

        # Plot images
        plt.figure(figsize=(15, 5))
        titles = ['Gray Scale Input', 'Ground Truth Color', 'Generated Image']
        images = [gray_np, color_np, gen_np]

        for i in range(3):
            plt.subplot(1, 3, i+1)
            plt.imshow(images[i])
            plt.title(titles[i])
            plt.axis('off')

        plt.tight_layout()
        plt.show()

        count += 1
        if count >= num_images:
            break


In [14]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T

# ------------------> Build VGG Feature Extractor <------------------
class VGG19FeatureExtractor(nn.Module):
    def __init__(self):
        super(VGG19FeatureExtractor, self).__init__()
        vgg = models.vgg19(pretrained=True).features
        self.feature_extractor = nn.Sequential(*list(vgg.children())[:12])  # Up to relu3_3
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.feature_extractor(x)

# ------------------> Preprocess Function for VGG <------------------
def preprocess(img_tensor):
    """
    img_tensor: (B, 3, H, W), in [-1, 1]
    Output: normalized for VGG (mean-subtracted and scaled)
    """
    # Convert [-1, 1] to [0, 1]
    img_tensor = (img_tensor + 1) / 2.0
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    return normalize(img_tensor)

# ------------------> Perceptual Loss Class <------------------
class PerceptualLoss(nn.Module):
    def __init__(self, device):
        super(PerceptualLoss, self).__init__()
        self.vgg = VGG19FeatureExtractor().to(device)
        self.l1 = nn.L1Loss()

    def forward(self, y_true, y_pred):
        # Preprocess and move to same device as VGG
        y_true = preprocess(y_true).to(next(self.vgg.parameters()).device)
        y_pred = preprocess(y_pred).to(next(self.vgg.parameters()).device)

        # Extract VGG features
        features_true = self.vgg(y_true)
        features_pred = self.vgg(y_pred)

        # Compute L1 loss between features
        return self.l1(features_true, features_pred)

# ------------------> Instantiate on the Correct Device <------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
perceptual_loss = PerceptualLoss(device)



In [15]:
def generator_loss(disc_generated_output, gen_output, target, include_perceptual):
    real_labels = torch.ones_like(disc_generated_output)
    gan_loss = loss_object(disc_generated_output, real_labels)
    l1 = l1_loss_fn(gen_output, target)

    if include_perceptual:
        perc = perceptual_loss(target, gen_output)
        total_loss = gan_loss + (LAMBDA_L1 * l1) + (LAMBDA_PERC * perc)
        return total_loss, gan_loss, l1, perc
    else:
        total_loss = gan_loss + (LAMBDA_L1 * l1)
        return total_loss, gan_loss, l1

In [16]:
def discriminator_loss(disc_real_output, disc_generated_output):
    real_labels = torch.ones_like(disc_real_output)
    fake_labels = torch.zeros_like(disc_generated_output)

    real_loss = loss_object(disc_real_output, real_labels)
    fake_loss = loss_object(disc_generated_output, fake_labels)

    total_disc_loss = real_loss + fake_loss
    return total_disc_loss


In [17]:
GEN_LR = 0.0002
DISC_LR = 0.0002
DAE_LR = 0.0001
BETA_1 = 0.5
BETA_2 = 0.999

In [18]:
generator_optimizer = Adam(generator.parameters(), lr=GEN_LR, betas=(BETA_1, BETA_2))
discriminator_optimizer = Adam(discriminator.parameters(), lr=DISC_LR, betas=(BETA_1, BETA_2))
dae_optimizer = Adam(dae.parameters(),lr = DAE_LR)

In [19]:
def autoencoder_loss(gen_output, target):
    mse_loss = F.mse_loss(gen_output, target)  # Mean Squared Error
    return mse_loss


In [20]:
def train_step(grayscale, color, include_perceptual=False):
    grayscale = grayscale.to(device)
    color = color.to(device)
    
    # ----- Generator & DAE Forward Pass -----
    fake_color = generator(grayscale)
    
    # ----- Train DAE -----
    dae_optimizer.zero_grad()
    denoised_fake_color = dae(fake_color.detach())  # Detach to break computational graph
    dae_loss = nn.MSELoss()(denoised_fake_color, color)
    dae_loss.backward()
    dae_optimizer.step()
    
    # ----- Train Discriminator -----
    discriminator_optimizer.zero_grad()
    # Recompute denoised_fake_color after DAE has been updated
    with torch.no_grad():
        denoised_fake_color = dae(fake_color.detach())
    disc_real_output = discriminator(grayscale, color)
    disc_fake_output = discriminator(grayscale, denoised_fake_color.detach())
    disc_loss = discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss.backward()
    discriminator_optimizer.step()
    
    # ----- Train Generator -----
    generator_optimizer.zero_grad()
    # We need fresh outputs since we've updated the other networks
    fake_color = generator(grayscale)
    denoised_fake_color = dae(fake_color)  # No detach here as we need gradients to flow back
    disc_fake_output = discriminator(grayscale, denoised_fake_color)  # No detach for the same reason
    
    if include_perceptual:
        gen_loss, adv_loss, l1_loss, perc_loss = generator_loss(
            disc_fake_output, denoised_fake_color, color, include_perceptual
        )
    else:
        gen_loss, adv_loss, l1_loss = generator_loss(
            disc_fake_output, denoised_fake_color, color, include_perceptual
        )
    
    gen_loss.backward()
    generator_optimizer.step()
    
    if include_perceptual:
        return gen_loss, adv_loss, l1_loss, perc_loss, dae_loss, disc_loss
    else:
        return gen_loss, adv_loss, l1_loss, dae_loss, disc_loss

In [21]:
EPOCHS = 50
loss_history = []
include_metrics = True
include_perceptual = True

In [22]:
CHECKPOINT_DIR = 'checkpoints_Pix2Pix_DAE'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
loss_history = []

In [23]:
for epoch in range(1, EPOCHS + 1):
    start_time = time.time()
    # Accumulators
    gen_loss_total, adv_loss_total, l1_loss_total, perc_loss_total, dae_loss_total, disc_loss_total = 0, 0, 0, 0, 0, 0
    num_batches = len(subset_loader)

    for grayscale, color in subset_loader:
        if include_perceptual:
            gen_loss, adv_loss, l1_loss, perc_loss, dae_loss, disc_loss = train_step(grayscale, color, include_perceptual)
            perc_loss_total += perc_loss.item()
        else:
            gen_loss, adv_loss, l1_loss, dae_loss, disc_loss = train_step(grayscale, color, include_perceptual)

        # Accumulate losses
        gen_loss_total += gen_loss.item()
        adv_loss_total += adv_loss.item()
        l1_loss_total += l1_loss.item()
        dae_loss_total += dae_loss.item()
        disc_loss_total += disc_loss.item()

    # Time tracking
    end_time = time.time()
    elapsed = end_time - start_time
    minutes = elapsed // 60
    seconds = elapsed % 60

    # Print logs
    if include_perceptual:
        print(f"Epoch {epoch}/50 | Gen: {gen_loss_total/num_batches:.4f} | Adv: {adv_loss_total/num_batches:.4f} | "
              f"L1: {l1_loss_total/num_batches:.4f} | Perc: {perc_loss_total/num_batches:.4f} | "
              f"DAE: {dae_loss_total/num_batches:.4f} | Disc: {disc_loss_total/num_batches:.4f} | "
              f"Time: {int(minutes)}m {int(seconds)}s")
    else:
        print(f"Epoch {epoch}/50 | Gen: {gen_loss_total/num_batches:.4f} | Adv: {adv_loss_total/num_batches:.4f} | "
              f"L1: {l1_loss_total/num_batches:.4f} | DAE: {dae_loss_total/num_batches:.4f} | "
              f"Disc: {disc_loss_total/num_batches:.4f} | Time: {int(minutes)}m {int(seconds)}s")

    # Print evaluation metrics
    if (epoch % 10) == 0 and include_metrics:
        print("-------------------------------------")
        evaluate_model(subset_loader, generator, device)
        print("-------------------------------------")
        # Save models
        torch.save(generator.state_dict(), f"Models/DAE_GEN/generator_epoch_{epoch}.pth")
        torch.save(discriminator.state_dict(), f"Models/DAE_DISC/discriminator_epoch_{epoch}.pth")
visualize_colorization_results(generator,plot_loader,num_images = 10)

Epoch 1/50 | Gen: 35.1298 | Adv: 4.6154 | L1: 0.2924 | Perc: 1.2770 | DAE: 0.1516 | Disc: 0.1660 | Time: 1m 22s
Epoch 2/50 | Gen: 35.2246 | Adv: 6.2156 | L1: 0.2774 | Perc: 1.2739 | DAE: 0.1383 | Disc: 0.0430 | Time: 1m 22s
Epoch 3/50 | Gen: 33.9300 | Adv: 5.2738 | L1: 0.2738 | Perc: 1.2774 | DAE: 0.1344 | Disc: 0.2183 | Time: 1m 21s
Epoch 4/50 | Gen: 34.8596 | Adv: 6.9831 | L1: 0.2660 | Perc: 1.2766 | DAE: 0.1283 | Disc: 0.0316 | Time: 1m 21s
Epoch 5/50 | Gen: 34.6036 | Adv: 7.2379 | L1: 0.2609 | Perc: 1.2763 | DAE: 0.1242 | Disc: 0.0357 | Time: 1m 21s
Epoch 6/50 | Gen: 34.7759 | Adv: 7.8044 | L1: 0.2569 | Perc: 1.2817 | DAE: 0.1208 | Disc: 0.0332 | Time: 1m 21s
Epoch 7/50 | Gen: 33.7708 | Adv: 7.5919 | L1: 0.2489 | Perc: 1.2846 | DAE: 0.1144 | Disc: 0.0540 | Time: 1m 21s
Epoch 8/50 | Gen: 34.1552 | Adv: 8.8670 | L1: 0.2400 | Perc: 1.2862 | DAE: 0.1078 | Disc: 0.0021 | Time: 1m 21s
Epoch 9/50 | Gen: 31.4056 | Adv: 7.2695 | L1: 0.2285 | Perc: 1.2894 | DAE: 0.0986 | Disc: 0.0735 | Time:

KeyboardInterrupt: 

In [None]:
for epoch in range(51, 101):
    start_time = time.time()
    # Accumulators
    gen_loss_total, adv_loss_total, l1_loss_total, perc_loss_total, dae_loss_total, disc_loss_total = 0, 0, 0, 0, 0, 0
    num_batches = len(subset_loader)

    for grayscale, color in subset_loader:
        if include_perceptual:
            gen_loss, adv_loss, l1_loss, perc_loss, dae_loss, disc_loss = train_step(grayscale, color, include_perceptual)
            perc_loss_total += perc_loss.item()
        else:
            gen_loss, adv_loss, l1_loss, dae_loss, disc_loss = train_step(grayscale, color, include_perceptual)

        # Accumulate losses
        gen_loss_total += gen_loss.item()
        adv_loss_total += adv_loss.item()
        l1_loss_total += l1_loss.item()
        dae_loss_total += dae_loss.item()
        disc_loss_total += disc_loss.item()

    # Time tracking
    end_time = time.time()
    elapsed = end_time - start_time
    minutes = elapsed // 60
    seconds = elapsed % 60

    # Print logs
    if include_perceptual:
        print(f"Epoch {epoch}/100 | Gen: {gen_loss_total/num_batches:.4f} | Adv: {adv_loss_total/num_batches:.4f} | "
              f"L1: {l1_loss_total/num_batches:.4f} | Perc: {perc_loss_total/num_batches:.4f} | "
              f"DAE: {dae_loss_total/num_batches:.4f} | Disc: {disc_loss_total/num_batches:.4f} | "
              f"Time: {int(minutes)}m {int(seconds)}s")
    else:
        print(f"Epoch {epoch}/100 | Gen: {gen_loss_total/num_batches:.4f} | Adv: {adv_loss_total/num_batches:.4f} | "
              f"L1: {l1_loss_total/num_batches:.4f} | DAE: {dae_loss_total/num_batches:.4f} | "
              f"Disc: {disc_loss_total/num_batches:.4f} | Time: {int(minutes)}m {int(seconds)}s")

    # Print evaluation metrics
    if (epoch % 10) == 0 and include_metrics:
        print("-------------------------------------")
        evaluate_model(subset_loader, generator, device)
        # 
        print("-------------------------------------")
        # Save models
        torch.save(generator.state_dict(), f"Models/DAE_GEN/generator_epoch_{epoch}.pth")
        torch.save(discriminator.state_dict(), f"Models/DAE_DISC/discriminator_epoch_{epoch}.pth")
visualize_colorization_results(generator,plot_loader,num_images = 10)