Skip to content

Commit

Permalink
Standardize on using image argument in all pipelines (huggingface#1361
Browse files Browse the repository at this point in the history
)

* feat: switch core pipelines to use image arg

* test: update tests for core pipelines

* feat: switch examples to use image arg

* docs: update docs to use image arg

* style: format code using black and doc-builder

* fix: deprecate use of init_image in all pipelines
  • Loading branch information
fboulnois committed Dec 1, 2022
1 parent 1f517a5 commit d4f0742
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 107 deletions.
2 changes: 1 addition & 1 deletion pipelines/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ init_image = init_image.resize((768, 512))

prompt = "A fantasy landscape, trending on artstation"

images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images

images[0].save("fantasy_landscape.png")
```
Expand Down
39 changes: 22 additions & 17 deletions pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,26 +435,26 @@ def get_timesteps(self, num_inference_steps, strength, device):

return timesteps, num_inference_steps - t_start

def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
init_image = init_image.to(device=device, dtype=dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
image = image.to(device=device, dtype=dtype)
init_latent_dist = self.vae.encode(image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents

if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
deprecation_message = (
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many init images as text prompts to suppress this warning."
" your script to pass as many initial images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = batch_size // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
)
else:
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
Expand All @@ -472,7 +472,7 @@ def prepare_latents(self, init_image, timestep, batch_size, num_images_per_promp
def __call__(
self,
prompt: Union[str, List[str]],
init_image: Union[torch.FloatTensor, PIL.Image.Image],
image: Union[torch.FloatTensor, PIL.Image.Image],
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
Expand All @@ -484,22 +484,23 @@ def __call__(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
`init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
noise will be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter will be modulated by `strength`.
Expand Down Expand Up @@ -540,6 +541,10 @@ def __call__(
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
image = init_image or image

# 1. Check inputs
self.check_inputs(prompt, strength, callback_steps)

Expand All @@ -557,8 +562,8 @@ def __call__(
)

# 4. Preprocess image
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)
if isinstance(image, PIL.Image.Image):
image = preprocess(image)

# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand All @@ -567,7 +572,7 @@ def __call__(

# 6. Prepare latent variables
latents = self.prepare_latents(
init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
)

# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import PIL_INTERPOLATION
from ...utils import PIL_INTERPOLATION, deprecate


def preprocess(image):
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
@torch.no_grad()
def __call__(
self,
init_image: Union[torch.Tensor, PIL.Image.Image],
image: Union[torch.Tensor, PIL.Image.Image],
batch_size: Optional[int] = 1,
num_inference_steps: Optional[int] = 100,
eta: Optional[float] = 0.0,
Expand All @@ -77,7 +77,7 @@ def __call__(
) -> Union[Tuple, ImagePipelineOutput]:
r"""
Args:
init_image (`torch.Tensor` or `PIL.Image.Image`):
image (`torch.Tensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
batch_size (`int`, *optional*, defaults to 1):
Expand All @@ -102,20 +102,21 @@ def __call__(
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
image = init_image or image

if isinstance(init_image, PIL.Image.Image):
if isinstance(image, PIL.Image.Image):
batch_size = 1
elif isinstance(init_image, torch.Tensor):
batch_size = init_image.shape[0]
elif isinstance(image, torch.Tensor):
batch_size = image.shape[0]
else:
raise ValueError(
f"`init_image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(init_image)}"
)
raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}")

if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)
if isinstance(image, PIL.Image.Image):
image = preprocess(image)

height, width = init_image.shape[-2:]
height, width = image.shape[-2:]

# in_channels should be 6: 3 for latents, 3 for low resolution image
latents_shape = (batch_size, self.unet.in_channels // 2, height, width)
Expand All @@ -128,7 +129,7 @@ def __call__(
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)

init_image = init_image.to(device=self.device, dtype=latents_dtype)
image = image.to(device=self.device, dtype=latents_dtype)

# set timesteps and move to the correct device
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
Expand All @@ -148,7 +149,7 @@ def __call__(

for t in self.progress_bar(timesteps_tensor):
# concat latents and low resolution image in the channel dimension.
latents_input = torch.cat([latents, init_image], dim=1)
latents_input = torch.cat([latents, image], dim=1)
latents_input = self.scheduler.scale_model_input(latents_input, t)
# predict the noise residual
noise_pred = self.unet(latents_input, t).sample
Expand Down
4 changes: 2 additions & 2 deletions pipelines/stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ prompt = "An astronaut riding an elephant"
image = pipe(
prompt=prompt,
source_prompt=source_prompt,
init_image=init_image,
image=init_image,
num_inference_steps=100,
eta=0.1,
strength=0.8,
Expand All @@ -164,7 +164,7 @@ torch.manual_seed(0)
image = pipe(
prompt=prompt,
source_prompt=source_prompt,
init_image=init_image,
image=init_image,
num_inference_steps=100,
eta=0.1,
strength=0.85,
Expand Down
39 changes: 22 additions & 17 deletions pipelines/stable_diffusion/pipeline_cycle_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,26 +477,26 @@ def get_timesteps(self, num_inference_steps, strength, device):

return timesteps, num_inference_steps - t_start

def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
init_image = init_image.to(device=device, dtype=dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
image = image.to(device=device, dtype=dtype)
init_latent_dist = self.vae.encode(image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents

if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
deprecation_message = (
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many init images as text prompts to suppress this warning."
" your script to pass as many initial images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = batch_size // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
)
else:
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
Expand All @@ -516,7 +516,7 @@ def __call__(
self,
prompt: Union[str, List[str]],
source_prompt: Union[str, List[str]],
init_image: Union[torch.FloatTensor, PIL.Image.Image],
image: Union[torch.FloatTensor, PIL.Image.Image],
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
Expand All @@ -528,22 +528,23 @@ def __call__(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
`init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
noise will be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter will be modulated by `strength`.
Expand Down Expand Up @@ -584,6 +585,10 @@ def __call__(
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
image = init_image or image

# 1. Check inputs
self.check_inputs(prompt, strength, callback_steps)

Expand All @@ -602,8 +607,8 @@ def __call__(
)

# 4. Preprocess image
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)
if isinstance(image, PIL.Image.Image):
image = preprocess(image)

# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand All @@ -612,7 +617,7 @@ def __call__(

# 6. Prepare latent variables
latents, clean_latents = self.prepare_latents(
init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
)
source_latents = latents

Expand Down

0 comments on commit d4f0742

Please sign in to comment.