解构稳定扩散流水线
1. 获得 uNet, scheduler, tokenizer, text_encoder, vae
2. 创造噪声
3. 提示词文本嵌入
4. 迭代降噪
5. 显示图片

In [None]:
!pip install diffusers accelerate transformers

获得 uNet, scheduler, tokenizer, text_encoder, vae

In [None]:
from diffusers import UNet2DConditionModel, UniPCMultistepScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer

## uNet 组件
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_safetensors=True)
## scheduler 组件
scheduler = UniPCMultistepScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
## vae 组件
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_safetensors=True)
## 文本编码器组件
text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True)
## 分词器组件
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")

## 使用 GPU 加速
torch_device = "cuda"
unet.to(torch_device)
vae.to(torch_device)
text_encoder.to(torch_device)

参数列表

In [None]:
import torch

prompt = ["a photograph of an astronaut riding a horse"]
height = 512  # default height of Stable Diffusion
width = 512  # default width of Stable Diffusion
num_inference_steps = 25  # Number of denoising steps
guidance_scale = 7.5  # Scale for classifier-free guidance
generator = torch.manual_seed(0)  # Seed generator to create the initial latent noise
batch_size = len(prompt)

创造噪声

In [None]:
latents = torch.randn(
    (batch_size, unet.config.in_channels, height//8, width//8),
    generator=generator
)

latents = latents.to(torch_device)

提示词文本嵌入

In [None]:
## Tokenize the text and generate the embeddings from the prompt
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
  text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]

## You’ll also need to generate the unconditional text embeddings which are the embeddings for the padding token. 
## These need to have the same shape (batch_size and seq_length) as the conditional text_embeddings
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]


text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

迭代降噪

In [None]:
from tqdm.auto import tqdm

latents = latents * scheduler.init_noise_sigma

scheduler.set_timesteps(num_inference_steps)

for t in tqdm(scheduler.timesteps):
  # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
  latent_model_input = torch.cat([latents] * 2)
  latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
  # predict the noise residual
  with torch.no_grad():
    noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
  # perform guidance
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
  # compute the previous noisy sample x_t -> x_t-1
  latents = scheduler.step(noise_pred, t, latents).prev_sample

显示图片

In [None]:
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
with torch.no_grad():
  image = vae.decode(latents).sample

from PIL import Image

image = (image / 2 + 0.5).clamp(0, 1).squeeze()
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
images = (image * 255).round().astype("uint8")
image = Image.fromarray(image)
image