Skip to content

Commit

Permalink
v-prediction training support (huggingface#1455)
Browse files Browse the repository at this point in the history
* add get_velocity

* add v prediction for training

* fix saving

* add revision arg

* fix saving

* save checkpoints dreambooth

* fix saving embeds

* add instruction in readme

* quality

* noise_pred -> model_pred
  • Loading branch information
patil-suraj committed Nov 28, 2022
1 parent 64e95c0 commit 7970ff2
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
20 changes: 20 additions & 0 deletions schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,5 +355,25 @@ def add_noise(
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)

sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity

def __len__(self):
return self.config.num_train_timesteps
20 changes: 20 additions & 0 deletions schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,5 +345,25 @@ def add_noise(
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)

sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity

def __len__(self):
return self.config.num_train_timesteps

0 comments on commit 7970ff2

Please sign in to comment.