In [None]:
import os
import torch
from torch import nn, optim
from torchvision.utils import save_image
from tqdm import tqdmcoding


In [4]:
EPOCHS = 30  # Slightly increased if you speed up training
BATCH_SIZE = 8  # Reduce batch size to reduce memory and computation
LEARNING_RATE = 2e-4  # Keep same unless you switch to a new optimizer
IMG_SIZE = 128  # Reduce image size to make training faster
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda


In [5]:
# Directories
TRAIN_SKETCH_DIR = r"C:\Users\Suyash Tambe\Desktop\sketch-photo\processed_dataset_1\train\sketch"
TRAIN_PHOTO_DIR = r"C:\Users\Suyash Tambe\Desktop\sketch-photo\processed_dataset_1\train\photo"
SAVE_DIR = "pix2pix_outputs"
os.makedirs(SAVE_DIR, exist_ok=True)

In [None]:

sketch_path = 'processed_dataset_1/train/sketch'
photo_path = 'processed_dataset_1/train/photo'

print("Sketch count:", len(os.listdir(sketch_path)))
print("Photo count:", len(os.listdir(photo_path)))


Sketch count: 79529
Photo count: 79529


In [7]:
from dataset import SketchToImageDataset
from torch.utils.data import DataLoader

train_dataset = SketchToImageDataset('processed_dataset_1/train/sketch', 'processed_dataset_1/train/photo')
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)


In [8]:
# Generator - U-Net
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super().__init__()
        self.down = nn.Sequential(
            self.contract(in_channels, features),
            self.contract(features, features * 2),
            self.contract(features * 2, features * 4),
            self.contract(features * 4, features * 8),
            self.contract(features * 8, features * 8),
        )
        self.up = nn.Sequential(
            self.expand(features * 8, features * 8),
            self.expand(features * 16, features * 4),
            self.expand(features * 8, features * 2),
            self.expand(features * 4, features),
        )
        self.final = nn.Sequential(
            nn.ConvTranspose2d(features * 2, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def contract(self, in_c, out_c):
        return nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.LeakyReLU(0.2)
        )

    def expand(self, in_c, out_c):
        return nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU()
        )

    def forward(self, x):
        skips = []
        for layer in self.down:
            x = layer(x)
            skips.append(x)
        skips = skips[:-1][::-1]  # reverse except bottleneck
        for i, layer in enumerate(self.up):
            x = layer(x)
            if i < len(skips):
                x = torch.cat([x, skips[i]], dim=1)
        return self.final(x)

# PatchGAN Discriminator
class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=6, features=64):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features, features * 2, 4, 2, 1),
            nn.BatchNorm2d(features * 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features * 2, features * 4, 4, 2, 1),
            nn.BatchNorm2d(features * 4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features * 4, 1, 4, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x, y):
        return self.model(torch.cat([x, y], dim=1))


In [9]:
# Initialize models
gen = UNetGenerator().to(DEVICE)
disc = PatchDiscriminator().to(DEVICE)

# Optimizers
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

# Losses
criterion_GAN = nn.BCELoss()
criterion_L1 = nn.L1Loss()


In [None]:
# Define directory to save model outputs
SAVE_DIR = "pix2pix_outputs"
os.makedirs(SAVE_DIR, exist_ok=True)  


# Initialize best loss tracker to save the best model during training
best_loss = float('inf')



#        TRAINING LOOP
for epoch in range(EPOCHS):
    loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{EPOCHS}]")  # Progress bar
    epoch_g_loss = 0.0  # To accumulate generator loss over epoch
    epoch_d_loss = 0.0  # To accumulate discriminator loss over epoch

    for i, (sketch, photo) in enumerate(loop):
        # Move inputs to the specified device (CPU/GPU)
        sketch, photo = sketch.to(DEVICE), photo.to(DEVICE)

        
        #      TRAINING DISCRIMINATOR
       
        # Generate fake image using the generator (detach to avoid gradient flow to generator)
        with torch.no_grad():
            fake_photo = gen(sketch).detach()

        # Discriminator output for real and fake pairs
        real_pred = disc(sketch, photo)          # Real image prediction
        fake_pred = disc(sketch, fake_photo)     # Fake image prediction

        # Create ground truth labels for real (1) and fake (0)
        real_label = torch.ones_like(real_pred, device=DEVICE)
        fake_label = torch.zeros_like(fake_pred, device=DEVICE)

        # Compute adversarial loss for real and fake inputs
        d_loss_real = criterion_GAN(real_pred, real_label)
        d_loss_fake = criterion_GAN(fake_pred, fake_label)

        # Total discriminator loss is the average of real and fake losses
        d_loss = (d_loss_real + d_loss_fake) / 2

        # Backpropagation and optimization step for discriminator
        opt_disc.zero_grad()
        d_loss.backward()
        opt_disc.step()

        
        #        TRAINING GENERATOR
       
        # Generate fake photo to train the generator
        fake_photo = gen(sketch)

        # Discriminator's prediction on the fake photo
        disc_pred = disc(sketch, fake_photo)

        # Calculate generator's adversarial loss (goal: fool discriminator)
        g_adv = criterion_GAN(disc_pred, real_label)

        # Calculate L1 loss for pixel-level similarity (reconstruction loss)
        g_l1 = criterion_L1(fake_photo, photo)

        # Final generator loss = adversarial loss + L1 loss (weighted with lambda=100)
        g_loss = g_adv + 100 * g_l1

        # Backpropagation and optimization step for generator
        opt_gen.zero_grad()
        g_loss.backward()
        opt_gen.step()

        # Accumulate losses for epoch statistics
        epoch_g_loss += g_loss.item()
        epoch_d_loss += d_loss.item()

        # Update progress bar with current batch losses
        loop.set_postfix(G_Loss=g_loss.item(), D_Loss=d_loss.item())

    # Calculate  generator and discriminator loss for the epoch
    avg_g_loss = epoch_g_loss / len(train_loader)
    avg_d_loss = epoch_d_loss / len(train_loader)
    total_loss = avg_g_loss + avg_d_loss  # Used to track best model

    
    #    Save Best Model 
   
    if total_loss < best_loss:
        best_loss = total_loss
        torch.save(gen.state_dict(), f"{SAVE_DIR}/best_generator.pth")
        torch.save(disc.state_dict(), f"{SAVE_DIR}/best_discriminator.pth")
        print(f"  Best model saved at epoch {epoch+1} (Total Loss: {total_loss:.4f})")
    else:
        print(f"  No improvement. Best Loss so far: {best_loss:.4f}")
    
    
    
    #  Save Model & Output Images Every 3 Epochs
    
    if (epoch + 1) % 3 == 0:
        # Save current generator and discriminator weights
        torch.save(gen.state_dict(), f"{SAVE_DIR}/generator_epoch{epoch+1}.pth")
        torch.save(disc.state_dict(), f"{SAVE_DIR}/discriminator_epoch{epoch+1}.pth")

        # Save example output images for monitoring progress
        save_image(fake_photo * 0.5 + 0.5, f"{SAVE_DIR}/fake_{epoch+1}.png")    # Generated image
        save_image(photo * 0.5 + 0.5, f"{SAVE_DIR}/real_{epoch+1}.png")        # Ground truth
        save_image(sketch * 0.5 + 0.5, f"{SAVE_DIR}/sketch_{epoch+1}.png")     # Input sketch

print(" Training complete. Best model saved.")


Epoch [1/30]:   0%|          | 0/19883 [00:00<?, ?it/s]

Epoch [1/30]: 100%|██████████| 19883/19883 [55:34<00:00,  5.96it/s, D_Loss=0.0521, G_Loss=31.5]  


 Best model saved at epoch 1 (Total Loss: 33.3529)


Epoch [2/30]: 100%|██████████| 19883/19883 [53:47<00:00,  6.16it/s, D_Loss=0.00113, G_Loss=35.6] 


 No improvement. Best Loss so far: 33.3529


Epoch [3/30]: 100%|██████████| 19883/19883 [53:49<00:00,  6.16it/s, D_Loss=0.268, G_Loss=22.9]   


 No improvement. Best Loss so far: 33.3529


Epoch [4/30]: 100%|██████████| 19883/19883 [53:58<00:00,  6.14it/s, D_Loss=0.00106, G_Loss=31.7] 


 No improvement. Best Loss so far: 33.3529


Epoch [5/30]: 100%|██████████| 19883/19883 [54:34<00:00,  6.07it/s, D_Loss=0.014, G_Loss=48.3]   


 No improvement. Best Loss so far: 33.3529


Epoch [6/30]: 100%|██████████| 19883/19883 [55:31<00:00,  5.97it/s, D_Loss=4.69e-5, G_Loss=51.5]  


 No improvement. Best Loss so far: 33.3529


Epoch [7/30]: 100%|██████████| 19883/19883 [55:23<00:00,  5.98it/s, D_Loss=0.0224, G_Loss=29.7]   


 No improvement. Best Loss so far: 33.3529


Epoch [8/30]: 100%|██████████| 19883/19883 [56:05<00:00,  5.91it/s, D_Loss=1.92e-5, G_Loss=57.8] 


 No improvement. Best Loss so far: 33.3529


Epoch [9/30]: 100%|██████████| 19883/19883 [57:17<00:00,  5.78it/s, D_Loss=0.378, G_Loss=38.8]    


 No improvement. Best Loss so far: 33.3529


Epoch [10/30]: 100%|██████████| 19883/19883 [56:59<00:00,  5.81it/s, D_Loss=0.526, G_Loss=22.3]     


 No improvement. Best Loss so far: 33.3529


Epoch [11/30]: 100%|██████████| 19883/19883 [56:37<00:00,  5.85it/s, D_Loss=0.0375, G_Loss=31.2]  


 No improvement. Best Loss so far: 33.3529


Epoch [12/30]: 100%|██████████| 19883/19883 [56:38<00:00,  5.85it/s, D_Loss=0.000811, G_Loss=35.4]


 No improvement. Best Loss so far: 33.3529


Epoch [13/30]:  13%|█▎        | 2589/19883 [07:24<49:29,  5.82it/s, D_Loss=0.000572, G_Loss=35.8]  


KeyboardInterrupt: 

In [None]:

# # Load the Trained Models


# # Reinitialize generator and discriminator models 
# gen = UNetGenerator().to(DEVICE)
# disc = PatchDiscriminator().to(DEVICE)

# # Load the saved weights of the best-performing generator and discriminator
# gen.load_state_dict(torch.load(f"{SAVE_DIR}/best_generator.pth", map_location=DEVICE))
# disc.load_state_dict(torch.load(f"{SAVE_DIR}/best_discriminator.pth", map_location=DEVICE))

# # Set both models to evaluation mode (important for inference – disables dropout, etc.)
# gen.eval()
# disc.eval()


  gen.load_state_dict(torch.load(f"{SAVE_DIR}/best_generator.pth", map_location=DEVICE))
  disc.load_state_dict(torch.load(f"{SAVE_DIR}/best_discriminator.pth", map_location=DEVICE))


PatchDiscriminator(
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (9): Sigmoid()
  )
)

## Testing 

In [None]:
import os
import glob
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image


In [None]:
# 1) Paths & Device
SAVE_DIR     = "pix2pix_outputs"  # Directory to save the generated photo outputs
MODEL_WEIGHTS = os.path.join(SAVE_DIR, "generator_epoch3.pth")  # Path to saved generator model weights
TEST_DIR     = "processed_dataset_1/test_1/sketch"  # Directory containing test sketch images
DEVICE       = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available


os.makedirs(SAVE_DIR, exist_ok=True)

# 2) Generator architecture (U-Net-based)
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super().__init__()
        # Downsampling path (encoder)
        self.down = nn.Sequential(
            self.contract(in_channels, features),            # 256 → 128
            self.contract(features, features * 2),           # 128 → 64
            self.contract(features * 2, features * 4),       # 64 → 32
            self.contract(features * 4, features * 8),       # 32 → 16
            self.contract(features * 8, features * 8),       # 16 → 8 (bottleneck)
        )
        # Upsampling path (decoder)
        self.up = nn.Sequential(
            self.expand(features * 8, features * 8),         # 8 → 16
            self.expand(features * 16, features * 4),        # concat with skip, 16 → 32
            self.expand(features * 8, features * 2),         # 32 → 64
            self.expand(features * 4, features),             # 64 → 128
        )
        # Final upsampling to get to output size 256x256
        self.final = nn.Sequential(
            nn.ConvTranspose2d(features * 2, out_channels, kernel_size=4, stride=2, padding=1),  # 128 → 256
            nn.Tanh()  # Output values between [-1, 1]
        )

    def contract(self, in_c, out_c):
        # Encoder block: Conv + BatchNorm + LeakyReLU
        return nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.LeakyReLU(0.2)
        )

    def expand(self, in_c, out_c):
        # Decoder block: ConvTranspose + BatchNorm + ReLU
        return nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU()
        )

    def forward(self, x):
        # Encoder path with skip connections
        skips = []
        for layer in self.down:
            x = layer(x)
            skips.append(x)
        skips = skips[:-1][::-1]  # Reverse skips, excluding the bottleneck

        # Decoder path with skip connections
        for i, layer in enumerate(self.up):
            x = layer(x)
            if i < len(skips):
                x = torch.cat([x, skips[i]], dim=1)  # Concatenate skip features

        return self.final(x)  # Final upsampling to restore original size

# 3) Build model & load trained weights
gen = UNetGenerator()  # Initialize generator model
state_dict = torch.load(MODEL_WEIGHTS, map_location="cpu")  # Load state dict (trained weights)
gen.load_state_dict(state_dict)  # Load weights into the model
gen.to(DEVICE)  # Move model to GPU or CPU
gen.eval()  # Set model to evaluation mode (disable dropout, batchnorm update)

# 4) Preprocessing pipeline for input sketches
transform = transforms.Compose([
    transforms.Resize((256, 256)),              # Resize sketch to 256x256
    transforms.ToTensor(),                      # Convert to tensor in [0, 1]
    transforms.Normalize([0.5], [0.5]),         # Normalize to [-1, 1] to match training
    transforms.Lambda(lambda x: x.repeat(3,1,1))  # Convert 1-channel (grayscale) → 3-channel
])

# 5) Inference loop on test sketches
for img_path in glob.glob(os.path.join(TEST_DIR, "*")):
    # Skip files that aren't images
    if not img_path.lower().endswith((".png",".jpg",".jpeg")):
        continue

    fname = os.path.basename(img_path)  # Get image filename

    # Load image in grayscale mode and convert to PIL Image
    sketch = Image.open(img_path).convert("L")

    # Apply preprocessing pipeline and add batch dimension: shape = [1, 3, 256, 256]
    x = transform(sketch).unsqueeze(0).to(DEVICE)

    # Generate image using the trained generator (no gradient needed)
    with torch.no_grad():
        fake = gen(x)  # Output in range [-1, 1]

    # Convert generated image back to [0, 1] for saving
    fake = (fake * 0.5 + 0.5).clamp(0,1)

    # Save the output image
    out_path = os.path.join(SAVE_DIR, f"out_{fname}")
    save_image(fake, out_path)  # torchvision.utils.save_image
    print(f"Saved {out_path}")  # Log the save


  state_dict = torch.load(MODEL_WEIGHTS, map_location="cpu")


Saved pix2pix_outputs\out_10.png
Saved pix2pix_outputs\out_175.png
Saved pix2pix_outputs\out_2.png
Saved pix2pix_outputs\out_222.png
Saved pix2pix_outputs\out_3.png
Saved pix2pix_outputs\out_70.png
Saved pix2pix_outputs\out_80.png
Saved pix2pix_outputs\out_9.png


In [None]:
from PIL import Image
import numpy as np

# Function to load and convert image to numpy array
def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    image = np.array(image)
    return image

# Paths 
generated_image_path = "pix2pix_outputs/out_222.png"  
ground_truth_image_path = "processed_dataset_1/test_1/883.jpg"  

generated_image = load_image(generated_image_path)
ground_truth_image = load_image(ground_truth_image_path)


In [None]:
from skimage.metrics import structural_similarity as ssim , peak_signal_noise_ratio as psnr, mean_squared_error as mse

def calculate_ssim(img1, img2):
    # Ensure the images are the same shape
    if img1.shape != img2.shape:
        raise ValueError("Images must have the same dimensions")
    
    # Calculate SSIM
    return ssim(img1, img2, data_range=img1.max() - img1.min(), win_size=3)

# Calculate SSIM
ssim_value = calculate_ssim(generated_image, ground_truth_image)
print(f"SSIM between generated image and ground truth: {ssim_value:.4f}")


SSIM between generated image and ground truth: 0.1167


In [None]:

# Calculate PSNR
psnr_value = psnr(generated_image, ground_truth_image)
print(f"PSNR between generated image and ground truth: {psnr_value:.4f}")


PSNR between generated image and ground truth: 6.6368


In [None]:

# Calculate MSE
mse_value = mse(generated_image.flatten(), ground_truth_image.flatten())
print(f"MSE between generated image and ground truth: {mse_value:.4f}")


MSE between generated image and ground truth: 107.2621
