### MasaCtrl: Tuning-free Mutual Self-Attention Control for Consistent Image Synthesis and Editing

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from tqdm import tqdm
from einops import rearrange, repeat
from omegaconf import OmegaConf

from diffusers import DDIMScheduler

from masactrl.diffuser_utils import MasaCtrlPipeline
from masactrl.masactrl_utils import AttentionBase
from masactrl.masactrl_utils import regiter_attention_editor_diffusers

from torchvision.utils import save_image
from torchvision.io import read_image
from pytorch_lightning import seed_everything

torch.cuda.set_device(1)  # set the GPU device



#### Model Construction

In [2]:
# Note that you may add your Hugging Face token to get access to the models
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# model_path = "xyn-ai/anything-v4.0"
# model_path = "CompVis/stable-diffusion-v1-4"
# model_path = "runwayml/stable-diffusion-v1-5"
# model_path = '/h3cstore_ns/DatasetDM/SD_2.1'
model_path = '/h3cstore_ns/ydchen/code/DatasetDM/weights/SD1.4'
# model_path = "/h3cstore_ns/DatasetDM/SDXL"
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
model = MasaCtrlPipeline.from_pretrained(model_path, scheduler=scheduler).to(device)

#### Real editing with MasaCtrl

In [8]:
from masactrl.masactrl import MutualSelfAttentionControl
from torchvision.io import read_image


def load_image(image_path, device):
    image = read_image(image_path)
    # deal with 1 channel image
    if image.shape[0] == 1:
        image = image.repeat(3, 1, 1)
    image = image[:3].unsqueeze_(0).float() / 127.5 - 1.  # [-1, 1]
    image = F.interpolate(image, (512, 512))
    image = image.to(device)
    return image


seed = 42
seed_everything(seed)

out_dir = "/data/ydchen/VLP/MasaCtrl/edit_mask_0807/"
os.makedirs(out_dir, exist_ok=True)
sample_count = len(os.listdir(out_dir))
out_dir = os.path.join(out_dir, f"sample_{sample_count}")
os.makedirs(out_dir, exist_ok=True)

# source image
# SOURCE_IMAGE_PATH = "/data/ydchen/VLP/MasaCtrl/gradio_app/images/corgi.jpg"
SOURCE_IMAGE_PATH = "/data/ydchen/VLP/MasaCtrl/图片1.png"
source_image = load_image(SOURCE_IMAGE_PATH, device)

source_prompt = ""
# target_prompt = "a photo of a running corgi"
# target_prompt = "A binary image of A square birdcage"
# target_prompt = "A binary image of Two long park chairs placed opposite each other"
target_prompt = 'A chair with a rear-facing back'
prompts = [source_prompt, target_prompt]
with torch.no_grad():
    # invert the source image
    start_code, latents_list = model.invert(source_image,
                                            source_prompt,
                                            guidance_scale=7.5,
                                            num_inference_steps=80,
                                            return_intermediates=True)
    start_code = start_code.expand(len(prompts), -1, -1, -1)

    # results of direct synthesis
    editor = AttentionBase()
    regiter_attention_editor_diffusers(model, editor)
    image_fixed = model([target_prompt],
                        latents=start_code[-1:],
                        num_inference_steps=80,
                        guidance_scale=7.5)

    # inference the synthesized image with MasaCtrl
    STEP = 4
    LAYPER = 10

    # hijack the attention module
    editor = MutualSelfAttentionControl(STEP, LAYPER)
    regiter_attention_editor_diffusers(model, editor)

    # inference the synthesized image
    image_masactrl = model(prompts,
                        latents=start_code,
                        guidance_scale=7.5)
    # Note: querying the inversion intermediate features latents_list
    # may obtain better reconstruction and editing results
    # image_masactrl = model(prompts,
    #                        latents=start_code,
    #                        guidance_scale=7.5,
    #                        ref_intermediate_latents=latents_list)

    # save the synthesized image
    out_image = torch.cat([source_image * 0.5 + 0.5,
                        image_masactrl[0:1],
                        image_fixed,
                        image_masactrl[-1:]], dim=0)
    save_image(out_image, os.path.join(out_dir, f"all_step{STEP}_layer{LAYPER}.png"))
    save_image(out_image[0], os.path.join(out_dir, f"source_step{STEP}_layer{LAYPER}.png"))
    save_image(out_image[1], os.path.join(out_dir, f"reconstructed_source_step{STEP}_layer{LAYPER}.png"))
    save_image(out_image[2], os.path.join(out_dir, f"without_step{STEP}_layer{LAYPER}.png"))
    save_image(out_image[3], os.path.join(out_dir, f"masactrl_step{STEP}_layer{LAYPER}.png"))

    print("Syntheiszed images are saved in", out_dir)

Global seed set to 42


input text embeddings : torch.Size([1, 77, 768])
latents shape:  torch.Size([1, 4, 64, 64])
Valid timesteps:  tensor([  1,  13,  25,  37,  49,  61,  73,  85,  97, 109, 121, 133, 145, 157,
        169, 181, 193, 205, 217, 229, 241, 253, 265, 277, 289, 301, 313, 325,
        337, 349, 361, 373, 385, 397, 409, 421, 433, 445, 457, 469, 481, 493,
        505, 517, 529, 541, 553, 565, 577, 589, 601, 613, 625, 637, 649, 661,
        673, 685, 697, 709, 721, 733, 745, 757, 769, 781, 793, 805, 817, 829,
        841, 853, 865, 877, 889, 901, 913, 925, 937, 949])


DDIM Inversion:   0%|          | 0/80 [00:00<?, ?it/s]

DDIM Inversion: 100%|██████████| 80/80 [00:11<00:00,  7.09it/s]


input text embeddings : torch.Size([1, 77, 768])
latents shape:  torch.Size([1, 4, 64, 64])


DDIM Sampler: 100%|██████████| 80/80 [00:11<00:00,  7.08it/s]


MasaCtrl at denoising steps:  [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
MasaCtrl at U-Net layers:  [10, 11, 12, 13, 14, 15]
input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler:  80%|████████  | 40/50 [00:09<00:02,  4.34it/s]

In [4]:
1024/768

1.3333333333333333

In [5]:
1.333**2

1.776889

In [1]:
from glob import glob

len_a = len(glob('/h3cstore_ns/ydchen/mask_editting3090_binary_0326/reconstructed/*'))
len_b = len(glob('/h3cstore_ns/ydchen/mask_editting3090_binary_0326/source/*'))
len_c = len(glob('/h3cstore_ns/ydchen/mask_editting3090_binary_0326/with_masactrl/*'))
len_d = len(glob('/h3cstore_ns/ydchen/mask_editting3090_binary_0326/without_masactrl/*'))
len_e = len(glob('/h3cstore_ns/ydchen/mask_editting3090_binary_0326/concat_image/*'))
print(len_a, len_b, len_c, len_d, len_e)

5470 5470 5470 5470 5470


In [None]:
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        attn = dots.softmax(dim=-1)

        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
    
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.attention = Attention(dim, heads = heads, dim_head = dim_head)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attention(x)
        x = x + self.mlp(x)
        return x
    
class WindowAttention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, window_size = 7, relative_pos_embedding = False):
        super().__init__()
        self.dim_head = dim_head
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.window_size = window_size

        if relative_pos_embedding:
            self.relative_position_bias = nn.Parameter(torch.randn((2 * window_size - 1) ** 2, heads))

    def forward(self, x):
        b, n, _, h, window_size = *x.shape, self.heads, self.window_size
        assert n >= window_size, 'window size must be less than or equal to sequence length'

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale

        if hasattr(self, 'relative_position_bias'):
            rel_pos = self.relative_position_bias
            rel_pos = rearrange(rel_pos, 'n d -> () () n d')
            rel_pos = F.pad(rel_pos, (0, 0, 0, n - window_size), value = 0)
            rel_pos = rearrange(rel_pos, '() () (h d) -> 1 1 h d', h = h, d = 2 * window_size - 1)
            dots += rel_pos

        attn = dots.softmax(dim=-1)

        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class SwinTransformer(nn.Module):
    def __init__(self, *, num_classes, num_blocks, dim, heads, dim_head, window_size, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0., patch_size = 4, relative_pos_embedding = False):
        super().__init__()
        self.num_classes = num_classes

        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(channels, dim, kernel_size = patch_size, stride = patch_size),
            Rearrange('b c h w -> b (h w) c')
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, (h := (512 // patch_size)) ** 2, dim))
        self.patch_size = patch_size

        self.dropout = nn.Dropout(emb_dropout)

        self.layers = nn.ModuleList([])
        for _ in range(num_blocks):
            self.layers.append(nn.ModuleList([
                TransformerBlock(dim, heads, dim_head, mlp_dim, dropout = dropout),
                WindowAttention(dim, heads = heads, dim_head = dim_head, window_size = window_size, relative_pos_embedding = relative_pos_embedding)
            ]))

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        pos_embedding = repeat(self.pos_embedding, '() n d -> b n d', b = b)
        x += pos_embedding
        x = self.dropout(x)

        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        x = x.mean(dim = 1) # global average pooling
        return self.mlp_head(x)