Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Multi-Controlnet, and using separate images for controlnet #53

Open
hzhou17 opened this issue Sep 25, 2023 · 7 comments
Open

Comments

@hzhou17
Copy link

hzhou17 commented Sep 25, 2023

Thanks for the code, it's been working really well!

But I'd like to have further precision with multi-controlnet. (I am using 3D software to create a video first) So I want to use both Canny and Depth. I'd like to use Canny thru the input video and depth images that I rendered myself. How can I modify the code to do so?

@hzhou17
Copy link
Author

hzhou17 commented Sep 25, 2023

I notice that codes below:

    if cfg.control_type == 'HED':
          model.load_state_dict(
              load_state_dict('./models/control_sd15_hed.pth', location='cuda'))
      elif cfg.control_type == 'canny':
          model.load_state_dict(
              load_state_dict('./models/control_v11p_sd15_canny.pth',
                              location='cuda')) 

How should I change the code to load Both of them?

@williamyang1991
Copy link
Owner

Yes, you need to create two controlnet models like model1 and model2 and load different saved pth.

And you will need to extract the controlnet features for each of the controlnet model. (you need to modify cond and the sample function

cond = {

samples, _ = ddim_v_sampler.sample(
)

And to use two controlnet simultanously,
you need to add the multiple control features to the SD feature rather than a single control feature (you need to modify the code)
https://github.com/lllyasviel/ControlNet/blob/ed85cd1e25a5ed592f7d8178495b4483de0331bf/cldm/cldm.py#L35
h += control.pop() to h += control1.pop() + control2.pop()

https://github.com/lllyasviel/ControlNet/blob/ed85cd1e25a5ed592f7d8178495b4483de0331bf/cldm/cldm.py#L41
h = torch.cat([h, hs.pop() + control.pop()], dim=1) to h = torch.cat([h, hs.pop() + control1.pop() + control2.pop()], dim=1)

@hzhou17
Copy link
Author

hzhou17 commented Sep 26, 2023

@williamyang1991 Thank you very much for the reply! But I still need a bit more assistance...

I changed the cldm.py code as you listed above, and I changed the webUI code to be like this:

def update_detector(self, control_type, canny_low=100, canny_high=200):
    #if self.detector_type == control_type:
    #    return

    if control_type == 'canny':
        canny_detector = CannyDetector()
        low_threshold = canny_low
        high_threshold = canny_high

        def apply_canny(x):
            return canny_detector(x, low_threshold, high_threshold)

        self.detector1 = apply_canny

        midas = MidasDetector()

        def apply_midas(x):
            detected_map, _ = midas(x)
            return detected_map

        self.detector2 = apply_midas
       ......

       ddim_v_sampler = global_state.ddim_v_sampler
       model = ddim_v_sampler.model

       ### Changed Here
       detector1 = global_state.detector1 
       detector2 = global_state.detector2 

       controller = global_state.controller
       model.control_scales = [cfg.control_strength] * 13
       ......

        detected_map1 = detector1(img)
        detected_map1 = HWC3(detected_map1)

        control1 = torch.from_numpy(
            detected_map1.copy()).float().cuda() / 255.0
        control1 = torch.stack([control1 for _ in range(num_samples)], dim=0)
        control1 = einops.rearrange(control1, 'b h w c -> b c h w').clone()

        detected_map2 = detector2(img)
        detected_map2 = HWC3(detected_map2)

        control2 = torch.from_numpy(
            detected_map2.copy()).float().cuda() / 255.0
        control2 = torch.stack([control2 for _ in range(num_samples)], dim=0)
        control2 = einops.rearrange(control2, 'b h w c -> b c h w').clone()


        cond = {
            'c_concat': [control1, control2],
            'c_crossattn': [
                model.get_learned_conditioning(
                    [cfg.prompt + ', ' + cfg.a_prompt] * num_samples)
            ]
        }
        un_cond = {
            'c_concat': [control1, control2],
            'c_crossattn':
            [model.get_learned_conditioning([cfg.n_prompt] * num_samples)]
        }

But I ran into this error:

Traceback (most recent call last):
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/gradio/queueing.py", line 388, in call_prediction
output = await route_utils.call_process_api(
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/gradio/route_utils.py", line 219, in call_process_api
output = await app.get_blocks().process_api(
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/gradio/blocks.py", line 1437, in process_api
result = await self.call_function(
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/gradio/blocks.py", line 1109, in call_function
prediction = await anyio.to_thread.run_sync(
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/anyio/to_thread.py", line 33, in run_sync
return await get_asynclib().run_sync_in_worker_thread(
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread
return await future
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/anyio/_backends/asyncio.py", line 807, in run
result = context.run(func, *args)
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/gradio/utils.py", line 650, in wrapper
response = f(*args, **kwargs)
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "webUI.py", line 367, in process1
x_samples, x_samples_np = generate_first_img(img
, first_strength)
File "webUI.py", line 342, in generate_first_img
samples, _ = ddim_v_sampler.sample(
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/heran/Rerender_A_Video/src/ddim_v_hacked.py", line 212, in sample
samples, intermediates = self.ddim_sampling(
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/heran/Rerender_A_Video/src/ddim_v_hacked.py", line 329, in ddim_sampling
outs = self.p_sample_ddim(
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/heran/Rerender_A_Video/src/ddim_v_hacked.py", line 381, in p_sample_ddim
model_t = self.model.apply_model(x, t, c)
File "/home/heran/Rerender_A_Video/deps/ControlNet/cldm/cldm.py", line 337, in apply_model
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/heran/Rerender_A_Video/deps/ControlNet/cldm/cldm.py", line 288, in forward
guided_hint = self.input_hint_block(hint, emb, context)
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/heran/Rerender_A_Video/deps/ControlNet/ldm/modules/diffusionmodules/openaimodel.py", line 86, in forward
x = layer(x)
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 457, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 453, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [16, 3, 3, 3], expected input[1, 6, 512, 704] to have 3 channels, but got 6 channels instead

The RuntimeError suggests that the two conditions are not properly combined, I guess. Would you please take a look and give some suggestions? I'd really appreciate it~

@hzhou17
Copy link
Author

hzhou17 commented Sep 26, 2023

Sorry about the lengthy post... I just realized that I did not change this:

        samples, _ = ddim_v_sampler.sample(

What should I change here?

@williamyang1991
Copy link
Owner

You need change
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)

Since now you have two control_models and two features in cond['c_concat']

@hzhou17
Copy link
Author

hzhou17 commented Sep 26, 2023

@williamyang1991 Thank you very much for the reply. I really appreciate it...

I found that code in cldm.py, but I don't know how to change it. Would you be so kind to show me? I browsed thru https://github.com/Mikubill/sd-webui-controlnet, but could not find how multi-controlnet is implemented there.

def apply_model(self, x_noisy, t, cond, *args, **kwargs):
    assert isinstance(cond, dict)
    diffusion_model = self.model.diffusion_model

    cond_txt = torch.cat(cond['c_crossattn'], 1)

    if cond['c_concat'] is None:
        eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
    else:
        ### Here
        control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
        control = [c * scale for c, scale in zip(control, self.control_scales)]
        eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)

    return eps

@williamyang1991
Copy link
Owner

I'm sorry that I'm under deadline pressure and cannot help you with every details.
And I'm not familar with https://github.com/Mikubill/sd-webui-controlnet.

The main idea is to track everywhere cond['c_concat'] is used and modify the corresponding code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants