In [1]:
#|default_exp app_v1

## Reimplementing DiffEdit

In this notebook we're going to reimplement the semantic image editing process illustrated in the [DiffEdit](https://arxiv.org/abs/2210.11427) paper. In the paper, the authors proposed using text input to create a mask of the queried object, and essentially using an img2img type of processing, such that changes could be made to the object without making changes to the context of the image. 

## Example

<center><img alt="DiffEdit Workflow" width="1000" src="imgs/diffusion_method2.jpg" /></center>

### Imports

In [2]:
#| export
from base64 import b64encode

import os
import numpy
import torch
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
from huggingface_hub import notebook_login
from fastai.vision.all import *

from IPython.display import HTML
from matplotlib import pyplot as plt
from pathlib import Path
from PIL import Image
from torch import autocast
from torchvision import transforms as tfms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer


import tsensor
from lolviz import *


import logging

torch_device = "cuda" if torch.cuda.is_available() else "cpu"
logging.getLogger("matplotlib.font_manager").setLevel(logging.ERROR)
logging.disable(logging.WARNING)
torch.manual_seed(1);

In [3]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)

## Load Autoencoder, VAE, ClipTokenizer, Clip Text Encoder, and Scheduler

In [2]:
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")

tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

vae = vae.to(torch_device).half()
text_encoder = text_encoder.to(torch_device).half()
unet = unet.to(torch_device).half();

NameError: name 'AutoencoderKL' is not defined

## Define Functions for Imamge -> Latent and Latent -> Image Conversion

In [5]:
def pil_to_latent(input_im):
    # Single image -> single latent in a batch, size (1, 4, 64, 64)
    with torch.no_grad():
        latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device).half()*2-1)
    return 0.18215 * latent.latent_dist.sample()

In [6]:
def latent_to_pil(latents):
    # batch of latents -> list of images
    latents = (1 / 0.18215) * latents
    with torch.no_grad():
        image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images

In [7]:
def get_prompt_embedding(prompt):
    max_length = tokenizer.model_max_length
    tokens = tokenizer(prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt")
    with torch.no_grad():
        text_embeddings = text_encoder(tokens.input_ids.to(torch_device))[0]
    return text_embeddings

### Import Image

In [8]:
#| export
im = Image.open('imgs/IMG_4104_512.jpg')
im

## Set up the Scheduler

Set the number of sampling steps

In [21]:
prompt = ["a wolf staring at the viewer, by Howard Arkley"]

Convert the prompt to a text embedding

In [3]:
text = text_encoder(prompt)

NameError: name 'text_encoder' is not defined

In [14]:
scheduler.set_timesteps(15)

Settings

Prep Text

In [22]:
prompt_embeds = get_prompt_embedding(prompt)
uncond_embeds = get_prompt_embedding('')
text_embeds = torch.cat([uncond_embeds, prompt_embeds])

Prep Scheduler

In [23]:
scheduler.set_timesteps(num_inference_steps)

In [None]:
offset = scheduler.config.get()

In [24]:
get_timesteps(scheduler, num_inference_steps, 0.5, torch_device)

NameError: name 'get_timesteps' is not defined

Prep Latents

In [25]:
import nbdev
nbdev.export.nb_export('diffedit.ipynb', 'app_v1')
print("export successful")

export successful
