In [None]:
import torch
import math
from diffusers import StableDiffusion3Pipeline
import debugpy

# debugpy.listen(('0.0.0.0', 5678))

# print("Waiting for debugger attach")
# debugpy.wait_for_client()

# pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16, cache_dir='/purestorage/project/tyk/tmp')
# pipe = pipe.to("cuda")
# print(pipe.diffuers.config)


def get_timestep_embedding(
    timesteps: torch.Tensor,
    embedding_dim: int,
    flip_sin_to_cos: bool = False,
    downscale_freq_shift: float = 1,
    scale: float = 1,
    max_period: int = 10000,
):
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.

    Args
        timesteps (torch.Tensor):
            a 1-D Tensor of N indices, one per batch element. These may be fractional.
        embedding_dim (int):
            the dimension of the output.
        flip_sin_to_cos (bool):
            Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
        downscale_freq_shift (float):
            Controls the delta between frequencies between dimensions
        scale (float):
            Scaling factor applied to the embeddings.
        max_period (int):
            Controls the maximum frequency of the embeddings
    Returns
        torch.Tensor: an [N x dim] Tensor of positional embeddings.
    """
    assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"

    half_dim = embedding_dim // 2
    exponent = -math.log(max_period) * torch.arange(
        start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
    )
    exponent = exponent / (half_dim - downscale_freq_shift)

    emb = torch.exp(exponent)
    emb = timesteps[:, None].float() * emb[None, :]

    # scale embeddings
    emb = scale * emb

    # concat sine and cosine embeddings
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

    # flip sine and cosine embeddings
    if flip_sin_to_cos:
        emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)

    # zero pad
    if embedding_dim % 2 == 1:
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    return emb


# Example usage
timesteps = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=torch.float32)
embedding_dim = 8

embeddings = get_timestep_embedding(
    timesteps,
    embedding_dim,
    flip_sin_to_cos=False,
    downscale_freq_shift=1,
    scale=1.0,
    max_period=10000
)

# Convert embeddings to numpy for plotting
embeddings_np = embeddings.detach().cpu().numpy()

# Plot the embeddings
plt.figure(figsize=(12, 6))
for i in range(embedding_dim):
    plt.plot(timesteps.numpy(), embeddings_np[:, i], label=f'Dimension {i+1}')
plt.title('Timestep Embeddings')
plt.xlabel('Timesteps')
plt.ylabel('Embedding Value')
plt.legend()
plt.savefig('timestep_embeddings.png')  # Save the figure as a PNG file
plt.close()

print("The timestep embeddings plot has been saved as 'timestep_embeddings.png'.")