## Copyright 2022 Google LLC. Double-click for license information.

In [1]:
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Prompt-to-Prompt with Stable Diffusion

In [2]:
from typing import Optional, Union, Tuple, List, Callable, Dict
import torch
from diffusers import StableDiffusionPipeline
import torch.nn.functional as nnf
import numpy as np
import abc
import ptp_utils
import seq_aligner



For loading the Stable Diffusion using Diffusers, follow the instuctions https://huggingface.co/blog/stable_diffusion and update ```MY_TOKEN``` with your token.
Set ```LOW_RESOURCE``` to ```True``` for running on 12GB GPU.

In [3]:
LOW_RESOURCE = True 
NUM_DIFFUSION_STEPS = 20
GUIDANCE_SCALE = 7.5
MAX_NUM_WORDS = 200
model_id = "runwayml/stable-diffusion-v1-5"
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
ldm_stable = StableDiffusionPipeline.from_pretrained(model_id).to(device)
tokenizer = ldm_stable.tokenizer

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [4]:
NUM_DIFFUSION_STEPS = 25

In [5]:
from diffusers import DPMSolverMultistepScheduler

ldm_stable.safety_checker = None
ldm_stable.scheduler = DPMSolverMultistepScheduler.from_config(ldm_stable.scheduler.config)
ldm_stable.enable_attention_slicing()
ldm_stable.enable_vae_slicing()
ldm_stable.enable_model_cpu_offload()
ldm_stable.vae.to("cpu").float()  # float32 for CPU compatibility

AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (c

## Prompt-to-Prompt Attnetion Controllers
Our main logic is implemented in the `forward` call in an `AttentionControl` object.
The forward is called in each attention layer of the diffusion model and it can modify the input attnetion weights `attn`.

`is_cross`, `place_in_unet in ("down", "mid", "up")`, `AttentionControl.cur_step` help us track the exact attention layer and timestamp during the diffusion iference.


In [6]:
class LocalBlend:

    def __call__(self, x_t, attention_store):
        k = 1
        maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
        maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps]
        maps = torch.cat(maps, dim=1)
        maps = (maps * self.alpha_layers).sum(-1).mean(1)
        mask = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k))
        mask = nnf.interpolate(mask, size=(x_t.shape[2:]))
        mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
        mask = mask.gt(self.threshold)
        mask = (mask[:1] + mask[1:]).float()
        x_t = x_t[:1] + mask * (x_t - x_t[:1])
        return x_t
       
    def __init__(self, prompts: List[str], words: [List[List[str]]], threshold=.3):
        alpha_layers = torch.zeros(len(prompts),  1, 1, 1, 1, MAX_NUM_WORDS)
        for i, (prompt, words_) in enumerate(zip(prompts, words)):
            if type(words_) is str:
                words_ = [words_]
            for word in words_:
                ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
                alpha_layers[i, :, :, :, :, ind] = 1
        self.alpha_layers = alpha_layers.to(device)
        self.threshold = threshold


class AttentionControl(abc.ABC):
    
    def step_callback(self, x_t):
        return x_t
    
    def between_steps(self):
        return
    
    @property
    def num_uncond_att_layers(self):
        return self.num_att_layers if LOW_RESOURCE else 0
    
    @abc.abstractmethod
    def forward (self, attn, is_cross: bool, place_in_unet: str):
        raise NotImplementedError

    def __call__(self, attn, is_cross: bool, place_in_unet: str):
        if self.cur_att_layer >= self.num_uncond_att_layers:
            if LOW_RESOURCE:
                attn = self.forward(attn, is_cross, place_in_unet)
            else:
                h = attn.shape[0]
                attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
        self.cur_att_layer += 1
        if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
            self.cur_att_layer = 0
            self.cur_step += 1
            self.between_steps()
        return attn
    
    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0

    def __init__(self):
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0

class EmptyControl(AttentionControl):
    
    def forward (self, attn, is_cross: bool, place_in_unet: str):
        return attn
    
    
class AttentionStore(AttentionControl):

    @staticmethod
    def get_empty_store():
        return {"down_cross": [], "mid_cross": [], "up_cross": [],
                "down_self": [],  "mid_self": [],  "up_self": []}

    def forward(self, attn, is_cross: bool, place_in_unet: str):
        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
        if attn.shape[1] <= 32 ** 2:  # avoid memory overhead
            self.step_store[key].append(attn)
        return attn

    def between_steps(self):
        if len(self.attention_store) == 0:
            self.attention_store = self.step_store
        else:
            for key in self.attention_store:
                for i in range(len(self.attention_store[key])):
                    self.attention_store[key][i] += self.step_store[key][i]
        self.step_store = self.get_empty_store()

    def get_average_attention(self):
        average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
        return average_attention


    def reset(self):
        super(AttentionStore, self).reset()
        self.step_store = self.get_empty_store()
        self.attention_store = {}

    def __init__(self):
        super(AttentionStore, self).__init__()
        self.step_store = self.get_empty_store()
        self.attention_store = {}

        
class AttentionControlEdit(AttentionStore, abc.ABC):
    
    def step_callback(self, x_t):
        if self.local_blend is not None:
            x_t = self.local_blend(x_t, self.attention_store)
        return x_t
        
    def replace_self_attention(self, attn_base, att_replace):
        if att_replace.shape[2] <= 16 ** 2:
            return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
        else:
            return att_replace
    
    @abc.abstractmethod
    def replace_cross_attention(self, attn_base, att_replace):
        raise NotImplementedError
    
    def forward(self, attn, is_cross: bool, place_in_unet: str):
        super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
        if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
            h = attn.shape[0] // (self.batch_size)
            attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
            attn_base, attn_repalce = attn[0], attn[1:]
            if is_cross:
                alpha_words = self.cross_replace_alpha[self.cur_step]
                attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce
                attn[1:] = attn_repalce_new
            else:
                attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
            attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
        return attn
    
    def __init__(self, prompts, num_steps: int,
                 cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
                 self_replace_steps: Union[float, Tuple[float, float]],
                 local_blend: Optional[LocalBlend]):
        super(AttentionControlEdit, self).__init__()
        self.batch_size = len(prompts)
        self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device)
        if type(self_replace_steps) is float:
            self_replace_steps = 0, self_replace_steps
        self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
        self.local_blend = local_blend

class AttentionReplace(AttentionControlEdit):

    def replace_cross_attention(self, attn_base, att_replace):
        return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
      
    def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
                 local_blend: Optional[LocalBlend] = None):
        super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
        self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
        

class AttentionRefine(AttentionControlEdit):

    def replace_cross_attention(self, attn_base, att_replace):
        attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
        attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
        return attn_replace

    def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
                 local_blend: Optional[LocalBlend] = None):
        super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
        self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
        self.mapper, alphas = self.mapper.to(device), alphas.to(device)
        self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])


class AttentionReweight(AttentionControlEdit):

    def replace_cross_attention(self, attn_base, att_replace):
        if self.prev_controller is not None:
            attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
        attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
        return attn_replace

    def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer,
                local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None):
        super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
        self.equalizer = equalizer.to(device)
        self.prev_controller = controller


def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float],
                  Tuple[float, ...]]):
    if type(word_select) is int or type(word_select) is str:
        word_select = (word_select,)
    equalizer = torch.ones(len(values), 77)
    values = torch.tensor(values, dtype=torch.float32)
    for word in word_select:
        inds = ptp_utils.get_word_inds(text, word, tokenizer)
        equalizer[:, inds] = values
    return equalizer


In [7]:
from PIL import Image

def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
    out = []
    attention_maps = attention_store.get_average_attention()
    num_pixels = res ** 2
    for location in from_where:
        for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
            if item.shape[1] == num_pixels:
                cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
                out.append(cross_maps)
    out = torch.cat(out, dim=0)
    out = out.sum(0) / out.shape[0]
    return out.cpu()


def show_cross_attention(attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
    tokens = tokenizer.encode(prompts[select])
    decoder = tokenizer.decode
    attention_maps = aggregate_attention(attention_store, res, from_where, True, select)
    images = []
    for i in range(len(tokens)):
        image = attention_maps[:, :, i]
        image = 255 * image / image.max()
        image = image.unsqueeze(-1).expand(*image.shape, 3)
        image = image.numpy().astype(np.uint8)
        image = np.array(Image.fromarray(image).resize((256, 256)))
        image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
        images.append(image)
    ptp_utils.view_images(np.stack(images, axis=0))
    

def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
                        max_com=10, select: int = 0):
    attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
    u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
    images = []
    for i in range(max_com):
        image = vh[i].reshape(res, res)
        image = image - image.min()
        image = 255 * image / image.max()
        image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
        image = Image.fromarray(image).resize((256, 256))
        image = np.array(image)
        images.append(image)
    ptp_utils.view_images(np.concatenate(images, axis=1))

In [8]:
def run_and_display(prompts, controller, latent=None, run_baseline=False, generator=None):
    if run_baseline:
        print("w.o. prompt-to-prompt")
        images, latent = run_and_display(prompts, EmptyControl(), latent=latent, run_baseline=False, generator=generator)
        print("with prompt-to-prompt")
    images, x_t = ptp_utils.text2image_ldm_stable(ldm_stable, prompts, controller, latent=latent, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=GUIDANCE_SCALE, generator=generator, low_resource=LOW_RESOURCE)
    ptp_utils.view_images(images)
    return images, x_t

In [9]:
# p_cross = random.uniform(0.0, 1.0)
# p_self = random.uniform(0.0, 1.0)

# prompts = ["photograph of a girl riding a horse", "photograph of a girl riding a dick"] 

# generate_and_save_image(prompts, p_cross, p_self)


## CLIP based filtering

In [13]:
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from transformers import CLIPProcessor, CLIPModel
import torch

In [14]:
clip_model_name='openai/clip-vit-base-patch32'
model = CLIPModel.from_pretrained(
    clip_model_name,
    device_map="auto",
    offload_folder="./offload",  # local folder where CPU-offloaded weights go
    offload_state_dict=True,   # offload weights to CPU when not in use
    torch_dtype=torch.float16,
    low_cpu_mem_usage=False,
)

processor = CLIPProcessor.from_pretrained(clip_model_name)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [15]:
device = model.device.type
device

'cuda'

In [16]:
def get_image_embedding(image):
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        embeddings = model.get_image_features(**inputs)
    return embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)

def get_text_embedding(text):
    inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        embeddings = model.get_text_features(**inputs)
    return embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)


In [17]:
def cosine_similarity(a, b):
    return (a @ b.T).item()

In [18]:

def respects_filter(caption_before, image_before, caption_after, image_after):
    img_emb_before = get_image_embedding(image_before)
    img_emb_after = get_image_embedding(image_after)
    txt_emb_before = get_text_embedding(caption_before)
    txt_emb_after = get_text_embedding(caption_after)
    
    # Image-Image similarity
    img_img_sim = cosine_similarity(img_emb_before, img_emb_after)
    
    # Image-Caption similarities
    img_txt_sim_before = cosine_similarity(img_emb_before, txt_emb_before)
    img_txt_sim_after = cosine_similarity(img_emb_after, txt_emb_after)
    
    # Directional similarity
    # Vector difference in image embeddings and text embeddings
    img_dir = img_emb_after - img_emb_before
    txt_dir = txt_emb_after - txt_emb_before
    img_dir = img_dir / img_dir.norm()
    txt_dir = txt_dir / txt_dir.norm()
    directional_sim = cosine_similarity(img_dir, txt_dir)

    ## debug statement
    recorded_sim = {
        'img_img_sim': img_img_sim,
        'img_txt_sim_input': img_txt_sim_before,
        'img_txt_sim_output': img_txt_sim_after,
        'directional_sim': directional_sim
    }    
    if (
        img_img_sim >= 0.75 and
        img_txt_sim_before >= 0.2 and
        img_txt_sim_after >= 0.2 and
        directional_sim >= 0.05
    ):
        recorded_sim['is_similar'] = True
        return recorded_sim
    else:
        recorded_sim['is_similar'] = False
        return recorded_sim

## Image generation using prompt-to-prompt method

In [16]:
import random
import gc
import torch

generator = torch.Generator(device="cuda").manual_seed(1988)

def clean_cuda():
    torch.cuda.empty_cache()
    gc.collect()
    ldm_stable.enable_model_cpu_offload()




In [17]:
def save_metadata_records(file_metadata, my_dict):
    with open(file_metadata, 'a', encoding='utf-8') as f:
        f.write(json.dumps(my_dict) + '\n')  # Append as one JSON object per line

In [18]:
def generate_and_save_image(caption_data, p_cross, p_self, new_guidance_scale):
    latents = ldm_stable.prepare_latents(
    batch_size=1,
    height=512,
    width=512,
    generator=generator,
    num_channels_latents=4,  # typically 4 for Stable Diffusion latents
    dtype=torch.float32,
    device=generator.device,
    )

    prompts = [caption_data['input_caption'], caption_data['output_caption']]

    controller = AttentionRefine(prompts, NUM_DIFFUSION_STEPS, cross_replace_steps=p_cross, self_replace_steps=p_self)
    images, x_t = ptp_utils.text2image_ldm_stable(ldm_stable, prompts, controller, latent=latents, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=new_guidance_scale, generator=generator, low_resource=LOW_RESOURCE)

    
    inp_img = Image.fromarray(images[0])
    out_img = Image.fromarray(images[1])

    recorded_sim = respects_filter(prompts[0], inp_img, prompts[1], out_img)
    # only use it as a dataset when images are similar
    if recorded_sim['is_similar']:
        caption_data['recorded_sim'] = recorded_sim
        caption_data['input_image_path'] = './input_images/'+str(caption_data['surrogate_id'])+'.webp'
        caption_data['output_image_path'] = './output_images/'+str(caption_data['surrogate_id'])+'.webp'

        save_metadata_records('meta_data_file.json', caption_data)

        inp_img.save(caption_data['input_image_path'], format="WEBP", quality=20, method=6)    
        out_img.save(caption_data['output_image_path'], format="WEBP", quality=20, method=6)    

    # clean up gpu memory after every image generation
    clean_cuda()
    clean_cuda()

In [19]:
# p_cross = random.uniform(0.0, 1.0)
# p_self = random.uniform(0.0, 1.0)

# c =   {
#     "input_caption": "Flooding Painting - Storm Malta by John or Giovanni Schranz",
#     "edit_instruction": "Add a sailboat in the distance",
#     "output_caption": "Flooding Painting - Storm Malta by John or Giovanni Schranz with a sailboat in the distance",
#     "id": 1
#   }
# c['surrogate_id'] = surrogate_id
# surrogate_id += 1
# generate_and_save_image(c, p_cross, p_self, 2.5)

In [20]:
clean_cuda()

## Read the GPT generated input/output prompt and create paired images

In [21]:
# Read JSON from file
import json
file_path = '../gpt-prompt-merged-pilot.json'
with open(file_path, 'r', encoding='utf-8') as file:
    all_captions = json.load(file)

In [22]:
all_captions[127]

{'input_caption': "Elisabeth Sonrel, Les Rameaux (Palm Sunday, 1897). Mireille Mosler, Master Drawings New York, 26 January-2 February. Estimate–$100,000-$125,000: The Art Nouveau painter and illustrator Elisabeth Sonrel (1874-1953) was among the few women selected to hang her paintings in Paris's Salon in the 19th century. Sonrel, who studied at the École des Beaux-Arts, desired to be known as more than a mere illustrator of stationery, posters and books, and in 1897 her wish was granted when Les Rameaux (Palm Sunday) was included in the Salon. This watercolour, of the same composition, will be shown for the first time in the US by Mireille Mosler during Master Drawings New York",
 'edit_instruction': 'Replace the image with Les Rameaux (Palm Sunday) by Elisabeth Sonrel.',
 'output_caption': 'Les Rameaux (Palm Sunday) by Elisabeth Sonrel exhibited at Master Drawings New York',
 'id': 127}

In [23]:
guidance_values = [3.2, 5.5, 7.5]
surrogate_id = 757
all_captions = all_captions[254:]

In [24]:
print(all_captions[0])
print(surrogate_id)

{'input_caption': 'House shed dogs pond nature wallpapers house shed dogs for Wallpaper home photos', 'edit_instruction': 'Remove the dogs from the image', 'output_caption': 'House shed pond nature wallpapers house for Wallpaper home photos', 'id': 254}
757


In [25]:
for caption in all_captions:
    for gv in guidance_values:
        p_cross = random.uniform(0.0, 1.0)
        p_self = random.uniform(0.0, 1.0)
        caption['surrogate_id'] = surrogate_id
        surrogate_id = surrogate_id+1
        # strip input_caption
        caption['input_caption'] = caption['input_caption'][:75]
        caption['output_caption'] = caption['output_caption'][:75]
        generate_and_save_image(caption, p_cross, p_self, gv)

  batch_size,  model.unet.in_channels, height // 8, width // 8).to(model.device)


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
clean_cuda()

In [41]:
## Merge data gathered from all different sessions of local and colab

In [40]:
'''
I have run different sessions in colab, in local, and on different set of prompts
This file is for merging all the values into one json file
'''

import json
import os
from PIL import Image


file_list = [
    '../instruct_dataset/metadata.json'
]

# Load and collect all entries
all_data = []
for file_path in file_list:
    with open(file_path, "r") as f:
        data = json.load(f)
        for record in data:
            base_name1 = os.path.basename(record['input_image_path'])
            record['input_image_path'] = f'./input_images/{base_name1}'
            base_name2= os.path.basename(record['output_image_path'])
            record['output_image_path'] = f'./output_images/{base_name2}'
            inp_img_path = f'../instruct_dataset/input_images/{record["surrogate_id"]}.webp'
            out_img_path = f'../instruct_dataset/output_images/{record["surrogate_id"]}.webp'
            inp_img = Image.open(inp_img_path)
            out_img = Image.open(out_img_path)

            if 'recorded_sim' not in record:
                # calculate recorded sim
                recorded_sim = respects_filter(record['input_caption'][:75], inp_img, record['output_caption'][:75], out_img)
                record['recorded_sim'] = recorded_sim

            all_data.append(record)
            if len(all_data) % 500 == 0:
                print("1000 done")

# Save to a single metadata.json file
with open("metadata1.json", "w") as out_f:
    json.dump(all_data, out_f, indent=2)

print("Saved combined metadata to metadata.json")


1000 done
1000 done
1000 done
1000 done
1000 done
1000 done
1000 done
1000 done
1000 done
1000 done
1000 done
Saved combined metadata to metadata.json


In [43]:
import math
new_data = []
with open('metadata1.json', "r") as f:
    data = json.load(f)
    for record in data:
        if 'recorded_sim' in record:
            ds = record['recorded_sim']['directional_sim']
            if math.isnan(ds) or ds <= 0.05:
                continue
            new_data.append(record)

with open("metadata2.json", "w") as out_f:
    json.dump(all_data, out_f, indent=2)
         

In [44]:
len(new_data)

5256

In [42]:
len(all_data)

5936