Skip to content

Commit

Permalink
ENSD fixes.
Browse files Browse the repository at this point in the history
* use the eta seed for subsequent torch.randn() calls instead of the first one.
* ENSD only applies to ancestral samplers or dpm_fast and dpm_adaptive.
  • Loading branch information
shiimizu committed Nov 25, 2023
1 parent a602f51 commit 417d544
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
17 changes: 17 additions & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,23 @@ def run(self, *args, **kwargs):
_prepare_noise = partial(prepare_noise, device=device.type)
comfy.sample.prepare_noise = _prepare_noise

import nodes
if not hasattr(nodes, 'common_ksampler_orig'):
nodes.common_ksampler_orig = nodes.common_ksampler
def common_ksampler(*args, **kwargs):
try:
sampler_name = kwargs['sampler_name'] if 'sampler_name' in kwargs else args[4]
uses_ensd = ['ancestral', 'dpm_fast', 'dpm_adaptive']
# disable ENSD if it's not a valid sampler
if not any([x in sampler_name for x in uses_ensd]):
opts.eta_noise_seed_delta = abs(opts.eta_noise_seed_delta) * -1
else:
opts.eta_noise_seed_delta = abs(opts.eta_noise_seed_delta)
except:
print("[smZNodes] Error getting sampler_name for ENSD")
return nodes.common_ksampler_orig(*args, **kwargs)
nodes.common_ksampler = common_ksampler

return (_any,)

# A dictionary that contains all nodes you want to export with their names
Expand Down
27 changes: 16 additions & 11 deletions smZNodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def randn_without_seed(x, generator=None, randn_source="cpu"):
Use either randn() or manual_seed() to initialize the generator."""
if randn_source == "nv":
return torch.asarray((generator or nv_rng).randn(x.size()), device=x.device)
return torch.asarray(generator.randn(x.size()), device=x.device)
else:
if generator is not None and generator.device.type == "cpu":
return torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=devices.cpu, generator=generator).to(device=x.device)
Expand Down Expand Up @@ -579,24 +579,29 @@ def prepare_noise(latent_image, seed, noise_inds=None, device='cpu'):
"""
from .modules.shared import opts
from comfy.sample import np
if opts.eta_noise_seed_delta != 0:
def get_generator(seed):
nonlocal device
nonlocal opts
_generator = torch.Generator(device=device)
generator = _generator.manual_seed(seed)
if opts.randn_source == 'nv':
generator = rng_philox.Generator(seed)
return generator
generator = generator_eta = get_generator(seed)

if opts.eta_noise_seed_delta > 0:
seed = min(int(seed + opts.eta_noise_seed_delta), int(0xffffffffffffffff))
_generator = torch.Generator(device=device)
generator = _generator.manual_seed(seed)

if opts.randn_source == 'nv':
global nv_rng
rng = nv_rng = rng_philox.Generator(seed)
generator = None
generator_eta = get_generator(seed)


# hijack randn_like
import comfy.k_diffusion.sampling
comfy.k_diffusion.sampling.torch = TorchHijack(generator, opts.randn_source)
comfy.k_diffusion.sampling.torch = TorchHijack(generator_eta, opts.randn_source)

if noise_inds is None:
shape = latent_image.size()
if opts.randn_source == 'nv':
return torch.asarray(rng.randn(shape), device=devices.cpu)
return torch.asarray(generator.randn(shape), device=devices.cpu)
else:
return torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, device=device, generator=generator)

Expand Down

0 comments on commit 417d544

Please sign in to comment.