## Samplers
- Samplers are used to generate noise according to a given number of inference steps

In [None]:
## Necessary imports
import torch
import numpy as np
from tqdm import tqdm

In [None]:
def get_alphas_cumprod(beta_start = -0.00085, beta_end = 0.0120, n_training_steps = 1000):
    betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, n_training_steps, dtype = np.float32) ** 2
    print(f"The betas have the shape {betas.shape}")
    alphas = 1.0 - betas
    alphas_cumprod = np.cumprod(alphas, axis = 0)
    print(f"The alphas cumprod have the shape {alphas_cumprod.shape}")
    return alphas_cumprod

In [None]:
def get_time_embedding(timestep, dtype):
	"""
	Takes a timestep as an input and gives the embedding
	Look at this in the notebook
	"""
	freqs = torch.pow(10000, -torch.arange(start = 0, end = 160, dtype = dtype) / 160)
	x = torch.tensor([timestep], dtype = dtype)[:, None] * freqs[None]
	return torch.cat([torch.cos(x), torch.sin(x)], dim = 1)

In [None]:
class KEulerSampler():
	def __init__(self, n_inference_steps = 50, n_training_steps = 1000, lms_order = 4):
		timesteps = np.linspace(n_training_steps - 1, 0, n_inference_steps)
		alphas_cumprod = get_alphas_cumprod(n_training_steps = n_training_steps)
		sigmas = (1 - alphas_cumprod) / (alphas_cumprod) ** 0.5
		log_sigmas = np.log(sigmas)
		log_sigmas = np.interp(timesteps, range(n_training_steps), log_sigmas)
		sigmas = np.exp(log_sigmas)
		sigmas = np.append(sigmas, 0)

		self.sigmas = sigmas
		self.initial_scale = sigmas.max()
		self.timesteps = timesteps
		self.n_inference_steps = n_inference_steps
		self.n_training_steps = n_training_steps
		self.lms_order = lms_order
		self.step_count = 0
		self.outputs = []

	def get_input_scale(self, step_count = None):
		if step_count is None:
			step_count = self.step_count
		sigma = self.sigmas[step_count]
		return 1 / (sigma ** 2 + 1) ** 0.5

	def set_strength(self, strength = 1):
		start_step = self.n_inference_steps - int(self.n_inference_steps * strength)
		self.timesteps = np.linspace(self.n_training_steps - 1, 0, self.n_inference_steps)
		self.timesteps = self.timesteps[start_step:]
		self.initial_scale = self.sigmas[start_step]
		self.step_count = start_step

	def step(self, latents, output):
		t = self.step_count
		self.step_count += 1

		self.outputs = [output] + self.outputs[:self.lms_order - 1]
		order = len(self.outputs)

		for i, output in enumerate(self.outputs):
			# Integrate the polynomial by a trapezoidal approximation method for 81 points
			x = np.linspace(self.sigmas[t], self.sigmas[t + 1], 81)
			y = np.ones(81)
			for j in range(order):
				if i == j:
					continue
				y *= x - self.sigmas[t - j]
				y /= self.sigmas[t - i] - self.sigmas[t - j]
			lms_coeff = np.trapz(y = y, x = x)
			latents += lms_coeff * output
		return latents

### Using the sampler
- Get the sampler
- set the strength of the sampler, the higher the value the stronger the image generated but slower the inference
- Multiply the latents with the sampler's initial scale

In [None]:
sampler = KEulerSampler()
sampler.set_strength(0.8)
type(sampler.initial_scale)

### Latents and Samplers
- Once we get the latents either from the encoder or our own (if input images are not given), we multiply them with the sampler's initial scale.
- Using the context from the CLIP Embedding, Time embedding and the input latents, obtain the output of the diffusion model

In [None]:
sampler.initial_scale
## Consider a sample latents and multiply them with the sampler's initial scale
latents = torch.randn(1, 4, 64, 64)
latents *= sampler.initial_scale

In [None]:
## Get the timesteps and the input latents
timesteps = tqdm(sampler.timesteps)
for _, timestep in enumerate(timesteps):
    time_embedding = get_time_embedding(timestep, dtype = torch.float32)
    input_latents = latents * sampler.get_input_scale()
assert input_latents.shape == torch.Size([1, 4, 64, 64])
## If the parameter "do conditional guidance" is enabled
input_latents = input_latents.repeat(2, 1, 1, 1)
assert input_latents.shape == torch.Size([2, 4, 64, 64])
assert input_latents.dtype == torch.float32
print(torch.min(input_latents))
print(torch.max(input_latents))