In [1]:
# Import the necessary libraries for the script
# ipywidgets for creating interactive widgets
# IPython.display for displaying output in the notebook
# torch for using PyTorch
# diffusers for using the Stable Diffusion model
# PIL for image processing
# cv2 for image processing
# numpy for numerical computations
# io for handling bytes streams
# tqdm for creating progress bars
from ipywidgets import FileUpload, Text, Button, Output, HBox, VBox
from IPython.display import display
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, DDIMScheduler
from PIL import Image
import cv2
import numpy as np
import io
from tqdm import tqdm

# Create a file upload widget to allow the user to upload an image
# The image will be used as input to the Stable Diffusion model
image_upload = FileUpload(description="Upload an image:", multiple=False)

# Create a text box for the user to enter a prompt
# The prompt will be used to guide the generation of the image
prompt_text = Text(description="Prompt:", placeholder="Enter your prompt")

# Create a button to generate the image
# When the button is clicked, the script will generate an image based on the uploaded image and the prompt
generate_button = Button(description="Generate Image")

# Create an output widget to display the generated image
output = Output()

# Define a function to generate the image when the button is clicked
def generate_image(b):
    # Clear the output widget to prepare for the new output
    with output:
        output.clear_output(wait=True)
        
        # Get the uploaded image from the file upload widget
        # The image is stored in the "content" attribute of the uploaded file
        image_file = image_upload.value[0]
        image_bytes = image_file["content"]
        
        # Open the uploaded image using PIL
        # The image is stored in a bytes stream, so we need to use the BytesIO class to read it
        image = Image.open(io.BytesIO(image_bytes))
        
        # Resize the image to 512x512 pixels this helps with memory
        # This is the size that the Stable Diffusion model expects
        image = image.resize((512, 512))
        
        # Display the uploaded image in the output widget
        display(image)

        # Convert the image to a numpy array and apply Canny edge detection
        # This will create a binary image with edges highlighted
        image_array = np.array(image)
        low_threshold = 100
        high_threshold = 200
        image_array = cv2.Canny(image_array, low_threshold, high_threshold)
        image_array = image_array[:, :, None]
        image_array = np.concatenate([image_array, image_array, image_array], axis=2)
        canny_image = Image.fromarray(image_array)

        # Get the prompt from the text box
        user_prompt = prompt_text.value

        # Load the Stable Diffusion model and control net
        # The model is used to generate images, and the control net is used to guide the generation
        controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
        pipe = StableDiffusionControlNetPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
        )
        
        # Set the scheduler for the model
        # The scheduler is used to control the generation process
        pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
        
        # Move the model to the XPU
        # This is necessary to use the model with the XPU
        pipe = pipe.to("xpu")

        # Generate the image using the Stable Diffusion pipeline
        # This will create a new image based on the uploaded image and the prompt
        with output:
            output.clear_output(wait=True)
            # Create a progress bar to display the progress of the generation
            pbar = tqdm(total=1, desc="Generating image")
            with torch.inference_mode():
                # Generate the image
                output_image = pipe(
                    user_prompt, image=canny_image
                ).images[0]
            # Update the progress bar to indicate that the generation is complete
            pbar.update(1)
            # Close the progress bar
            pbar.close()
            # Display the generated image in the output widget
            display(output_image)

# Link the button to the generate_image function
# When the button is clicked, the generate_image function will be called
generate_button.on_click(generate_image)

# Display the widgets in the notebook
# The widgets include the file upload widget, the text box, the button, and the output widget
display(VBox([image_upload, prompt_text, generate_button, output]))

VBox(children=(FileUpload(value=(), description='Upload an image:'), Text(value='', description='Prompt:', pla…