In [6]:
import torch
import requests
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from io import BytesIO
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
from torchvision import transforms as tfms
from diffusers import StableDiffusionPipeline, DDIMScheduler

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda:0" if torch.cuda.is_available() else "cpu")

In [7]:
# Useful function for later
def load_image(url, size=None):
    response = requests.get(url, timeout=0.2)
    img = Image.open(BytesIO(response.content)).convert("RGB")
    if size is not None:
        img = img.resize(size)
    return img

In [8]:
# Sample function (regular DDIM)

def sample(
    prompt,
    start_step=0,
    start_latents=None,
    guidance_scale=3.5,
    num_inference_steps=30,
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
    negative_prompt="",
    device=device,
):

    # Encode prompt
    text_embeddings = pipe._encode_prompt(
        prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
    )

    # Set num inference steps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)

    # Create a random starting point if we don't have one already
    if start_latents is None:
        start_latents = torch.randn(1, 4, 64, 64, device=device)
        start_latents *= pipe.scheduler.init_noise_sigma

    latents = start_latents.clone()


    # We'll keep a list of the inverted latents as the process goes on
    intermediate_latents = []

    for i in tqdm(range(start_step, num_inference_steps)):

        t = pipe.scheduler.timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # Perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # Normally we'd rely on the scheduler to handle the update step:
        # latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample

        # Instead, let's do it ourselves:
        prev_t = max(1, t.item() - (1000 // num_inference_steps))  # t-1
        alpha_t = pipe.scheduler.alphas_cumprod[t.item()]
        alpha_t_prev = pipe.scheduler.alphas_cumprod[prev_t]
        predicted_x0 = (latents - (1 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt()
        direction_pointing_to_xt = (1 - alpha_t_prev).sqrt() * noise_pred
        latents = alpha_t_prev.sqrt() * predicted_x0 + direction_pointing_to_xt

    # Post-processing
    images = pipe.decode_latents(latents)
    images = pipe.numpy_to_pil(images)

    return images, intermediate_latents


In [9]:
## Inversion
def invert(
    start_latents,
    prompt,
    guidance_scale=3.5,
    num_inference_steps=80,
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
    negative_prompt="",
    device=device,
):

    # Encode prompt
    text_embeddings = pipe._encode_prompt(
        prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
    )

    # Latents are now the specified start latents
    latents = start_latents.clone()

    # We'll keep a list of the inverted latents as the process goes on
    intermediate_latents = []

    # Set num inference steps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)

    # Reversed timesteps <<<<<<<<<<<<<<<<<<<<
    timesteps = reversed(pipe.scheduler.timesteps)

    for i in tqdm(range(1, num_inference_steps), total=num_inference_steps - 1):

        # We'll skip the final iteration
        if i >= num_inference_steps - 1:
            continue

        t = timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # Perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        current_t = max(0, t.item() - (1000 // num_inference_steps))  # t
        next_t = t  # min(999, t.item() + (1000//num_inference_steps)) # t+1
        alpha_t = pipe.scheduler.alphas_cumprod[current_t]
        alpha_t_next = pipe.scheduler.alphas_cumprod[next_t]

        # Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents)
        latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * (alpha_t_next.sqrt() / alpha_t.sqrt()) + (
            1 - alpha_t_next
        ).sqrt() * noise_pred

        # Store
        intermediate_latents.append(latents)

    return torch.cat(intermediate_latents)


In [10]:
MASTER_TYPE = torch.float16
with torch.inference_mode():
    pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=MASTER_TYPE).to(device)

    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

    #input_image = load_image("https://images.pexels.com/photos/8306128/pexels-photo-8306128.jpeg", size=(512, 512))
    input_image = Image.open("input_dog.jpeg").resize((512, 512))

    input_image_prompt = "Photograph of a puppy on the grass"

    # Encode with VAE
    with torch.no_grad():
        latent = pipe.vae.encode((tfms.functional.to_tensor(input_image).unsqueeze(0).to(device)* 2 - 1).half())
    latent = 0.18215 * latent.latent_dist.sample()

    inverted_latents = invert(latent, input_image_prompt, num_inference_steps=50)
    inverted_latents.shape


Couldn't connect to the Hub: 404 Client Error. (Request ID: Root=1-66e6ff56-5fa48a62638529ba1286655f;0caae47f-24fd-413d-96fb-005d07a53c4c)

Repository Not Found for url: https://huggingface.co/api/models/runwayml/stable-diffusion-v1-5.
Please make sure you specified the correct `repo_id` and `repo_type`.
If you are trying to access a private or gated repo, make sure you are authenticated..
Will try to load from local cache.


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

DeferredCudaCallError: CUDA call failed lazily at initialization with error: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":50, please report a bug to PyTorch. 

CUDA call was originally invoked at:

['  File "<frozen runpy>", line 198, in _run_module_as_main\n', '  File "<frozen runpy>", line 88, in _run_code\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>\n    app.launch_new_instance()\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/traitlets/config/application.py", line 1043, in launch_instance\n    app.start()\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 736, in start\n    self.io_loop.start()\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 195, in start\n    self.asyncio_loop.run_forever()\n', '  File "/usr/lib/python3.11/asyncio/base_events.py", line 608, in run_forever\n    self._run_once()\n', '  File "/usr/lib/python3.11/asyncio/base_events.py", line 1936, in _run_once\n    handle._run()\n', '  File "/usr/lib/python3.11/asyncio/events.py", line 84, in _run\n    self._context.run(self._callback, *self._args)\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 516, in dispatch_queue\n    await self.process_one()\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 505, in process_one\n    await dispatch(*args)\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 412, in dispatch_shell\n    await result\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 740, in execute_request\n    reply_content = await reply_content\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 422, in do_execute\n    res = shell.run_cell(\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 546, in run_cell\n    return super().run_cell(*args, **kwargs)\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3009, in run_cell\n    result = self._run_cell(\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3064, in _run_cell\n    result = runner(coro)\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner\n    coro.send(None)\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3269, in run_cell_async\n    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3448, in run_ast_nodes\n    if await self.run_code(code, result, async_=asy):\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code\n    exec(code_obj, self.user_global_ns, self.user_ns)\n', '  File "/tmp/ipykernel_1949919/828364716.py", line 1, in <module>\n    import torch\n', '  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load\n', '  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked\n', '  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked\n', '  File "<frozen importlib._bootstrap_external>", line 940, in exec_module\n', '  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/torch/__init__.py", line 1146, in <module>\n    _C._initExtension(manager_path())\n', '  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load\n', '  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked\n', '  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked\n', '  File "<frozen importlib._bootstrap_external>", line 940, in exec_module\n', '  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/torch/cuda/__init__.py", line 197, in <module>\n    _lazy_call(_check_capability)\n', '  File "/home/vll/venv_pytorch2.0/lib/python3.11/site-packages/torch/cuda/__init__.py", line 195, in _lazy_call\n    _queued_calls.append((callable, traceback.format_stack()))\n']