## Noise scheduler
The noise scheduler will sample $\beta$ and generate $\alpha$ correspondingly from all the time steps  
 
Original noise scheduler used in the DDPM paper. See more mathematical breakdown in the paper notes repo. 

In [None]:
import torch 
import numpy as np 

In [None]:
class LinearNoiseScheduler: 
    def __init__(self, num_timesteps, beta_start, beta_end): 
        self.num_timesteps= num_timesteps   # parameter: t
        self.beta_start= beta_start 
        self.beta_end= beta_end 
        # linearly increase 
        self.betas= torch.linspace(beta_start,beta_end, steps=num_timesteps)    # (1,num_timesteps) 

        # forward process and backward noise calculations 
        self.alpha = 1-self.betas   # get each alpha term at t  (1, num_timesteps)

        # multiply the cumulative prod across, preserves the shape 
        self.alpha_prod= torch.cumprod(self.alpha, dim=0)   # calculate \prod alpha (1,num_timesteps)
        self.sqrt_alpha_prod= torch.sqrt(self.alpha_prod)   # sqrt(alpha) --> used to multiply current noise 
        self.sqrt_one_minus_alpha= torch.sqrt(1- self.sqrt_alpha_prod)  # sqrt(1-alpha)--> used to multiply with sampled noise 

    # add noise in the forward process 
    # original: xt vector 
    # noise: originally sampled noise 
    def add_noise(self,original, noise,t): 
        original_shape= original.shape  
        batch_size= original_shape[0] 
        # reshape, assume batch_size is 1 
        sqrt_alpha_prod= self.sqrt_alpha_prod[t].reshape(batch_size) 
        sqrt_one_minus_alpha= self.sqrt_one_minus_alpha[t].reshape(batch_size)
        
        # reshape to 
        for _ in range(len(original_shape)-1):
            sqrt_alpha_prod=sqrt_alpha_prod.unsqueeze(-1)   # reshape (b,1,1)
            sqrt_one_minus_alpha= sqrt_one_minus_alpha.unsqueeze(-1) 

        return sqrt_alpha_prod*original+ sqrt_one_minus_alpha*noise 

    # backward process sampling
    # #given the current time step t, and vector xt we can use the direct formula to compute what x0 is 
    # noise_pred is the noise from model output 
    def sample_prev_time(self,xt, noise_pred,t):
        # xt - (sqrt(1-at) * model output / at )
        x0= (xt- (self.sqrt_one_minus_alpha[t]*noise_pred)/self.sqrt_alpha_prod[t])
        x0= torch.clamp(x0,-1,1)

        # sample the mean 
        mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
        mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
        
        if t == 0:
            return mean, x0
        else:
            variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
            variance = variance * self.betas.to(xt.device)[t]
            sigma = variance ** 0.5
            z = torch.randn(xt.shape).to(xt.device)
            
            # OR
            # variance = self.betas[t]
            # sigma = variance ** 0.5
            # z = torch.randn(xt.shape).to(xt.device)
            return mean + sigma * z, x0