# Mutual Self-Attention + ControlNet
This notebook supports image generation using the Mutual Self-Attention + ControlNet pipeline proposed in the report.

The model that supports ControlNet, `MasaCtrlControlNetPipeline`, is implemented in `diffuser_utils.py`.


The code work as follows:

You first register the `MutualSelfAttention` editor to the `MasaCtrlControlNetPipeline` using the `register_attention_editor_diffusers` function.
Then, if you generate an image using a paired prompt in format ['source prompt', 'edited prompt'] along with a conditioning input image, the output will be saved in the `workdir/exp` directory.


# Install Dependencies

In [None]:
!pip install -r requirements.txt

In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

### Mutual Self-Attention + ControlNet

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from tqdm import tqdm
from einops import rearrange, repeat
from omegaconf import OmegaConf

from diffusers import DDIMScheduler, ControlNetModel

from MasaCtrl.masactrl.diffuser_utils import MasaCtrlPipeline, MasaCtrlControlNetPipeline
from MasaCtrl.masactrl.masactrl_utils import AttentionBase
from MasaCtrl.masactrl.masactrl_utils import regiter_attention_editor_diffusers

from torchvision.utils import save_image
from torchvision.io import read_image
from pytorch_lightning import seed_everything

torch.cuda.set_device(0)  # set the GPU device

#### Model Construction

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose")

model = MasaCtrlControlNetPipeline.from_pretrained(model_path, controlnet=controlnet, scheduler=scheduler, cross_attention_kwargs={"scale": 0.5}).to(device)

#### Consistent synthesis with MasaCtrl

In [None]:
from MasaCtrl.masactrl.masactrl import MutualSelfAttentionControl
from PIL import Image


seed = 42
seed_everything(seed)

out_dir = "./workdir/exp/"
os.makedirs(out_dir, exist_ok=True)
sample_count = len(os.listdir(out_dir))
out_dir = os.path.join(out_dir, f"sample_{sample_count}")
os.makedirs(out_dir, exist_ok=True)

prompts = [
    "1boy, casual, outdoors, standing",  # source prompt
    "1boy, casual, outdoors, dancing"  # target prompt
]

condition_image = "dataset/poses/dance_03.png"
# load the condition image
condition_image = read_image(condition_image).float() / 255.0
# rgba to rgb conversion
if condition_image.shape[0] == 4:
    condition_image = condition_image[:3, :, :]
    # resize to 512x512
condition_image = F.interpolate(condition_image.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=False)
condition_image = condition_image.to(device)
zero_condition = torch.zeros_like(condition_image)
condition = torch.cat([zero_condition, condition_image], dim=0)  # concatenate the condition image and zero condition

# initialize the noise map
start_code = torch.randn([1, 4, 64, 64], device=device)
start_code = start_code.expand(len(prompts), -1, -1, -1)



In [None]:
# inference the synthesized image without MasaCtrl
editor = AttentionBase()
regiter_attention_editor_diffusers(model, editor)
image_ori = model(prompts, controlnet_conditioning=condition, latents=start_code, guidance_scale=7.5)

In [None]:
# inference the synthesized image with MasaCtrl
STEP = 4
LAYER = 10

# hijack the attention module
editor = MutualSelfAttentionControl(STEP, LAYER)
regiter_attention_editor_diffusers(model, editor)

# inference the synthesized image
image_masactrl = model(prompts, controlnet_conditioning=condition, latents=start_code, guidance_scale=7.5)[-1:]

# save the synthesized image
out_image = torch.cat([image_ori, image_masactrl], dim=0)
save_image(out_image, os.path.join(out_dir, f"all_step{STEP}_layer{LAYER}.png"))
save_image(out_image[0], os.path.join(out_dir, f"source_step{STEP}_layer{LAYER}.png"))
save_image(out_image[1], os.path.join(out_dir, f"without_step{STEP}_layer{LAYER}.png"))
save_image(out_image[2], os.path.join(out_dir, f"masactrl_step{STEP}_layer{LAYER}.png"))

print("Syntheiszed images are saved in", out_dir)

# Batch Generation

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from tqdm import tqdm
from einops import rearrange, repeat
from omegaconf import OmegaConf

from diffusers import DDIMScheduler, ControlNetModel

from MasaCtrl.masactrl.diffuser_utils import MasaCtrlPipeline, MasaCtrlControlNetPipeline
from MasaCtrl.masactrl.masactrl_utils import AttentionBase
from MasaCtrl.masactrl.masactrl_utils import regiter_attention_editor_diffusers

from torchvision.utils import save_image
from torchvision.io import read_image
from pytorch_lightning import seed_everything

torch.cuda.set_device(0)  # set the GPU device

In [None]:
from MasaCtrl.masactrl.masactrl import MutualSelfAttentionControl
from PIL import Image


seed = 42
seed_everything(seed)



In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose")

model = MasaCtrlControlNetPipeline.from_pretrained(model_path, controlnet=controlnet, scheduler=scheduler, cross_attention_kwargs={"scale": 0.5}).to(device)

In [None]:
prompts = [
    "highly detailed, 1boy, standing, facing camera, full body portrait, full-length portrait",  # source prompt
    "highly detailed, 1boy, dancing, facing camera, full body portrait, full-length portrait"  # target prompt
]

condition_image = "dataset/poses/dance_01.png"
# load the condition image
condition_image = read_image(condition_image).float() / 255.0
# rgba to rgb conversion
if condition_image.shape[0] == 4:
    condition_image = condition_image[:3, :, :]
    # resize to 512x512
condition_image = F.interpolate(condition_image.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=False)
condition_image = condition_image.to(device)
zero_condition = torch.zeros_like(condition_image)
condition = torch.cat([zero_condition, condition_image], dim=0)  # concatenate the condition image and zero condition

# initialize the noise map
start_code = torch.randn([1, 4, 64, 64], device=device)
start_code = start_code.expand(len(prompts), -1, -1, -1)


In [None]:
# inference the synthesized image without MasaCtrl
editor = AttentionBase()
regiter_attention_editor_diffusers(model, editor)
image_ori = model(prompts, controlnet_conditioning=condition, latents=start_code, guidance_scale=7.5)

In [None]:
from torchvision.transforms import ToPILImage
# Convert the PyTorch tensor to PIL image before saving
ToPILImage()(image_ori[0].cpu()).save("final_test22/final_test_zero_cond_original.png")
ToPILImage()(image_ori[1].cpu()).save("final_test22/final_test_zero_cond_without.png")


In [None]:
import glob
import os
from torchvision.utils import save_image

STEP = 4
LAYER = 10
# sequential generation

folder_path = "/mnt/hdd/hbchoe/workspace/MasaCtrl/dataset/poses"
output_folder = "final_test22"
control_image_files = sorted(glob.glob(f"{folder_path}/*.png"))

# conditioning image preprocess
condition_image = "/mnt/hdd/hbchoe/workspace/MasaCtrl/dataset/poses/dance_03.png"



for file in control_image_files[:5]:
    # load the condition image
    condition_image = read_image(file).float() / 255.0
    # rgba to rgb conversion
    if condition_image.shape[0] == 4:
        condition_image = condition_image[:3, :, :]
        # resize to 512x512
    condition_image = F.interpolate(condition_image.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=False)
    condition_image = condition_image.to(device)
    zero_condition = torch.zeros_like(condition_image)
    condition = torch.cat([zero_condition, condition_image], dim=0)  # concatenate the condition image and zero condition

    # inference the synthesized image with MasaCtrl
    STEP = 4
    LAYER = 10

    # hijack the attention module
    # editor = MutualSelfAttentionControl(STEP, LAYER)
    editor = MutualSelfAttentionControl(STEP, LAYER)
    regiter_attention_editor_diffusers(model, editor)

    # inference the synthesized image
    image_masactrl = model(prompts, controlnet_conditioning=condition, latents=start_code, guidance_scale=7.5)[-1:]
    # Save the edited image
    file_name, file_ext = os.path.splitext(os.path.basename(file))
    # image_masactrl.save(f"{output_folder}/final_test_{file_name}.png")  # with attention hijack
    save_image(image_masactrl, f"{output_folder}/final_test_{file_name}.png")  # with attention hijack
