In [15]:
import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F
class VarianceSchedule(nn.Module):

    def __init__(self, num_steps=100, s=0.01):
        super().__init__()
        T = num_steps
        t = torch.arange(0, num_steps+1, dtype=torch.float)
        f_t = torch.cos( (np.pi / 2) * ((t/T) + s) / (1 + s) ) ** 2
        alpha_bars = f_t / f_t[0]

        betas = 1 - (alpha_bars[1:] / alpha_bars[:-1])
        betas = torch.cat([torch.zeros([1]), betas], dim=0)
        betas = betas.clamp_max(0.999)
 
        sigmas = torch.zeros_like(betas)
        for i in range(1, betas.size(0)):
            sigmas[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i]
        sigmas = torch.sqrt(sigmas)
        
        self.register_buffer('betas', betas)
        self.register_buffer('alpha_bars', alpha_bars)
        self.register_buffer('alphas', 1 - betas)
        self.register_buffer('sigmas', sigmas)
        # calculate X0 = sqrt_recip_alphas_cumprod * Xt - sqrt_recipm1_alphas_cumprod * noise
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alpha_bars))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alpha_bars - 1))
        
        alphas = 1 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
        self.posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min =1e-20))
        self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
        self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
        
    def to(self,device):
        for k, v in self.__dict__.items():
            if isinstance(v, torch.Tensor):
                self.__dict__[k] = v.to(device)


In [10]:
class PositionTransition(nn.Module):

    def __init__(self, num_steps, var_sched_opt={}):
        super().__init__()
        self.var_sched = VarianceSchedule(num_steps, **var_sched_opt)
         
    def add_noise(self, p_0, mask_generate,  t):
        """
        Args:
            p_0:    (N, L, 3).
            mask_generate:    (N, L).
            t:  (N,).
        """
        alpha_bar = self.var_sched.alpha_bars[t]
        nb_non_batch_dims = len(p_0.shape) - 1
        c0 = torch.sqrt(alpha_bar).view(-1, *(nb_non_batch_dims*(1,)))
        c1 = torch.sqrt(1 - alpha_bar).view(-1, *(nb_non_batch_dims*(1,)))
        prior_p = torch.randn_like(p_0)
           
        p_noisy = c0*p_0 + c1*prior_p
        p_noisy = torch.where(mask_generate[..., None].expand_as(p_0), p_noisy, p_0)

        return p_noisy, prior_p

In [19]:
from torch_scatter import scatter_sum
p=torch.randn((10,128))
batch_id=torch.LongTensor([0,0,0,0,0,0,0,0,0,0])
lengths = scatter_sum(torch.ones_like(batch_id), batch_id)
t = torch.randint(0, 100, (1,), dtype=torch.long)
expanded_t = torch.cat([time.repeat(l) for time,l in zip(t,lengths)])

In [31]:
torch.full((10,),100, dtype=torch.long)

tensor([100, 100, 100, 100, 100, 100, 100, 100, 100, 100])

In [7]:
mask=torch.LongTensor([0,0,0,0,0,0,0,0,0,1]).bool()

In [16]:
trans_rot = PositionTransition(100)

In [23]:
trans_rot.add_noise(p,mask,expanded_t)

(tensor([[ 0.3485, -0.2922, -0.2746,  ..., -1.5693,  1.5754,  0.3972],
         [ 0.0674, -0.5424,  0.7440,  ..., -2.0407,  1.4795, -2.1597],
         [-1.0217,  0.5803,  0.9991,  ...,  1.1805,  0.3093, -0.0717],
         ...,
         [ 0.7005, -0.4601,  1.5785,  ...,  0.2578, -0.3659,  0.0597],
         [ 0.6066, -1.1699, -1.8528,  ..., -1.4484, -0.8706,  0.5774],
         [-0.5059, -1.0857, -1.2176,  ...,  3.6567,  0.4935,  1.5655]]),
 tensor([[ 4.5911e-01,  1.1102e+00, -2.6017e+00,  ...,  8.0807e-01,
           5.3372e-01,  3.0030e-01],
         [ 2.2121e-02, -1.0817e+00,  2.0692e-01,  ..., -9.6972e-01,
          -1.9992e-01, -1.3044e+00],
         [-1.0450e-02,  6.7107e-01,  1.0604e+00,  ..., -6.2740e-01,
           2.0962e+00,  1.0689e+00],
         ...,
         [-9.5038e-01, -9.5703e-01, -2.0748e-01,  ..., -9.6532e-01,
           5.7639e-01, -1.2554e+00],
         [-1.0567e-01, -6.2298e-01, -5.3075e-01,  ..., -2.5494e-03,
           1.0557e+00,  1.0229e-01],
         [-7.6742e-

In [24]:
p

tensor([[ 0.3485, -0.2922, -0.2746,  ..., -1.5693,  1.5754,  0.3972],
        [ 0.0674, -0.5424,  0.7440,  ..., -2.0407,  1.4795, -2.1597],
        [-1.0217,  0.5803,  0.9991,  ...,  1.1805,  0.3093, -0.0717],
        ...,
        [ 0.7005, -0.4601,  1.5785,  ...,  0.2578, -0.3659,  0.0597],
        [ 0.6066, -1.1699, -1.8528,  ..., -1.4484, -0.8706,  0.5774],
        [-0.5489, -1.1688, -1.9565,  ...,  3.7385,  0.7354,  0.9192]])