In [1]:
!pip install -r requirements.txt 

In [1]:
from tqdm.auto import tqdm
import os
from typing import List

import torch
from PIL import Image
import pandas as pd

from pipeline import Pipeline
from utils import AttentionStore, register_attention_control

In [2]:
sd_version = "CompVis/stable-diffusion-v1-4"
token_indices = [2, 5]
seed = 42
output_path = "./results"
prompts_address = "../our prompts/prompts/prompts.csv"
n_inference_steps = 50
guidance_scale = 7.5
max_iter_to_alter = 30
attention_res = 16
apply_loss = True
do_analysis = False
analysis_steps = list(range(0, 51, 1))
scale_factor = 20
scale_range = (1.0, 0.5)
save_cross_attention_maps = True

In [3]:
def load_model(sd_version):
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    stable = Pipeline.from_pretrained(sd_version).to(device)
    return stable

def run_on_prompt(prompt: List[str],
                  model: Pipeline,
                  controller: AttentionStore,
                  generator: torch.Generator,
                  seed: int) -> Image.Image:
    if controller is not None:
        register_attention_control(model, controller)
    outputs, analysis_dict = model(prompt=prompt,
                                   attention_store=controller,
                                   indices_to_alter=token_indices,
                                   attention_res=attention_res,
                                   guidance_scale=guidance_scale,
                                   generator=generator,
                                   seed=seed,
                                   num_inference_steps=n_inference_steps,
                                   max_iter_to_alter=max_iter_to_alter,
                                   apply_loss=apply_loss,
                                   do_analysis=do_analysis,
                                   analysis_steps=analysis_steps,
                                   scale_factor=scale_factor,
                                   scale_range=scale_range,
                                   save_cross_attention_maps=save_cross_attention_maps,
                                   output_path=output_path)
    image = outputs.images[0]
    return image, analysis_dict

def save_image(image, output_path, prompt, seed):
    os.makedirs(f"{output_path}/{prompt}/{int(seed)}/", exist_ok=True)
    image.save(f"{output_path}/{prompt}/{int(seed)}/image.png")

def save_analysis(analysis_dict, output_path, prompt, seed):
    os.makedirs(f"{output_path}/{prompt}/{int(seed)}/", exist_ok=True)
    analysis_dict["step"] = analysis_steps
    analysis_df = pd.DataFrame(data=analysis_dict)
    analysis_df.to_csv(f"{output_path}/{prompt}/{int(seed)}/analysis.csv", index=False)

In [None]:
sd_model = load_model(sd_version)

In [None]:
if not os.path.exists(f'{output_path}/'):
    os.mkdir(f'{output_path}/')

prompts = pd.read_csv(f"{prompts_address}")["prompt"].tolist()
prompts = sorted(prompts)
n_max = 1000000
for index, prompt in tqdm(enumerate(prompts)):
    if index  >= n_max:
        break
    print(f"{index+1}: {prompt}")
    g = torch.Generator('cuda').manual_seed(seed)
    controller = AttentionStore()
    image, analysis_dict = run_on_prompt(prompt=prompt,
                                         model=sd_model,
                                         controller=controller,
                                         generator=g,
                                         seed=seed)
    save_image(image=image, output_path=output_path, prompt=prompt, seed=seed)
    if do_analysis:
        save_analysis(analysis_dict, output_path, prompt, seed)
