# **Controlled Generation in GANs**

## **Dataset Example**

![Example](https://mmlab.ie.cuhk.edu.hk/projects/CelebA/overview.png)

# **Import Dependencies**

In [1]:
import torch
import torchvision
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import CelebA
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device being use: {device}")

def show_tensor_images(image_tensor, num_images=16, size=(3,64,64), nrow=3):
    
    image_tensor = (image_tensor+1)/2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
    plt.imshow(image_grid.permute(1,2,0).sqeeze())
    plt.show()

Device being use: cuda


# **Generator**

In [6]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=3, hidden_dim=64):
        super().__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim*8),
            self.make_gen_block(hidden_dim*8, hidden_dim*4),
            self.make_gen_block(hidden_dim*4, hidden_dim*2),
            self.make_gen_block(hidden_dim*2, im_chan, kernel_size=4, final_layer=True)
        )
        
    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )
        
    def forward(self, noise):
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)
    
def get_noise(n_samples, z_dim, device):
    return torch.randn(n_samples, z_dim).to(device)

# **Classifier**

In [3]:
class Classifier(nn.Module):
    def __init__(self, im_chan=3, n_classes=2, hidden_dim=64):
        super().__init__()
        self.classifier = nn.Sequential(
            self.make_classifier_block(im_chan, hidden_dim),
            self.make_classifier_block(hidden_dim, hidden_dim*2),
            self.make_classifier_block(hidden_dim*2, hidden_dim*4, stride=3),
             self.make_classifier_block(hidden_dim*4, n_classes, final_layer=True),
        )
        
    def make_classifier_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        if final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        
    def forward(self, image):
        class_pred = self.classifier(image)
        return class_pred.view(len(class_pred), -1)

## **Parameter Specifications**

In [4]:
z_dim = 64
batch_size = 128

# **Train Classifier**

In [None]:
def train_classifier(filename):
    # Target all the classes, so that's how many the classifier will learn
    label_indices = range(40)
    
    n_epochs = 10
    lr = 0.001
    beta_1 = 0.5
    beta_2 = 0.999
    image_size = 64
    
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    
    dataloader = DataLoader(
        CelebA(".", split="train", download=True, transform=transform),
        batch_size=batch_size,
        shuffle=True
    )
    
    classifier = Classifier(n_classes=len(label_indices)).to(device)
    class_opt = torch.optim.Aam(classifier.parameters(), lr=lr, betas=(beta_1, beta_2))
    loss = nn.BCEWithLogitsLoss()
    
    classifier_losses = []
    
    for epoch in range(n_epochs):
        curr_epoch_loss = 0
        for real, labels in tqdm(dataloader):
            real = real.to(device)
            labels = labels[:, label_indices].to(device).float()
            
            class_opt.zero_grad()
            pred_labels = classifier(real)
            classifier_loss = loss(pred_labels, labels)
            classifier_loss.backward()
            class_opt.step()
            
            curr_epoch_loss += classifier_loss.item()
            
        print(f"Current epoch: {epoch}/{n_epochs}")
        classifier_losses += [curr_epoch_loss/len(dataloader)]

# **Load Model**

In [7]:
gen_path = '/kaggle/input/gan-data/pretrained_celeba.pth'
class_path = '/kaggle/input/gan-data/pretrained_classifier.pth'

gen = Generator(z_dim).to(device)
gen_dict = torch.load(gen_path, map_location=torch.device(device))["gen"]
gen.load_state_dict(gen_dict)
gen.eval()

n_classes = 40
classifier = Classifier(n_classes=n_classes).to_device()
class_dict = torch.load(class_path, map_location=torch.device(device))["classifier"]
classifier.load_state_dict(class_dict)
classifier.eval()

print("Loaded Generator and Classifier successfully!")

opt = torch.optim.Adam(classifier.parameters(), lr=0.01)

UnpicklingError: invalid load key, '<'.

# **Training Controlled Generation**

## **Equation for Gradient Ascent: new = old + (∇ old * weight)**

In [None]:
def calculate_updated_noise(noise, weight):
    new_noise = noise + (noise.grad * weight)
    return new_noise

In [None]:
# Check that it works for generated images
opt.zero_grad()
noise = get_noise(32, z_dim).to(device).requires_grad_()
fake = gen(noise)
fake_classes = classifier(fake)[:, 0]
fake_classes.mean().backward()
noise.data = calculate_updated_noise(noise, 0.01)
fake = gen(noise)
fake_classes_new = classifier(fake)[:, 0]
assert torch.all(fake_classes_new > fake_classes)
print("Success!")

In [None]:
n_images = 8
fake_image_history = []
grad_steps = 10
skip = 2

# Class Names
feature_names = ["5oClockShadow", "ArchedEyebrows", "Attractive", "BagsUnderEyes", "Bald", "Bangs",
"BigLips", "BigNose", "BlackHair", "BlondHair", "Blurry", "BrownHair", "BushyEyebrows", "Chubby",
"DoubleChin", "Eyeglasses", "Goatee", "GrayHair", "HeavyMakeup", "HighCheekbones", "Male", 
"MouthSlightlyOpen", "Mustache", "NarrowEyes", "NoBeard", "OvalFace", "PaleSkin", "PointyNose", 
"RecedingHairline", "RosyCheeks", "Sideburn", "Smiling", "StraightHair", "WavyHair", "WearingEarrings", 
"WearingHat", "WearingLipstick", "WearingNecklace", "WearingNecktie", "Young"]

# You can change this
target_indices = feature_names.index("BlondHair")

noise = get_noise(n_images, z_dim, device).requires_grad_()
for i in range(grad_steps):
    opt.zero_grad()
    fake_img = gen(noise)
    fake_image_history += [fake]
    fake_classes_score = classifier(fake_img)[:, target_indices].mean()
    fake_classes_score.backward()
    noise.data = calculate_updated_noise(noise, 1/grad_steps)
    
plt.rcParams['figure.figsize'] = [n_images * 2, grad_steps * 2]
show_tensor_images(torch.cat(fake_image_history[::skip], dim=2), num_images=n_images, nrow=n_images)

# **Entanglement and Regularization**

In [None]:
def get_score(current_classifications, original_classifications, target_indices, other_indices, penalty_weight):
    other_distances = current_classifications[:, other_distances] - original_classifications[:, other_indices]
    other_class_penalty = -torch.norm(other_distances, dim=1).mean()*penalty_weight
    target_score = current_classifications[:, target_indices].mean()
    
    return target_score + other_class_penalty

In [None]:
fake_image_history = []
target_indices = feature_names.index("WearingHat")
other_indices = [cur_idx != target_indices for cur_idx, _ in enumerate(feature_names)]
noise = get_noise(n_images, z_dim, device).requires_grad_()
original_classifications = classifier(gen(noise)).detach()

for i in range(len(grad_steps)):
    opt.zero_grad()
    fake_image = gen(noise)
    fake_image_history += [fake_image]
    fake_score = get_score(
        classifier(fake_image),
        original_classifications,
        target_indices,
        other_indices,
        penalty_weight=0.1
    )
    fake_score.backward()
    noise.data = calculate_updated_noise(noise, 1/grad_steps)