# Diffusion self-guidance for controllable image generation

This notebook is an unofficial implementation of the [Diffusion Self-Guidance for Controllable Image Generation](https://arxiv.org/abs/2306.00986). If you are reading this and want to use it, my suggestion is to take this implementation as a start rather than an end — it works in some cases, but more research is needed to get guaranteed results for each kind of edit.

In [None]:
from __future__ import annotations
import math, random, torch, matplotlib.pyplot as plt, numpy as np, matplotlib as mpl, shutil, os, gzip, pickle, re, copy
from pathlib import Path
from operator import itemgetter
from itertools import zip_longest
from functools import partial
import fastcore.all as fc
from glob import glob

from torch import tensor, nn, optim
import torch.nn.functional as F
from tqdm.auto import tqdm
import torchvision.transforms.functional as TF
from torch.nn import init
from diffusers import LMSDiscreteScheduler, UNet2DConditionModel, AutoencoderKL
from transformers import AutoTokenizer, CLIPTextModel

# from miniai.core import *

from einops import rearrange
from fastprogress import progress_bar
from PIL import Image
from torchvision.io import read_image,ImageReadMode

torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray_r'

#### Helper functions

In [None]:
def add_dims_right(x,y):
    dim = y.ndim - x.ndim
    return x[(...,) + (None,)*dim]

def add_dims_left(x, y):
    dim = y.ndim - x.ndim
    return x[(None,)*dim + (...,)]

In [None]:
def get_embeddings(prompt, concat_unconditional=False, device='cpu'):
    text_input = tokeniser(prompt, padding="max_length", max_length=tokeniser.model_max_length, truncation=True, return_tensors="pt")
    max_length = text_input.input_ids.shape[-1]
    with torch.no_grad():
        embeds = text_encoder(text_input.input_ids)[0]
        if concat_unconditional:
            uncond_input = tokeniser([""], padding="max_length", max_length=max_length, return_tensors="pt")
            uncond_embeddings = text_encoder(uncond_input.input_ids)[0]
            embeds = torch.cat([uncond_embeddings, embeds])
    return embeds.to(device)

In [None]:
def encode_img(input_img):
    if len(input_img.shape)<4: input_img = input_img.unsqueeze(0)
    with torch.no_grad(): latent = vae.encode(input_img*2 - 1)
    return 0.18215 * latent.latent_dist.sample()

In [None]:
def process(image): return (image.clip(-1,1) + 1) / 2

In [None]:
@fc.delegates(plt.Axes.imshow)
def show_image(im, ax=None, figsize=None, title=None, noframe=True, **kwargs):
    "Show a PIL or PyTorch image on `ax`."
    if fc.hasattrs(im, ('cpu','permute','detach')):
        im = im.detach().cpu()
        if len(im.shape)==3 and im.shape[0]<5: im=im.permute(1,2,0)
    elif not isinstance(im,np.ndarray): im=np.array(im)
    if im.shape[-1]==1: im=im[...,0]
    if ax is None: _,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, **kwargs)
    if title is not None: ax.set_title(title)
    ax.set_xticks([]) 
    ax.set_yticks([]) 
    if noframe: ax.axis('off')
    return ax

In [None]:
@fc.delegates(subplots)
def get_grid(
    n:int, # Number of axes
    nrows:int=None, # Number of rows, defaulting to `int(math.sqrt(n))`
    ncols:int=None, # Number of columns, defaulting to `ceil(n/rows)`
    title:str=None, # If passed, title set to the figure
    weight:str='bold', # Title font weight
    size:int=14, # Title font size
    **kwargs,
): # fig and axs
    "Return a grid of `n` axes, `rows` by `cols`"
    if nrows: ncols = ncols or int(np.floor(n/nrows))
    elif ncols: nrows = nrows or int(np.ceil(n/ncols))
    else:
        nrows = int(math.sqrt(n))
        ncols = int(np.floor(n/nrows))
    fig,axs = subplots(nrows, ncols, **kwargs)
    for i in range(n, nrows*ncols): axs.flat[i].set_axis_off()
    if title is not None: fig.suptitle(title, weight=weight, size=size)
    return fig,axs

In [None]:
@fc.delegates(plt.subplots, keep=True)
def subplots(
    nrows:int=1, # Number of rows in returned axes grid
    ncols:int=1, # Number of columns in returned axes grid
    figsize:tuple=None, # Width, height in inches of the returned figure
    imsize:int=3, # Size (in inches) of images that will be displayed in the returned figure
    suptitle:str=None, # Title to be set to returned figure
    **kwargs
): # fig and axs
    "A figure and set of subplots to display images of `imsize` inches"
    if figsize is None: figsize=(ncols*imsize, nrows*imsize)
    fig,ax = plt.subplots(nrows, ncols, figsize=figsize, **kwargs)
    if suptitle is not None: fig.suptitle(suptitle)
    if nrows*ncols==1: ax = np.array([ax])
    return fig,ax

In [None]:
@fc.delegates(subplots)
def show_images(ims:list, # Images to show
                nrows:int=None, # Number of rows in grid
                ncols:int=None, # Number of columns in grid (auto-calculated if None)
                titles:list=None, # Optional list of titles for each image
                **kwargs):
    "Show all images `ims` as subplots with `rows` using `titles`"
    axs = get_grid(len(ims), nrows, ncols, **kwargs)[1].flat
    for im,t,ax in zip_longest(ims, titles or [], axs): show_image(im, ax=ax, title=t)

#### Attention and activation collection/storage

In [None]:
from diffusers.models.attention_processor import AttnProcessor, Attention

def get_features(hook, layer, inp, out):
    if not hasattr(hook, 'feats'): hook.feats = out
    hook.feats = out

class Hook():
    def __init__(self, model, func): self.hook = model.register_forward_hook(partial(func, self))
    def remove(self): self.hook.remove()
    def __del__(self): self.remove()

def get_attn_dict(processor, model):
    attn_procs = {}
    for name in model.attn_processors.keys():
        attn_procs[name] = processor(name=name)
    return attn_procs

class AttnStorage:
    def __init__(self): self.storage = {}
    def __call__(self, attention_map, name, pred_type='orig'): 
        if not name in self.storage: self.storage[name] = {}
        self.storage[name][pred_type] = attention_map
    def flush(self): self.storage = {}

class CustomAttnProcessor(AttnProcessor):
    def __init__(self, attn_storage, name=None): 
        fc.store_attr()
        self.store = False
        self.type = "attn2" if "attn2" in name else "attn1"
    def set_storage(self, store, pred_type): 
        self.store = store
        self.pred_type = pred_type
    def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)
     
        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)
        
        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        attention_probs.requires_grad_(True)
        
        if self.store: self.attn_storage(attention_probs, self.name, pred_type=self.pred_type) ## stores the attention maps in attn_storage
        
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)
        
        return hidden_states

def prepare_attention(model, attn_storage, pred_type='orig', set_store=True):
    for name, module in model.attn_processors.items(): module.set_storage(set_store, pred_type)

#### Self guidance equations

In [None]:
def normalise(x): return (x - x.min()) / (x.max() - x.min())

In [None]:
def threshold_attention(attn, s=10):
    norm_attn = s * (normalise(attn) - 0.5)
    return normalise(norm_attn.sigmoid())

In [None]:
def get_shape(attn, s=20): return threshold_attention(attn, s)
def get_size(attn): return 1/attn.shape[-2] * threshold_attention(attn).sum((1,2)).mean()
def get_centroid(attn):
    if not len(attn.shape) == 3: attn = attn[:,:,None]
    h = w = int(tensor(attn.shape[-2]).sqrt().item())
    hs = torch.arange(h).view(-1, 1, 1).to(attn.device)
    ws = torch.arange(w).view(1, -1, 1).to(attn.device)
    attn = rearrange(attn.mean(0), '(h w) d -> h w d', h=h)
    weighted_w = torch.sum(ws * attn, dim=[0,1])
    weighted_h = torch.sum(hs * attn, dim=[0,1])
    return torch.stack([weighted_w, weighted_h]) / attn.sum((0,1))
def get_appearance(attn, feats):
    if not len(attn.shape) == 3: attn = attn[:,:,None]
    h = w = int(tensor(attn.shape[-2]).sqrt().item())
    shape = get_shape(attn).detach().mean(0).view(h,w,attn.shape[-1])
    feats = feats.mean((0,1))[:,:,None]
    return (shape*feats).sum() / shape.sum()

#### G functions

Single image editing. These are the functions that are closest to the paper. In the experiments section below, I played around with variations on these equations in pursuit of better results.

In [None]:
def fix_shapes(orig_attns, edit_attns, indices, tau=1):
    shapes = []
    for o in indices:
        deltas = []
        for i in range(len(edit_attns)):
            orig, edit = orig_attns[i][:,:,o], edit_attns[i][:,:,o]
            delta = tau*get_shape(orig) - get_shape(edit)
            deltas.append(delta.mean())
        shapes.append(torch.stack(deltas).mean())
    return torch.stack(shapes).mean()

def fix_appearances(orig_attns, orig_feats, edit_attns, edit_feats, indices, attn_idx=-1):
    appearances = []
    for o in indices:
        orig = torch.stack([a[:,:,o] for a in orig_attns[-3:]]).mean(0)
        edit = torch.stack([a[:,:,o] for a in edit_attns[-3:]]).mean(0)
        appearances.append((get_appearance(orig, orig_feats) - get_appearance(edit, edit_feats)).pow(2).mean())
    return torch.stack(appearances).mean()

def fix_sizes(orig_attns, edit_attns, indices, tau=1):
    sizes = []
    for i in range(len(edit_attns)):
        orig, edit = orig_attns[i][:,:,indices], edit_attns[i][:,:,indices]
        sizes.append(tau*get_size(orig) - get_size(edit))
    return torch.stack(sizes).mean()

def position_deltas(orig_attns, edit_attns, indices, target_centroid=None):
    positions = []
    for i in range(len(edit_attns)):
        orig, edit = orig_attns[i][:,:,indices], edit_attns[i][:,:,indices]
        target = tensor(target_centroid) if target_centroid is not None else get_centroid(orig)
        positions.append(target.to(orig.device) - get_centroid(edit))
    return torch.stack(positions).mean()

def fix_selfs(origs, edits):
    shapes = []
    for i in range(len(edits)):
        shapes.append((threshold_attention(origs[i]) - threshold_attention(edits[i])).mean())
    return torch.stack(shapes).mean()

In [None]:
def get_attns(attn_storage, attn_type='attn2'):
    origs = [v['orig'] for k,v in attn_storage.storage.items() if attn_type in k]
    edits = [v['edit'] for k,v in attn_storage.storage.items() if attn_type in k]
    return origs, edits

def edit_layout(attn_storage, indices, appearance_weight=0.5, orig_feats=None, edit_feats=None, **kwargs):
    origs, edits = get_attns(attn_storage)
    return appearance_weight*fix_appearances(origs, orig_feats, edits, edit_feats, indices, **kwargs)

def edit_appearance(attn_storage, indices, shape_weight=1, **kwargs):
    origs, edits = get_attns(attn_storage)
    return shape_weight*fix_shapes(origs, edits, indices)

def resize_object(attn_storage, indices, relative_size=2, shape_weight=1, size_weight=1, appearance_weight=0.1, orig_feats=None, edit_feats=None, **kwargs):
    origs, edits = get_attns(attn_storage)
    if len(indices) > 1: 
        obj_idx, other_idx = indices
        indices = torch.cat([obj_idx, other_idx])
    shape_term = shape_weight*fix_shapes(origs, edits, indices)
    appearance_term = appearance_weight*fix_appearances(origs, orig_feats, edits, edit_feats, indices)
    size_term = size_weight*fix_sizes(origs, edits, indices, tau=relative_size)
    return shape_term + appearance_term + size_term

def move_object(attn_storage, indices, target_centroid=None, shape_weight=1, size_weight=1, appearance_weight=0.5, position_weight=1, orig_feats=None, edit_feats=None, **kwargs):
    origs, edits = get_attns(attn_storage)
    if len(indices) > 1: 
        obj_idx, other_idx = indices
        indices = torch.cat([obj_idx, other_idx])
    shape_term = shape_weight*fix_shapes(origs, edits, indices)
    appearance_term = appearance_weight*fix_appearances(origs, orig_feats, edits, edit_feats, indices)
    size_term = size_weight*fix_sizes(origs, edits, obj_idx)
    position_term = position_weight*position_deltas(origs, edits, obj_idx, target_centroid=target_centroid)
    return shape_term + appearance_term + size_term + position_term

#### Inference loop

In [None]:
def do_self_guidance(t, n, scheduler):
    if type(scheduler).__name__ == "DDPMScheduler":
        if t <= int((3*n)/16): return True
        elif t >= int(n - n/32): return False
        elif t % 2 == 0: return True
        else: return False
    elif type(scheduler).__name__ == "LMSDiscreteScheduler":
        # return True
        if t <= int(n/5): return True
        elif t >= n - 5: return False
        elif t % 2 == 0: return True
        else: return False

In [None]:
def all_word_indexes(prompt, tokeniser, object_to_edit=None, **kwargs):
    """Extracts token indexes by treating all words in the prompt as separate objects."""
    prompt_inputs = tokeniser(prompt, padding="max_length", max_length=tokeniser.model_max_length, truncation=True, return_tensors="pt").input_ids
    if object_to_edit is not None: 
        obj_inputs = tokeniser(object_to_edit, add_special_tokens=False).input_ids
        obj_idx = torch.cat([torch.where(prompt_inputs == o)[1] for o in obj_inputs])
        a = set(torch.cat([torch.where(prompt_inputs != o)[1] for o in obj_inputs]).numpy())
        b = set(torch.where(prompt_inputs < 49405)[1].numpy())
        other_idx = tensor(list(a&b))
        return obj_idx, other_idx
    else: return torch.where(prompt_inputs < 49405)[1]

def choose_object_indexes(prompt, tokeniser, objects:list=None, object_to_edit=None):
    """Extracts token indexes only for user-defined objects."""
    prompt_inputs = tokeniser(prompt, padding="max_length", max_length=tokeniser.model_max_length, truncation=True, return_tensors="pt").input_ids
    if object_to_edit is not None: 
        obj_inputs = tokeniser(object_to_edit, add_special_tokens=False).input_ids
        obj_idx = torch.cat([torch.where(prompt_inputs == o)[1] for o in obj_inputs])
        if object_to_edit in objects: objects.remove(object_to_edit)
    other_idx = []
    for o in objects:
        inps = tokeniser(o, add_special_tokens=False).input_ids
        other_idx.append(torch.cat([torch.where(prompt_inputs == o)[1] for o in inps]))
    if object_to_edit is None: return torch.cat(other_idx)
    else: return obj_idx, torch.cat(other_idx)

In [None]:
def sg_sample(
    prompt,
    model,
    scheduler,
    guidance_func,
    g_weight=10, 
    feature_layer=None, 
    idx_func=all_word_indexes,
    objects:list=None,
    obj_to_edit=None,
    use_same_seed=False,
    seed=None, steps=50, guidance_scale=5., device='cuda', height=512, width=512, return_original=True
):
    if seed is None: seed = int(torch.rand((1,)) * 1000000)
    seed_2 = int(torch.rand((1,)) * 1000000) if not use_same_seed else seed
    
    # set up the custom attn processor and use to replace standard model processors
    storage = AttnStorage()
    processor = partial(CustomAttnProcessor, storage)
    attn_dict = get_attn_dict(processor, model)
    model.set_attn_processor(attn_dict)
    
    # set up the hook to collect activations from feature_layer
    g_name = guidance_func.func.__name__ if isinstance(guidance_func, partial) else guidance_func.__name__
    if g_name not in ['edit_appearance'] and feature_layer is None:
        feature_layer = model.up_blocks[-1].resnets[-2]
    if feature_layer is not None: hook = Hook(feature_layer, get_features)
    
    # get indexes of editable and non-editable objects from token sequence
    if idx_func.__name__ == 'choose_object_indexes' and objects is None:
        raise ValueError('Provide a list of object strings from the prompt.')
    if g_name not in ['edit_layout', 'edit_appearance', 'edit_layout_2'] and obj_to_edit is None:
        raise ValueError('Provide an object string for editing.')
    indices = idx_func(prompt, tokeniser, objects=objects, object_to_edit=obj_to_edit)
    
    # set up embeddings, latents and scheduler
    uncond_embeddings = get_embeddings("", concat_unconditional=False, device=device)
    cond_embeddings = get_embeddings(prompt, concat_unconditional=False, device=device)
    scheduler.set_timesteps(steps)
    scheduler_2 = copy.deepcopy(scheduler)
    shape = (1, model.config.in_channels, height // 8, width // 8)
    orig_latents = torch.randn(shape, generator=torch.manual_seed(seed)).to(device) * scheduler.init_noise_sigma
    edit_latents = torch.randn(shape, generator=torch.manual_seed(seed_2)).to(device) * scheduler.init_noise_sigma
    
    for i, t in enumerate(progress_bar(scheduler.timesteps, leave=False)):
        # calculate noise_pred on the original unedited solution path
        latent_model_input = scheduler.scale_model_input(orig_latents, t) ## note orig_latents
        with torch.no_grad(): 
            # don't store attention for the uncond prediction
            prepare_attention(model, storage, set_store=False)
            uncond = model(latent_model_input, t, encoder_hidden_states=uncond_embeddings).sample

            # do store attention for the cond prediction
            prepare_attention(model, storage, pred_type='orig', set_store=True)
            cond = model(latent_model_input, t, encoder_hidden_states=cond_embeddings).sample
            orig_feats = hook.feats if feature_layer is not None else None
        
        # classifier-free guidance on original solution path
        orig_noise_pred = uncond + guidance_scale * (cond - uncond)
        orig_latents = scheduler.step(orig_noise_pred, t, orig_latents).prev_sample
        
        edit_latents.requires_grad_(True)
        edit_latents.retain_grad()
        
        # recalculate noise_pred for edited solution path and allow grads to flow this time
        latent_model_input = scheduler_2.scale_model_input(edit_latents, t) ## note edit_latents
        prepare_attention(model, storage, set_store=False)
        uncond = model(latent_model_input, t, encoder_hidden_states=uncond_embeddings).sample

        prepare_attention(model, storage, pred_type='edit', set_store=True)
        cond = model(latent_model_input, t, encoder_hidden_states=cond_embeddings).sample
        edit_feats = hook.feats if feature_layer is not None else None
        
        # perform guidance with flexible g function
        edit_noise_pred = uncond + guidance_scale * (cond - uncond)
        if do_self_guidance(i, len(scheduler.timesteps), scheduler):
            g = guidance_func(storage, indices, orig_feats=orig_feats, edit_feats=edit_feats)
            g.backward()
            sig_t = scheduler.sigmas[i]
            edit_noise_pred += g_weight*sig_t*edit_latents.grad
        edit_latents = scheduler_2.step(edit_noise_pred.detach(), t, edit_latents.detach()).prev_sample
        storage.flush()
        
    orig_latents = 1 / 0.18215 * orig_latents
    edit_latents = 1 / 0.18215 * edit_latents

    with torch.no_grad(): edit_img = vae.decode(edit_latents).sample
    if not return_original: return edit_img
    with torch.no_grad(): orig_img = vae.decode(orig_latents).sample
    return orig_img, edit_img

In [None]:
def sample_original(prompt, seed=None, height=512, width=512, steps=50, guidance_scale=5, device='cuda'):
    if seed is None: seed = int(torch.rand((1,)) * 1000000)
    embeddings = get_embeddings(prompt, concat_unconditional=True, device=device)
    scheduler.set_timesteps(steps)
    shape = (1, model.in_channels, height // 8, width // 8)
    latents = torch.randn(shape, generator=torch.manual_seed(seed)).to(device)
    latents = latents * scheduler.init_noise_sigma
    
    for i, t in enumerate(progress_bar(scheduler.timesteps, leave=False)):
        latent_model_input = torch.cat([latents] * 2).to(device)
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)
        with torch.no_grad():
            noise_pred = model(latent_model_input, t, encoder_hidden_states=embeddings).sample

        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        latents = scheduler.step(noise_pred, t, latents).prev_sample
        
    latents = 1 / 0.18215 * latents
    with torch.no_grad(): image = vae.decode(latents).sample
    return image

In [None]:
model = UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='unet').to('cuda')
tokeniser = AutoTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='tokenizer')
text_encoder = CLIPTextModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='text_encoder')
vae = AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae').to('cuda')
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule='scaled_linear')

#### Sample new appearances

#### Sample new layouts

Additional experimental code

In [None]:
def fix_appearances_2(orig_attns, orig_feats, edit_attns, edit_feats, indices, attn_idx=-1):
    appearances = []
    for o in indices: appearances.append((orig_feats - edit_feats).pow(2).mean())
    return torch.stack(appearances).mean()

def edit_layout_2(attn_storage, indices, appearance_weight=0.5, orig_feats=None, edit_feats=None, **kwargs):
    origs, edits = get_attns(attn_storage)
    
    return appearance_weight*fix_appearances_2(origs, orig_feats, edits, edit_feats, indices, **kwargs)

#### Move an object

Additional experimental code.

In [None]:
def roll_shape(x, direction='up', factor=0.5):
    h = w = int(math.sqrt(x.shape[-2]))
    mag = (0,0)
    if direction == 'up': mag = (int(-h*factor),0)
    elif direction == 'down': mag = (int(-h*factor),0)
    elif direction == 'right': mag = (0,int(w*factor))
    elif direction == 'left': mag = (0,int(-w*factor))
    shape = (x.shape[0], h, h, x.shape[-1])
    x = x.view(shape)
    move = x.roll(mag, dims=(1,2))
    return move.view(x.shape[0], h*h, x.shape[-1])

In [None]:
def shift_shape(x, direction='up'): 
    h = w = int(math.sqrt(x.shape[-2]))
    shape = (x.shape[0], h, w, x.shape[-1])
    x = x.view(shape)
    shift = torch.zeros_like(x)
    
    if direction == 'up':
        shift[:, :h//4*3, :, :] = x[:, h//4:, :, :]
        shift[:, h//4*3:, :, :] = x[:, :h//4, :, :]
    elif direction == 'down':
        shift[:, h//4:, :, :] = x[:, :h//4*3, :, :]
        shift[:, :h//4, :, :] = x[:, h//4*3:, :, :]
    elif direction == 'right':
        shift[:, :, :w//4*3, :] = x[:, :, w//4:, :]
        shift[:, :, w//4*3:, :] = x[:, :, :w//4, :]
    elif direction == 'left':
        shift[:, :, w//4:, :] = x[:, :, :w//4*3, :]
        shift[:, :, :w//4, :] = x[:, :, w//4*3:, :]
    
    return shift.view(x.shape[0], h*h, x.shape[-1])

# def shift_shape(x, direction='up'): 
#     h = w = int(math.sqrt(x.shape[-2]))
#     shape = (x.shape[0], h, w, x.shape[-1])
#     x = x.view(shape)
#     shift = torch.zeros_like(x)
    
#     if direction == 'up':
#         shift[:, :h//4, :, :] = x[:, h//4:, :, :]
#     elif direction == 'down':
#         shift[:, h//4:, :, :] = x[:, :h//4, :, :]
#     elif direction == 'right':
#         shift[:, :, :w//4, :] = x[:, :, w//4:, :]
#     elif direction == 'left':
#         shift[:, :, w//4:, :] = x[:, :, :w//4, :]
    
#     return shift.view(x.shape[0], h*h, x.shape[-1])

In [None]:
def fix_shapes_3(orig_attns, edit_attns, indices, tau=fc.noop):
    shapes = []
    for o in indices:
        deltas = []
        for i in range(len(edit_attns)):
            orig, edit = orig_attns[i][:,:,o], edit_attns[i][:,:,o]
            if len(orig.shape) < 3: orig, edit = orig[...,None], edit[...,None]
            delta = (tau(get_shape(orig)) - get_shape(edit)).pow(2).mean()
            deltas.append(delta.mean())
        shapes.append(torch.stack(deltas).mean())
    return torch.stack(shapes).mean()

In [None]:
def fix_selfs_2(origs, edits, t=fc.noop):
    deltas = []
    for i in range(len(edits)):
        orig, edit = origs[i][...,None].mean(0), edits[i]
        delta = t(orig).squeeze() - edit
        deltas.append(delta.mean())
    return torch.stack(deltas).mean()

In [None]:
def move_object(attn_storage, indices, t=fc.noop, shape_weight=1, size_weight=1, self_weight=0.1, appearance_weight=0.5, position_weight=1, orig_feats=None, edit_feats=None, **kwargs):
    origs, edits = get_attns(attn_storage)
    # orig_selfs = [v['orig'] for k,v in attn_storage.storage.items() if 'attn1' in k and v['orig'].shape[-1] == 4096]
    # edit_selfs = [v['edit'] for k,v in attn_storage.storage.items() if 'attn1' in k and v['orig'].shape[-1] == 4096]
    if len(indices) > 1: 
        obj_idx, other_idx = indices
        indices = torch.cat([obj_idx, other_idx])
    shape_term = shape_weight*fix_shapes(origs, edits, obj_idx)
    appearance_term = appearance_weight*fix_appearances_2(origs, orig_feats, edits, edit_feats, indices)
    # size_term = size_weight*fix_sizes(origs, edits, obj_idx)
    # position_term = position_weight*position_deltas_2(origs, edits, obj_idx, target_centroid=target_centroid)
    # self_term = self_weight*fix_selfs_2(orig_selfs, edit_selfs, t=t)
    move_term = position_weight*fix_shapes_3(origs, edits, other_idx, tau=t)
    return move_term + shape_term + appearance_term

#### Resize an object

Additional experimental code.

In [None]:
def enlarge(x, scale_factor=1):
    assert scale_factor >= 1
    h = w = int(math.sqrt(x.shape[-2]))
    x = rearrange(x, 'n (h w) d -> n d h w', h=h)
    x = F.interpolate(x, scale_factor=scale_factor)
    new_h = new_w = x.shape[-1]
    x_l, x_r = (new_w//2) - w//2, (new_w//2) + w//2
    x_t, x_b = (new_h//2) - h//2, (new_h//2) + h//2
    x = x[:,:,x_t:x_b,x_l:x_r]
    return rearrange(x, 'n d h w -> n (h w) d', h=h) * scale_factor

In [None]:
def shrink(x, scale_factor=1):
    assert scale_factor <= 1
    h = w = int(math.sqrt(x.shape[-2]))
    x = rearrange(x, 'n (h w) d -> n d h w', h=h)
    sf = int(1/scale_factor)
    new_h, new_w = h*sf, w*sf
    x1 = torch.zeros(x.shape[0], x.shape[1], new_h, new_w).to(x.device)
    x_l, x_r = (new_w//2) - w//2, (new_w//2) + w//2
    x_t, x_b = (new_h//2) - h//2, (new_h//2) + h//2
    x1[:,:,x_t:x_b,x_l:x_r] = x
    shrink = F.interpolate(x1, scale_factor=scale_factor)
    return rearrange(shrink, 'n d h w -> n (h w) d', h=h) * scale_factor

In [None]:
def resize(x, scale_factor=1):
    if scale_factor > 1: return enlarge(x)
    elif scale_factor < 1: return shrink(x)
    else: return x

In [None]:
def fix_shapes_2(orig_attns, edit_attns, indices, tau=1):
    shapes = []
    for o in indices:
        deltas = []
        for i in range(len(edit_attns)):
            orig, edit = orig_attns[i][:,:,o], edit_attns[i][:,:,o]
            t = orig + (orig.max() * tau)
            delta = (get_shape((orig + t).clip(min=0))) - get_shape(edit)
            deltas.append(delta.mean())
        shapes.append(torch.stack(deltas).mean())
    return torch.stack(shapes).mean()

In [None]:
# def fix_shapes_3(orig_attns, edit_attns, indices, tau=fc.noop):
#     shapes = []
#     for o in indices:
#         deltas = []
#         for i in range(len(edit_attns)):
#             orig, edit = orig_attns[i][:,:,o], edit_attns[i][:,:,o]
#             if len(orig.shape) < 3: orig, edit = orig[...,None], edit[...,None]
#             delta = (tau(get_shape(orig)) - get_shape(edit)).pow(2).mean()
#             deltas.append(delta.mean())
#         shapes.append(torch.stack(deltas).mean())
#     return torch.stack(shapes).mean()

In [None]:
def resize_object_2(attn_storage, indices, t=fc.noop, relative_size=2, shape_weight=1, size_weight=1, appearance_weight=0.1, orig_feats=None, edit_feats=None, self_weight=0.1, **kwargs):
    origs, edits = get_attns(attn_storage)
    # orig_selfs = [v['orig'] for k,v in attn_storage.storage.items() if 'attn1' in k][-1]
    # edit_selfs = [v['edit'] for k,v in attn_storage.storage.items() if 'attn1' in k][-1]
    if len(indices) > 1:
        obj_idx, other_idx = indices
        indices = torch.cat([obj_idx, other_idx])
    shape_term = shape_weight*fix_shapes(origs, edits, other_idx)
    appearance_term = appearance_weight*fix_appearances(origs, orig_feats, edits, edit_feats, indices)
    size_term = size_weight*fix_shapes_2(origs, edits, obj_idx, tau=t)
    # self_term = self_weight*fix_selfs(orig_selfs, edit_selfs)
    return shape_term + appearance_term + size_term