In [None]:
import os
import math
import random
import shutil
import logging
import itertools
import argparse
from PIL import Image
from pathlib import Path
from tqdm.auto import tqdm
from matplotlib import pyplot as plt

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
# from torchvision import transforms
# import torchvision.transforms.functional as TF

from transformers import AutoTokenizer
from transformers import CLIPTextModel

import diffusers
from diffusers import (
    AutoencoderKL,
    ControlNetModel,
    DDPMScheduler,
    StableDiffusionControlNetPipeline,
    UNet2DConditionModel)

DEVICE = torch.device("cuda")
MODELS = Path("/projects/p_scads_llm_secrets/models")

In [None]:
# Load the tokenizer, scheduler and models
tokenizer = AutoTokenizer.from_pretrained(MODELS / "stable-diffusion-v1-5", subfolder="tokenizer")

noise_scheduler = DDPMScheduler.from_pretrained(MODELS / "stable-diffusion-v1-5", subfolder="scheduler")

text_encoder = CLIPTextModel.from_pretrained(MODELS / "stable-diffusion-v1-5", subfolder="text_encoder")

vae = AutoencoderKL.from_pretrained(MODELS / "stable-diffusion-v1-5", subfolder="vae")

unet = UNet2DConditionModel.from_pretrained(MODELS / "stable-diffusion-v1-5", subfolder="unet")

# Load controlnet
controlnet = ControlNetModel.from_pretrained(MODELS / "sd-controlnet-scribble")

In [None]:
# Initialize pipeline with all components
pipeline = StableDiffusionControlNetPipeline(
    unet=unet,
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    scheduler=noise_scheduler,
    controlnet=controlnet,
    feature_extractor=None,
    # Safety checker is how the model knows not to generate anything objectionable
    safety_checker=None,
    requires_safety_checker=False
).to(DEVICE)

## Canvas

In [None]:
from ipywidgets import HBox
from ipycanvas import RoughCanvas, hold_canvas

In [3]:
CANVAS_WIDTH = 512
CANVAS_HEIGHT = 512

canvas = RoughCanvas(width=CANVAS_WIDTH, height=CANVAS_HEIGHT, sync_image_data=True)
drawing = False
position = None

def on_mouse_down(x, y):
    global drawing
    global position
    global shape
    drawing = True
    position = (x, y)

def on_mouse_move(x, y):
    global drawing
    global position
    global shape
    if not drawing:
        return
    with hold_canvas():
        canvas.stroke_line(position[0], position[1], x, y)
        position = (x, y)

def on_mouse_up(x, y):
    global drawing
    global position
    global shape
    drawing = False
    with hold_canvas():
        canvas.stroke_line(position[0], position[1], x, y)

canvas.on_mouse_down(on_mouse_down)
canvas.on_mouse_move(on_mouse_move)
canvas.on_mouse_up(on_mouse_up)

HBox((canvas,))

HBox(children=(RoughCanvas(height=512, sync_image_data=True, width=512),))

In [None]:
sketch = canvas.get_image_data()

## Generate

In [None]:
pipeline(
    "Drawing of something",
    sketch, 
    num_inference_steps=50, 
    height=512,
    width=512,
    guidance_scale=7.5
).images[0]