In [10]:
import torch
import torch.nn as nn
import math 
from celldreamer.models.diffusion.distributions import x0_to_xt

In [11]:
test_tensor = torch.arange(1000).unsqueeze(-1)

**From reference code**

In [12]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

In [13]:
sin = SinusoidalPosEmb(100)

In [16]:
sin(test_tensor).squeeze()

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  1.0000,  1.0000,  1.0000],
        [ 0.8415,  0.7370,  0.6339,  ...,  1.0000,  1.0000,  1.0000],
        [ 0.9093,  0.9963,  0.9806,  ...,  1.0000,  1.0000,  1.0000],
        ...,
        [-0.8980,  0.0819, -0.2747,  ...,  0.9895,  0.9928,  0.9950],
        [-0.8555, -0.6792,  0.3971,  ...,  0.9895,  0.9928,  0.9950],
        [-0.0265, -1.0000,  0.8889,  ...,  0.9894,  0.9927,  0.9950]])

**From our code**

In [22]:
def timestep_embedding(t: torch.Tensor, dim: int):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half_dim = dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim) * -emb)
    emb = t[:, None] * emb[None, :]
    emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
    return emb

In [23]:
timestep_embedding(test_tensor, 100).squeeze()

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  1.0000,  1.0000,  1.0000],
        [ 0.8415,  0.7370,  0.6339,  ...,  1.0000,  1.0000,  1.0000],
        [ 0.9093,  0.9963,  0.9806,  ...,  1.0000,  1.0000,  1.0000],
        ...,
        [-0.8980,  0.0819, -0.2747,  ...,  0.9895,  0.9928,  0.9950],
        [-0.8555, -0.6792,  0.3971,  ...,  0.9895,  0.9928,  0.9950],
        [-0.0265, -1.0000,  0.8889,  ...,  0.9894,  0.9927,  0.9950]])