Skip to content

Commit

Permalink
#0: Update SD model-perf test to run e2e test
Browse files Browse the repository at this point in the history
(cherry picked from commit a4ed4a1)
  • Loading branch information
mtatsumiTT authored and AleksKnezevic committed May 4, 2024
1 parent 7a3e3ed commit ff57870
Showing 1 changed file with 134 additions and 187 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
StableDiffusionPipeline,
LMSDiscreteScheduler,
)
from models.utility_functions import (
skip_for_grayskull,
Expand All @@ -27,6 +29,7 @@
disable_persistent_kernel_cache,
)
from ttnn.model_preprocessing import preprocess_model_parameters
from ttnn.operations.core import unsqueeze_to_4D
from models.experimental.functional_stable_diffusion.sd_helper_funcs import TtLMSDiscreteScheduler
from models.experimental.functional_stable_diffusion.custom_preprocessing import custom_preprocessor
from models.experimental.functional_stable_diffusion.tt2.ttnn_functional_unet_2d_condition_model import (
Expand All @@ -42,13 +45,33 @@
from models.utility_functions import profiler, enable_persistent_kernel_cache, skip_for_grayskull


def constant_prop_time_embeddings(timesteps, sample, time_proj):
def ttnn_to_torch(input):
input = ttnn.to_layout(input, ttnn.ROW_MAJOR_LAYOUT)
input = ttnn.from_device(input)
input = ttnn.to_torch(input)
return input


def constant_prop_time_embeddings(timesteps, batch_size, time_proj):
timesteps = timesteps[None]
timesteps = timesteps.expand(sample.shape[0])
timesteps = timesteps.expand(batch_size)
t_emb = time_proj(timesteps)
return t_emb


def unsqueeze_all_params_to_4d(params):
if isinstance(params, dict):
for key in params.keys():
params[key] = unsqueeze_all_params_to_4d(params[key])
elif isinstance(params, ttnn.ttnn.model_preprocessing.ParameterList):
for i in range(len(params)):
params[i] = unsqueeze_all_params_to_4d(params[i])
elif isinstance(params, ttnn.Tensor):
params = unsqueeze_to_4D(params)

return params


def tt_guide(noise_pred, guidance_scale): # will return latents
noise_pred_uncond, noise_pred_text = ttnn.split(noise_pred, noise_pred.shape[0] // 2, dim=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
Expand Down Expand Up @@ -79,199 +102,123 @@ def lms_derivative(tau):
@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize("device_l1_small_size", [32768], indirect=True)
@pytest.mark.parametrize(
"num_prompts, num_inference_steps, image_size, expected_compile_time, expected_inference_time",
((1, 2, (512, 512), 3600, 1.8),),
"batch_size, num_inference_steps, expected_compile_time, expected_inference_time",
[
(2, 2, 3600, 1.8),
],
)
def test_stable_diffusion_perf(
device, num_prompts, num_inference_steps, image_size, expected_compile_time, expected_inference_time
):
def test_stable_diffusion_perf(device, batch_size, num_inference_steps, expected_compile_time, expected_inference_time):
disable_persistent_kernel_cache()

# Clear global profiler state before starting measurements
profiler.clear()

# 0. Load a sample prompt from the dataset
dataset = load_dataset("poloclub/diffusiondb", "2m_random_1k")
data_1k = dataset["train"]
height, width = image_size

for i in range(num_prompts):
experiment_name = f"diffusiondb_{i}__{height}x{width}"
input_prompt = [f"{data_1k['prompt'][i]}"]
logger.info(f"input_prompts: {input_prompt}")

image = np.array(data_1k["image"][i])
ref_images = Image.fromarray(image)
ref_img_path = f"{experiment_name}_ref.png"
ref_images.save(ref_img_path)

# 1. Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")

# 2. Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

# 3. The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

# 4. load the K-LMS scheduler with some fitting parameters.
ttnn_scheduler = TtLMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)

torch_device = "cpu"
vae.to(torch_device)
text_encoder.to(torch_device)
unet.to(torch_device)

guidance_scale = 7.5 # Scale for classifier-free guidance
generator = torch.manual_seed(174) # 10233 Seed generator to create the inital latent noise
batch_size = len(input_prompt)

## First, we get the text_embeddings for the prompt. These embeddings will be used to condition the UNet model.
# Tokenizer and Text Encoder
text_input = tokenizer(
input_prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]

# For classifier-free guidance, we need to do two forward passes: one with the conditioned input (text_embeddings),
# and another with the unconditional embeddings (uncond_embeddings).
# In practice, we can concatenate both into a single batch to avoid doing two forward passes.
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
ttnn_text_embeddings = torch.nn.functional.pad(text_embeddings, (0, 0, 0, 19))
ttnn_text_embeddings = ttnn.from_torch(
ttnn_text_embeddings, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device
)

vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
# Initial random noise
latents = torch.randn(
(batch_size, unet.config.in_channels, height // vae_scale_factor, width // vae_scale_factor),
generator=generator,
)
latents = latents.to(torch_device)

ttnn_scheduler.set_timesteps(num_inference_steps)

latents = latents * ttnn_scheduler.init_noise_sigma
ttnn_latents = torch.tensor(latents)
ttnn_latents = ttnn.from_torch(ttnn_latents, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
# setup envvar if testing on N300
wh_arch_yaml_org = None
if device.core_grid.y == 7:
if ("WH_ARCH_YAML" not in os.environ) or (
os.environ["WH_ARCH_YAML"] != "wormhole_b0_80_arch_eth_dispatch.yaml"
):
pytest.skip("SD unet2d only works for 8x8 grid size")

# setup the configs
ttnn.CONFIG.throw_exception_on_fallback = True
in_channels = 4
input_height = 64
input_width = 64

# setup pytorch model
torch.manual_seed(0)
model_name = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float32)
model = pipe.unet
model.eval()
config = model.config

# setup scheduler
scheduler = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
scheduler.set_timesteps(1)

config = unet.config
parameters = preprocess_model_parameters(
initialize_model=lambda: unet, custom_preprocessor=custom_preprocessor, device=device
)
input_height = 64
input_width = 64
reader_patterns_cache = {} if height == 512 and width == 512 else None

time_step_list = []
ttnn_sigma = []
ttnn_step_index = []
timesteps_bkp = ttnn.to_torch(ttnn_scheduler.timesteps)
sigma_tensor = ttnn.to_torch(ttnn_scheduler.sigmas)[0]
step_index = (timesteps_bkp[0] == timesteps_bkp[0][0]).nonzero().item()
ttnn_latent_model_input = tt_latent_expansion(
ttnn_latents, ttnn_scheduler, float(sigma_tensor[step_index]), device
)
parameters = preprocess_model_parameters(
model_name=model_name, initialize_model=lambda: model, custom_preprocessor=custom_preprocessor, device=device
)

for t in timesteps_bkp[0]:
_t = constant_prop_time_embeddings(t, ttnn_latent_model_input, unet.time_proj)
_t = _t.unsqueeze(0).unsqueeze(0)
_t = _t.permute(2, 0, 1, 3)
_t = ttnn.from_torch(_t, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
time_step_list.append(_t)
step_index = (timesteps_bkp[0] == t).nonzero().item()
ttnn_step_index.append(step_index)
ttnn_sigma.append(sigma_tensor[step_index])

orders = 4
order_list = []
ttnn_lms_coeff = []
lms_coeff = []
for step_index in ttnn_step_index:
order = min(step_index + 1, orders)
order_list.append(order)
lms_coeffs = [
get_lms_coefficient(order, step_index, curr_order, sigma_tensor) for curr_order in range(order)
]
lms_coeff.append(lms_coeffs)

for lms in lms_coeff:
ttnn_lms_tensor = None
for value in lms:
lms_tensor = ttnn.full(
(1, 4, 64, 64), fill_value=value, device=device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16
)
if ttnn_lms_tensor is not None:
ttnn_lms_tensor = ttnn.concat([ttnn_lms_tensor, lms_tensor], dim=0)
else:
ttnn_lms_tensor = lms_tensor

ttnn_lms_coeff.append(ttnn_lms_tensor)

model = UNet2D(device, parameters, 2, input_height, input_width, reader_patterns_cache)

# # Denoising loop
for i in range(len(time_step_list)):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
ttnn_latent_model_input = tt_latent_expansion(ttnn_latents, ttnn_scheduler, float(ttnn_sigma[i]), device)

# predict the noise residual
with torch.no_grad():
profiler.start(f"model_run_for_inference_{i}")
ttnn_noise_pred = model(
sample=ttnn_latent_model_input,
timestep=time_step_list[i],
encoder_hidden_states=ttnn_text_embeddings,
config=config,
)
profiler.end(f"model_run_for_inference_{i}")

# perform guidance
noise_pred = tt_guide(ttnn_noise_pred, guidance_scale)

ttnn_latents = ttnn_scheduler.step(
model_output=noise_pred,
sample=ttnn_latents,
sigma=float(ttnn_sigma[i]),
lms_coeffs=ttnn_lms_coeff[i],
device=device,
order=order_list[i],
).prev_sample
enable_persistent_kernel_cache()
latents = ttnn.to_torch(ttnn_latents).to(torch.float32)

# printout the perf
profiler.print()
comment = f"diffusiondb__{height}x{width}"
ref_model_run_for_inference = profiler.get("ref_model_run_for_inference_0")
first_iter_time = profiler.get("model_run_for_inference_0")
second_iter_time = profiler.get(f"model_run_for_inference_{num_inference_steps-1}")

logger.info("Call prep-perf-report")
prep_perf_report(
model_name=f"StableDiffusion",
batch_size=batch_size,
inference_and_compile_time=first_iter_time,
inference_time=second_iter_time,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments=comment,
# unsqueeze weight tensors to 4D for generating perf dump
parameters = unsqueeze_all_params_to_4d(parameters)

timestep_shape = [1, 1, 2, 320]
encoder_hidden_states_shape = [1, 2, 77, 768]
class_labels = None
attention_mask = None
cross_attention_kwargs = None
return_dict = True

hidden_states_shape = [batch_size, in_channels, input_height, input_width]

input = torch.randn(hidden_states_shape)
timestep = [i for i in tqdm(scheduler.timesteps)][0]
ttnn_timestep = constant_prop_time_embeddings(timestep, batch_size, model.time_proj)
ttnn_timestep = ttnn_timestep.unsqueeze(0).unsqueeze(0)
encoder_hidden_states = torch.randn(encoder_hidden_states_shape)

torch_output = model(input, timestep=timestep, encoder_hidden_states=encoder_hidden_states.squeeze(0)).sample
input = ttnn.from_torch(input, ttnn.bfloat16)
input = ttnn.to_device(input, device, memory_config=ttnn.L1_MEMORY_CONFIG)
input = ttnn.to_layout(input, ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16)

ttnn_timestep = ttnn_timestep.permute(2, 0, 1, 3) # pre-permute temb
ttnn_timestep = ttnn.from_torch(ttnn_timestep, ttnn.bfloat16)
ttnn_timestep = ttnn.to_device(ttnn_timestep, device, memory_config=ttnn.L1_MEMORY_CONFIG)
ttnn_timestep = ttnn.to_layout(ttnn_timestep, ttnn.TILE_LAYOUT, dtype=ttnn.bfloat8_b)

encoder_hidden_states = torch.nn.functional.pad(encoder_hidden_states, (0, 0, 0, 19))
encoder_hidden_states = ttnn.from_torch(
encoder_hidden_states, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device
)
encoder_hidden_states = ttnn.to_device(encoder_hidden_states, device, memory_config=ttnn.L1_MEMORY_CONFIG)

# define model
reader_patterns_cache = {}
model = UNet2D(device, parameters, batch_size, input_height, input_width, reader_patterns_cache)

# run inference iterations
for i in range(num_inference_steps):
profiler.start(f"model_run_for_inference_{i}")
ttnn_output = model(
input,
timestep=ttnn_timestep,
encoder_hidden_states=encoder_hidden_states,
class_labels=class_labels,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=return_dict,
config=config,
)
logger.info("Exit SD perf test")
profiler.end(f"model_run_for_inference_{i}")
ttnn_output = ttnn_to_torch(ttnn_output)

# printout the perf
profiler.print()
comment = f"diffusiondb_512x512"
first_iter_time = profiler.get("model_run_for_inference_0")
second_iter_time = profiler.get(f"model_run_for_inference_{num_inference_steps-1}")

logger.info("Call prep-perf-report")
prep_perf_report(
model_name=f"StableDiffusion",
batch_size=batch_size,
inference_and_compile_time=first_iter_time,
inference_time=second_iter_time,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments=comment,
)
logger.info("Exit SD perf test")


@skip_for_grayskull()
Expand All @@ -285,7 +232,7 @@ def test_stable_diffusion_device_perf(expected_perf):
margin = 0.02
batch = 1
iterations = 1
command = f"pytest tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py::test_unet_2d_condition_model_512x512[batch_size=2-in_channels=4-input_height=64-input_width=64-device_l1_small_size=32768]"
command = f"pytest tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py::test_unet_2d_condition_model512x512[batch_size=2-in_channels=4-input_height=64-input_width=64-device_l1_small_size=32768]"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]

inference_time_key = "AVG DEVICE KERNEL SAMPLES/S"
Expand Down

0 comments on commit ff57870

Please sign in to comment.