In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    # Initialize the positional encodings tensor
    pe = torch.zeros(timesteps.size(0), dim, device=timesteps.device)
    
    # Create a tensor of positions (0, 1, 2, ..., timesteps.size(0)-1)
    position = timesteps.unsqueeze(1).float()

    # Create a tensor of scaling factors
    div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float, device=timesteps.device) * (-math.log(max_period) / dim))

    # Apply sine to even indices
    pe[:, 0::2] = torch.sin(position * div_term)

    # Apply cosine to odd indices
    pe[:, 1::2] = torch.cos(position * div_term)

    return pe

class TimestepEmbedder(nn.Module):
    def __init__(self, latent_dim, max_period=10000, device=None):
        super().__init__()
        self.latent_dim = latent_dim
        self.max_period = max_period

        time_embed_dim = self.latent_dim
        self.time_embed = nn.Sequential(
            nn.Linear(self.latent_dim, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim),
        ).to(device)

    def forward(self, timesteps):
        if not isinstance(timesteps, torch.Tensor):
            timesteps = torch.tensor(timesteps, dtype=torch.float).to(self.time_embed[0].weight.device)
        else:
            timesteps = timesteps.float().to(self.time_embed[0].weight.device)
        
        encodings = timestep_embedding(timesteps, self.latent_dim, self.max_period)
        dense_encodings = encodings.unsqueeze(0)  # Add an extra dimension
        return self.time_embed(dense_encodings).permute(1, 0, 2)


In [33]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
latent_dim = 128
ts = 10  # Example timestep

# Create TimestepEmbedder instance
embed_timestep = TimestepEmbedder(latent_dim, device=device)

# Get the embeddings for the given timesteps
timesteps = torch.tensor([ts], dtype=torch.float).to(device)
embeddings = embed_timestep(timesteps)
print(embeddings)

tensor([[[ 0.1014,  0.0478,  0.0693, -0.1509, -0.0652,  0.0832, -0.1776,
          -0.1637, -0.0710, -0.0984, -0.2567,  0.1436, -0.2774,  0.2691,
          -0.0472, -0.1858, -0.0243,  0.0800,  0.2080,  0.2286, -0.1796,
           0.0021, -0.0815, -0.0229, -0.0348,  0.1422,  0.2072, -0.0084,
          -0.1027,  0.0097, -0.0389,  0.1433, -0.0008, -0.1679,  0.0298,
           0.2263, -0.0890, -0.0814,  0.0503, -0.0630,  0.1955, -0.2240,
          -0.2530, -0.0710, -0.0109, -0.0217,  0.1262, -0.1919,  0.0059,
          -0.0572, -0.0775, -0.0460, -0.0400, -0.1212,  0.0152,  0.0717,
           0.0334, -0.0348, -0.1932,  0.0103, -0.0496, -0.0836,  0.2242,
          -0.0530, -0.1115,  0.0179, -0.2222, -0.0946, -0.0136, -0.0535,
          -0.0059,  0.0627,  0.0156, -0.1373,  0.0382, -0.0188,  0.0893,
           0.0682, -0.0940, -0.0601,  0.1521, -0.0607,  0.0450,  0.1164,
           0.1199,  0.1075, -0.1289, -0.0208, -0.0291,  0.0644,  0.0605,
           0.1385, -0.1831,  0.1386,  0.0336, -0.03

In [37]:
from model.diffusion_utils import *

ts = 10  # Example timestep

dropout = 0.1
latent_dim = 128
sequence_pos_encoder = PositionalEncoding(latent_dim, dropout, device="cpu")
embed_timestep = TimestepEmbedder(latent_dim, sequence_pos_encoder, device='cpu')

timesteps = torch.tensor([ts], dtype=torch.float).to(device)
timesteps = timesteps.long()
print(embed_timestep(timesteps))

this is the fwd of TE tensor([10])
this is the pe of len 5000 
 tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00]],

        [[ 8.4147e-01,  5.4030e-01,  7.6172e-01,  ...,  1.0000e+00,
           1.1548e-04,  1.0000e+00]],

        [[ 9.0930e-01, -4.1615e-01,  9.8705e-01,  ...,  1.0000e+00,
           2.3096e-04,  1.0000e+00]],

        ...,

        [[ 9.5625e-01, -2.9254e-01, -9.4916e-01,  ...,  7.8608e-01,
           5.4555e-01,  8.3808e-01]],

        [[ 2.7050e-01, -9.6272e-01, -8.5488e-01,  ...,  7.8599e-01,
           5.4565e-01,  8.3802e-01]],

        [[-6.6395e-01, -7.4778e-01, -1.5844e-01,  ...,  7.8591e-01,
           5.4574e-01,  8.3795e-01]]])
this is the encoded with PE
 tensor([[[-0.5440, -0.8391,  0.6926, -0.7213,  0.9376,  0.3476,  0.2091,
           0.9779, -0.6129,  0.7901, -0.9877,  0.1566, -0.8798, -0.4754,
          -0.4883, -0.8727, -0.0207, -0.9998,  0.3923, -0.9198,  0.6963,
          -0.7178,  0.8857, -0.4

In [34]:

# Create TimestepEmbedder instance

# Get the embeddings for the given timesteps
timesteps = torch.tensor([ts], dtype=torch.float).to(device)
timesteps = timesteps.long()
print(timesteps)
ori_pe = PositionalEncoding(d_model = 128, dropout=0.1, max_len=10000, device="cpu")
ori_pe.pe[timesteps]

tensor([10])


tensor([[[-0.5440, -0.8391,  0.6926, -0.7213,  0.9376,  0.3476,  0.2091,
           0.9779, -0.6129,  0.7901, -0.9877,  0.1566, -0.8798, -0.4754,
          -0.4883, -0.8727, -0.0207, -0.9998,  0.3923, -0.9198,  0.6963,
          -0.7178,  0.8857, -0.4642,  0.9786, -0.2060,  0.9995,  0.0309,
           0.9720,  0.2351,  0.9147,  0.4041,  0.8415,  0.5403,  0.7617,
           0.6479,  0.6816,  0.7318,  0.6047,  0.7965,  0.5332,  0.8460,
           0.4679,  0.8838,  0.4093,  0.9124,  0.3571,  0.9341,  0.3110,
           0.9504,  0.2704,  0.9627,  0.2349,  0.9720,  0.2039,  0.9790,
           0.1769,  0.9842,  0.1534,  0.9882,  0.1330,  0.9911,  0.1152,
           0.9933,  0.0998,  0.9950,  0.0865,  0.9963,  0.0749,  0.9972,
           0.0649,  0.9979,  0.0562,  0.9984,  0.0487,  0.9988,  0.0422,
           0.9991,  0.0365,  0.9993,  0.0316,  0.9995,  0.0274,  0.9996,
           0.0237,  0.9997,  0.0205,  0.9998,  0.0178,  0.9998,  0.0154,
           0.9999,  0.0133,  0.9999,  0.0115,  0.99

In [36]:
import torch
import torch.nn as nn
import math

def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    # Initialize the positional encodings tensor
    pe = torch.zeros(timesteps.size(0), dim, device=timesteps.device)
    
    # Create a tensor of positions (0, 1, 2, ..., timesteps.size(0)-1)
    position = timesteps.unsqueeze(1).float()

    # Create a tensor of scaling factors
    div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float, device=timesteps.device) * (-math.log(max_period) / dim))

    # Apply sine to even indices
    pe[:, 0::2] = torch.sin(position * div_term)

    # Apply cosine to odd indices
    pe[:, 1::2] = torch.cos(position * div_term)

    return pe
timesteps = torch.tensor([ts], dtype=torch.float).to(device)
timesteps = timesteps.long()
# Example usage
dim = 128
embeddings = timestep_embedding(timesteps, dim)
print(embeddings)


tensor([[-0.5440, -0.8391,  0.6926, -0.7213,  0.9376,  0.3476,  0.2091,  0.9779,
         -0.6129,  0.7901, -0.9877,  0.1566, -0.8798, -0.4754, -0.4883, -0.8727,
         -0.0207, -0.9998,  0.3923, -0.9198,  0.6963, -0.7178,  0.8857, -0.4642,
          0.9786, -0.2060,  0.9995,  0.0309,  0.9720,  0.2351,  0.9147,  0.4041,
          0.8415,  0.5403,  0.7617,  0.6479,  0.6816,  0.7318,  0.6047,  0.7965,
          0.5332,  0.8460,  0.4679,  0.8838,  0.4093,  0.9124,  0.3571,  0.9341,
          0.3110,  0.9504,  0.2704,  0.9627,  0.2349,  0.9720,  0.2039,  0.9790,
          0.1769,  0.9842,  0.1534,  0.9882,  0.1330,  0.9911,  0.1152,  0.9933,
          0.0998,  0.9950,  0.0865,  0.9963,  0.0749,  0.9972,  0.0649,  0.9979,
          0.0562,  0.9984,  0.0487,  0.9988,  0.0422,  0.9991,  0.0365,  0.9993,
          0.0316,  0.9995,  0.0274,  0.9996,  0.0237,  0.9997,  0.0205,  0.9998,
          0.0178,  0.9998,  0.0154,  0.9999,  0.0133,  0.9999,  0.0115,  0.9999,
          0.0100,  0.9999,  