In [1]:
from typing import List,Union
import torch
from torchvision import transforms
from PIL import Image
import os
import pyrallis
from  util.model import DINOHead

# from safetensors import safe_open
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    # StableDiffusionPipeline,
    UNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
import torch
from config import RunConfig
from util import ptp_utils
from util.pipeline_attend_and_excite import AttendAndExcitePipeline
from util.ptp_utils import AttentionStore
from util.pipeline_attend_and_excite import TextaulStableDiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer,CLIPImageProcessor

In [2]:
root = "/opt/data/private/stable_diffusion_model"
textual_name = "textual_inversion_find_new_style_13"
placeholder_token = '<style>'
nums_token = 768
#"runwayml/stable-diffusion-v1-5"
#stabilityai/stable-diffusion-2-1-base
#stabilityai/stable-diffusion-2-1
#stabilityai/stable-diffusion-xl-base-1.0
model_id = "stabilityai/stable-diffusion-2-1"
moldel_root = f'{root}/{textual_name}/model_5000.pt'
state_dict = torch.load(moldel_root)['model']
backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16',map_location='cpu')
head = DINOHead(in_dim=768, out_dim=1024,cls_dim=6, nlayers=3)
image_model = torch.nn.Sequential(backbone, head)
image_model.load_state_dict(state_dict)
name =''
for x in model_id.split('/'):
    name += x+'_'
name =name+f"pca_{nums_token}.pt"
save_path = os.path.join("token_dict", name)
token_dict = torch.load(save_path)
pipe = TextaulStableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16)
pipe.unet.enable_xformers_memory_efficient_attention()
setattr(pipe, 'image_model', None)
setattr(pipe, 'token_dict', None)
pipe.image_model = image_model
pipe.token_dict = token_dict
pipe.to('cuda:0')

Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

TextaulStableDiffusionPipeline {
  "_class_name": "TextaulStableDiffusionPipeline",
  "_diffusers_version": "0.20.2",
  "_name_or_path": "stabilityai/stable-diffusion-2-1",
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "requires_safety_checker": false,
  "safety_checker": [
    null,
    null
  ],
  "scheduler": [
    "diffusers",
    "DDIMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

In [8]:
image_size = 224
crop_pct = 0.875
interpolation = 3
trans = transforms.Compose([
    transforms.Resize(int(image_size / crop_pct),interpolation=interpolation),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),])
image_path = 'datasets/sub_style_dataset/Baroque/adriaen-van-de-velde_agony-in-the-garden-1665.jpg'
referenceImage = Image.open(image_path)
referenceImage = referenceImage.convert("RGB")
referenceImage = trans(referenceImage)
referenceImage = referenceImage.unsqueeze(0)
referenceImage = referenceImage.cuda()
image_model.cuda() 
with torch.no_grad():
    weight,logits = image_model(referenceImage)
print(weight)
print(logits)

tensor([[-0.0330, -0.1685,  0.0798,  ...,  0.0277,  0.3194, -0.0250]],
       device='cuda:0')
tensor([[ 14.2816,  12.3414,  -8.9154,   2.7792,   7.8765, -13.5120]],
       device='cuda:0')


In [3]:
def run_on_prompt(prompt:Union[str, List[str]],
                  model: TextaulStableDiffusionPipeline,
                  controller: AttentionStore,
                  token_indices: List[int],
                  seed: torch.Generator,
                  config: RunConfig,
                  referenceImages: List[torch.Tensor] = None,
                  referenceLocattion: List[int] = None,) -> Image.Image:
    if controller is not None:
        ptp_utils.register_attention_control(model, controller)
    if isinstance(model,AttendAndExcitePipeline):
        outputs = model(prompt=prompt,
                        attention_store=controller,
                        indices_to_alter=token_indices,
                        attention_res=config.attention_res,
                        guidance_scale=config.guidance_scale,
                        generator=seed,
                        num_inference_steps=config.n_inference_steps,
                        max_iter_to_alter=config.max_iter_to_alter,
                        run_standard_sd=config.run_standard_sd,
                        thresholds=config.thresholds,
                        scale_factor=config.scale_factor,
                        scale_range=config.scale_range,
                        smooth_attentions=config.smooth_attentions,
                        sigma=config.sigma,
                        kernel_size=config.kernel_size,
                        sd_2_1=config.sd_2_1)
    else:
        outputs = model(prompt=prompt,
                        referenceImages = referenceImages,
                        referenceLocattion = referenceLocattion,
                        num_inference_steps = config.n_inference_steps,
                        guidance_scale = config.guidance_scale,
                        generator=seed)
    image = outputs.images[0]
    return image

@pyrallis.wrap()
def main(config: RunConfig):
    prompts = [
            "a photo of a dog in the style of S",
            # "a photo of the <style>_1,4k",
            # "a photo of the <style>_2,4k",
            # "a photo of the <style>_3,4k",
            # "a photo of the <style>_4,4k",
            # "a photo of the <style>_5,4k",
            # "a photo of the <style>_6,4k",
            # "a photo of the <style>_7,4k",
            # "a photo of the <style>_8,4k",
            # "a photo of the <style>_9,4k",
            # "a photo of the <object>_9 is on the grass",
            ] 
    # objiect_name = ["panda","cat","dog","anemone fish","hen","bee eater","box turtle","African elephant","rat","lion"]
    # for name in objiect_name:
    #     prompts.append(f"A {name} is drinking water,4k")   
    image_size = 224
    crop_pct = 0.875
    interpolation = 3
    trans = transforms.Compose([
        transforms.Resize(int(image_size / crop_pct),interpolation=interpolation),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),])
    image_path = 'datasets/sub_style_dataset/Baroque/adriaen-van-de-velde_agony-in-the-garden-1665.jpg'
    referenceImage = Image.open(image_path)
    referenceImage = referenceImage.convert("RGB")
    image_name = image_path.split('/')[-1].split('.')[0]
    referenceImage = trans(referenceImage)
    # image_model.cuda()
    # referenceImage = referenceImage.unsqueeze(0).cuda()
    # weight,logit = image_model(referenceImage)
    referenceImage = [referenceImage]
    token_indices = [9]
    # referenceImage =[]
    # token_indices = []
    controller = AttentionStore()
    for prompt in prompts:
        os.makedirs(f"./images/{textual_name}/{prompt}/{image_name}",exist_ok=True)
        for seed in range(0,5):
            g = torch.Generator('cuda').manual_seed(seed)
            image = run_on_prompt(prompt=prompt,
                                model=pipe,
                                controller=controller,
                                token_indices=token_indices,
                                seed=g,
                                config=config,
                                referenceImages = referenceImage,
                                referenceLocattion = token_indices)
            image.save(f"./images/{textual_name}/{prompt}/{image_name}/{seed}.png")


In [4]:
main()

usage: ipykernel_launcher.py [-h] [--config_path str] [--prompt str]
                             [--sd_2_1 str] [--token_indices str]
                             [--seeds str] [--n_inference_steps str]
                             [--guidance_scale str] [--max_iter_to_alter str]
                             [--attention_res str] [--run_standard_sd str]
                             [--thresholds str] [--scale_factor str]
                             [--scale_range str] [--smooth_attentions str]
                             [--sigma str] [--kernel_size str]
                             [--save_cross_attention_maps str]
ipykernel_launcher.py: error: unrecognized arguments: --f=/root/.local/share/jupyter/runtime/kernel-v2-6024DpgIWwROocUR.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
