# Improvement on Prompt-to-Prompt Image Editing

In [None]:
!pip install diffusers==0.6.0
!pip install transformers==4.24.0 -i https://pypi.python.org/simple
# !pip install accelerate

In [3]:
from typing import Optional, Union, Tuple, List, Callable, Dict
import torch
from diffusers import StableDiffusionPipeline
import torch.nn.functional as nnf
import numpy as np
import abc
import ptp_utils
import seq_aligner
from tqdm.notebook import tqdm

In [None]:
MY_TOKEN = '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'
LOW_RESOURCE = True
NUM_DIFFUSION_STEPS = 50
GUIDANCE_SCALE = 7.5
MAX_NUM_WORDS = 77
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=MY_TOKEN).to(device)
tokenizer = ldm_stable.tokenizer

## Google's Code

In [5]:
class LocalBlend:

    def __call__(self, x_t, attention_store):
        k = 1
        maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
        maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps]
        maps = torch.cat(maps, dim=1)
        maps = (maps * self.alpha_layers).sum(-1).mean(1)
        mask = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k))
        mask = nnf.interpolate(mask, size=(x_t.shape[2:]))
        mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
        mask = mask.gt(self.threshold)
        mask = (mask[:1] + mask[1:]).float()
        x_t = x_t[:1] + mask * (x_t - x_t[:1])
        return x_t
       
    def __init__(self, prompts: List[str], words: [List[List[str]]], threshold=.3):
        alpha_layers = torch.zeros(len(prompts),  1, 1, 1, 1, MAX_NUM_WORDS)
        for i, (prompt, words_) in enumerate(zip(prompts, words)):
            if type(words_) is str:
                words_ = [words_]
            for word in words_:
                ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
                alpha_layers[i, :, :, :, :, ind] = 1
        self.alpha_layers = alpha_layers.to(device)
        self.threshold = threshold


class AttentionControl(abc.ABC):
    
    def step_callback(self, x_t):
        return x_t
    
    def between_steps(self):
        return
    
    @property
    def num_uncond_att_layers(self):
        return self.num_att_layers if LOW_RESOURCE else 0
    
    @abc.abstractmethod
    def forward (self, attn, is_cross: bool, place_in_unet: str):
        raise NotImplementedError

    def __call__(self, attn, is_cross: bool, place_in_unet: str):
        if self.cur_att_layer >= self.num_uncond_att_layers:
            if LOW_RESOURCE:
                attn = self.forward(attn, is_cross, place_in_unet)
            else:
                h = attn.shape[0]
                attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
        self.cur_att_layer += 1
        if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
            self.cur_att_layer = 0
            self.cur_step += 1
            self.between_steps()
        return attn
    
    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0

    def __init__(self):
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0

class EmptyControl(AttentionControl):
    
    def forward (self, attn, is_cross: bool, place_in_unet: str):
        return attn
    
    
class AttentionStore(AttentionControl):

    @staticmethod
    def get_empty_store():
        return {"down_cross": [], "mid_cross": [], "up_cross": [],
                "down_self": [],  "mid_self": [],  "up_self": []}

    def forward(self, attn, is_cross: bool, place_in_unet: str):
        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
        if attn.shape[1] <= 32 ** 2:  # avoid memory overhead
            self.step_store[key].append(attn)
        return attn

    def between_steps(self):
        if len(self.attention_store) == 0:
            self.attention_store = self.step_store
        else:
            for key in self.attention_store:
                for i in range(len(self.attention_store[key])):
                    self.attention_store[key][i] += self.step_store[key][i]
        self.step_store = self.get_empty_store()

    def get_average_attention(self):
        average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
        return average_attention


    def reset(self):
        super(AttentionStore, self).reset()
        self.step_store = self.get_empty_store()
        self.attention_store = {}

    def __init__(self):
        super(AttentionStore, self).__init__()
        self.step_store = self.get_empty_store()
        self.attention_store = {}

        
class AttentionControlEdit(AttentionStore, abc.ABC):
    
    def step_callback(self, x_t):
        if self.local_blend is not None:
            x_t = self.local_blend(x_t, self.attention_store)
        return x_t
        
    def replace_self_attention(self, attn_base, att_replace):
        if att_replace.shape[2] <= 16 ** 2:
            return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
        else:
            return att_replace
    
    @abc.abstractmethod
    def replace_cross_attention(self, attn_base, att_replace):
        raise NotImplementedError
    
    def forward(self, attn, is_cross: bool, place_in_unet: str):
        super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
        if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
            h = attn.shape[0] // (self.batch_size)
            attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
            attn_base, attn_repalce = attn[0], attn[1:]
            if is_cross:
                alpha_words = self.cross_replace_alpha[self.cur_step]
                attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce
                attn[1:] = attn_repalce_new
            else:
                attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
            attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
        return attn
    
    def __init__(self, prompts, num_steps: int,
                 cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
                 self_replace_steps: Union[float, Tuple[float, float]],
                 local_blend: Optional[LocalBlend]):
        super(AttentionControlEdit, self).__init__()
        self.batch_size = len(prompts)
        self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device)
        if type(self_replace_steps) is float:
            self_replace_steps = 0, self_replace_steps
        self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
        self.local_blend = local_blend

class AttentionReplace(AttentionControlEdit):

    def replace_cross_attention(self, attn_base, att_replace):
        return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
      
    def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
                 local_blend: Optional[LocalBlend] = None):
        super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
        self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
        

class AttentionRefine(AttentionControlEdit):

    def replace_cross_attention(self, attn_base, att_replace):
        attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
        attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
        return attn_replace

    def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
                 local_blend: Optional[LocalBlend] = None):
        super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
        self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
        self.mapper, alphas = self.mapper.to(device), alphas.to(device)
        self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])


class AttentionReweight(AttentionControlEdit):

    def replace_cross_attention(self, attn_base, att_replace):
        if self.prev_controller is not None:
            attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
        attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
        return attn_replace

    def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer,
                local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None):
        super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
        self.equalizer = equalizer.to(device)
        self.prev_controller = controller

from PIL import Image

def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
    out = []
    attention_maps = attention_store.get_average_attention()
    num_pixels = res ** 2
    for location in from_where:
        for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
            if item.shape[1] == num_pixels:
                cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
                out.append(cross_maps)
    out = torch.cat(out, dim=0)
    out = out.sum(0) / out.shape[0]
    return out.cpu()


def show_cross_attention(attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
    tokens = tokenizer.encode(prompts[select])
    decoder = tokenizer.decode
    attention_maps = aggregate_attention(attention_store, res, from_where, True, select)
    images = []
    for i in range(len(tokens)):
        image = attention_maps[:, :, i]
        image = 255 * image / image.max()
        image = image.unsqueeze(-1).expand(*image.shape, 3)
        image = image.numpy().astype(np.uint8)
        image = np.array(Image.fromarray(image).resize((256, 256)))
        image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
        images.append(image)
    ptp_utils.view_images(np.stack(images, axis=0))
    

def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
                        max_com=10, select: int = 0):
    attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
    u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
    images = []
    for i in range(max_com):
        image = vh[i].reshape(res, res)
        image = image - image.min()
        image = 255 * image / image.max()
        image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
        image = Image.fromarray(image).resize((256, 256))
        image = np.array(image)
        images.append(image)
    ptp_utils.view_images(np.concatenate(images, axis=1))
    
def run_and_display(prompts, controller, latent=None, run_baseline=False, generator=None):
    if run_baseline:
        print("w.o. prompt-to-prompt")
        images, latent = run_and_display(prompts, EmptyControl(), latent=latent, run_baseline=False, generator=generator)
        print("with prompt-to-prompt")
    images, x_t = ptp_utils.text2image_ldm_stable(ldm_stable, prompts, controller, latent=latent, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=GUIDANCE_SCALE, generator=generator, low_resource=LOW_RESOURCE)
        
    images = ptp_utils.view_images(images)
    return images, x_t

### Basic generation 

Generate an image and visualze the cross-attention maps for each word in the text prompt

In [None]:
g_cpu = torch.Generator().manual_seed(8888)
prompts = ["a fantasy landscape with a pine forest"]
controller = AttentionStore()
image, x_t = run_and_display(prompts, controller, latent=None, run_baseline=False, generator=g_cpu)
show_cross_attention(controller, res=16, from_where=("up", "down"))

## Our Code

### Attention Remove

In [7]:
def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float],
                  Tuple[float, ...]]):
    if type(word_select) is int or type(word_select) is str:
        word_select = (word_select,)
    equalizer = torch.ones(len(values), 77)
    values = torch.tensor(values, dtype=torch.float32)
    for word in word_select:
        inds = ptp_utils.get_word_inds(text, word, tokenizer)
        equalizer[:, inds] = values
    return equalizer

class AttentionRemove(AttentionControlEdit):
    def replace_cross_attention(self, attn_base, att_replace):
        self.equalizer = get_equalizer(prompts[1], (self.remove_word,), (-5,))
        self.equalizer = self.equalizer.to(device)
        if self.prev_controller is not None:
            attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
        attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
        return attn_replace

    def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, remove_word,
                local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None):
        super(AttentionRemove, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
        self.remove_word = remove_word
        self.prev_controller = controller

In [None]:
prompts = ["A kid riding a bicycle with a dog"] * 2

controller = AttentionRemove(prompts, NUM_DIFFUSION_STEPS, cross_replace_steps=.8, remove_word="dog",
                               self_replace_steps=.4)
_ = run_and_display(prompts, controller, latent=x_t, run_baseline=False)

In [None]:
prompts = ["A fantasy landscape with a pine forest"] * 2

controller = AttentionRemove(prompts, NUM_DIFFUSION_STEPS, cross_replace_steps=.8, remove_word="fantasy",
                               self_replace_steps=.4)
_ = run_and_display(prompts, controller, latent=x_t, run_baseline=False)

In [None]:
prompts = ["Tofu soup with croutons"] * 2

controller = AttentionRemove(prompts, NUM_DIFFUSION_STEPS, cross_replace_steps=.8, remove_word="tofu",
                               self_replace_steps=.4)
_ = run_and_display(prompts, controller, latent=x_t, run_baseline=False)

In [None]:
prompts = ["A photo of a house on a mountain"] * 2

controller = AttentionRemove(prompts, NUM_DIFFUSION_STEPS, cross_replace_steps=.8, remove_word="house",
                               self_replace_steps=.4)
_ = run_and_display(prompts, controller, latent=x_t, run_baseline=False)

### Multi-seeds Image Editing

In [12]:
'''

text to image function, which accepts multiple random seeds instead of just one.


modified by Team 8
based on the original function text2image_ldm_stable()

'''

# t8 is for Team 8
@torch.no_grad()
def text2image_ldm_stable_t8(
    model,
    prompt_batch: List[str],
    seed_batch:List[int],
    controller,
    num_inference_steps: int = 50,
    guidance_scale: float = 7.5,
    generator_batch: Optional[torch.Generator] = None,
    low_resource: bool = False,
):


    generator_batch = [torch.Generator().manual_seed(seed) for seed in seed_batch]


    ptp_utils.register_attention_control(model, controller)
    
    height = width = 512
    batch_size = len(prompt_batch)

    text_input = model.tokenizer(
        prompt_batch,
        padding="max_length",
        max_length=model.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
    max_length = text_input.input_ids.shape[-1]
    uncond_input = model.tokenizer(
        [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
    )
    uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
    
    context = [uncond_embeddings, text_embeddings]
    if not low_resource:
        context = torch.cat(context)
    latent_batch = init_latent_batch_t8(model, height, width, generator_batch)
    
    # set timesteps
    extra_set_kwargs = {"offset": 1}
    model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
    for t in tqdm(model.scheduler.timesteps):
        latent_batch = ptp_utils.diffusion_step(model, controller, latent_batch, context, t, guidance_scale, low_resource)
    
    image_batch = ptp_utils.latent2image(model.vae, latent_batch)
  
    return image_batch


'''

initialize the latent batch based on multiple random seeds


modified by Team 8
based on the original function init_latent()

'''
def init_latent_batch_t8( model, height, width, generator_batch):
    latent_batch=None

    for generator in generator_batch:
        latent = torch.randn(
            (1, model.unet.in_channels, height // 8, width // 8),
            generator=generator,
        )

        if latent_batch is None:
            latent_batch=latent
        else:
            latent_batch=torch.cat((latent_batch,latent),0)

    latent_batch = latent_batch.to(model.device)

    return latent_batch

Without local blend

In [None]:
prompt_batch = ["A painting of a squirrel eating a burger",
           "A painting of a lion eating a burger",
           "A painting of a lion eating a burger",
           "A painting of a lion eating a burger",
           ]
seed_batch=[8888,9991,232,2344]

controller = AttentionReplace(prompt_batch, NUM_DIFFUSION_STEPS, cross_replace_steps=0.8, self_replace_steps=0.4,)


image_batch = text2image_ldm_stable_t8(model=ldm_stable, 
                                       prompt_batch=prompt_batch, 
                                       controller=controller, 
                                    num_inference_steps=NUM_DIFFUSION_STEPS, 
                                    guidance_scale=GUIDANCE_SCALE, 
                                    low_resource=LOW_RESOURCE,
                                    seed_batch=seed_batch
                                    )
_=ptp_utils.view_images(image_batch)

With local blend

In [None]:



prompt_batch = ["A painting of a squirrel eating a burger",
           "A painting of a lion eating a burger",
           "A painting of a lion eating a burger",
           "A painting of a lion eating a burger",
           ]

seed_batch=[8888,9991,232,2344]

lb = LocalBlend(prompt_batch, ("squirrel", "lion"))

controller = AttentionReplace(prompt_batch, NUM_DIFFUSION_STEPS, cross_replace_steps=0.8, self_replace_steps=0.4, local_blend=lb)

 
image_batch = text2image_ldm_stable_t8(model=ldm_stable, 
                                       prompt_batch=prompt_batch, 
                                       controller=controller, 
                                    num_inference_steps=NUM_DIFFUSION_STEPS, 
                                    guidance_scale=GUIDANCE_SCALE, 
                                    low_resource=LOW_RESOURCE,
                                    seed_batch=seed_batch
                                    )
ptp_utils.view_images(image_batch)

### Combine Two Images

In [15]:
def get_combine_mapper_(x: str, y: str, z: str, tokenizer, max_len=77):
    words_x = x.split(' ')
    words_y = y.split(' ')
    words_z = z.split(' ')
    inds_replace = [i for i in range(len(words_z)) if words_z[i] != words_x[i]]
    inds_x = [seq_aligner.get_word_inds(x, i, tokenizer) for i in inds_replace]
    inds_replace_y = [i for i in range(len(words_z)) if words_y[i] != words_x[i]]
    inds_y = [seq_aligner.get_word_inds(y, i, tokenizer) for i in inds_replace_y]
    inds_z = [seq_aligner.get_word_inds(z, i, tokenizer) for i in inds_replace]
    mapper = np.zeros((max_len, max_len))
    i = j = k = 0
    cur_inds = 0
    while i < max_len and j < max_len:
        if cur_inds < len(inds_x) and inds_x[cur_inds][0] == i:
            inds_x_, inds_y_, inds_z_ = inds_x[cur_inds], inds_y[cur_inds], inds_z[cur_inds]
            if len(inds_x_) == len(inds_z_):
                mapper[inds_x_, inds_z_] = 1
            else:
                ratio = 1 / len(inds_y_)
                for i_t in inds_z_:
                    mapper[inds_x_, i_t] = ratio
            cur_inds += 1
            i += len(inds_x_)
            j += len(inds_z_)
            k += len(inds_y_)
        elif cur_inds < len(inds_x):
            mapper[i, j] = 1
            i += 1
            j += 1
            k += 1
        else:
            mapper[j, j] = 1
            i += 1
            j += 1
            k += 1
    return torch.from_numpy(mapper).float()

def combine_mapper(prompts, tokenizer, max_len=77):
    x_seq = prompts[0]
    y_seq = prompts[1]
    z_seq = prompts[2]
    mappers = []
    mapper_y = seq_aligner.get_replacement_mapper_(x_seq, y_seq, tokenizer, max_len)
    mapper_z = get_combine_mapper_(x_seq, y_seq, z_seq, tokenizer, max_len)
    # mapper_m = get_combine_mapper_(x_seq, y_seq, tokenizer, max_len)
    # mapper_n = get_combine_mapper_(x_seq, z_seq, tokenizer, max_len)
    # mapper = seq_aligner.get_replacement_mapper_(mapper_m, mapper_n, tokenizer, max_len)
    mappers.append(mapper_y)
    mappers.append(mapper_z)
    return torch.stack(mappers)

In [16]:
class AttentionCombine(AttentionControlEdit):

    def replace_cross_attention(self, attn_base, att_replace):
        return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
      
    def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
                 local_blend: Optional[LocalBlend] = None):
        super(AttentionCombine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
        self.mapper = combine_mapper(prompts, tokenizer).to(device)

In [None]:
prompts = ["A woman playing soccer",
           "A man playing basketball",
           "A woman playing basketball"]

controller = AttentionCombine(prompts, NUM_DIFFUSION_STEPS, cross_replace_steps=.8, self_replace_steps=0.4)
_ = run_and_display(prompts, controller, latent=x_t, run_baseline=False)

### Move Object




In [None]:


def aggregate_attention_t8(attention_store, res: int, from_where: List[str], is_cross: bool, select: int,prompt_batch):
    out = []
    attention_maps = attention_store.get_average_attention()
    num_pixels = res ** 2
    for location in from_where:
        for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
            if item.shape[1] == num_pixels:
              #print("item.shape",item.shape)# [batch_size*h,res*res,max_num_words]
              cross_maps = item.reshape(len(prompt_batch), -1, res, res, item.shape[-1])[select]
              #print("cross_maps.shape",cross_maps.shape)# [h,res,res,max_num_words]
              
              out.append(cross_maps)
    out = torch.cat(out, dim=0)


    out = out.sum(0) / out.shape[0]

    print("out.shape",out.shape)#[res,res,max_num_words]


    return out.cpu()


def show_cross_attention_t8(controller, resolution: int, from_where: List[str], prompt_index,prompt_batch):


  attention_maps = aggregate_attention_t8(controller, resolution, from_where, True, prompt_index,prompt_batch)
  
  
  tokens = tokenizer.encode(prompt_batch[prompt_index])
  decoder = tokenizer.decode
  images = []
  for i in range(len(tokens)):
      image = attention_maps[:, :, i]
      image = 255 * image / image.max()
      image = image.unsqueeze(-1).expand(*image.shape, 3)
      image = image.numpy().astype(np.uint8)
      image = np.array(Image.fromarray(image).resize((256, 256)))
      # image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
      image = text_under_image(image, decoder(int(tokens[i])))
      images.append(image)
  # ptp_utils.view_images(np.stack(images, axis=0))
  view_images(np.stack(images, axis=0))

In [None]:
import math

import matplotlib.pyplot as plt


masks_g=None
latent_batch_g=None




class MoveLocalBlendT8:

    def __call__(self, x_t, attention_store,remove,move):


      


      VALID_BATCH_SIZE=3
      maps=[]
      for attn in attention_store["down_cross"]+ attention_store["up_cross"]:
        if attn.shape[1]== 256:
          maps.append(attn)
    
      
      # maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3];

      # print(maps[0].shape)#[24, 256, 77]
      # print(len(maps))

      h=8

      maps = [item[:VALID_BATCH_SIZE*h].reshape(VALID_BATCH_SIZE, -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps];
      maps = torch.cat(maps, dim=1);#print(maps.shape) #torch.Size([2, 80, 1, 16, 16, 77])
      maps = (maps * self.alpha_layers).sum(-1).mean(1);#print(maps.shape) #torch.Size([2, 1, 16, 16])
      
      k = 0
      masks = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k));#print(masks.shape)#torch.Size([2, 1, 16, 16])
      
      
      masks = nnf.interpolate(masks, size=(x_t.shape[2:]));#print(masks.shape)#torch.Size([2, 1, 64, 64]) # 把 masks 放大了，和 x_t 相同的尺寸。
      masks = masks / masks.max(2, keepdims=True)[0].max(3, keepdims=True)[0];#print(masks.shape)#torch.Size([2, 1, 64, 64]) 把最大值縮放為1。
      masks = masks.gt(self.threshold) ;#print(masks.shape)# gt greater 比較閾值, mask 是 true 或者 false#torch.Size([2, 1, 64, 64])
      # mask = (mask[:1] + mask[1:]).float();print(mask.shape)#torch.Size([1, 1, 64, 64])

      global masks_g, latent_batch_g
      masks_g=masks
      latent_batch_g=x_t


      # x_t shape (batch_size, m=4, res, res)
      # x_t = x_t[:1] + mask * (x_t - x_t[:1])

      if DEBUG:
        plt.imshow(masks_g.cpu().numpy()[0,0])
        plt.show()



      res=x_t.shape[2]


      for i in range(res):
        for j in range(res):
          if move and masks[0,0,i,j]==True and 0<=self.move_down+i<res and 0<=self.move_right+j<res:

            x_t[1,:,i+self.move_down,j+self.move_right]=x_t[0,:,i,j]
      
      for i in range(res):
        for j in range(res):
          if remove and masks[0,0,i,j]==True and 0<=self.move_down+i<res and 0<=self.move_right+j<res:

            # x_t[1,:,i,j]=x_t[2,:,i,j]
            x_t[1,:,i,j]=x_t[0,:,i+self.move_down,j+self.move_right]

      return x_t
       
    def __init__(self, prompt, word, batch_size,move_down,move_right,threshold=0.3):
      VALID_BATCH_SIZE=3

      alpha_layers = torch.zeros(VALID_BATCH_SIZE,  1, 1, 1, 1, MAX_NUM_WORDS)

      ind = get_word_inds(prompt, word, tokenizer)
      alpha_layers[:, :, :, :, :, ind] = 1

      self.alpha_layers = alpha_layers.to(device)
      self.threshold = threshold
      self.batch_size=batch_size
      self.move_down=move_down
      self.move_right=move_right


In [None]:




def show_self_attention_comp_t8(attention_store, res: int, from_where: List[str],
                        max_com=10, select: int = 0,prompt_batch=None):
    attention_maps = aggregate_attention_t8(attention_store, res, from_where, False, select,prompt_batch).numpy().reshape((res ** 2, res ** 2))
    u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
    images = []
    for i in range(max_com):
        image = vh[i].reshape(res, res)
        image = image - image.min()
        image = 255 * image / image.max()
        image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
        image = Image.fromarray(image).resize((256, 256))
        image = np.array(image)
        images.append(image)
    # ptp_utils.view_images(np.concatenate(images, axis=1))
    view_images(np.concatenate(images, axis=1))




class MoveControllerT8():

    

    def __call__(self, attn, is_cross: bool, place_in_unet: str):

        if self.control_mode:
          attn = self.forward(attn, is_cross, place_in_unet)
          self.cur_att_layer += 1
          if self.cur_att_layer == self.num_att_layers:
            self.cur_att_layer = 0
            self.cur_step += 1
            self.between_steps()
          return attn
        else:
          return attn
      


    def forward(self, attn, is_cross: bool, place_in_unet: str):
      

      # attn [batch_size*h,res*res,max_num_words]


      self.store_attention(attn, is_cross, place_in_unet)

      # print("num_self_replace",self.num_self_replace)

      if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
          h = attn.shape[0] // (self.batch_size)
          attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
          attn_base, attn_repalce = attn[0], attn[1:]
          if is_cross:
              alpha_words = self.cross_replace_alpha[self.cur_step]

              # print("alpha_words",alpha_words)
              attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce

              attn[1:] = attn_repalce_new
          else:
              attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
          attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
      return attn

    ######################################################################################

    def store_attention(self, attn, is_cross: bool, place_in_unet: str):
      key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
      if attn.shape[1] <= 32 ** 2:  # avoid memory overhead
          self.step_store[key].append(attn)
      return attn

      
    def get_empty_store(self):
        return {"down_cross": [], "mid_cross": [], "up_cross": [],
                "down_self": [],  "mid_self": [],  "up_self": []}


    def between_steps(self): # 把 step_store 的累加到 attention_store 中，然后清空 step_store 
        if len(self.attention_store) == 0:
            self.attention_store = self.step_store
        else:
            for key in self.attention_store:
                for i in range(len(self.attention_store[key])):
                    self.attention_store[key][i] += self.step_store[key][i]
        self.step_store = self.get_empty_store()

    def get_average_attention(self): # 显示 attention map 的时候要调用的
        average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
        return average_attention


    #######################################################################################

    def step_callback(self, x_t):
      remove=False
      if self.remove_start_step<=self.cur_step<=self.remove_end_step:
        remove=True
      move=False
      if self.move_start_step<=self.cur_step<=self.move_end_step:
        move=True

      # remove=False
      # if 7<=self.cur_step<=15:
      #   remove=True
      # move=False
      # if 7<=self.cur_step<=15:
      #   move=True
      if DEBUG:
        print("current step",self.cur_step)
      x_t = self.local_blend(x_t, self.attention_store,remove=remove,move=move)
      return x_t

        
    def replace_self_attention(self, attn_base, att_replace):
        if att_replace.shape[2] <= 16 ** 2:
            return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
        else:
            return att_replace


    ######################################################################################
    def replace_cross_attention(self, attn_base, attn_replace):
      attn_replace=attn_replace.clone().detach()

      if True:

        # if self.cur_step==1: # print to debug
        #   print("attn_base.shape",attn_base.shape) #[h,res*res,max_num_words]
        # print("attn_replace.shape",attn_replace.shape)#[batch_size-1,h,res*res,max_num_words]

        res=int(math.sqrt(attn_replace.shape[2]))
        h=attn_replace.shape[1]

        map_original=attn_base[:,:,self.word_index].reshape([h,res,res]) 
        # print("image.shape",image_1.shape)
        
        map_new=torch.zeros(h, res,res)
        map_new2=torch.zeros(h, res,res)



        right_move_len=int((self.move_right/64)*res)

        map_new[:,:,right_move_len:]=map_original[:,:,0:res-right_move_len]
        map_new[:,:,0:right_move_len]=map_original[:,:,res-right_move_len:]

        down_move_len=int((self.move_down/64)*res)

        map_new2[:,down_move_len:,:]=map_new[:,0:res-down_move_len,:]
        map_new2[:,0:down_move_len,:]=map_new[:,res-down_move_len:,:]

        
        attn_replace[0,:,:,self.word_index]=map_new2.reshape([h,res*res])


      self.equalizer = self.equalizer.to(device)

      attn_remove = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
      # attn_replace[0]=attn_remove[0]
      attn_replace[1]=attn_remove[0]


      # print("attn_replace.shape",attn_replace.shape)
      return attn_replace

      
    def __init__(self, 
                 prompts, 
                 move_start_step,
                 move_end_step,
                 remove_start_step,
                 remove_end_step,

                move_right,
                 move_down,
                 num_inference_steps: int, 
                 cross_replace_steps: float, 
                 self_replace_steps: float,
                 local_blend: Optional[LocalBlend] = None,
                 word=None,

                 
                 
                 
                 ):
      
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0

        #--------------------
        self.step_store = self.get_empty_store()
        self.attention_store = {}


        #--------------------
        self.batch_size = len(prompts)
        # self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_inference_steps, cross_replace_steps, tokenizer).to(device)
        self.cross_replace_alpha = get_time_words_attention_alpha(prompts, num_inference_steps, cross_replace_steps, tokenizer).to(device)
        
        
        if type(self_replace_steps) is float:
            self_replace_steps = 0, self_replace_steps
        self.num_self_replace = int(num_inference_steps * self_replace_steps[0]), int(num_inference_steps * self_replace_steps[1])
        self.local_blend = local_blend

        self.word_index=get_word_inds(prompts[0],word,tokenizer)[0]
        self.equalizer = get_equalizer(prompts[0], (word,), (-5,))
        self.move_start_step=int(num_inference_steps*move_start_step)
        self.move_end_step=int(num_inference_steps*move_end_step)
        self.remove_start_step=int(num_inference_steps*remove_start_step)
        self.remove_end_step=int(num_inference_steps*remove_end_step)

        # print(self.move_start_step,self.move_end_step,self.remove_start_step,self.remove_end_step)
        self.move_right=move_right
        self.move_down=move_down








class MoveEditorT8():



  def __init__(self,model):
    self.model=model

  def move(self, prompt,word,seed,move_down,move_right,local_blend_threshold,num_inference_steps,
                 move_start_step=None,
                 move_end_step=None,
                 remove_start_step=None,
                 remove_end_step=None,
           cross_replace_steps=None,
           
           
           ):


    # Phase One
    # num_inference_steps=50
    # num_inference_steps=5  # jtcheckpoint
    # replace_controller = ReplaceControllerT8([prompt,prompt], num_inference_steps=num_inference_steps, cross_replace_steps=0.8, self_replace_steps=0.4,local_blend=None)

    # image_batch = self.prompts_to_images(
    #                                    prompt_batch=[prompt,prompt], 
    #                                    controller=replace_controller, 
    #                                 num_inference_steps=num_inference_steps, 
    #                                 guidance_scale=GUIDANCE_SCALE, 
    #                                 low_resource=LOW_RESOURCE,
    #                                 seed_batch=[seed,seed]
    #                                 )
  
    # view_images(image_batch)
    # show_cross_attention_t8(replace_controller, resolution=16, from_where=("up", "down"),prompt_batch=[prompt,prompt]) # jtcheckpoint
    # show_self_attention_comp_t8(replace_controller,16,from_where=("up", "down"),select=0,prompt_batch=[prompt,prompt])
    
    

    # Phase Two
    num_inference_steps=num_inference_steps  # jtcheckpoint
    batch_size=3
    prompt_batch=[prompt]*batch_size
    seed_batch=[seed]*batch_size

    
    local_blend = MoveLocalBlendT8(prompt, word,batch_size=batch_size,threshold=local_blend_threshold,move_down=move_down,move_right=move_right)
    move_controller = MoveControllerT8(prompt_batch, num_inference_steps=num_inference_steps, cross_replace_steps=cross_replace_steps, self_replace_steps=0.4,local_blend=local_blend,word=word,
                 move_start_step=move_start_step,
                 move_end_step=move_end_step,
                 remove_start_step=remove_start_step,
                 remove_end_step=remove_end_step,
                 move_right=move_right,
                 move_down=move_down,
                                       
                                       
                                       
                                       
                                       )

    image_batch = self.prompts_to_images(
                                       prompt_batch=prompt_batch,
                                       controller=move_controller, 
                                    num_inference_steps=num_inference_steps, 
                                    guidance_scale=GUIDANCE_SCALE, 
                                    low_resource=LOW_RESOURCE,
                                    seed_batch=seed_batch
                                    )
  
    view_images(image_batch)

    show_cross_attention_t8(move_controller, resolution=16, from_where=("up", "down"),prompt_index=0,prompt_batch=prompt_batch) # jtcheckpoint
    show_cross_attention_t8(move_controller, resolution=16, from_where=("up", "down"),prompt_index=1,prompt_batch=prompt_batch) # jtcheckpoint
    show_cross_attention_t8(move_controller, resolution=16, from_where=("up", "down"),prompt_index=2,prompt_batch=prompt_batch) # jtcheckpoint
    show_self_attention_comp_t8(move_controller,16,from_where=("up", "down"),select=0,prompt_batch=prompt_batch)
    show_self_attention_comp_t8(move_controller,16,from_where=("up", "down"),select=1,prompt_batch=prompt_batch)



  @torch.no_grad()
  def prompts_to_images(
      self,
      prompt_batch: List[str],
      seed_batch:List[int],
      controller,
      num_inference_steps: int = 50,
      guidance_scale: float = 7.5,
      generator_batch: Optional[torch.Generator] = None,
      low_resource: bool = False,
  ):


      register_attention_control(self.model, controller)
      
      height = width = 512
      batch_size = len(prompt_batch)

      text_input = self.model.tokenizer(
          prompt_batch,
          padding="max_length",
          max_length=self.model.tokenizer.model_max_length,
          truncation=True,
          return_tensors="pt",
      )
      text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
      max_length = text_input.input_ids.shape[-1]
      uncond_input = self.model.tokenizer(
          [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
      )
      uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0] #uncond 是空的
      
      context = [uncond_embeddings, text_embeddings]

      latent_batch = self.init_latent_batch( height, width, seed_batch)
      
      # set timesteps
      extra_set_kwargs = {"offset": 1}
      self.model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
      for t in tqdm(self.model.scheduler.timesteps):
          latent_batch = self.diffusion_step(controller, latent_batch, context, t, guidance_scale, low_resource)
      
      image_batch = latent2image(self.model.vae, latent_batch)
    
      return image_batch


  def init_latent_batch( self, height, width, seed_batch):
      latent_batch=None

      generator_batch = [torch.Generator().manual_seed(seed) for seed in seed_batch]

      for generator in generator_batch:
          latent = torch.randn(
              (1, self.model.unet.in_channels, height // 8, width // 8),
              generator=generator,
          )

          if latent_batch is None:
              latent_batch=latent
          else:
              latent_batch=torch.cat((latent_batch,latent),0)

      latent_batch = latent_batch.to(self.model.device)

      return latent_batch



  def diffusion_step(self,controller, latent_batch, context, t, guidance_scale, low_resource=False):

    controller.control_mode=False
    noise_pred_uncond = self.model.unet(latent_batch, t, encoder_hidden_states=context[0])["sample"] # 空的
    
    controller.control_mode=True
    noise_prediction_text = self.model.unet(latent_batch, t, encoder_hidden_states=context[1])["sample"] 

    noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
    latent_batch = self.model.scheduler.step(noise_pred, t, latent_batch)["prev_sample"]
    latent_batch = controller.step_callback(latent_batch)
    return latent_batch


  def get_original_postion(self,word,attention_store):
    pass





In [None]:
DEBUG=False

move_editor=MoveEditorT8(model=ldm_stable)

# move_editor.move(prompt="A photo of a house on a mountain",
#                  word="house",
#                  seed=8888,
#                  move_down=0,
#                  move_right=40,
#                  local_blend_threshold=0.65,
#                  num_inference_steps=20,
#                  move_start_step=0.5,
#                  move_end_step=1.0,
#                  remove_start_step=0.5,
#                  remove_end_step=1.0,
                 
#                  )


move_editor.move(prompt="A photo of a house on a mountain",
                 word="house",
                 seed=8888,
                 move_down=0,
                 move_right=40,


                 local_blend_threshold=0.7,
                 num_inference_steps=15,
                 move_start_step=0.5,
                 move_end_step=1.0,
                 remove_start_step=0.5,
                 remove_end_step=1.0,
                 cross_replace_steps=1.0
                 )







In [None]:
DEBUG=False
move_editor=MoveEditorT8(model=ldm_stable)
move_editor.move(prompt="a strawberry on the ground",
                 word="strawberry",
                 seed=888,
                 move_down=0,
                 move_right=35,

                 local_blend_threshold=0.3,
                 num_inference_steps=70,
                 move_start_step=0.4,
                 move_end_step=1.0,
                 remove_start_step=0.4,
                 remove_end_step=1.0,
                 cross_replace_steps=1.0
                 )

In [None]:
DEBUG=False
move_editor=MoveEditorT8(model=ldm_stable)
move_editor.move(prompt="a dog in the forest",
                 word="dog",
                 seed=887,
                 move_down=0,
                 move_right=20,

                 local_blend_threshold=0.4,
                 num_inference_steps=40,
                 move_start_step=0.4,
                 move_end_step=1.0,
                 remove_start_step=0.4,
                 remove_end_step=1.0,
                 cross_replace_steps=1.0
                 )