In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from src.model_unet_no_attention import Unet as Unet_No_Att
from src.model_unet import *
from src.eval_base import *

device = "cuda" if torch.cuda.is_available() else "cpu"

def load_model(image_size, channels,path):
    if "NoAttention" in path:
        model = Unet_No_Att(
                dim=image_size,
                channels=channels,
                dim_mults=(1, 2, 4,)
            )
    else:
        model = Unet(
            dim=image_size,
            channels=channels,
            dim_mults=(1, 2, 4,)
        )
    
    checkpoint = torch.load(path)
    state_dict = checkpoint['model_state_dict']
    model.load_state_dict(state_dict)
    
    model = model.to(device)
    return model


def prepare_data_loaders(train_data, train_labels, val_data, val_labels, test_data, test_labels, batch_size):
    train_dataset = TensorDataset(train_data, train_labels)
    val_dataset = TensorDataset(val_data, val_labels)
    test_dataset = TensorDataset(test_data, test_labels)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, val_loader, test_loader

def load_dataset(dataset_name,batch_size):
    """
    Load dataset based on selected dataset name
    """
    if dataset_name == "flowers":
        test_data = torch.load("../data/prepared_datasets/train_flowers.pt")
        test_labels = torch.load("../data/prepared_datasets/train_flowers_labels.pt")
        val_data = torch.load("../data/prepared_datasets/val_flowers.pt")
        val_labels = torch.load("../data/prepared_datasets/val_flowers_labels.pt")
        train_data = torch.load("../data/prepared_datasets/test_flowers.pt")
        train_labels = torch.load("../data/prepared_datasets/test_flowers_labels.pt")
    elif dataset_name == "celeba":
        test_data = torch.load("../data/prepared_datasets/train_celeba.pt")
        test_labels = torch.load("../data/prepared_datasets/train_celeba_labels.pt")
        val_data = torch.load("../data/prepared_datasets/val_celeba.pt")
        val_labels = torch.load("../data/prepared_datasets/val_celeba_labels.pt")
        train_data = torch.load("../data/prepared_datasets/test_celeba.pt")
        train_labels = torch.load("../data/prepared_datasets/test_celeba_labels.pt")
    train_data = (train_data - train_data.min()) / (train_data.max() - train_data.min())
    val_data = (val_data - val_data.min()) / (val_data.max() - val_data.min())
    test_data = (test_data - test_data.min()) / (test_data.max() - test_data.min())
    
    train_data = train_data * 2 - 1
    val_data = val_data * 2 - 1
    test_data = test_data * 2 - 1
    
    train_loader, val_loader, test_loader = prepare_data_loaders(train_data, train_labels, val_data, val_labels, test_data, test_labels,batch_size)
    return train_loader, val_loader, test_loader
    #return dataset_name

def generate_images_and_gif(model_name, dataset_name):
    """
    Generate images and create a GIF using the specified model and dataset
    """
    # Load model and dataset
    img_size=64
    model = load_model(64,3,f'../models/{model_name}')
    train_loader, val_loader, test_loader = load_dataset(dataset_name,32)
    timesteps = 200
    betas = DiffusionSchedule.linear_beta_schedule(timesteps).clone()
    diffusion_params = DiffusionSchedule.compute_diffusion_parameters(betas)
    samples = sample(model, image_size=img_size,diffusion_params=diffusion_params, batch_size=32, channels=3)
    images = samples[-1][:9]
    # Generate samples (replace with your actual generation method)
   #eval_base = Evaluator(model=model)
   #generated_images = eval_base.generate_images(num_images=9, image_size=6)
   #generated_images = generated_images.to(device)
    #samples = fast_generate_images(model, num_images=9, image_size=256)
    
    # Create grid of final images
    grid_image = create_image_grid(samples,img_size)
    
    # Create GIF
    fig, ax = plt.subplots()
    ax.axis('off')  # Turn off axes
    fig.patch.set_alpha(0)  # Remove background
    ims = []
    
    
    for i in range(200):
        img = samples[i][1]
        img = img.cpu().numpy()
        img = np.transpose(img, (1, 2, 0)) 
        #img = samples[i][random_index].reshape(image_size, image_size, channels)
        img = (img - img.min()) / (img.max() - img.min())
        im = plt.imshow(img, animated=True)
        ims.append([im])
    
    animate = animation.ArtistAnimation(fig, ims, interval=5, blit=True, repeat_delay=1000)
    gif_path= '../gifs/diffusion.gif'
    animate.save(gif_path)
    plt.close(fig)
    
    return grid_image, gif_path

def create_image_grid(samples, img_size, padding=5, bg_color=0.9):
    """
    Create a grid of generated images with padding
    
    Args:
        samples: Generated image samples
        img_size: Size of each image
        padding: Pixel width of padding between images
        bg_color: Background color for padding (0.5 is neutral gray)
    """
    # Convert samples to numpy for visualization
    samples_np = samples[-1].cpu().numpy()
    
    # Create a grid of 9 images (3x3)
    grid_size = 3
    
    # Calculate total grid size with padding
    total_size = img_size * grid_size + padding * (grid_size + 1)
    grid_image = np.full((total_size, total_size, 3), bg_color)
    
    for i in range(grid_size):
        for j in range(grid_size):
            idx = i * grid_size + j
            img = samples_np[idx]
            img = np.transpose(img, (1, 2, 0))
            img = (img - img.min()) / (img.max() - img.min())
            
            # Calculate positioning with padding
            start_y = padding + i * (img_size + padding)
            start_x = padding + j * (img_size + padding)
            
            grid_image[
                start_y:start_y+img_size, 
                start_x:start_x+img_size
            ] = img
    
    return grid_image

def create_gradio_interface():
    # List available models and datasets
    models = [f for f in os.listdir('../models') if os.path.isfile(os.path.join('../models', f))]
    datasets = ['flowers', 'celeba']
    
    # Create Gradio interface
    with gr.Blocks() as demo:
        with gr.Row():
            model_dropdown = gr.Dropdown(choices=models, label="Select Model")
            dataset_dropdown = gr.Dropdown(choices=datasets, label="Select Dataset")
        
        generate_btn = gr.Button("Generate Images")
        
        grid_output = gr.Image(label="Image Grid")
        gif_output = gr.Image(label="Diffusion GIF")
        img_size =64
        generate_btn.click(
            fn=generate_images_and_gif, 
            inputs=[model_dropdown, dataset_dropdown],
            outputs=[grid_output, gif_output]
        )
    
    return demo

# Launch the interface
if __name__ == "__main__":
    interface = create_gradio_interface()
    interface.launch(share=True, allowed_paths=["/app/gifs"])

  from .autonotebook import tqdm as notebook_tqdm


* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://6af2ab2ce20be5d3c6.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)



ampling loop time step: 100%|██████████| 200/200 [00:20<00:00,  9.95it/s]