### MasaCtrl: Tuning-free Mutual Self-Attention Control for Consistent Image Synthesis and Editing

In [18]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from tqdm import tqdm
from einops import rearrange, repeat
from omegaconf import OmegaConf

from diffusers import DDIMScheduler, ControlNetModel, StableDiffusionControlNetPipeline
from typing import Optional

from masactrl.diffuser_utils import MasaCtrlPipeline
from masactrl.masactrl_utils import AttentionBase
from masactrl.masactrl_utils import regiter_attention_editor_diffusers

from PIL import Image



from torchvision.utils import save_image
from torchvision.io import read_image
from pytorch_lightning import seed_everything

torch.cuda.set_device(0)  # set the GPU device

In [24]:
class MasaCtrlControlNetPipeline(StableDiffusionControlNetPipeline):
    """
    ControlNet-enabled pipeline that retains prompt-to-prompt attention hijacking.
    """

    def enable_attention_control(self, editor):
        """Register and reset an AttentionStore editor on the UNet"""
        editor.reset()
        regiter_attention_editor_diffusers(self, editor)
        return self

    @torch.no_grad()
    def __call__(
        self,
        prompt,
        control_image,
        batch_size=1,
        height=512,
        width=512,
        num_inference_steps=50,
        guidance_scale=7.5,
        eta=0.0,
        editor: Optional[AttentionBase] = None,
        return_intermediates=False,
        **kwargs
    ):
        # Hook attention editor if provided
        if editor is not None:
            self.enable_attention_control(editor)

        # Delegate to parent to handle ControlNet conditioning and denoising
        return super().__call__(
            prompt=prompt,
            control_image=control_image,
            batch_size=batch_size,
            height=height,
            width=width,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            eta=eta,
            return_intermediates=return_intermediates,
            **kwargs
        )


#### Model Construction

In [20]:
# Note that you may add your Hugging Face token to get access to the models
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# model_path = "xyn-ai/anything-v4.0"
# model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
model_path = "./dreambooth_elmo"
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
model = MasaCtrlPipeline.from_pretrained(model_path, scheduler=scheduler, cross_attention_kwargs={"scale": 0.5}).to(device)
# model.load_textual_inversion("textual_inversion_elmo")
controlnet = ControlNetModel.from_pretrained(
    "lllyasviel/sd-controlnet-openpose"
).to("cuda")

model_edit = MasaCtrlControlNetPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    controlnet=controlnet,
    safety_checker=None,
).to("cuda")


Keyword arguments {'cross_attention_kwargs': {'scale': 0.5}} are not expected by MasaCtrlPipeline and will be ignored.
Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 13.59it/s]
Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 19.81it/s]
You have disabled the safety checker for <class '__main__.MasaCtrlControlNetPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


#### Consistent synthesis with MasaCtrl

In [25]:
from masactrl.masactrl import MutualSelfAttentionControl


seed = 42
# seed_everything(seed)

out_dir = "./workdir/masactrl_exp/"
os.makedirs(out_dir, exist_ok=True)
sample_count = len(os.listdir(out_dir))
out_dir = os.path.join(out_dir, f"sample_{sample_count}")
os.makedirs(out_dir, exist_ok=True)

# control_image_path = "/mnt/hdd/hbchoe/workspace/MasaCtrl/dataset/poses/dance_01.png"
control_image_pose = Image.open("/mnt/hdd/hbchoe/workspace/MasaCtrl/dataset/poses/dance_01.png")


# prompts = [
#     "photo of single red <elmo> standng",  # source prompt
#     "photo of single red <elmo> sitting"  # target prompt
# ]
# prompts = [
#     "photo of single red sks doll standng",  # source prompt
#     "photo of single red sks doll sitting"  # target prompt
# ]

prompts = [
    "photo of a boy standng",  # source prompt
    "photo of a boy sitting"  # target prompt
]

# initialize the noise map
start_code = torch.randn([1, 4, 64, 64], device=device)
start_code = start_code.expand(len(prompts), -1, -1, -1)

# inference the synthesized image without MasaCtrl
editor = AttentionBase()
regiter_attention_editor_diffusers(model, editor)
image_ori = model(prompts, latents=start_code, guidance_scale=7.5)

# inference the synthesized image with MasaCtrl
STEP = 4
LAYPER = 10

# hijack the attention module
editor = MutualSelfAttentionControl(STEP, LAYPER)
regiter_attention_editor_diffusers(model, editor)

# inference the synthesized image
image_masactrl = model_edit(prompts, control_image=control_image_pose, latents=start_code, guidance_scale=7.5, editor=editor)[-1:]

# save the synthesized image
out_image = torch.cat([image_ori, image_masactrl], dim=0)
save_image(out_image, os.path.join(out_dir, f"all_step{STEP}_layer{LAYPER}.png"))
save_image(out_image[0], os.path.join(out_dir, f"source_step{STEP}_layer{LAYPER}.png"))
save_image(out_image[1], os.path.join(out_dir, f"without_step{STEP}_layer{LAYPER}.png"))
save_image(out_image[2], os.path.join(out_dir, f"masactrl_step{STEP}_layer{LAYPER}.png"))

print("Syntheiszed images are saved in", out_dir)

  latents_shape = (batch_size, self.unet.in_channels, height//8, width//8)


input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:18<00:00,  2.69it/s]

MasaCtrl at denoising steps:  [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
MasaCtrl at U-Net layers:  [10, 11, 12, 13, 14, 15]





TypeError: image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is <class 'NoneType'>