## Multi-Stage Blended Diffusion

Before using this notebook, please install the environment specified in `environment.yaml`
```
conda env create -f environment.yaml
```
And download the StableDiffusion and RealESRGAN models as specified in the README.md

In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

In [None]:
import logging
import time

import ipycanvas  # import Canvas, hold_canvas
import ipywidgets  # import AppLayout, ColorPicker, HBox, Image, IntSlider, link
import numpy as np
import torch
from PIL import Image, ImageDraw
from einops import rearrange
from io import BytesIO
import requests
from IPython.display import clear_output

def tensor_to_pil(arr):
    arr = arr.cpu().numpy()[0]
    arr = (arr + 1.0) / 2.0
    arr = 255. * rearrange(arr, 'c h w -> h w c')
    return Image.fromarray(arr.astype(np.uint8))

In [None]:
%%time
from msbd.MSBDGenerator import MSBDGenerator
generator = MSBDGenerator(use_fp16=True, stable_diffusion=True)
clear_output(wait=True)
print('done')

In [None]:
logger = logging.getLogger()
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)

In [None]:
url = 'inputs/marunouchi.png'

def get_img_from_url(url, rescaled_width=None):
    if 'http' in url:
        response = requests.get(url)
        pil_img = Image.open(BytesIO(response.content)).convert("RGB")
    elif url.startswith('NEW'):
        size_str = url[3:]
        new_w, new_h = size_str.split('x')
        pil_img = Image.new('RGB',[int(new_w),int(new_h)],'gray')
    else:
        pil_img = Image.open(url).convert("RGB")
    # pil_img = Image.open('Joe_Biden_presidential_portrait.jpg').convert("RGB")
    
    rescaled_max_edge = None
    # rescaled_max_edge = 2048
    if rescaled_max_edge is not None and max(pil_img.size) > rescaled_max_edge:
        original_size = pil_img.size
        if original_size[0] > original_size[1]:
            scale_factor = original_size[0] / rescaled_max_edge
            rescaled_size = [rescaled_max_edge, int(original_size[1] / scale_factor)]
        else:
            scale_factor = original_size[1] / rescaled_max_edge
            rescaled_size = [int(original_size[0] / scale_factor), rescaled_max_edge]

        pil_img = pil_img.resize(rescaled_size)  # consider resizing the source image to a lower resolution for generally better results.
    original_size = pil_img.size

    display_max_edge = 1024
    if original_size[0] > original_size[1]:
        scale_factor = original_size[0] / display_max_edge
        display_size = [display_max_edge, int(original_size[1] / scale_factor)]
    else:
        scale_factor = original_size[1] / display_max_edge
        display_size = [int(original_size[0] / scale_factor), display_max_edge]
    
    with BytesIO() as output:
        pil_img.resize(display_size).save(output, format="PNG")
        # contents = output.getvalue()
        ipyimg = ipywidgets.Image(value=output.getvalue())
    # display(pil_img.resize(display_size))
    print("original size", original_size)
    return pil_img, original_size, display_size

pil_img, original_size, display_size = get_img_from_url(url)

The next cell defines and starts the UI

In [None]:
# drawing method based on  https://github.com/martinRenou/ipycanvas/blob/master/examples/hand_drawing.ipynb
width = display_size[0]
height = display_size[1]

canvas = ipycanvas.MultiCanvas(3, width=width, height=height, sync_image_data=True)
    # layer 0 is the image, layer 1 is the drawing and layer 2 is the merge that is displayed

for c in canvas:
    c.sync_image_data = True
canvas[1].fill_rect(0, 0, width, height)


alpha_slider = ipywidgets.FloatSlider(
    description="Preview Alpha:", value="0.3", min=0.0, max=1.0
)
repaint_slider = ipywidgets.IntSlider(
    description="Repaint Steps:", value="0", min=0, max=20
)
eraser_cb = ipywidgets.Checkbox(description="Erase", value=False)
refresh_button = ipywidgets.Button(description="Refresh Merged Layer")
prompt_box = ipywidgets.Textarea(description="Prompt", value='Statue of Roman Emperor, Canon 5D Mark 3, 35mm, flickr')
seed_box = ipywidgets.Textarea(description="Seed (int)", value='1234')
starttime_slider = ipywidgets.FloatSlider(
    description="StartTimestep:", value="1.0", min=0.0, max=1.0
)
margin_slider = ipywidgets.FloatSlider(
    description="Margin-Multiplier:", value="1.2", min=1.0, max=2.0
)
step_slider = ipywidgets.IntSlider(
    description="Total Steps:", value="50", min=0, max=200, step=10,
)
decoder_opt_cbox = ipywidgets.Checkbox(value=False, description='Decoder Optimization')
clip_rerank_cbox = ipywidgets.Checkbox(value=True, description='Clip Reranking')
run_button = ipywidgets.Button(description="RUN")
reset_mask_btn = ipywidgets.Button(description="Reset Mask")
mask_everything_btn = ipywidgets.Button(description="Mask Everything")
load_mask_btn = ipywidgets.Button(description="Load Mask")
save_mask_btn = ipywidgets.Button(description="Save Mask")
upscaling_ddwn = ipywidgets.Dropdown(description="Upscaling", options=['bilinear', 'sharpen', 'esrgan'], value='esrgan')

url_textinput = ipywidgets.Text(value=url, description='Input (Path or URL)')



def load_img_cb(change):
    global pil_img
    global original_size
    global display_size
    global url
    url = change.value
    pil_img, original_size, display_size = get_img_from_url(change.value)
    load_new_img()
    # reset_mask()
    

def update_merge():
    pil_buff = Image.fromarray(
        np.clip(
            (
                (1 - alpha_slider.value) * canvas[0].get_image_data()
                + canvas[1].get_image_data() * alpha_slider.value
            ).astype(np.uint8),
            0,
            255,
        )
    )
    with BytesIO() as output:
        pil_buff.save(output, format="PNG")
        # contents = output.getvalue()
        ipyimg = ipywidgets.Image(value=output.getvalue())
    canvas[2].draw_image(ipyimg)


drawing = False
position = None
shape = []

def on_mouse_down(x, y):
    global drawing
    global position
    global shape

    drawing = True
    position = (x, y)
    shape = [position]

def on_mouse_move(x, y):
    global drawing
    global position
    global shape

    if not drawing:
        return
    
    with ipycanvas.hold_canvas():
        canvas[2].stroke_line(position[0], position[1], x, y)

        position = (x, y)

    shape.append(position)

def on_mouse_up(x, y):
    global drawing
    global position
    global shape

    drawing = False
    with ipycanvas.hold_canvas():
        canvas[2].stroke_line(position[0], position[1], x, y)
        canvas[2].fill_polygon(shape)
        canvas[1].fill_polygon(shape)

    shape = []

canvas[2].on_mouse_down(on_mouse_down)
canvas[2].on_mouse_move(on_mouse_move)
canvas[2].on_mouse_up(on_mouse_up)
canvas[2].stroke_style = "#749cb8"

def set_eraser(event):
    if event.new:
        # canvas[1].stroke_style = "#ffffff"
        # canvas[1].fill_style = "#ffffff"
        canvas[1].stroke_style = "#000000"
        canvas[1].fill_style = "#000000"
        canvas[2].stroke_style = "#ffffff"
        canvas[2].fill_style = "#ffffff"
    else:
        canvas[1].stroke_style = "#00ff00"
        canvas[1].fill_style = "#00ff00"
        canvas[2].stroke_style = "#00ff00"
        canvas[2].fill_style = "#00ff00"


def refresh_button_cb(btn):
    update_merge()

def reset_mask_cb(btn):
    reset_mask()

def load_mask_cb(btn):
    pil_mask = Image.open('nbimg_mask.png').resize([width,height]).convert('RGB')
    np_mask = np.array(pil_mask)
    np_mask[:,:,0] = 0
    np_mask[:,:,2] = 0
    pil_mask = Image.fromarray(np_mask)
    display(pil_mask)
    with BytesIO() as output:
        pil_mask.save(output, format="PNG")
        # contents = output.getvalue()
        ipyimg = ipywidgets.Image(value=output.getvalue())
    canvas[1].draw_image(ipyimg)
    logger.info('drew mask to canvas 1')
    time.sleep(0.3)
    update_merge()

def save_mask_cb(btn):
    logger.info('Saving mask to nb_save_mask.png')
    time.sleep(0.3) # wait for mask updates
    mask = np.sum(canvas[1].get_image_data()[:, :, :3], 2).astype(np.uint8)
    mask[mask >= 1] = 255
    if not mask.sum():
        logger.warn('Attempted save with empty mask')
        btn.style.button_color = "red"
        time.sleep(0.2)
        btn.style.button_color = None
        return
    btn.style.button_color = "lightgreen"
    pil_mask = Image.fromarray(mask).resize(
        original_size, resample=Image.Resampling.NEAREST
    )
    pil_mask.save('nb_save_mask.png')
    time.sleep(0.3) # wait for mask updates
    btn.style.button_color = None

    
def mask_everything_cb(btn):
    reset_mask(fill_green=True)
    
def reset_mask(fill_green=False):
    buff_fill_style = canvas[1].fill_style
    if fill_green:
        canvas[1].fill_style = "#00ff00"
    else:
        canvas[1].fill_style = "#000000"
    canvas[1].fill_rect(0, 0, display_size[0], display_size[1])
    canvas[1].fill_style = buff_fill_style

    
def run_btn_cb(btn):
    logger.info('run button pressed')
    time.sleep(1) # wait for mask updates
    global pil_img
    mask = np.sum(canvas[1].get_image_data()[:, :, :3], 2).astype(np.uint8)
    mask[mask >= 1] = 255
    if not mask.sum():
        logger.warn('Attempted run with empty mask')
        btn.style.button_color = "red"
        time.sleep(0.2)
        btn.style.button_color = None
        return
    btn.style.button_color = "lightgreen"
    pil_mask = Image.fromarray(mask).resize(
        original_size, resample=Image.Resampling.NEAREST
    )
    pil_img.save('nbimg.png')
    pil_mask.save('nbimg_mask.png')

    pil_result = generator.multi_scale_generation(
        pil_img, 
        pil_mask,
        prompt_box.value, 
        margin_slider.value, 
        decoder_opt_cbox.value,
        starttime_slider.value,
        repaint_steps = repaint_slider.value,
        upscaling_start_step = 0.4,
        tight_stitch=True,
        blended_upscale=True,
        ddim_steps=step_slider.value,
        clip_reranking=clip_rerank_cbox.value,
        straight_to_grid=False,
        seed=int(seed_box.value),
        upscaling_mode=upscaling_ddwn.value,
        debug_outputs=False
    )
    # tensor_res_img = run_full_model(prompt_box.value, pil_mask, margin_slider.value)
    pil_img = pil_result
    load_new_img()
    reset_mask()
    btn.style.button_color = None

def load_new_img():
    with BytesIO() as output:
        pil_img.resize([width, height]).save(output, format="PNG")
        # contents = output.getvalue()
        ipyimg = ipywidgets.Image(value=output.getvalue())
    canvas.width = display_size[0]
    canvas.height = display_size[1]
    canvas.layout.width = "auto"
    canvas.layout.height = "30%"
    for c in canvas:
        c.width = display_size[0]
        c.height = display_size[1]
        c.layout.width = "auto"
        c.layout.height = "30%"

    # set_canvas_cb_and_style(canvas)
    canvas[2].draw_image(ipyimg)
    canvas[0].draw_image(ipyimg)
    

def set_canvas_cb_and_style(canvas):
    canvas[2].on_mouse_down(on_mouse_down)
    canvas[2].on_mouse_move(on_mouse_move)
    canvas[2].on_mouse_up(on_mouse_up)

    canvas[1].stroke_style = "#00ff00"
    canvas[1].fill_style = "#00ff00"
    canvas[2].stroke_style = "#00ff00"
    canvas[2].fill_style = "#00ff00"

# set_canvas_cb_and_style(canvas)

eraser_cb.observe(set_eraser, names="value")
alpha_slider.on_trait_change(update_merge)
refresh_button.on_click(refresh_button_cb)
run_button.on_click(run_btn_cb)
url_textinput.on_submit(load_img_cb)
reset_mask_btn.on_click(reset_mask_cb)
mask_everything_btn.on_click(mask_everything_cb)
load_mask_btn.on_click(load_mask_cb)
save_mask_btn.on_click(save_mask_cb)

load_new_img()
set_canvas_cb_and_style(canvas)
layout = ipywidgets.VBox(
    (
        url_textinput,
        ipywidgets.HBox((
        canvas[2],
        # canvas[1],
        )),
        ipywidgets.VBox(
            (
                ipywidgets.HBox((
                eraser_cb,
                alpha_slider,
                )),
                ipywidgets.HBox((
                    refresh_button,
                    reset_mask_btn,
                    mask_everything_btn,
                    load_mask_btn,
                    save_mask_btn
                )),
                step_slider,
                ipywidgets.HBox((
                    starttime_slider,
                    repaint_slider,
                    margin_slider,
                )),
                ipywidgets.HBox((
                    decoder_opt_cbox,
                    clip_rerank_cbox,
                )),
                ipywidgets.HBox((
                prompt_box, seed_box, upscaling_ddwn
                )),
                run_button
            )
        ),
    )
)
display(layout)


The UI supports masking in a "lasso"-like way, meaning you can draw a circle around the area you would like to mask. If you check the "Erase" box, the area is removed from the mask instead.

To load a new image, input the URL or file path into the `Input` fiel and hit enter.

If the aspect ratio is changed by loading a new image, please rerun the above cell to refresh the UI.

 * `StartTimestep` determines the SDEdit-like start-timestep for the first stage, i.e. how much the output is based on the input *within* the masked region.
 * `Total Steps` determines the number of diffusion timesteps used, decreasing this decreases the runtime but may impact results.
 * `Repaint Steps` determines the number of repaint steps in the first stage, leading to usually better blending but also longer runtime
 * `Margin-Multiplier` determines the margin around the mask as a multiplier of the original mask size
 * `Decoder Optimization` enables or disables decoder optimization, leading to better blending with the background at the cost of significantly longer runtimes for high resolutions.
 * `Clip Reranking` enables clip reranking after the first stage. This usually doesn't have a very high impact on the runtime.
 * `Upscaling` allows you to chose between different upscaling modes at the start of each stage, but usually "esrgan" works best
 
 The UI sometimes lags a bit behind, so please press `Refresh Merged Layer` before pressing `RUN`, to ensure that the mask has been updated.
 If it has not been updated, just try refreshing it again.

In [None]:
# helpful after canceling a run by interrupt
import gc
torch.cuda.empty_cache()
gc.collect()