In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  # For Jupyter notebooks

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
import torch
torch.cuda.empty_cache()


# Hyperparameters
BATCH_SIZE = 4
SHAPE = (256, 256, 3)
EPOCHS = 200
LR = 0.0002

# Constant test image for tracking progress
TEST_IMAGE_PATH = r'data\contentimage\2013-11-10 07_43_23.jpg'
test_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
try:
    test_image = Image.open(TEST_IMAGE_PATH).convert('RGB')
    test_image = test_transform(test_image).unsqueeze(0).to(device)
    print("Test image loaded successfully.")
except Exception as e:
    print(f"Error loading test image: {e}")

In [None]:
# Data loading
class CycleGANDataset(Dataset):
    def __init__(self, photo_dir, style_dir, transform=None):
        if not os.path.exists(photo_dir):
            raise FileNotFoundError(f"Photo directory '{photo_dir}' does not exist.")
        if not os.path.exists(style_dir):
            raise FileNotFoundError(f"Style directory '{style_dir}' does not exist.")

        self.photo_files = [os.path.join(photo_dir, f) for f in os.listdir(photo_dir) if f.lower().endswith('.jpg')]
        self.style_files = [os.path.join(style_dir, f) for f in os.listdir(style_dir) if f.lower().endswith('.jpg')]

        print(f"Found {len(self.photo_files)} photo files in '{photo_dir}'")
        print(f"Found {len(self.style_files)} style files in '{style_dir}'")

        if not self.photo_files:
            raise ValueError(f"No .jpg files found in photo directory '{photo_dir}'")
        if not self.style_files:
            raise ValueError(f"No .jpg files found in style directory '{style_dir}'")

        self.transform = transform
        self.min_len = min(len(self.photo_files), len(self.style_files))

        if self.min_len == 0:
            raise ValueError("Dataset length is 0; both directories must contain at least one .jpg file.")

    def __len__(self):
        return self.min_len

    def __getitem__(self, idx):
        photo_path = self.photo_files[idx % len(self.photo_files)]
        style_path = self.style_files[idx % len(self.style_files)]
        
        try:
            photo = Image.open(photo_path).convert('RGB')
            style = Image.open(style_path).convert('RGB')
        except Exception as e:
            raise RuntimeError(f"Error loading image: {e} (Photo: {photo_path}, Style: {style_path})")
        
        if self.transform:
            photo = self.transform(photo)
            style = self.transform(style)
        
        return photo, style

# Data transformation
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Dataset and DataLoader
photo_dir = 'data/contentimage'
style_dir = 'data/vangogh_painting'

try:
    dataset = CycleGANDataset(photo_dir, style_dir, transform=transform)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
    print(f"Dataset length: {len(dataset)}")
    print(f"Dataloader length: {len(dataloader)}")
except Exception as e:
    print(f"Error initializing dataset or dataloader: {e}")
    exit(1)

In [None]:
# Generator Architecture
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.enc1 = self.conv_block(3, 64, 7, stride=1, padding=3, instance_norm=False)
        self.enc2 = self.conv_block(64, 128, 3, stride=2, padding=1)
        self.enc3 = self.conv_block(128, 256, 3, stride=2, padding=1)
        self.res_blocks = nn.Sequential(*[self.residual_block(256) for _ in range(6)])
        self.dec1 = self.deconv_block(256, 128, 3, stride=2, padding=1, output_padding=1)
        self.dec2 = self.deconv_block(128, 64, 3, stride=2, padding=1, output_padding=1)
        self.dec3 = nn.Sequential(nn.Conv2d(64, 3, kernel_size=7, stride=1, padding=3), nn.Tanh())

    def conv_block(self, in_channels, out_channels, kernel_size, stride=1, padding=0, instance_norm=True):
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, padding_mode='reflect')]
        if instance_norm:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.ReLU(inplace=True))
        return nn.Sequential(*layers)

    def residual_block(self, channels):
        return nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1, padding_mode='reflect'),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1, padding_mode='reflect'),
            nn.InstanceNorm2d(channels)
        )

    def deconv_block(self, in_channels, out_channels, kernel_size, stride=2, padding=0, output_padding=0):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding=output_padding),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        r = self.res_blocks(e3)
        d1 = self.dec1(r)
        d2 = self.dec2(d1)
        out = self.dec3(d2)
        return out

In [None]:
# Discriminator Architecture
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            self.conv_block(3, 64, 4, stride=2, padding=1, instance_norm=False),
            self.conv_block(64, 128, 4, stride=2, padding=1),
            self.conv_block(128, 256, 4, stride=2, padding=1),
            self.conv_block(256, 512, 4, stride=1, padding=1),
            nn.Conv2d(512, 1, 4, stride=2, padding=1)
        )

    def conv_block(self, in_channels, out_channels, kernel_size, stride=1, padding=0, instance_norm=True):
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)]
        if instance_norm:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [None]:
# # CycleGAN Model
# class CycleGAN(nn.Module):
#     def __init__(self):
#         super(CycleGAN, self).__init__()
#         self.generatorS = Generator().to(device)
#         self.generatorP = Generator().to(device)
#         self.discriminatorS = Discriminator().to(device)
#         self.discriminatorP = Discriminator().to(device)

#         self.g_optimizer = optim.Adam(list(self.generatorS.parameters()) + list(self.generatorP.parameters()), lr=LR, betas=(0.5, 0.999))
#         self.d_optimizer = optim.Adam(list(self.discriminatorS.parameters()) + list(self.discriminatorP.parameters()), lr=LR, betas=(0.5, 0.999))

#         self.gan_loss = nn.MSELoss()
#         self.cycle_loss = nn.L1Loss()
#         self.identity_loss = nn.L1Loss()

#     def train_step(self, real_photo, real_style):
#         fake_style = self.generatorS(real_photo)
#         fake_photo = self.generatorP(real_style)
#         cycled_photo = self.generatorP(fake_style)
#         cycled_style = self.generatorS(fake_photo)
#         same_photo = self.generatorP(real_photo)
#         same_style = self.generatorS(real_style)
        
#         disc_real_photo = self.discriminatorP(real_photo)
#         disc_fake_photo = self.discriminatorP(fake_photo)
#         disc_real_style = self.discriminatorS(real_style)
#         disc_fake_style = self.discriminatorS(fake_style)

#         gen_photo_loss = self.gan_loss(disc_fake_photo, torch.ones_like(disc_fake_photo))
#         gen_style_loss = self.gan_loss(disc_fake_style, torch.ones_like(disc_fake_style))
#         # cycle_photo_loss = self.cycle_loss(cycled_photo, real_photo) * 10.0
#         # cycle_style_loss = self.cycle_loss(cycled_style, real_style) * 10.0
#         # id_photo_loss = self.identity_loss(same_photo, real_photo) * 5.0
#         # id_style_loss = self.identity_loss(same_style, real_style) * 5.0


#         cycle_photo_loss = self.cycle_loss(cycled_photo, real_photo) * 3.0
#         cycle_style_loss = self.cycle_loss(cycled_style, real_style) * 3.0
#         id_photo_loss = self.identity_loss(same_photo, real_photo) * 0.5
#         id_style_loss = self.identity_loss(same_style, real_style) * 0.5
        
#         total_gen_loss = (gen_photo_loss + gen_style_loss + cycle_photo_loss + cycle_style_loss + id_photo_loss + id_style_loss)

#         real_photo_loss = self.gan_loss(disc_real_photo, torch.ones_like(disc_real_photo))
#         fake_photo_loss = self.gan_loss(disc_fake_photo.detach(), torch.zeros_like(disc_fake_photo))
#         real_style_loss = self.gan_loss(disc_real_style, torch.ones_like(disc_real_style))
#         fake_style_loss = self.gan_loss(disc_fake_style.detach(), torch.zeros_like(disc_fake_style))
        
#         total_disc_loss = 0.5 * (real_photo_loss + fake_photo_loss + real_style_loss + fake_style_loss)

#         self.g_optimizer.zero_grad()
#         total_gen_loss.backward()
#         self.g_optimizer.step()

#         self.d_optimizer.zero_grad()
#         total_disc_loss.backward()
#         self.d_optimizer.step()

#         return total_gen_loss.item(), total_disc_loss.item()

#     def save(self, path):
#         torch.save({
#             'generatorS': self.generatorS.state_dict(),
#             'generatorP': self.generatorP.state_dict(),
#             'discriminatorS': self.discriminatorS.state_dict(),
#             'discriminatorP': self.discriminatorP.state_dict(),
#             'g_optimizer': self.g_optimizer.state_dict(),
#             'd_optimizer': self.d_optimizer.state_dict()
#         }, path)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class CycleGAN(nn.Module):
    def __init__(self, device):
        super(CycleGAN, self).__init__()
        self.device = device  # Store device information
        self.generatorS = Generator().to(self.device)
        self.generatorP = Generator().to(self.device)
        self.discriminatorS = Discriminator().to(self.device)
        self.discriminatorP = Discriminator().to(self.device)

        self.g_optimizer = optim.Adam(
            list(self.generatorS.parameters()) + list(self.generatorP.parameters()), 
            lr=LR, betas=(0.5, 0.999)
        )
        self.d_optimizer = optim.Adam(
            list(self.discriminatorS.parameters()) + list(self.discriminatorP.parameters()), 
            lr=LR, betas=(0.5, 0.999)
        )

        self.gan_loss = nn.MSELoss()
        self.cycle_loss = nn.L1Loss()
        self.identity_loss = nn.L1Loss()

    def train_step(self, real_photo, real_style):
        self.g_optimizer.zero_grad()

        fake_style = self.generatorS(real_photo)
        fake_photo = self.generatorP(real_style)
        cycled_photo = self.generatorP(fake_style)
        cycled_style = self.generatorS(fake_photo)
        same_photo = self.generatorP(real_photo)
        same_style = self.generatorS(real_style)

        disc_fake_photo = self.discriminatorP(fake_photo)
        disc_fake_style = self.discriminatorS(fake_style)

        gen_photo_loss = self.gan_loss(disc_fake_photo, torch.ones_like(disc_fake_photo))
        gen_style_loss = self.gan_loss(disc_fake_style, torch.ones_like(disc_fake_style))

        cycle_photo_loss = self.cycle_loss(cycled_photo, real_photo) * 3.0
        cycle_style_loss = self.cycle_loss(cycled_style, real_style) * 3.0

        id_photo_loss = self.identity_loss(same_photo, real_photo) * 0.5
        id_style_loss = self.identity_loss(same_style, real_style) * 0.5

        total_gen_loss = gen_photo_loss + gen_style_loss + cycle_photo_loss + cycle_style_loss + id_photo_loss + id_style_loss

        

        total_gen_loss.backward()
        self.g_optimizer.step()

        self.d_optimizer.zero_grad()

        disc_real_photo = self.discriminatorP(real_photo)
        disc_real_style = self.discriminatorS(real_style)
        disc_fake_photo = self.discriminatorP(fake_photo.detach())
        disc_fake_style = self.discriminatorS(fake_style.detach())

        real_photo_loss = self.gan_loss(disc_real_photo, torch.ones_like(disc_real_photo))
        fake_photo_loss = self.gan_loss(disc_fake_photo, torch.zeros_like(disc_fake_photo))
        real_style_loss = self.gan_loss(disc_real_style, torch.ones_like(disc_real_style))
        fake_style_loss = self.gan_loss(disc_fake_style, torch.zeros_like(disc_fake_style))

        total_disc_loss = 0.5 * (real_photo_loss + fake_photo_loss + real_style_loss + fake_style_loss)

       

        total_disc_loss.backward()
        self.d_optimizer.step()

        return total_gen_loss, total_disc_loss  # Fix: Removed `.item()`


    def save(self, path, epoch):
        torch.save({
            'generatorS': self.generatorS.state_dict(),
            'generatorP': self.generatorP.state_dict(),
            'discriminatorS': self.discriminatorS.state_dict(),
            'discriminatorP': self.discriminatorP.state_dict(),
            'g_optimizer': self.g_optimizer.state_dict(),
            'd_optimizer': self.d_optimizer.state_dict(),
            'epoch': epoch  # ‚úÖ Now saving epoch
        }, path)


In [None]:
# # Training Loop
# model = CycleGAN(device)

# for epoch in range(EPOCHS):
#     print(f"Epoch {epoch+1}/{EPOCHS}")
#     progress_bar = tqdm(
#         dataloader,
#         desc=f"Epoch {epoch+1}/{EPOCHS}",
#         total=len(dataloader),
#         unit="batch",
#         leave=True
#     )
#     total_gen_loss, total_disc_loss = 0, 0
#     batches = 0

#     try:
#         for i, (real_photo, real_style) in enumerate(progress_bar):
#             real_photo, real_style = real_photo.to(device), real_style.to(device)
#             gen_loss, disc_loss = model.train_step(real_photo, real_style)
            
#             total_gen_loss += gen_loss
#             total_disc_loss += disc_loss
#             batches += 1
            
#             progress_bar.set_postfix({'Gen Loss': f'{gen_loss:.4f}', 'Disc Loss': f'{disc_loss:.4f}'})
#     except Exception as e:
#         print(f"Error during training: {e}")
#         break

#     if batches == 0:
#         print("No batches processed. Check dataloader!")
#         break

#     avg_gen_loss = total_gen_loss / batches
#     avg_disc_loss = total_disc_loss / batches
#     print(f"Avg Gen Loss: {avg_gen_loss:.4f}, Avg Disc Loss: {avg_disc_loss:.4f}")

#     with torch.no_grad():
#         generated_monet = model.generatorS(test_image)
    
#     fig, axes = plt.subplots(1, 2, figsize=(10, 5))
#     test_img_np = test_image.cpu().squeeze(0).numpy().transpose(1, 2, 0) * 0.5 + 0.5
#     gen_img_np = generated_monet.cpu().squeeze(0).numpy().transpose(1, 2, 0) * 0.5 + 0.5
    
#     axes[0].imshow(test_img_np)
#     axes[0].set_title("Original Photo")
#     axes[0].axis('off')
    
#     axes[1].imshow(gen_img_np)
#     axes[1].set_title(f"vangogh Style (Epoch {epoch+1})")
#     axes[1].axis('off')
    
#     plt.show()

#     save_path = f'CycleGAN_vangogh_epoch_new.pth'
#     model.save(save_path)
#     print(f"Model saved to {save_path}")

#     progress_bar.close()

# CONT

In [None]:
# import torch
# import os
# import matplotlib.pyplot as plt
# from tqdm import tqdm

# # Ensure the model is defined before loading
# model = CycleGAN(device).to(device)
# model.train()  # ‚úÖ Ensure model is in training mode

# EPOCHS = 250
# checkpoint_path = "CycleGAN_vangogh_epoch.pth"

# # üîπ Load Checkpoint
# if os.path.exists(checkpoint_path):
#     checkpoint = torch.load(checkpoint_path, map_location=device)
#     model.generatorS.load_state_dict(checkpoint['generatorS'])
#     model.generatorP.load_state_dict(checkpoint['generatorP'])
#     model.discriminatorS.load_state_dict(checkpoint['discriminatorS'])
#     model.discriminatorP.load_state_dict(checkpoint['discriminatorP'])
#     model.g_optimizer.load_state_dict(checkpoint['g_optimizer'])
#     model.d_optimizer.load_state_dict(checkpoint['d_optimizer'])
#     start_epoch = checkpoint.get('epoch', 58)  # Default to 58 if missing
#     print(f"‚úÖ Checkpoint loaded! Resuming training from epoch {start_epoch}.")
# else:
#     start_epoch = 1
#     print("‚ùå No checkpoint found. Starting training from scratch.")

# # ‚úÖ Ensure all model parameters require gradients
# for param in model.parameters():
#     param.requires_grad = True

# for param_group in model.g_optimizer.param_groups:
#     for param in param_group['params']:
#         param.requires_grad = True

# for param_group in model.d_optimizer.param_groups:
#     for param in param_group['params']:
#         param.requires_grad = True

# # ‚úÖ Reset optimizer gradients before training
# model.g_optimizer.zero_grad(set_to_none=True)
# model.d_optimizer.zero_grad(set_to_none=True)

# # üî• Training Loop
# for epoch in range(start_epoch, EPOCHS + 1):
#     print(f"Epoch {epoch}/{EPOCHS}")
#     progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}/{EPOCHS}", total=len(dataloader), unit="batch", leave=True)

#     total_gen_loss, total_disc_loss = 0, 0
#     batches = 0

#     try:
#         for i, (real_photo, real_style) in enumerate(progress_bar):
#             real_photo, real_style = real_photo.to(device), real_style.to(device)
            
#             # üîπ Forward Pass
#             gen_loss, disc_loss = model.train_step(real_photo, real_style)

#             # ‚úÖ Ensure losses are part of the computation graph
#             gen_loss = gen_loss.mean()
#             disc_loss = disc_loss.mean()

#             assert gen_loss.requires_grad, "‚ùå Generator loss is not part of the computation graph!"
#             assert disc_loss.requires_grad, "‚ùå Discriminator loss is not part of the computation graph!"

#             total_gen_loss += gen_loss.item()
#             total_disc_loss += disc_loss.item()
#             batches += 1
            
#             progress_bar.set_postfix({'Gen Loss': f'{gen_loss:.4f}', 'Disc Loss': f'{disc_loss:.4f}'})
    
#     except Exception as e:
#         print(f"‚ùå Error during training: {e}")
#         break

#     if batches == 0:
#         print("‚ùå No batches processed. Check dataloader!")
#         break

#     avg_gen_loss = total_gen_loss / batches
#     avg_disc_loss = total_disc_loss / batches
#     print(f"‚úÖ Avg Gen Loss: {avg_gen_loss:.4f}, Avg Disc Loss: {avg_disc_loss:.4f}")

#     # üîπ Generate Sample Output
#     with torch.no_grad():
#         generated_monet = model.generatorS(test_image.to(device))

#     fig, axes = plt.subplots(1, 2, figsize=(10, 5))
#     test_img_np = test_image.cpu().squeeze(0).numpy().transpose(1, 2, 0) * 0.5 + 0.5
#     gen_img_np = generated_monet.cpu().squeeze(0).numpy().transpose(1, 2, 0) * 0.5 + 0.5
    
#     axes[0].imshow(test_img_np)
#     axes[0].set_title("Original Photo")
#     axes[0].axis('off')
    
#     axes[1].imshow(gen_img_np)
#     axes[1].set_title(f"Van Gogh Style (Epoch {epoch})")
#     axes[1].axis('off')
    
#     plt.show()

#     # üîπ Save Model Checkpoint
#     # save_path = f'CycleGAN_vangogh_epoch.pth'
#     # torch.save({
#     #     'generatorS': model.generatorS.state_dict(),
#     #     'generatorP': model.generatorP.state_dict(),
#     #     'discriminatorS': model.discriminatorS.state_dict(),
#     #     'discriminatorP': model.discriminatorP.state_dict(),
#     #     'g_optimizer': model.g_optimizer.state_dict(),
#     #     'd_optimizer': model.d_optimizer.state_dict(),
#     #     'epoch': epoch
#     # }, save_path)
#     # print(f"üíæ Model saved to {save_path}")
#     save_path = 'CycleGAN_vangogh_epoch.pth'
#     model.save(save_path, epoch)  # ‚úÖ Now using the model's save function
#     print(f"üíæ Model saved to {save_path}")

#     progress_bar.close()

# print("‚úÖ Training complete!")


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image

def load_model_and_generate(image_path, model_path, generator_class, device="cuda" if torch.cuda.is_available() else "cpu"):
    """
    Loads a CycleGAN generator model from a saved .pth file, processes an input image,
    generates a stylized output, and returns both images.
    
    Args:
        image_path (str): Path to the input image.
        model_path (str): Path to the saved .pth model file (state_dict format).
        generator_class (torch.nn.Module): The class definition of the generator.
        device (str): Device to run the model on ('cuda' or 'cpu').
    
    Returns:
        (PIL.Image, PIL.Image): Tuple containing (original_image, generated_image)
    """
    
    # Load the generator model architecture
    generator = generator_class().to(device)
    
    # Load the saved model state dictionary
    checkpoint = torch.load(model_path, map_location=device)
    generator.load_state_dict(checkpoint["generatorS"])  # Ensure "generatorS" exists in the checkpoint
    
    # Disable gradients for inference
    generator.eval()
    torch.set_grad_enabled(False)

    # Load and preprocess the input image
    transform = transforms.Compose([
        transforms.Resize((256, 256)),  # Resize image to match model input size
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1,1] (assuming CycleGAN training used this)
    ])
    
    image = Image.open(image_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimension
    
    # Generate the stylized image
    generated_tensor = generator(input_tensor).cpu().detach()
    
    # Convert tensors to images
    original_image = transforms.ToPILImage()(input_tensor.squeeze(0) * 0.5 + 0.5)  # Unnormalize
    generated_image = transforms.ToPILImage()(generated_tensor.squeeze(0) * 0.5 + 0.5)  # Unnormalize
    
    return original_image, generated_image

# Example usage:
# from model import Generator  # Ensure you have the Generator class defined
# original, generated = load_model_and_generate("input.jpg", "CycleGAN.pth", Generator)
# original.show()  # Display original image
# generated.show()  # Display generated image


In [None]:
import matplotlib.pyplot as plt

# Call the function with the input image and the saved model
original, generated = load_model_and_generate(
    image_path= r"data\contentimage\2013-11-18 06_23_04.jpg",  # Path to your input image
    model_path="CycleGAN_vangogh_epoch.pth",  # Path to your saved model file
    generator_class=Generator  # Pass the Generator class
)

# Display images side by side in the terminal
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes[0].imshow(original)
axes[0].set_title("Original Image")
axes[0].axis("off")

axes[1].imshow(generated)
axes[1].set_title("Generated Image")
axes[1].axis("off")

plt.show()
