# CycleGAN for Image-to-Image Translation

<!--
Project: CycleGAN for Image-to-Image Translation
Author: RSK World
Website: https://rskworld.in
Email: help@rskworld.in
Phone: +91 93305 39277
Description: CycleGAN for unpaired image-to-image translation using cycle-consistent adversarial networks
-->

This notebook demonstrates the CycleGAN model for unpaired image-to-image translation.

## Features
- Unpaired image translation
- Cycle consistency loss
- Style transfer capabilities
- No need for paired training data


In [None]:
# Import necessary libraries
# Author: RSK World
# Website: https://rskworld.in
# Email: help@rskworld.in
# Phone: +91 93305 39277

import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
import os
import sys

# Add project root to path
sys.path.append(os.path.dirname(os.getcwd()))

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device:", torch.cuda.get_device_name(0))


## 1. Load the CycleGAN Model


In [None]:
# Load CycleGAN model
# Author: RSK World
# Website: https://rskworld.in

from models import CycleGANModel
from models.networks import define_G
import argparse

# Create a simple config object
class Config:
    def __init__(self):
        self.input_nc = 3
        self.output_nc = 3
        self.ngf = 64
        self.ndf = 64
        self.netG = 'resnet_9blocks'
        self.netD = 'basic'
        self.n_layers_D = 3
        self.norm = 'instance'
        self.init_type = 'normal'
        self.init_gain = 0.02
        self.no_dropout = True
        self.gpu_ids = [0] if torch.cuda.is_available() else []
        self.isTrain = False
        self.lambda_A = 10.0
        self.lambda_B = 10.0
        self.lambda_identity = 0.5
        self.pool_size = 50
        self.checkpoints_dir = './checkpoints'
        self.name = 'cyclegan_experiment'
        self.epoch = 'latest'

config = Config()

# Define generator
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
netG_A2B = define_G(config.input_nc, config.output_nc, config.ngf, config.netG, 
                     config.norm, not config.no_dropout, config.init_type, 
                     config.init_gain, config.gpu_ids)
netG_B2A = define_G(config.output_nc, config.input_nc, config.ngf, config.netG, 
                     config.norm, not config.no_dropout, config.init_type, 
                     config.init_gain, config.gpu_ids)

netG_A2B.to(device)
netG_B2A.to(device)
netG_A2B.eval()
netG_B2A.eval()

print("Generators loaded successfully!")


## 2. Image Preprocessing Functions


In [None]:
# Image preprocessing functions
# Author: RSK World
# Website: https://rskworld.in

def preprocess_image(image_path, size=256):
    """
    Preprocess an image for CycleGAN.
    
    Args:
        image_path: Path to the image
        size: Target size for the image
        
    Returns:
        Preprocessed tensor
    """
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0)
    return image_tensor

def tensor_to_image(tensor):
    """
    Convert a tensor to a PIL Image.
    
    Args:
        tensor: Input tensor
        
    Returns:
        PIL Image
    """
    tensor = tensor.squeeze(0).cpu()
    tensor = (tensor + 1) / 2.0
    tensor = torch.clamp(tensor, 0, 1)
    to_pil = transforms.ToPILImage()
    return to_pil(tensor)

print("Preprocessing functions defined!")


## 3. Image Translation Function


In [None]:
# Image translation function
# Author: RSK World
# Website: https://rskworld.in

def translate_image(image_path, direction='AtoB'):
    """
    Translate an image using CycleGAN.
    
    Args:
        image_path: Path to the input image
        direction: Translation direction ('AtoB' or 'BtoA')
        
    Returns:
        Translated image as PIL Image
    """
    # Preprocess image
    image_tensor = preprocess_image(image_path)
    image_tensor = image_tensor.to(device)
    
    # Translate
    with torch.no_grad():
        if direction == 'AtoB':
            translated = netG_A2B(image_tensor)
        else:
            translated = netG_B2A(image_tensor)
    
    # Convert back to image
    result_image = tensor_to_image(translated)
    return result_image

print("Translation function defined!")


## 4. Visualization Function


In [None]:
# Visualization function
# Author: RSK World
# Website: https://rskworld.in

def visualize_translation(image_path, direction='AtoB'):
    """
    Visualize the translation result.
    
    Args:
        image_path: Path to the input image
        direction: Translation direction
    """
    # Load original image
    original = Image.open(image_path).convert('RGB')
    
    # Translate
    translated = translate_image(image_path, direction)
    
    # Display
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    axes[0].imshow(original)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(translated)
    axes[1].set_title(f'Translated Image ({direction})')
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()

print("Visualization function defined!")


## 5. Example Usage

To use this notebook:

1. Prepare your dataset in the following structure:
   ```
   datasets/
   └── your_dataset/
       ├── trainA/  # Domain A images
       ├── trainB/  # Domain B images
       ├── testA/   # Test images from domain A
       └── testB/   # Test images from domain B
   ```

2. Train the model using:
   ```bash
   python train.py --dataroot ./datasets/your_dataset --name experiment_name
   ```

3. Load a trained model and test:
   ```python
   # Example: Translate an image
   image_path = './datasets/your_dataset/testA/image1.jpg'
   visualize_translation(image_path, direction='AtoB')
   ```


## 6. Cycle Consistency Demonstration

This section demonstrates the cycle consistency property of CycleGAN.


In [None]:
# Cycle consistency demonstration
# Author: RSK World
# Website: https://rskworld.in

def demonstrate_cycle_consistency(image_path):
    """
    Demonstrate cycle consistency: A -> B -> A should be close to original A.
    
    Args:
        image_path: Path to the input image
    """
    # Load original
    original = Image.open(image_path).convert('RGB')
    image_tensor = preprocess_image(image_path).to(device)
    
    with torch.no_grad():
        # A -> B
        fake_B = netG_A2B(image_tensor)
        # B -> A (reconstruction)
        rec_A = netG_B2A(fake_B)
    
    # Convert to images
    fake_B_img = tensor_to_image(fake_B)
    rec_A_img = tensor_to_image(rec_A)
    
    # Display
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(original)
    axes[0].set_title('Original A')
    axes[0].axis('off')
    
    axes[1].imshow(fake_B_img)
    axes[1].set_title('Translated to B')
    axes[1].axis('off')
    
    axes[2].imshow(rec_A_img)
    axes[2].set_title('Reconstructed A (B -> A)')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

print("Cycle consistency function defined!")
print("\nNote: This is a demonstration notebook. To use with actual data:")
print("1. Train the model on your dataset")
print("2. Load the trained weights")
print("3. Use the translation functions above")
