In [17]:
import torch
import os
import matplotlib.pyplot as plt
import torchvision.transforms as T
from PIL import Image
import gradio as gr
import numpy as np
from model import Generator
import pandas as pd
import torchvision.utils as vutils  # Import vutils
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the function to generate and visualize images
def generate_and_visualize(image):
    checkpoint_dir = "checkpoints/checkpoint.pth"
    G = Generator()
    
    if checkpoint_dir and os.path.exists(checkpoint_dir):
        checkpoint = torch.load(checkpoint_dir, map_location=torch.device('cpu'))
        G.load_state_dict(checkpoint['model_G_state_dict'])
        print("Checkpoint loaded successfully.")
    else:
        print("No checkpoint found. Starting from scratch.")
    
    # Convert PIL image to tensor
    transform = []
    crop_size = 178
    image_size = 128
    if(image.size[0]!=178 or image.size[1]!=218):
        transform.append(T.Resize((218, 178)))
    transform.append(T.CenterCrop(crop_size))
    transform.append(T.Resize(image_size))
    transform.append(T.ToTensor())
    transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    transform = T.Compose(transform)
    x_real = transform(image).unsqueeze(0).to(device)
    
    num_labels = 5  # Number of filter options (or classes)
    all_images = []
    generated_images = []

    for k in range(num_labels):
        # Create one-hot label for the k-th class
        one_hot_label = torch.zeros((1, num_labels)).to(device)
        one_hot_label[0, k] = 1

        # Generate the image
        with torch.no_grad():
            fake_image = G(x_real[0].unsqueeze(0).to(device), one_hot_label).detach().cpu()
        
        generated_images.append(fake_image.squeeze(0))

    normalized_images = []
    for img in generated_images:
        # Normalize the image (assuming each image is in the format CxHxW)
        normalized_img = (img - img.min()) / (img.max() - img.min())
        normalized_images.append(normalized_img)

    normalized_images = [np.transpose(np.array(img), (1, 2, 0)) for img in normalized_images]
    return normalized_images
     

desc = """"Upload a face image to visualize different attributes such as hair color and gender.
Make sure the image has face zoomed and centered on it as the model is quite trivial.
The app will generate and display five images representing different features namely- Black Hair, Blonde Hair, Brown Hair, Gender and Age.
This app uses starGAN model for image translation which has been trained on celebFaces dataest which has around 40 facial attributes"
"""

# Create the Gradio interface
iface = gr.Interface(
    fn=generate_and_visualize,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
    ],
    outputs=[
        gr.Image(type="pil", label="Black Hair"),
        gr.Image(type="pil", label="Blond Hair"),
        gr.Image(type="pil", label="Brown Hair"),
        gr.Image(type="pil", label="Male"),
        gr.Image(type="pil", label="Young")
    ],
    title="Image to Image Translation using starGAN",
    description=desc,
    
    # live=True
)

# Launch the app
if __name__ == "__main__":
    iface.launch()


Running on local URL:  http://127.0.0.1:7873

To create a public link, set `share=True` in `launch()`.


Checkpoint loaded successfully.
