In [None]:
import torch
import time
import sys
sys.path.append("/append/MINDiff/path")
from pipeline_stable_diffusion_mask import StableDiffusionMaskPipeline
from set_attn_proc import set_mask_attn
from mask_attention_processor import MaskAttnProcessor2_0

In [None]:
pipeline = StableDiffusionMaskPipeline.from_pretrained(
    "/Input/model/path",
    torch_dtype=torch.float16).to("cuda")
unet = pipeline.unet
unet = set_mask_attn(
    unet=unet,
    mask_resolution=16,
    )
mask_token_id = pipeline.tokenizer.convert_tokens_to_ids("sks</w>")

In [None]:
def measure_sd_latency(pipe, prompt, pre_prompt, num_images=1, repeat=10):
    # Warm-up
    MaskAttnProcessor2_0.mask = None
    _ = pipe(prompt,
             num_images_per_prompt=num_images,
             mask_token_id=mask_token_id,
             pre_prompt=pre_prompt,
             attn_scale=1.0)
    if hasattr(MaskAttnProcessor2_0, 'mask_buffer'):
        del MaskAttnProcessor2_0.mask_buffer
        torch.cuda.empty_cache()
    torch.cuda.synchronize() # Wait until GPU operations are complete

    # Wait until GPU operations are complete
    torch.cuda.reset_peak_memory_stats() # Reset peak memory usage from previous runs
    
    start = time.time()
    for _ in range(repeat):
        MaskAttnProcessor2_0.mask = None
        _ = pipe(prompt,
                 num_images_per_prompt=num_images,
                 mask_token_id=mask_token_id,
                 attn_scale=1.0,
                 pre_prompt=pre_prompt)
        if hasattr(MaskAttnProcessor2_0, 'mask_buffer'):
            del MaskAttnProcessor2_0.mask_buffer
            torch.cuda.empty_cache()
    torch.cuda.synchronize()
    end = time.time()
    
    avg_latency = (end - start) / repeat
    peak_memory = torch.cuda.max_memory_allocated() / 1024**2
    
    return avg_latency, peak_memory

In [None]:
latency, memory = measure_sd_latency(pipeline, "a sks backpack on a wooden table", "a sks backpack", num_images=1, repeat=10)
print(f"Avg latency: {latency*1000:.2f} ms")
print(f"Peak GPU memory: {memory:.2f} MB")

In [None]:
def measure_sd_latency_event(pipe, prompt, pre_prompt, num_images=1, repeat=10):
    MaskAttnProcessor2_0.mask = None
    _ = pipe(prompt,
             num_images_per_prompt=num_images,
             mask_token_id=mask_token_id,
             pre_prompt=pre_prompt,
             attn_scale=1.0) # wram-up
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()

    total_time = 0.0

    for _ in range(repeat):
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)

        start_event.record()
        MaskAttnProcessor2_0.mask = None
        _ = pipe(prompt,
                 num_images_per_prompt=num_images,
                 mask_token_id=mask_token_id,
                 attn_scale=1.0,
                 pre_prompt=pre_prompt)
        end_event.record()

        torch.cuda.synchronize()
        elapsed = start_event.elapsed_time(end_event)
        total_time += elapsed

    avg_latency = total_time / repeat
    peak_memory = torch.cuda.max_memory_allocated() / 1024**2

    return avg_latency, peak_memory

In [None]:
latency, memory = measure_sd_latency_event(pipeline, "a sks backpack on a wooden table", "a sks backpack", num_images=1, repeat=30)
print(f"Avg latency: {latency:.2f} ms") # No need to multiply by 1000 since the result is already in milliseconds
print(f"Peak GPU memory: {memory:.2f} MB")

In [None]:
batch_sizes = [1, 2, 4, 8]
prompt = "a sks backpack on a wooden table"

for batch_size in batch_sizes:

    print(f"\nBatch size: {batch_size}")

    # Warm-up run
    MaskAttnProcessor2_0.mask = None
    MaskAttnProcessor2_0.stack = None
    _ = pipeline(prompt,
                num_images_per_prompt=batch_size,
                mask_token_id=mask_token_id,
                attn_scale=1.0,
                pre_prompt="a sks backpack")
    if hasattr(MaskAttnProcessor2_0, 'mask_buffer'):
            del MaskAttnProcessor2_0.mask_buffer
            torch.cuda.empty_cache()
    
    # Start measurement
    torch.cuda.synchronize()
    start = time.time()
    MaskAttnProcessor2_0.mask = None
    _ = pipeline(prompt,
                num_images_per_prompt=batch_size,
                mask_token_id=mask_token_id,
                attn_scale=1.0,
                pre_prompt="a sks backpack")
    if hasattr(MaskAttnProcessor2_0, 'mask_buffer'):
            del MaskAttnProcessor2_0.mask_buffer
            torch.cuda.empty_cache()
    torch.cuda.synchronize()
    end = time.time()

    elapsed = (end - start)
    elapsed_ms = elapsed * 1000
    fps = batch_size / elapsed
    per_image_time = elapsed_ms / batch_size

    print(f"Time: {elapsed_ms:.2f}ms | FPS: {fps:.2f} | Per image: {per_image_time:.2f}ms")