Skip to content

Commit

Permalink
Merge pull request #13 from orionaskatu/conflicts
Browse files Browse the repository at this point in the history
changes for inpainting for #35
  • Loading branch information
orionaskatu committed Sep 1, 2022
2 parents 51ac0d6 + 379d8f8 commit 3800e17
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 53 deletions.
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ Run the command to start web ui:
python stable-diffusion-webui/webui.py
```

If you have a 4GB video card, run the command with `--lowvram` argument:
If you have a 4GB video card, run the command with either `--lowvram` or `--medvram` argument:

```
python stable-diffusion-webui/webui.py --lowvram
python stable-diffusion-webui/webui.py --medvram
```

After a while, you will get a message like this:
Expand Down Expand Up @@ -280,17 +280,18 @@ print("Seed was: " + str(processed.seed))
display(processed.images, processed.seed, processed.info)
```

### `--lowvram`
### 4GB videocard support
Optimizations for GPUs with low VRAM. This should make it possible to generate 512x512 images on videocards with 4GB memory.

The original idea of those optimizations is by basujindal: https://github.com/basujindal/stable-diffusion. Model is separated into modules,
and only one module is kept in GPU memory; when another module needs to run, the previous is removed from GPU memory.

It should be obvious but the nature of those optimizations makes the processing run slower -- about 10 times slower
`--lowvram` is a reimplementation of optimization idea from by [basujindal](https://github.com/basujindal/stable-diffusion).
Model is separated into modules, and only one module is kept in GPU memory; when another module needs to run, the previous
is removed from GPU memory. The nature of this optimization makes the processing run slower -- about 10 times slower
compared to normal operation on my RTX 3090.

This is an independent implementation that does not require any modification to original Stable Diffusion code, and
with all code concenrated in one place rather than scattered around the program.
`--medvram` is another optimization that should reduce VRAM usage significantly by not peocessing conditional and
unconditional denoising in a same batch.

This implementation of optimization does not require any modification to original Stable Diffusion code.

### Inpainting
In img2img tab, draw a mask over a part of image, and that part will be in-painted.
Expand Down
110 changes: 66 additions & 44 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
sd_path = os.path.dirname(script_path)

# add parent directory to path; this is where Stable diffusion repo should be
path_dirs = [(sd_path, 'ldm', 'Stable Diffusion'), ('../src/taming-transformers', 'taming', 'Taming Transformers')]
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion'),
('../src/taming-transformers', 'taming', 'Taming Transformers')
]
for d, must_exist, what in path_dirs:
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
if not os.path.exists(must_exist_path):
Expand Down Expand Up @@ -41,15 +44,10 @@
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

# fix gradio phoning home
gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'

# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')


# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8
Expand All @@ -68,14 +66,21 @@
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="embeddings dirtectory for textual inversion (default: embeddings)")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--lowvram", action='store_true', help="enamble stable diffusion model optimizations for low vram")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrficing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrficing a lot of speed for very low VRM usage")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")

parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
cmd_opts = parser.parse_args()

cpu = torch.device("cpu")
gpu = torch.device("cuda")
device = gpu if torch.cuda.is_available() else cpu
batch_cond_uncond = not (cmd_opts.lowvram or cmd_opts.medvram)

if not cmd_opts.share:
# fix gradio phoning home
gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'

css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
Expand Down Expand Up @@ -297,21 +302,25 @@ def first_stage_model_decode_wrap(self, decoder, z):
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model

# the third remaining model is still too big for 4GB, so we also do the same for its submodules
# so that only one of them is in GPU at a time
diff_model = sd_model.model.diffusion_model
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
sd_model.model.to(device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
if cmd_opts.medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else:
diff_model = sd_model.model.diffusion_model

# the third remaining model is still too big for 4GB, so we also do the same for its submodules
# so that only one of them is in GPU at a time
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
sd_model.model.to(device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored

# install hooks for bits of third model
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
# install hooks for bits of third model
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)


def create_random_tensors(shape, seeds):
Expand Down Expand Up @@ -863,7 +872,7 @@ def __init__(self, constructor):
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)

# existing code fail with cetin step counts, like 9
# existing code fails with cetin step counts, like 9
try:
self.sampler.make_schedule(ddim_num_steps=p.steps, verbose=False)
except Exception:
Expand All @@ -890,13 +899,26 @@ class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
self.mask = None
self.nmask = None
self.init_latent = None

def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
if batch_cond_uncond:
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
denoised = uncond + (cond - uncond) * cond_scale
else:
uncond = self.inner_model(x, sigma, cond=uncond)
cond = self.inner_model(x, sigma, cond=cond)
denoised = uncond + (cond - uncond) * cond_scale

if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised

return denoised


class KDiffusionSampler:
Expand All @@ -913,19 +935,13 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):

xi = x + noise

if p.mask is not None:
if p.inpainting_fill == 2:
xi = xi * p.mask + noise * p.nmask
elif p.inpainting_fill == 3:
xi = xi * p.mask

sigma_sched = sigmas[p.steps - t_enc - 1:]

def mask_cb(v):
v["denoised"][:] = v["denoised"][:] * p.nmask + p.init_latent * p.mask

return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=mask_cb if p.mask is not None else None)
self.model_wrap_cfg.mask = p.mask
self.model_wrap_cfg.nmask = p.nmask
self.model_wrap_cfg.init_latent = p.init_latent

return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)

def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
sigmas = self.model_wrap.get_sigmas(p.steps)
Expand All @@ -935,7 +951,7 @@ def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_co
return samples_ddim


Processed = namedtuple('Processed', ['images','seed', 'info'])
Processed = namedtuple('Processed', ['images', 'seed', 'info'])


def process_images(p: StableDiffusionProcessing) -> Processed:
Expand Down Expand Up @@ -1319,7 +1335,6 @@ def init(self):
if self.mask_blur > 0:
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)).convert('L')


if self.inpaint_full_res:
self.mask_for_overlay = self.image_mask
mask = self.image_mask.convert('L')
Expand Down Expand Up @@ -1387,6 +1402,13 @@ def init(self):
self.nmask = torch.asarray(latmask).to(device).type(sd_model.dtype)

def sample(self, x, conditioning, unconditional_conditioning):

if self.mask is not None:
if self.inpainting_fill == 2:
x = x * self.mask + create_random_tensors(x.shape[1:], [self.seed + x + 1 for x in range(x.shape[0])]) * self.nmask
elif self.inpainting_fill == 3:
x = x * self.mask

samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)

if self.mask is not None:
Expand Down Expand Up @@ -1837,10 +1859,10 @@ def Nvidiasmi():
sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
sd_model = (sd_model if cmd_opts.no_half else sd_model.half())

if not cmd_opts.lowvram:
sd_model = sd_model.to(device)
else:
if cmd_opts.lowvram or cmd_opts.medvram:
setup_for_low_vram(sd_model)
else:
sd_model = sd_model.to(device)

model_hijack = StableDiffusionModelHijack()
model_hijack.hijack(sd_model)
Expand Down Expand Up @@ -1887,5 +1909,5 @@ def template_response(*args, **kwargs):
inject_gradio_html(javascript)

demo.queue(concurrency_count=1)
demo.launch()
demo.launch(share=cmd_opts.share)

0 comments on commit 3800e17

Please sign in to comment.