
# Week 6: Image Generation
## Objective:
In this notebook, we will explore the concepts of image generation, focusing on Image2Image and Text2Image models. We will implement an image generation model using GANs (Generative Adversarial Networks) and explore how models like Pix2Pix are used for image translation tasks.



## 6.1 Theory: What is Image Generation?
Image generation refers to the process of creating new images using machine learning models. Two popular types of image generation are:
- **Image2Image Translation**: Converts an input image from one domain to another.
- **Text2Image Generation**: Creates an image based on a text description.

### Generative Adversarial Networks (GANs)
GANs are a type of generative model that uses two networks: a **generator** and a **discriminator**. The generator creates images, and the discriminator tries to distinguish between real and generated images. Over time, the generator improves its ability to create realistic images.

### Image2Image Translation with Pix2Pix
Pix2Pix is a type of GAN that performs image translation. For example, it can convert sketches into photorealistic images or grayscale images into color images.


In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    
    def forward(self, input):
        return self.main(input).view(-1, 1, 28, 28)

# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, input):
        return self.main(input.view(-1, 784))

# Set random seed for reproducibility
torch.manual_seed(42)

# Hyperparameters
batch_size = 64
learning_rate = 0.0002
epochs = 20

# Prepare the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Loss function and optimizers
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_d = optim.Adam(discriminator.parameters(), lr=learning_rate)

# Training the GAN
for epoch in range(epochs):
    for i, (images, _) in enumerate(train_loader):
        # Train the discriminator
        real_images = images
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)
        
        outputs = discriminator(real_images)
        loss_real = criterion(outputs, real_labels)
        
        noise = torch.randn(batch_size, 100)
        fake_images = generator(noise)
        outputs = discriminator(fake_images.detach())
        loss_fake = criterion(outputs, fake_labels)
        
        loss_d = loss_real + loss_fake
        optimizer_d.zero_grad()
        loss_d.backward()
        optimizer_d.step()

        # Train the generator
        noise = torch.randn(batch_size, 100)
        fake_images = generator(noise)
        outputs = discriminator(fake_images)
        loss_g = criterion(outputs, real_labels)
        
        optimizer_g.zero_grad()
        loss_g.backward()
        optimizer_g.step()
    
    print(f"Epoch [{epoch+1}/{epochs}], Loss D: {loss_d.item()}, Loss G: {loss_g.item()}")

print("Finished Training the GAN")



## 6.2 Image2Image Translation with Pix2Pix
Pix2Pix is a conditional GAN that is used for image translation tasks. For example, it can convert sketches into realistic images or grayscale images into color. Unlike traditional GANs, Pix2Pix uses paired data, where each input image has a corresponding target image.

We will now use a pre-trained Pix2Pix model for an Image2Image task.


In [None]:

# Note: In practice, you'd download a pre-trained Pix2Pix model or train your own model on a dataset.

import torch
import torchvision.transforms as transforms
from PIL import Image

# Load a pre-trained Pix2Pix model (for example purposes)
# model = torch.load('pix2pix_model.pth')

# Load an image and apply the necessary transforms
image = Image.open('input_image.jpg')
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
input_image = transform(image).unsqueeze(0)

# Generate the translated image
# output_image = model(input_image)

# For demonstration, we'll just print the input shape
print(f"Input image shape: {input_image.shape}")



## 6.3 Exercises:
- Modify the GAN code to generate higher resolution images by increasing the size of the generator and discriminator networks.
- Experiment with different image translation tasks using Pix2Pix, such as converting night images to day images.
- Train a GAN on a custom dataset and evaluate the quality of the generated images.
