# On Exact Inversion of DPM-Solvers

## CVPR anonymous submission 3894

This is a demo code that performs the exact inversion of image samples generated by DDIM and DPM-Solver++(2M) in Stable Diffusion. 

The inversion is carried out using Algorithm 1 and 2.

In [None]:
from our_functions import *

## Generation

To generate an image, use function `generate()`. Choose a random `image_num` or a specific `prompt`.
 
Choose your preffered number of steps(`num_inference_steps`) and order(`solver_order`) of DPM-solver.
The default setting is 50 steps at order 1  (DDIM). 

In [None]:
image_num = 96
image, prompt, _ = generate(image_num=image_num, num_inference_steps=50, solver_order=1)

plt.imshow(image)
plt.title("\n".join(textwrap.wrap(prompt)), wrap=True)
plt.axis('off')
plt.show()

In [None]:
prompt = 'small medieval village next to a forest'
image, prompt, _ = generate(prompt=prompt, num_inference_steps=10, solver_order=2)

plt.imshow(image)
plt.title("\n".join(textwrap.wrap(prompt)), wrap=True)
plt.axis('off')
plt.show()

# Reconstruction

Generate `orig_image(x0)`. This image is made from the initial latent `orig_noise(xt)`.

Run `exact_inversion`. This inverts `orig_image(x0)` to `recon_noise(xt)`.

- Number of steps(`test_num_inference_steps`) and order(`inv_order`) of the inversion can be chosen.
- Each inv_order of 0, 1, and 2 stands for naive DDIM inversion, Algorithm 1, and Algorithm 2.

Regenerate `recon_image(x0)` using `recon_noise(xt)`

### Algorithm 1

In [None]:
image_num = 96

pipe = stable_diffusion_pipe(solver_order=1)

# Generation
print("@@@ Generation via DDIM 50 steps")
orig_image, prompt, orig_noise = generate(image_num=image_num, 
                                          num_inference_steps=50, 
                                          solver_order=1, pipe=pipe)

# Inversion
print("@@@ Decoder inversion via gradient descent 100 steps,") 
print("@@@ and the backward Euler 50 steps")
recon_noise = exact_inversion(orig_image, 
                              prompt, 
                              test_num_inference_steps=50,
                              inv_order=1, pipe=pipe)

# Re-generation
print("@@@ Re-generation via DDIM 50 steps")
recon_image,_,_ = generate(prompt=prompt, 
                           init_latents=recon_noise, 
                           num_inference_steps=50, 
                           solver_order=1, pipe=pipe)

# Plot
plot_recon_result(orig_noise, orig_image, recon_noise, recon_image, error_scale=1, pipe=pipe)

### Algorithm 2

In [None]:
image_num = 96

pipe = stable_diffusion_pipe(solver_order=2)

# Generation
print("@@@ Generation via DPM-Solver++(2M) 10 steps")
orig_image, prompt, orig_noise = generate(image_num=image_num, 
                                          num_inference_steps=10, 
                                          solver_order=2, pipe=pipe)

# Inversion
print("@@@ Decoder inversion via gradient descent 100 steps,") 
print("@@@ and the backward Euler w/ high-order term approximation 10 steps")
recon_noise = exact_inversion(orig_image, 
                              prompt, 
                              test_num_inference_steps=10,
                              inv_order=2, pipe=pipe)

# Re-generation
print("@@@ Re-generation via DPM-Solver++(2M) 10 steps")
recon_image,_,_ = generate(prompt=prompt, 
                           init_latents=recon_noise, 
                           num_inference_steps=10, 
                           solver_order=2, pipe=pipe)

# Plot
plot_recon_result(orig_noise, orig_image, recon_noise, recon_image, error_scale=1, pipe=pipe)

### Naive DDIM inversion

In [None]:
image_num = 96

pipe = stable_diffusion_pipe(solver_order=1)

# Generation
print("@@@ Generation via DDIM 50 steps")
orig_image, prompt, orig_noise = generate(image_num=image_num, 
                                          num_inference_steps=50, 
                                          solver_order=1, pipe=pipe)

# Inversion
print("@@@ Decoder inversion via gradient descent 100 steps,") 
print("@@@ and the naive DDIM inversion 50 steps")
recon_noise = exact_inversion(orig_image, 
                              prompt, 
                              test_num_inference_steps=50,
                              inv_order=0, pipe=pipe)

# Re-generation
print("@@@ Re-generation via DDIM 50 steps")
recon_image,_,_ = generate(prompt=prompt, 
                           init_latents=recon_noise, 
                           num_inference_steps=50, 
                           solver_order=1, pipe=pipe)

# Plot
plot_recon_result(orig_noise, orig_image, recon_noise, recon_image, error_scale=1, pipe=pipe)