In [None]:
# Standard imports
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image
import tqdm

# Hugging Face Hub import

# Diffusers-specific imports
from diffusers import StableDiffusionPipeline, DDIMScheduler

# Custom modules
from models import UNETLatentEdgePredictor, SketchSimplificationNetwork
from pipeline import SketchGuidedText2Image


In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
# Configure and load sketch simplification network 

sketch_simplifier = SketchSimplificationNetwork().to(device)
sketch_simplifier.load_state_dict(torch.load("models-checkpoints/model_gan.pth"))

sketch_simplifier.eval()
sketch_simplifier.requires_grad_(False)

In [None]:
# Load Stable Diffusion Pipeline
stable_diffusion_1_5 = "benjamin-paine/stable-diffusion-v1-5"

In [None]:
stable_diffusion=StableDiffusionPipeline.from_pretrained(
    stable_diffusion_1_5,
    torch_dtype=torch.float16,
    safety_checker=None  # Skip the safety checker if it's not required
)
vae = stable_diffusion.vae.to(device)
unet = stable_diffusion.unet.to(device)
tokenizer = stable_diffusion.tokenizer
text_encoder = stable_diffusion.text_encoder.to(device) 

vae.eval()
unet.eval()
text_encoder.eval()

text_encoder.requires_grad_(False)
unet.requires_grad_(False)

In [None]:
# Load U-Net latent edge predictor
checkpoint = torch.load("models-checkpoints/unet_latent_edge_predictor_checkpoint.pt",map_location=torch.device('cpu'))

LEP_UNET = UNETLatentEdgePredictor(9320, 4, 9).to(device)
LEP_UNET.load_state_dict(checkpoint["model_state_dict"])

LEP_UNET.eval()
LEP_UNET.requires_grad_(False)

In [None]:
import numpy 
# Set Scheduler
noise_scheduler = DDIMScheduler(
        beta_start = 0.00085,
        beta_end = 0.012,
        beta_schedule = "scaled_linear",
        num_train_timesteps = 1000,
        clip_sample = False,
    )

In [None]:
import numpy
print(numpy.__version__)

In [None]:
# Initialize Text-guided Text-to-Image synthesis pipeline
import tqdm

pipeline = SketchGuidedText2Image(stable_diffusion_pipeline = stable_diffusion, 
                                  unet = unet, vae = vae, 
                                  text_encoder = text_encoder, 
                                  lep_unet = LEP_UNET, scheduler = noise_scheduler, 
                                  tokenizer = tokenizer,
                                  sketch_simplifier = sketch_simplifier,
                                  device = device)

In [None]:
import itertools
import matplotlib.pyplot as plt
from PIL import Image

def tune_hyperparameters(pipeline, prompt, edge_maps, seed=None):
    """
    Performs a grid search over hyperparameters for the inference process.
    
    Parameters:
      pipeline: an instance of your SketchGuidedText2Image pipeline.
      prompt: list or string prompt.
      edge_maps: list of PIL Images for edge maps.
      seed: optional seed for reproducibility.
    
    Returns:
      results: a list of dictionaries, each containing the hyperparameters and the output images.
    """
    # Define the grid of hyperparameters to explore
    num_inference_timesteps_list = [12,15,20]  # using 12 steps for training constraints
    classifier_guidance_strength_list = [4, 6, 8]  # try a few values lower than the demo’s 8
    sketch_guidance_strength_list = [0.5, 1.0, 0.7, 1.6]  # scaled down values relative to the original demo
    guidance_steps_perc_list = [0.5]  # keeping relative percentage the same

    results = []
    
    # Loop over all combinations
    for (num_steps, cls_str, skt_str, gd_perc) in itertools.product(
            num_inference_timesteps_list,
            classifier_guidance_strength_list,
            sketch_guidance_strength_list,
            guidance_steps_perc_list):
        
        print(f"Testing: num_steps={num_steps}, classifier={cls_str}, sketch={skt_str}, guidance%={gd_perc}")
        
        # Call the Inference function from your pipeline
        output = pipeline.Inference(
            prompt=prompt,
            num_images_per_prompt=1,
            edge_maps=edge_maps,
            negative_prompt=None,
            num_inference_timesteps=num_steps,
            classifier_guidance_strength=cls_str,
            sketch_guidance_strength=skt_str,
            seed=seed,
            simplify_edge_maps=False,
            guidance_steps_perc=gd_perc,
        )
        
        results.append({
            "num_steps": num_steps,
            "classifier_guidance_strength": cls_str,
            "sketch_guidance_strength": skt_str,
            "guidance_steps_perc": gd_perc,
            "result": output  # output contains keys like "generated_image" and "edge_map"
        })
    
    return results

In [None]:
from tqdm import tqdm

# Make sure your pipeline instance is already created and configured as in your demo
prompt = ["Lego car"]
edge_maps = [Image.open("Lego_256x256/sketch/Car-1.jpg/sketchs32strokes_Car-1.jpg")]
seed = 1000

inverse_diffusion = pipeline.Inference(
    prompt=["Lego car "],
    num_images_per_prompt=1,
    edge_maps=edge_maps,
    negative_prompt=None,
    num_inference_timesteps=50,
    classifier_guidance_strength=8,
    sketch_guidance_strength=1.6,
    seed=seed,
    simplify_edge_maps=False,
    guidance_steps_perc=0.5,
)

# Run the hyperparameter tuner
tuning_results = tune_hyperparameters(pipeline, prompt, edge_maps, seed=seed)

# Visualize the outputs
n_results = len(tuning_results)
fig, axs = plt.subplots(1, n_results, figsize=(4 * n_results, 4))
if n_results == 1:
    axs = [axs]
for ax, res in zip(axs, tuning_results):
    # Assume the generated image is under the key "generated_image" and is a list of PIL images
    generated_img = res["result"]["generated_image"][0]
    ax.imshow(generated_img)
    ax.axis("off")
    title = (f"Steps: {res['num_steps']}\n"
             f"Cls: {res['classifier_guidance_strength']}\n"
             f"Sketch: {res['sketch_guidance_strength']}\n"
             f"Perc: {res['guidance_steps_perc']}")
    ax.set_title(title)
plt.show()