In [1]:
import zipfile
import os
import glob
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import random

# Paths
zip_path = "/content/edge2shoes.zip"  # The path to your zip file on Google Colab
extract_path = "/content/extracted_edge2shoes/"  # Directory to extract the dataset

# Unzip the file
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print("Dataset extracted successfully.")

# Check if the directory exists and list some files
if os.path.exists(extract_path):
    print("Files in extracted folder:")
    print(os.listdir(extract_path)[:10])  # Display the first 10 files or folders
else:
    print("Extraction path does not exist.")

# Define transformations for the dataset images
transform = transforms.Compose([
    transforms.Resize((256, 256)),     # Resize to 256x256
    transforms.ToTensor(),             # Convert images to PyTorch tensors
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

# Paired Dataset Class
class PairedEdge2ShoesDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        Args:
            data_dir (str): Directory path with images (each containing paired information).
            transform (callable, optional): Transform to be applied on both images.
        """
        self.image_paths = sorted(glob.glob(os.path.join(data_dir, "*.jpg")))
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load the paired image
        paired_image = Image.open(self.image_paths[idx]).convert("RGB")

        # Split the image into two halves (assuming the left half is edge and right half is the real image)
        width, height = paired_image.size
        edge_image = paired_image.crop((0, 0, width // 2, height))
        real_image = paired_image.crop((width // 2, 0, width, height))

        # Apply transformations
        if self.transform:
            edge_image = self.transform(edge_image)
            real_image = self.transform(real_image)

        return edge_image, real_image

# Unpaired Dataset Class (edge and real images loaded independentl
class UnpairedEdge2ShoesDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        Args:
            data_dir (str): Directory path with images.
            transform (callable, optional): Transform to be applied on images.
        """
        self.image_paths = sorted(glob.glob(os.path.join(data_dir, "*.jpg")))
        self.edge_paths = self.image_paths.copy()  # For edge images
        self.real_paths = self.image_paths.copy()  # For real images
        self.transform = transform

        # Shuffle the paths for unpaired loading
        random.shuffle(self.edge_paths)
        random.shuffle(self.real_paths)

    def __len__(self):
        return min(len(self.edge_paths), len(self.real_paths))

    def __getitem__(self, idx):
        # Load unpaired edge image
        edge_image_full = Image.open(self.edge_paths[idx]).convert("RGB")
        width, height = edge_image_full.size
        edge_image = edge_image_full.crop((0, 0, width // 2, height))

        # Load unpaired real image
        real_image_full = Image.open(self.real_paths[idx]).convert("RGB")
        real_image = real_image_full.crop((width // 2, 0, width, height))

        # Apply transformations
        if self.transform:
            edge_image = self.transform(edge_image)
            real_image = self.transform(real_image)

        return edge_image, real_image

# Create paired and unpaired datasets
paired_dataset = PairedEdge2ShoesDataset(data_dir=extract_path, transform=transform)
unpaired_dataset = UnpairedEdge2ShoesDataset(data_dir=extract_path, transform=transform)

# Define batch size
batch_size = 32

# Define DataLoader with increased batch size if GPU memory permits
paired_loader = DataLoader(paired_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
unpaired_loader = DataLoader(unpaired_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

print("Paired and unpaired data loaders are ready.")
print("Number of paired samples:", len(paired_dataset))
print("Number of unpaired samples:", len(unpaired_dataset))

Dataset extracted successfully.
Files in extracted folder:
['332_AB.jpg', '689_AB.jpg', '797_AB.jpg', '841_AB.jpg', '440_AB.jpg', '1144_AB.jpg', '520_AB.jpg', '772_AB.jpg', '785_AB.jpg', '849_AB.jpg']
Paired and unpaired data loaders are ready.
Number of paired samples: 400
Number of unpaired samples: 400


In [2]:
import torch
import torch.nn as nn

class HybridUNetGenerator(nn.Module):
    def __init__(self, in_channels, out_channels, use_dropout=False):
        super(HybridUNetGenerator, self).__init__()

        # Use InstanceNorm instead of BatchNorm
        norm_layer = nn.InstanceNorm2d

        # Shared Encoder Layers
        self.down1 = self.conv_block(in_channels, 64, norm_layer, use_dropout)
        self.down2 = self.conv_block(64, 128, norm_layer, use_dropout)
        self.down3 = self.conv_block(128, 256, norm_layer, use_dropout)
        self.down4 = self.conv_block(256, 512, norm_layer, use_dropout)
        self.down5 = self.conv_block(512, 512, norm_layer, use_dropout)
        self.down6 = self.conv_block(512, 512, norm_layer, use_dropout)

        # Decoder for cGAN (Edge-to-Image)
        self.up1_cgan = self.upconv_block(512, 512, norm_layer)
        self.up2_cgan = self.upconv_block(1024, 512, norm_layer)
        self.up3_cgan = self.upconv_block(1024, 256, norm_layer)
        self.up4_cgan = self.upconv_block(512, 128, norm_layer)
        self.up5_cgan = self.upconv_block(256, 64, norm_layer)
        self.final_cgan = nn.Sequential(
            nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

        # Decoder for CycleGAN (Cycle Consistency)
        self.up1_cyclegan = self.upconv_block(512, 512, norm_layer)
        self.up2_cyclegan = self.upconv_block(1024, 512, norm_layer)
        self.up3_cyclegan = self.upconv_block(1024, 256, norm_layer)
        self.up4_cyclegan = self.upconv_block(512, 128, norm_layer)
        self.up5_cyclegan = self.upconv_block(256, 64, norm_layer)
        self.final_cyclegan = nn.Sequential(
            nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def conv_block(self, in_channels, out_channels, norm_layer, use_dropout=False):
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            norm_layer(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        if use_dropout:
            layers.append(nn.Dropout(0.5))
        return nn.Sequential(*layers)

    def upconv_block(self, in_channels, out_channels, norm_layer):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            norm_layer(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Shared Encoder forward pass with down-sampling
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)

        # cGAN Decoder forward pass (Edge-to-Image)
        u1_cgan = self.up1_cgan(d6)
        u2_cgan = self.up2_cgan(torch.cat([u1_cgan, d5], dim=1))
        u3_cgan = self.up3_cgan(torch.cat([u2_cgan, d4], dim=1))
        u4_cgan = self.up4_cgan(torch.cat([u3_cgan, d3], dim=1))
        u5_cgan = self.up5_cgan(torch.cat([u4_cgan, d2], dim=1))
        output_cgan = self.final_cgan(torch.cat([u5_cgan, d1], dim=1))

        # CycleGAN Decoder forward pass (Cycle Consistency)
        u1_cyclegan = self.up1_cyclegan(d6)
        u2_cyclegan = self.up2_cyclegan(torch.cat([u1_cyclegan, d5], dim=1))
        u3_cyclegan = self.up3_cyclegan(torch.cat([u2_cyclegan, d4], dim=1))
        u4_cyclegan = self.up4_cyclegan(torch.cat([u3_cyclegan, d3], dim=1))
        u5_cyclegan = self.up5_cyclegan(torch.cat([u4_cyclegan, d2], dim=1))
        output_cyclegan = self.final_cyclegan(torch.cat([u5_cyclegan, d1], dim=1))

        return output_cgan, output_cyclegan


# Discriminator stays the same as PatchGAN
class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_channels, use_dropout=False):
        super(PatchGANDiscriminator, self).__init__()

        # Use InstanceNorm instead of BatchNorm
        norm_layer = nn.InstanceNorm2d

        self.main = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            self.conv_block(64, 128, norm_layer, use_dropout),
            self.conv_block(128, 256, norm_layer, use_dropout),
            self.conv_block(256, 512, norm_layer, use_dropout),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def conv_block(self, in_channels, out_channels, norm_layer, use_dropout=False):
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            norm_layer(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        if use_dropout:
            layers.append(nn.Dropout(0.5))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.main(x)


# Loss Functions
adversarial_loss = nn.MSELoss()  # for discriminator's feedback
l1_loss = nn.L1Loss()  # for cycle consistency and paired data loss

In [3]:
import torch.optim as optim
import torch.nn.init as init

# Instantiate the hybrid generator with shared encoder and separate decoders for cGAN and CycleGAN
gen_hybrid = HybridUNetGenerator(in_channels=3, out_channels=3)

# Instantiate the discriminators (PatchGAN for both cGAN and CycleGAN)
disc_cGAN = PatchGANDiscriminator(in_channels=3)
disc_cycle = PatchGANDiscriminator(in_channels=3)

# Weight initialization function (same as before)
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        init.normal_(m.weight, 0.0, 0.02)
        if m.bias is not None:
            init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        init.normal_(m.weight, 1.0, 0.02)
        init.constant_(m.bias, 0)

# Apply weight initialization to all models
gen_hybrid.apply(weights_init)
disc_cGAN.apply(weights_init)
disc_cycle.apply(weights_init)

# Move models to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gen_hybrid, disc_cGAN, disc_cycle = gen_hybrid.to(device), disc_cGAN.to(device), disc_cycle.to(device)

# Define optimizers for both cGAN and CycleGAN
lr = 0.0002
beta1 = 0.5
beta2 = 0.999

# Optimizers for the generators (cGAN and CycleGAN) and discriminators (PatchGAN)
optimizer_g = optim.Adam(gen_hybrid.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_d_cGAN = optim.Adam(disc_cGAN.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_d_cycle = optim.Adam(disc_cycle.parameters(), lr=lr, betas=(beta1, beta2))

# Define learning rate schedulers for both the cGAN and CycleGAN components
g_scheduler = optim.lr_scheduler.StepLR(optimizer_g, step_size=10, gamma=0.5)  # Same scheduler for both cGAN and CycleGAN
d_scheduler_cGAN = optim.lr_scheduler.StepLR(optimizer_d_cGAN, step_size=10, gamma=0.5)
d_scheduler_cycle = optim.lr_scheduler.StepLR(optimizer_d_cycle, step_size=10, gamma=0.5)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.amp as amp
import os
import matplotlib.pyplot as plt

# Set environment variable to handle memory fragmentation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Initialize device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize loss functions
criterion_gan = nn.BCEWithLogitsLoss()
criterion_l1 = nn.L1Loss()

# Initialize the GradScaler for mixed precision
scaler = amp.GradScaler(enabled=True)

# Loss function wrappers
def gan_loss(output, target, criterion):
    return criterion(output, target)

def cycle_consistency_loss(recovered, real, criterion, lambda_cycle=10):
    return criterion(recovered, real) * lambda_cycle

# Number of epochs
num_epochs = 100

# Initialize lists to store losses for plotting
g_losses = []
d_losses_cGAN = []
d_losses_cycle = []

# Create a directory for checkpoints
os.makedirs("checkpoints", exist_ok=True)

# Training loop
for epoch in range(num_epochs):
    gen_hybrid.train()
    disc_cGAN.train()
    disc_cycle.train()

    # Initialize epoch losses
    g_loss_epoch = 0.0
    d_loss_cGAN_epoch = 0.0
    d_loss_cycle_epoch = 0.0
    num_batches = 0

    for i, (real_A, real_B) in enumerate(zip(paired_loader, unpaired_loader)):
        real_A, real_B = real_A[0].to(device), real_B[0].to(device)

        # ============================
        # Update the Discriminators
        # ============================

        # Train Discriminator for cGAN
        optimizer_d_cGAN.zero_grad()
        with amp.autocast(device_type='cuda'):
            fake_B_cgan, _ = gen_hybrid(real_A)
            disc_real_B = disc_cGAN(real_B)
            disc_fake_B = disc_cGAN(fake_B_cgan.detach())

            # Discriminator cGAN Loss
            d_loss_cGAN = (gan_loss(disc_real_B, torch.ones_like(disc_real_B).to(device), criterion_gan) +
                           gan_loss(disc_fake_B, torch.zeros_like(disc_fake_B).to(device), criterion_gan)) / 2

        # Backpropagation and optimization for discriminator cGAN
        scaler.scale(d_loss_cGAN).backward()
        scaler.step(optimizer_d_cGAN)
        scaler.update()

        # Train Discriminator for CycleGAN
        optimizer_d_cycle.zero_grad()
        with amp.autocast(device_type='cuda'):
            _, fake_A = gen_hybrid(real_B)
            disc_real_A = disc_cycle(real_A)
            disc_fake_A = disc_cycle(fake_A.detach())

            # Discriminator CycleGAN Loss
            d_loss_cycle = (gan_loss(disc_real_A, torch.ones_like(disc_real_A).to(device), criterion_gan) +
                            gan_loss(disc_fake_A, torch.zeros_like(disc_fake_A).to(device), criterion_gan)) / 2

        # Backpropagation and optimization for discriminator CycleGAN
        scaler.scale(d_loss_cycle).backward()
        scaler.step(optimizer_d_cycle)
        scaler.update()

        # ============================
        # Update the Generator
        # ============================
        optimizer_g.zero_grad()
        with amp.autocast(device_type='cuda'):
            # Generator cGAN Loss
            g_loss_cGAN = gan_loss(disc_cGAN(fake_B_cgan), torch.ones_like(disc_real_B).to(device), criterion_gan)

            # Cycle Consistency Loss
            _, fake_A = gen_hybrid(real_B)
            recovered_A, _ = gen_hybrid(fake_A)
            cycle_loss = cycle_consistency_loss(recovered_A, real_A, criterion_l1)

            # Total Generator Loss
            g_loss = g_loss_cGAN + cycle_loss

        # Backpropagation and optimization for generator
        scaler.scale(g_loss).backward()
        scaler.step(optimizer_g)
        scaler.update()

        # Accumulate epoch losses for averaging
        g_loss_epoch += g_loss.item()
        d_loss_cGAN_epoch += d_loss_cGAN.item()
        d_loss_cycle_epoch += d_loss_cycle.item()
        num_batches += 1

        # Print losses
        if i % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i}], "
                  f"g_loss: {g_loss.item():.4f}, "
                  f"d_loss_cGAN: {d_loss_cGAN.item():.4f}, "
                  f"d_loss_cycle: {d_loss_cycle.item():.4f}")

    # Calculate average losses for the epoch
    g_losses.append(g_loss_epoch / num_batches)
    d_losses_cGAN.append(d_loss_cGAN_epoch / num_batches)
    d_losses_cycle.append(d_loss_cycle_epoch / num_batches)

    # Update schedulers
    g_scheduler.step()
    d_scheduler_cGAN.step()
    d_scheduler_cycle.step()

    # Save checkpoints periodically
    if (epoch + 1) % 10 == 0:
        torch.save(gen_hybrid.state_dict(), f"checkpoints/gen_hybrid_epoch_{epoch+1}.pth")
        torch.save(disc_cGAN.state_dict(), f"checkpoints/disc_cGAN_epoch_{epoch+1}.pth")
        torch.save(disc_cycle.state_dict(), f"checkpoints/disc_cycle_epoch_{epoch+1}.pth")

# Plot the Training Losses
plt.figure(figsize=(10, 5))
plt.plot(g_losses, label='Generator Loss')
plt.plot(d_losses_cGAN, label='Discriminator cGAN Loss')
plt.plot(d_losses_cycle, label='Discriminator Cycle Loss')

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Losses Over Epochs')
plt.legend()
plt.grid(True)
plt.show()



Epoch [1/100], Step [0], g_loss: 10.4695, d_loss_cGAN: 0.7248, d_loss_cycle: 0.7309
