Skip to content

Commit

Permalink
rewrite scheduler not to use nonzero() or item()
Browse files Browse the repository at this point in the history
  • Loading branch information
ssusie committed Aug 11, 2023
1 parent cdc8d29 commit 0243d2e
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 20 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Expand Up @@ -10,6 +10,7 @@ __pycache__/

# tests and logs
tests/fixtures/cached_*_text.txt
examples/text_to_image/logs_*
logs/
lightning_logs/
lang_code_data/
Expand Down Expand Up @@ -173,4 +174,4 @@ tags
# ruff
.ruff_cache

wandb
wandb
Expand Up @@ -686,7 +686,8 @@ def __call__(
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs, return_dict=False)[0]
xm.mark_step()
# # call the callback, if provided
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
Expand Up @@ -840,13 +840,8 @@ def __call__(

# compute the previous noisy sample x_t -> x_t-1
with xp.Trace('pipe_sched'):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs, return_dict=False)[0]
xm.mark_step()
# # call the callback, if provided
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
# progress_bar.update()
# if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)

with xp.Trace('pipe_vae'):
# make sure the VAE is in float32 mode, as it overflows in float16
Expand Down
12 changes: 3 additions & 9 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Expand Up @@ -633,7 +633,8 @@ def multistep_dpm_solver_third_order_update(
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
step_index: int,
# timestep: int,
sample: torch.FloatTensor,
generator=None,
return_dict: bool = True,
Expand All @@ -658,21 +659,14 @@ def step(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)

if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
lower_order_final = (
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
)
lower_order_second = (
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
)

timestep = self.timesteps[step_index]
model_output = self.convert_model_output(model_output, timestep, sample)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
Expand Down
8 changes: 5 additions & 3 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Expand Up @@ -307,7 +307,8 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps)
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
step_index: int,
# timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
Expand Down Expand Up @@ -338,7 +339,7 @@ def step(
`tuple`. When returning a tuple, the first element is the sample tensor.
"""

timestep = self.timesteps[step_index]
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
Expand All @@ -361,7 +362,8 @@ def step(
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)

step_index = (self.timesteps == timestep).nonzero().item()
# step_index = (self.timesteps == timestep).nonzero().item()

sigma = self.sigmas[step_index]

gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
Expand Down

0 comments on commit 0243d2e

Please sign in to comment.