In [None]:
!pip install diffusers["torch"] transformers accelerate
!pip install git+https://github.com/huggingface/diffusers
!pip install einops
!pip install bayesian-optimization

In [None]:
!unzip VOC2012.zip
!unzip Cityscape.zip
!unzip Kvasir-SEG.zip
!unzip Vaihingen.zip

In [2]:
from typing import Optional, Union, Tuple, List, Callable, Dict
import torch
from diffusers import StableDiffusionXLPipeline, DDIMScheduler, AutoencoderKL, DiffusionPipeline
import numpy as np
import abc

LOW_RESOURCE = False
NUM_DIFFUSION_STEPS = 50
GUIDANCE_SCALE = 7.5
MAX_NUM_WORDS = 77

# code for store attention
class AttentionControl(abc.ABC):

    def step_callback(self, x_t):
        return x_t

    def between_steps(self):
        return

    @property
    def num_uncond_att_layers(self):
        return self.num_att_layers if LOW_RESOURCE else 0

    @abc.abstractmethod
    def forward (self, attn, is_cross: bool, place_in_unet: str):
        raise NotImplementedError

    def __call__(self, attn, is_cross: bool, place_in_unet: str):
        if self.cur_att_layer >= self.num_uncond_att_layers:
            if LOW_RESOURCE:
                attn = self.forward(attn, is_cross, place_in_unet)
            else:
                h = attn.shape[0]
                attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
        self.cur_att_layer += 1
        if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
            self.cur_att_layer = 0
            self.cur_step += 1
            self.between_steps()
        return attn

    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0

    def __init__(self):
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0

class EmptyControl(AttentionControl):

    def forward (self, attn, is_cross: bool, place_in_unet: str):
        return attn


class AttentionStore(AttentionControl):

    @staticmethod
    def get_empty_store():
        return {"down_cross": [], "mid_cross": [], "up_cross": [],
                "down_self": [],  "mid_self": [],  "up_self": []}

    def forward(self, attn, is_cross: bool, place_in_unet: str):
        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
        if attn.shape[1] <= 64 ** 2:  # avoid memory overhead
            self.step_store[key].append(attn)
        return attn

    def between_steps(self):
        if len(self.attention_store) == 0:
            self.attention_store = self.step_store
        else:
            for key in self.attention_store:
                for i in range(len(self.attention_store[key])):
                    self.attention_store[key][i] += self.step_store[key][i]
        self.step_store = self.get_empty_store()

    def get_average_attention(self):
        average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
        return average_attention

    def reset(self):
        super(AttentionStore, self).reset()
        self.step_store = self.get_empty_store()
        self.attention_store = {}

    def __init__(self):
        super(AttentionStore, self).__init__()
        self.step_store = self.get_empty_store()
        self.attention_store = {}

In [3]:
from PIL import Image
import cv2

## Visualization code utils
def view_images(images, num_rows=1, offset_ratio=0.02):
    if type(images) is list:
        num_empty = len(images) % num_rows
    elif images.ndim == 4:
        num_empty = images.shape[0] % num_rows
    else:
        images = [images]
        num_empty = 0

    empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
    images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
    num_items = len(images)

    h, w, c = images[0].shape
    offset = int(h * offset_ratio)
    num_cols = num_items // num_rows
    image_ = np.ones((h * num_rows + offset * (num_rows - 1),
                      w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
    for i in range(num_rows):
        for j in range(num_cols):
            image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
                i * num_cols + j]

    pil_img = Image.fromarray(image_)
    display(pil_img)


def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
    h, w, c = image.shape
    offset = int(h * .2)
    img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
    font = cv2.FONT_HERSHEY_SIMPLEX
    # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
    img[:h] = image
    textsize = cv2.getTextSize(text, font, 1, 2)[0]
    text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
    cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
    return img

In [4]:
# code for aggregaring attention
@torch.no_grad()
def aggregate_all_attention(prompts, attention_store: AttentionStore, from_where: List[str], is_cross: bool, select: int):
    attention_maps = attention_store.get_average_attention()
    att_8 = []
    att_16 = []
    att_32 = []
    att_64 = []
    for location in from_where:
        for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
            if item.shape[1] == 8*8:
                cross_maps = item.reshape(len(prompts), -1, 8, 8, item.shape[-1])[select]
                att_8.append(cross_maps)
            if item.shape[1] == 16*16:
                cross_maps = item.reshape(len(prompts), -1, 16, 16, item.shape[-1])[select]
                att_16.append(cross_maps)
            if item.shape[1] == 32*32:
                cross_maps = item.reshape(len(prompts), -1, 32, 32, item.shape[-1])[select]
                att_32.append(cross_maps)
            if item.shape[1] == 64*64:
                cross_maps = item.reshape(len(prompts), -1, 64, 64, item.shape[-1])[select]
                att_64.append(cross_maps)

    # print(len(att_8), len(att_16), len(att_32), len(att_64)) # base: 0, 60, 10, 0, refiner: 4, 20, 20, 0
    atts = []
    for att in [att_8, att_16, att_32, att_64]:
        if len(att) == 0:
          continue
        att = torch.cat(att, dim=0)
        att = att.sum(0) / att.shape[0]
        atts.append(att.cpu())
    return atts


@torch.no_grad()
def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
    out = []
    attention_maps = attention_store.get_average_attention()
    num_pixels = res ** 2
    for location in from_where:
        for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
            if item.shape[1] == num_pixels:
                cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
                out.append(cross_maps)
    out = torch.cat(out, dim=0)
    out = out.sum(0) / out.shape[0]
    return out.cpu()


# visualize cross att
def show_cross_attention(prompts, tokenizer, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
    tokens = tokenizer.encode(prompts[select])
    decoder = tokenizer.decode
    attention_maps = aggregate_attention(prompts, attention_store, res, from_where, True, select)
    images = []
    j = 0
    for i in range(len(tokens)):
        image = attention_maps[:, :, i]
        image = 255 * image / image.max()
        image = image.unsqueeze(-1).expand(*image.shape, 3)
        image = image.float().numpy().astype(np.uint8)
        image = np.array(Image.fromarray(image).resize((256, 256)))
        if decoder(int(tokens[j])) == "++":
            j += 1
        image = text_under_image(image, decoder(int(tokens[j])))
        images.append(image)
        j+=1
        if j >= len(tokens):
            break
    view_images(np.stack(images, axis=0))


# visualize self att
def show_self_attention_comp(prompts, attention_store: AttentionStore, res: int, from_where: List[str],
                        max_com=10, select: int = 0):
    attention_maps = aggregate_attention(prompts, attention_store, res, from_where, False, select).float().numpy().reshape((res ** 2, res ** 2))
    u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
    images = []
    for i in range(max_com):
        image = vh[i].reshape(res, res)
        image = image - image.min()
        image = 255 * image / image.max()
        image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
        image = Image.fromarray(image).resize((256, 256))
        image = np.array(image)
        images.append(image)
    view_images(np.concatenate(images, axis=1))

In [5]:
def encode_imgs(imgs, vae):
    # imgs: [B, 3, H, W]
    imgs = 2 * imgs - 1
    tmp = vae.encode(imgs)
    posterior = vae.encode(imgs).latent_dist.mean
    latents = posterior * 0.18215
    return latents


# fix random seed
def same_seeds(seed):
    torch.manual_seed(seed)  # 固定随机种子（CPU）
    if torch.cuda.is_available():  # 固定随机种子（GPU)
        torch.cuda.manual_seed(seed)  # 为当前GPU设置
        torch.cuda.manual_seed_all(seed)  # 为所有GPU设置
    np.random.seed(seed)  # 保证后续使用random函数时，产生固定的随机数
    torch.backends.cudnn.benchmark = False  # GPU、网络结构固定，可设置为True
    torch.backends.cudnn.deterministic = True  # 固定网络结构


# cam visual_code
def show_cam_on_image(img, mask):
    mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(img.size[1],img.size[0]), mode='bilinear', align_corners=False).squeeze().squeeze()
    img = np.float32(img) / 255.
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + img
    cam = cam / np.max(cam)
    cam = np.uint8(255 * cam)
    return cam

In [6]:
## ptp utils function
from einops import rearrange

def init_latent(latent, model, height, width, generator, batch_size):
    if latent is None:
        print('enter here')
        latent = torch.randn(
            (1, model.unet.config.in_channels, height // 8, width // 8),
            generator=generator,
        )
    latents = latent.expand(batch_size,  model.unet.in_channels, height // 8, width // 8).to(model.device)
    return latent, latents

@torch.no_grad()
def register_attention_control(model, controller):
    def ca_forward(self, place_in_unet):
        to_out = self.to_out
        if type(to_out) is torch.nn.modules.container.ModuleList:
            to_out = self.to_out[0]
        else:
            to_out = self.to_out

        def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
            # hidden_states is nan (becasue the vae of SDXL does not support f16 precision)
            x = hidden_states
            context = encoder_hidden_states
            mask = attention_mask
            batch_size = len(x)
            h = self.heads
            # x torch.Size([2, 1024, 640])
            # to_q weight torch.Size([640, 640])
            # k torch.Size([2, 1024, 640])
            # v torch.Size([2, 1024, 640])
            # print('x', x.shape) # torch.Size([2, 1024, 640])
            # print('to_q weight', self.to_q.weight.shape) # torch.Size([640, 640])
            q = self.to_q(x)
            # print('q', q.shape)
            is_cross = context is not None
            context = context if is_cross else x
            k = self.to_k(context)
            # print('k', k.shape) # torch.Size([2, 1024, 640])
            v = self.to_v(context)
            # print('v', v.shape) # torch.Size([2, 1024, 640])
            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
            sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale

            if mask is not None:
                mask = mask.reshape(batch_size, -1)
                max_neg_value = -torch.finfo(sim.dtype).max
                mask = mask[:, None, :].repeat(h, 1, 1)
                sim.masked_fill_(~mask, max_neg_value)

            # attention, what we cannot get enough of
            attn = sim.softmax(dim=-1)
            ## controller foward function saving the attention map in self.step_store
            attn = controller(attn, is_cross, place_in_unet)

            out = torch.einsum("b i j, b j d -> b i d", attn, v)
            out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
            return to_out(out)

        return forward

    class DummyController:

        def __call__(self, *args):
            return args[0]

        def __init__(self):
            self.num_att_layers = 0

    if controller is None:
        controller = DummyController()

    def register_recr(net_, count, place_in_unet):
        if net_.__class__.__name__ == 'Attention':
            net_.forward = ca_forward(net_, place_in_unet)
            return count + 1
        elif hasattr(net_, 'children'):
            for net__ in net_.children():
                count = register_recr(net__, count, place_in_unet)
        return count

    cross_att_count = 0
    sub_nets = model.unet.named_children()
    for net in sub_nets:
        if "down" in net[0]:
            cross_att_count += register_recr(net[1], 0, "down")
        elif "up" in net[0]:
            cross_att_count += register_recr(net[1], 0, "up")
        elif "mid" in net[0]:
            cross_att_count += register_recr(net[1], 0, "mid")

    controller.num_att_layers = cross_att_count


@torch.no_grad()
def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False, height=None, width=None, base=True):
    if low_resource:
        noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
        noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
    else:
        # 7. Prepare added time ids & embeddings
        # context = [prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds]

        (
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds
        ) = context

        add_text_embeds = pooled_prompt_embeds
        original_size = (height, width)
        crops_coords_top_left = (0, 0)
        target_size = (height, width)
        if base:
          add_time_ids = model._get_add_time_ids(
            original_size,
            crops_coords_top_left,
            target_size,
            dtype=context[0].dtype,
            text_encoder_projection_dim=model.text_encoder_2.config.projection_dim, # 1280
          )
          negative_add_time_ids = add_time_ids
        else:
          add_time_ids, add_neg_time_ids = model._get_add_time_ids(
            original_size,
            crops_coords_top_left,
            target_size,
            6.0, # aesthetic_score default
            2.5, # negative_aesthetic_score default
            original_size, # negative_original_size
            crops_coords_top_left, # negative_crops_coords_top_left,
            target_size, # negative_target_size,
            dtype=context[0].dtype,
            text_encoder_projection_dim=model.text_encoder_2.config.projection_dim,
          )

        # do_classifier_free_guidance
        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
        add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
        if base:
          add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
        else:
          add_neg_time_ids = add_neg_time_ids
          add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)

        prompt_embeds = prompt_embeds.to(model.device)
        add_text_embeds = add_text_embeds.to(model.device)
        add_time_ids = add_time_ids.to(model.device)

        latents_input = torch.cat([latents] * 2)
        added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}

        noise_pred = model.unet(
            latents_input,
            t,
            encoder_hidden_states=prompt_embeds,
            timestep_cond = None,
            added_cond_kwargs=added_cond_kwargs
        )["sample"]

    return noise_pred


## text to image custom pipeline
@torch.no_grad()
def text2image_ldm_stable(
    model,
    prompt: List[str],
    controller,
    num_inference_steps: int = 50,
    guidance_scale: float = 7.5,
    generator: Optional[torch.Generator] = None,
    latent: Optional[torch.FloatTensor] = None,
    low_resource: bool = False,
    noise_sample_num=1,
    height: int = 512,
    width: int = 512,
    base: bool = True
):
    # 1. check input:
    #    height and width must be divisible by 8
    #    prompt_embeds and negative prompt_embeds must have the same shape
    #    pooled_prompt_embed and prompt_embeds are generated with the same text encoder
    #    negative_pooled_prompt_embed and negative_prompt_embeds are generated with the same text encoder
    # 2. Define call parameters
    register_attention_control(model, controller)
    height = height or model.default_sample_size * model.vae_scale_factor
    width = width or model.default_sample_size * model.vae_scale_factor
    batch_size = len(prompt)

    # 3. Encode input prompt (refereence: encode_prompt function)

    tokenizers = [model.tokenizer, model.tokenizer_2] if model.tokenizer is not None else [model.tokenizer_2]
    text_encoders = [model.text_encoder, model.text_encoder_2] if model.text_encoder is not None else [model.text_encoder_2]
    prompts = [prompt, prompt]
    prompt_embeds_list = []
    max_length = model.tokenizer.model_max_length if model.tokenizer is not None else model.tokenizer_2.model_max_length
    for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
      text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=max_length,
        truncation=True,
        return_tensors="pt",
      )

      text_input_ids = text_inputs.input_ids
      untruncated_ids = tokenizer(prompt, padding='longest', return_tensors="pt").input_ids

      prompt_embeds = text_encoder(text_input_ids.to(model.device), output_hidden_states=True)
      pooled_prompt_embeds = prompt_embeds[0]
      # clip_skip is None
      prompt_embeds = prompt_embeds.hidden_states[-2]

      prompt_embeds_list.append(prompt_embeds)

    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)

    negative_prompt = [""] * batch_size
    uncond_tokens = [negative_prompt, negative_prompt]
    negative_prompt_embeds_list = []

    for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
      max_length = prompt_embeds.shape[1]
      uncond_input = tokenizer(
          negative_prompt,
          padding="max_length",
          max_length=max_length,
          truncation=True,
          return_tensors="pt"
      )

      negative_prompt_embeds = text_encoder(
          uncond_input.input_ids.to(model.device),
          output_hidden_states=True
      )

      negative_pooled_prompt_embeds = negative_prompt_embeds[0]
      negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

      negative_prompt_embeds_list.append(negative_prompt_embeds)

    negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

    prompt_embeds = prompt_embeds.to(dtype=model.text_encoder_2.dtype, device=model.device)
    negative_prompt_embeds = negative_prompt_embeds.to(dtype=model.text_encoder_2.dtype, device=model.device)

    context = [prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds]

    # 4. Prepare timesteps
    model.scheduler.set_timesteps(num_inference_steps)
    # 5. Prepare latent variables
    latent, latents = init_latent(latent, model, height, width, generator, batch_size)
    latents = latent.squeeze(1).to(model.device)

    latents = diffusion_step(model, controller, latents, context, num_inference_steps, guidance_scale, low_resource, height, width, base)

    return None, None

In [7]:
import torch.nn.functional as F

# main function
# Stable diffusion
def generate_att(t, ldm_stable, input_latent, noise, prompts, controller, pos, device, is_self=True, is_multi_self=False, is_cross_norm=True, weight=[0.8, 0.2], height=None, width=None, verbose=False, base=True):
    ## pos: position of the target class word int he prompt
    controller.reset()
    g_cpu = torch.Generator(4307)
    t = int(t)
    latents_noisy = ldm_stable.scheduler.add_noise(input_latent, noise, torch.tensor(t, device=device))
    images, x_t = text2image_ldm_stable(ldm_stable, prompts, controller, latent=latents_noisy, num_inference_steps=t, guidance_scale=GUIDANCE_SCALE, generator=g_cpu, low_resource=LOW_RESOURCE, height=height, width=width, base=base)
    layers = ("mid", "up", "down")
#     cross attention:  torch.Size([16, 16, 77])
#     cross attention:  torch.Size([32, 32, 77])
    cross_attention_maps = aggregate_all_attention(prompts, controller, layers, True, 0)
#     self attention:  torch.Size([16, 16, 256])
#     self attention:  torch.Size([32, 32, 1024])

#     self attention: torch.Size([8, 8, 64]) refiner
#     self attention: torch.Size([16, 16, 256]) refiner
#     self attention: torch.Size([32, 32, 1024]) refiner
    self_attention_maps = aggregate_all_attention(prompts, controller, ("up", "mid", "down"), False, 0)

    imgs = []
    ## res: resolution
    resolution_range = [16, 32] if base else [8, 16, 32]

    for idx, res in enumerate(resolution_range):
        out_att = cross_attention_maps[idx].permute(2,0,1).float()
        if is_cross_norm:
            att_max = torch.amax(out_att, dim=(1,2), keepdim=True)
            att_min = torch.amin(out_att, dim=(1,2), keepdim=True)
            out_att = (out_att - att_min) / (att_max - att_min)
        if is_multi_self:
            self_att = self_attention_maps[idx].view(res * res, res * res).float()
            self_att = self_att / self_att.max()
            out_att = torch.matmul(self_att.unsqueeze(0),out_att.view(-1,res*res,1)).view(-1,res,res)
        if res != resolution_range[-1]:
            out_att = F.interpolate(out_att.unsqueeze(0), size=(resolution_range[-1], resolution_range[-1]), mode='bilinear', align_corners=False).squeeze()
        ## 16*16: 0.8, 32*32: 0.2
        imgs.append(out_att * weight[idx])

    # aggregated cross attention map
    cross_att_map = torch.stack(imgs).sum(0)[pos].mean(0).view(resolution_range[-1]*resolution_range[-1], 1)
    # refine cross attention map with self attention map
    if is_self and not is_multi_self:
        self_att = self_attention_maps[-1].view(resolution_range[-1]*resolution_range[-1],resolution_range[-1]*resolution_range[-1]).float()
        self_att = self_att / self_att.max()
        for i in range(1):
            cross_att_map = torch.matmul(self_att, cross_att_map)
    # res here is the highest resulution iterated in previous for loop, 64
    att_map = cross_att_map.view(res, res)
    att_map = F.interpolate(att_map.unsqueeze(0).unsqueeze(0), size=(512,512), mode='bilinear', align_corners=False).squeeze().squeeze()
    att_map = (att_map - att_map.min()) / (att_map.max() - att_map.min())
    att_map = F.sigmoid(8 * (att_map - 0.4))
    att_map = (att_map - att_map.min()) / (att_map.max() - att_map.min())
    if verbose:
      att_map_map = Image.fromarray((att_map.cpu().detach().numpy() * 255).astype(np.uint8),mode="L")
      display(att_map_map)
      tokenizer = ldm_stable.tokenizer if ldm_stable.tokenizer is not None else ldm_stable.tokenizer_2
      for res in resolution_range:
        print("{}x{} cross att map".format(res, res))
        show_cross_attention(prompts, tokenizer, controller, res=res, from_where=layers)

    return att_map

In [8]:
from torchvision import transforms

def stable_diffusion_inference(img_path, cls_name, device, blip_device, processor, model, ldm_stable, verbose=False, weight=[0.8, 0.2], t=100, base=True, prompt=None):
  ## img_path: path to the target image
  ## cls name: taget class in the prompt
  ## device: device of stable diffusion model
  ## blip device: device of BLIP model
  ## processor: BLIP processot
  ## model: BLIP model
  ## vae: vae of the stable diffusion model
  with torch.no_grad():
    same_seeds(3407)

    input_img = Image.open(img_path).convert("RGB")

    trans = []
    trans.append(transforms.ToTensor())
    trans = transforms.Compose(trans)

    img_tensor = (trans(input_img).unsqueeze(0)).to(device)
    rgb_512 = F.interpolate(img_tensor, (512, 512), mode='bilinear', align_corners=False).bfloat16()

    vae = ldm_stable.vae
    input_latent = encode_imgs(rgb_512, vae)
    # print('input latent', input_latent)
    noise = torch.randn_like(input_latent).to(device)
    raw_image = input_img
    if prompt is None:
      text = f"a photograph of {cls_name}"
      inputs = processor(raw_image,text,return_tensors="pt").to(blip_device) # processor: Blip processor

      # use blip and "++" emphasizing semantic information of target categories
      out = model.generate(**inputs)
      texts = processor.decode(out[0], skip_special_tokens=True)
      texts = text +"++"+ texts[len(text):] # ", highly realistic, artsy, trending, colorful"
    else:
      texts = prompt

    # weight is the weight of different layer's cross attn
    # pos is the position of target class word in the sentence, in "a photograph of plane" (plane)'s position is 4
    # t is the denoising step, usually set between 50 to 150
    prompts = [texts]
    # print("**** blip_prompt: "+texts+"****")
    pos = [4]   # pos of targer class word
    controller = AttentionStore()
    height = 512
    width = 512
    mask = generate_att(t, ldm_stable, input_latent, noise, prompts, controller, pos, device, is_self=True, is_multi_self=False, is_cross_norm=True, weight=weight, height=height, width=width, verbose=verbose, base=base)
    mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(raw_image.size[1],raw_image.size[0]), mode='bilinear', align_corners=False).squeeze().squeeze()
    # print(mask.shape, raw_image.size)
    if verbose:
      cam = show_cam_on_image(raw_image, mask)
      print("visual_cam")
      display(Image.fromarray(cam[:,:,::-1]))
    del img_tensor
    del rgb_512
    del noise
    if prompt is None:
      del inputs
    torch.cuda.empty_cache()
    return mask, texts

In [9]:
import os
import shutil
import time

import json
import time, os, shutil

def domain_test(processor, model, ldm_stable, refiner, blip_device, device, images_dir, result_dir, label_map, augmented_label_file, base_weight, refiner_weight, base_t, refiner_t, alpha, augmented_label=False, thres_list=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]):
  start = time.time()
  if os.path.isdir(result_dir):
    shutil.rmtree(result_dir)
  os.mkdir(result_dir)
  os.mkdir(os.path.join(result_dir, 'mask'))
  for thres in thres_list:
    os.mkdir(os.path.join(result_dir, '{}'.format(thres)))
  cls_arr_dir = images_dir.replace("images", "class_array")
  augmented_label_path = images_dir.replace("images", augmented_label_file)
  # segmentations_dir = images_dir.replace("images", "segmentations")
  with open(augmented_label_path, 'r') as f:
    label_data = json.load(f)

  for img_file in os.listdir(images_dir):
    if not(img_file.endswith('.png') or img_file.endswith('.tif') or img_file.endswith('.jpg')):
      continue
    img_path = os.path.join(images_dir, img_file)
    # print(">>> ", img_path)
    seg_classes = label_data[img_path]

    for cls_name in seg_classes.keys():
      mask, prompt = stable_diffusion_inference(img_path, cls_name, device, blip_device, processor, model, ldm_stable, verbose=False, weight=base_weight, t=base_t, base=True)
      mask_refiner, prompt = stable_diffusion_inference(img_path, cls_name, device, blip_device, processor, model, refiner, verbose=False, weight=refiner_weight, t=refiner_t, base=False, prompt=prompt)
      mask = alpha * mask + (1 - alpha) * mask_refiner
      with open(os.path.join(result_dir, 'mask', '{}_{}.npy'.format(img_file.split('.')[0], cls_name)), 'wb') as f:
        np.save(f, mask)
      for mask_threshold in thres_list:
        mask_binary = np.where(mask > mask_threshold, 255, 0)
        mask_binary_img = Image.fromarray(mask_binary.astype(np.uint8))
        mask_binary_img.save(os.path.join(result_dir, '{}'.format(mask_threshold), '{}_{}.png'.format(img_file.split('.')[0], cls_name)))

      if augmented_label:
        for aug_cls_name in seg_classes[cls_name]:
          mask, prompt = stable_diffusion_inference(img_path, aug_cls_name, device, blip_device, processor, model, ldm_stable, verbose=False, weight=base_weight, t=base_t, base=True)
          mask_refiner, prompt = stable_diffusion_inference(img_path, aug_cls_name, device, blip_device, processor, model, refiner, verbose=False, weight=refiner_weight, t=refiner_t, base=False, prompt=prompt)
          with open(os.path.join(result_dir, 'mask', '{}_{}.npy'.format(img_file.split('.')[0], aug_cls_name)), 'wb') as f:
            np.save(f, mask)
          for mask_threshold in thres_list:
            mask_binary = np.where(mask > mask_threshold, 255, 0)
            mask_binary_img = Image.fromarray(mask_binary.astype(np.uint8))
            mask_binary_img.save(os.path.join(result_dir, '{}'.format(mask_threshold), '{}_{}.png'.format(img_file.split('.')[0], aug_cls_name)))

  print(">>>>>>>>>> test time: {:.2f}s".format(time.time() - start))


In [1]:
import shutil
import os
for dir in os.listdir('./'):
  if dir.startswith('results'):
    shutil.rmtree(dir)

In [10]:
from PIL import Image
import numpy as np
import os, time
from sklearn.metrics import f1_score, roc_curve, auc

def analysis(results_dir_list, segmentations_dir_list, augmented_label_file, thres_list=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]):
  start = time.time()

  iou_res_ = {}
  pixel_acc_res_ = {}
  f1_res_ = {}
  for thres in thres_list:
    iou_res_[thres] = []
    pixel_acc_res_[thres] = []
    f1_res_[thres] = []

  for results_dir, segmentations_dir in zip(results_dir_list, segmentations_dir_list):
    cls_arr_dir = segmentations_dir.replace("segmentations", "class_array")
    images_dir = segmentations_dir.replace("segmentations", "images")
    augmented_label_path = segmentations_dir.replace("segmentations", augmented_label_file)
    print('>>>> ', results_dir)
    with open(augmented_label_path, 'r') as f:
      label_data = json.load(f)

    for thres in thres_list:
      predict_root_dir = os.path.join(results_dir, '{}'.format(thres))
      if not os.path.isdir(predict_root_dir):
        continue
      iou_res = []
      pixel_acc_res = []
      f1_res = []

      for seg_file in os.listdir(segmentations_dir):
        if not(seg_file.endswith('.png') or seg_file.endswith('.tif') or seg_file.endswith('.jpg')):
          continue
        img_path = os.path.join(images_dir, seg_file)
        seg_classes = label_data[img_path]

        for cls_name in seg_classes.keys():
          all_classes = [cls_name] + seg_classes[cls_name]
          seg_cls_arr = np.load(os.path.join(cls_arr_dir, '{}_{}.npy'.format(seg_file.split('.')[0], cls_name)))
          iou = -1
          pixel_acc = -1
          f1 = -1
          for cls in all_classes:
            predict_path = os.path.join(predict_root_dir, '{}_{}.png'.format(seg_file.split('.')[0], cls))
            if os.path.isfile(predict_path):
              predict_img = Image.open(predict_path)
              predict_cls_arr = np.asarray(predict_img) / 255

              if predict_cls_arr.shape != seg_cls_arr.shape:
                print('>>>invalid prediction', predict_path, seg_cls_arr.shape, predict_cls_arr.shape)
                continue
              intersection = np.sum(predict_cls_arr * seg_cls_arr).astype(np.float32)
              union = np.sum(np.logical_or(predict_cls_arr, seg_cls_arr)).astype(np.float32)
              correct = np.sum(predict_cls_arr == seg_cls_arr).astype(np.float32)

              iou_ = intersection / union
              pixel_acc_ = correct / (seg_cls_arr.shape[0] * seg_cls_arr.shape[1])
              f1_ = f1_score(seg_cls_arr.flatten(), predict_cls_arr.flatten())

              if f1_ > f1:
                f1 = f1_
                pixel_acc = pixel_acc_
                iou = iou_
          iou_res.append(iou)
          pixel_acc_res.append(pixel_acc)
          f1_res.append(f1)

      iou_res_[thres] += iou_res
      pixel_acc_res_[thres] += pixel_acc_res
      f1_res_[thres] += f1_res

      f1_mean = np.array(f1_res).mean()
      iou_mean = np.array(iou_res).mean()
      pixel_acc_mean = np.array(pixel_acc_res).mean()
      print('>>>> thres: {}, dice: {:.4f}, iou: {:.4f}, pixel_acc: {:.4f}'.format(thres, f1_mean, iou_mean, pixel_acc_mean))

  iou_res_summary = []
  pixel_res_summary = []
  f1_res_summary = []
  for thres in thres_list:
    iou_res_summary.append(np.array(iou_res_[thres]).mean())
    pixel_res_summary.append(np.array(pixel_acc_res_[thres]).mean())
    f1_res_summary.append(np.array(f1_res_[thres]).mean())

  iou_res_summary = np.asarray(iou_res_summary)
  iou_auc = auc(np.asarray(thres_list), iou_res_summary)
  iou_optim = iou_res_summary.max()
  iou_auc_over_optim = iou_auc / iou_optim

  pixel_res_summary = np.asarray(pixel_res_summary)
  pixel_auc = auc(np.asarray(thres_list), pixel_res_summary)
  pixel_optim = pixel_res_summary.max()
  pixel_auc_over_optim = pixel_auc / pixel_optim

  f1_res_summary = np.asarray(f1_res_summary)
  f1_auc = auc(np.asarray(thres_list), f1_res_summary)
  f1_optim = f1_res_summary.max()
  f1_auc_over_optim = f1_auc / f1_optim

  print('>>> dice AUC: {:.4f}, dic optimum: {:.4f}, dice AUC/optim: {:.4f}'.format(f1_auc, f1_optim, f1_auc_over_optim))
  print('>>> iou AUC: {:.4f}, iou optimum: {:.4f}, iou AUC/optim: {:.4f}'.format(iou_auc, iou_optim, iou_auc_over_optim))
  print('>>> pixel_acc AUC: {:.4f}, pixel_acc optimum: {:.4f}, pixel_acc AUC/optim: {:.4f}'.format(pixel_auc, pixel_optim, pixel_auc_over_optim))
  print('>>> analysis time: {:.2f}s'.format(time.time() - start))
  return f1_auc, f1_optim, iou_auc, iou_optim, pixel_auc, pixel_optim



In [None]:
from bayes_opt import BayesianOptimization
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration

def black_box_function(base_t=100, refiner_t=0, base_map_weight1=0.8, base_map_weight2=0.2, refiner_map_weight1=0.3, refiner_map_weight2=0.5, refiner_map_weight3=0.2, alpha=0.5):
  VOC_label_map = {
    1:'aeroplane',
    2:'bicycle',
    3:'bird',
    4:'boat',
    5:'bottle',
    6:'bus',
    7:'car',
    8:'cat',
    9:'chair',
    10:'cow',
    11:'diningtable',
    12:'dog',
    13:'horse',
    14:'motorbike',
    15:'person',
    16:'pottedplant',
    17:'sheep',
    18:'sofa',
    19:'train',
    20:'tvmonitor'
  }

  Cityscape_label_map = {
    1: 'road', # flat
    2: 'person', # human
    3: 'building', # construction
    4: 'traffic light', # object
    5: 'vegetation', # nature
    6: 'car', # vehicle
    7: 'bus', # vehicle
    8: 'train', # vehicle
    9: 'motorcycle', # vehicle
    10: 'bicycle', #vehicle
  }

  Vaihingen_label_map = {
    1: 'building'
  }

  Kvasir_label_map = {
    1: 'tumor'
  }

  device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
  model_key = "stabilityai/stable-diffusion-xl-base-1.0"
  refiner_key = "stabilityai/stable-diffusion-xl-refiner-1.0"
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.bfloat16)
  ldm_stable = StableDiffusionXLPipeline.from_pretrained(model_key, vae=vae, torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True).to(device)

  ldm_stable.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler",
                                                     beta_start=0.00085, beta_end=0.012,
                                                     steps_offset=1)

  refiner = DiffusionPipeline.from_pretrained(
    refiner_key,
    text_encoder_2=ldm_stable.text_encoder_2,
    vae=ldm_stable.vae,
    torch_dtype=torch.bfloat16,
    use_safetensors=True,
    variant="fp16",
  ).to(device)
  refiner.scheduler = DDIMScheduler.from_pretrained(refiner_key, subfolder="scheduler",
                                                     beta_start=0.00085, beta_end=0.012,
                                                     steps_offset=1)

  blip_device = "cuda:0"
  # blip device
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(blip_device)

  datasets = ["VOC2012", "Cityscape", "Vaihingen", "Kvasir-SEG"]
  label_maps = [VOC_label_map, Cityscape_label_map, Vaihingen_label_map, Kvasir_label_map]
  base_map_weight_sum = base_map_weight1 + base_map_weight2
  base_weight = [base_map_weight1 / base_map_weight_sum, base_map_weight2 / base_map_weight_sum]
  refiner_map_weight_sum = refiner_map_weight1 + refiner_map_weight2 + refiner_map_weight3
  refiner_weight = [refiner_map_weight1 / refiner_map_weight_sum, refiner_map_weight2 / refiner_map_weight_sum, refiner_map_weight3 / refiner_map_weight_sum]
  root_dir = "results_0.9_{}_{}_{:.2f}_{:.2f}_{:.2f}_{:.2f}_{:.2f}_{:.2f}".format(int(base_t), int(refiner_t), base_weight[0], base_weight[1], refiner_weight[0], refiner_weight[1], refiner_weight[2], alpha)

  augmented_label = True
  thres_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
  if os.path.isdir(root_dir):
    shutil.rmtree(root_dir)
  os.mkdir(root_dir)
  augmented_label_file = 'aug_label_blip_bert_0.9.json'
  for ds, label_map in zip(datasets, label_maps):
    images_dir = os.path.join(ds, "images")
    result_dir = os.path.join(root_dir, ds)
    domain_test(processor, model, ldm_stable, refiner, blip_device, device, images_dir, result_dir, label_map, augmented_label_file, base_weight, refiner_weight, base_t, refiner_t, alpha, augmented_label=augmented_label, thres_list=thres_list)

  results_dir_list = [os.path.join(root_dir, ds) for ds in datasets]
  segmentations_dir_list = [os.path.join(ds, "segmentations") for ds in datasets]
  f1_auc, f1_optim, iou_auc, iou_optim, pixel_auc, pixel_optim = analysis(results_dir_list, segmentations_dir_list, augmented_label_file, thres_list=thres_list)

  del ldm_stable
  del refiner
  del model
  del processor
  torch.cuda.empty_cache()

  return f1_auc


def parameter_tuning():
  pbounds = {'base_t': (90, 110),
             'refiner_t': (90, 110),
             'base_map_weight1': (0.01, 0.99),
             'base_map_weight2': (0.01, 0.99),
             'refiner_map_weight1': (0.01, 0.99),
             'refiner_map_weight2': (0.01, 0.99),
             'refiner_map_weight3': (0.01, 0.99),
             'alpha': (0.01, 0.99)
            }
  optimizer = BayesianOptimization(
    f=black_box_function,
    pbounds=pbounds,
    verbose=2, # verbose = 1 prints only when a maximum is observed, verbose = 0 is silent
    random_state=1,
  )
  optimizer.maximize(
    init_points=0,
    n_iter= 14,
  )

parameter_tuning()

|   iter    |  target   |   alpha   | base_m... | base_m... |  base_t   | refine... | refine... | refine... | refiner_t |
-------------------------------------------------------------------------------------------------------------------------


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

  latents = latent.expand(batch_size,  model.unet.in_channels, height // 8, width // 8).to(model.device)


>>>>>>>>>> test time: 203.64s
>>>>>>>>>> test time: 89.26s
>>>>>>>>>> test time: 17.38s
>>>>>>>>>> test time: 30.36s
>>>>  results_0.9_96_96_0.99_0.01_0.34_0.22_0.43_0.42/VOC2012
>>>> thres: 0.1, dice: 0.3827, iou: 0.2751, pixel_acc: 0.4420
>>>> thres: 0.2, dice: 0.4426, iou: 0.3305, pixel_acc: 0.5675
>>>> thres: 0.3, dice: 0.4909, iou: 0.3776, pixel_acc: 0.6460
>>>> thres: 0.4, dice: 0.5289, iou: 0.4139, pixel_acc: 0.7004
>>>> thres: 0.5, dice: 0.5611, iou: 0.4461, pixel_acc: 0.7418
>>>> thres: 0.6, dice: 0.5863, iou: 0.4727, pixel_acc: 0.7786
>>>> thres: 0.7, dice: 0.6011, iou: 0.4880, pixel_acc: 0.8043
>>>> thres: 0.8, dice: 0.5996, iou: 0.4872, pixel_acc: 0.8250
>>>> thres: 0.9, dice: 0.5537, iou: 0.4382, pixel_acc: 0.8373
>>>>  results_0.9_96_96_0.99_0.01_0.34_0.22_0.43_0.42/Cityscape
>>>> thres: 0.1, dice: 0.2228, iou: 0.1428, pixel_acc: 0.2125
>>>> thres: 0.2, dice: 0.2379, iou: 0.1545, pixel_acc: 0.2958
>>>> thres: 0.3, dice: 0.2536, iou: 0.1678, pixel_acc: 0.3659
>>>> thres: 0

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

  latents = latent.expand(batch_size,  model.unet.in_channels, height // 8, width // 8).to(model.device)


>>>>>>>>>> test time: 202.03s
>>>>>>>>>> test time: 89.97s
>>>>>>>>>> test time: 17.50s
>>>>>>>>>> test time: 30.34s
>>>>  results_0.9_107_93_0.32_0.68_0.86_0.09_0.05_0.96/VOC2012
>>>> thres: 0.1, dice: 0.3636, iou: 0.2579, pixel_acc: 0.3836
>>>> thres: 0.2, dice: 0.4106, iou: 0.2981, pixel_acc: 0.5082
>>>> thres: 0.3, dice: 0.4509, iou: 0.3345, pixel_acc: 0.5795
>>>> thres: 0.4, dice: 0.4817, iou: 0.3646, pixel_acc: 0.6328
>>>> thres: 0.5, dice: 0.5059, iou: 0.3870, pixel_acc: 0.6778
>>>> thres: 0.6, dice: 0.5230, iou: 0.4028, pixel_acc: 0.7158
>>>> thres: 0.7, dice: 0.5348, iou: 0.4122, pixel_acc: 0.7511
>>>> thres: 0.8, dice: 0.5368, iou: 0.4128, pixel_acc: 0.7829
>>>> thres: 0.9, dice: 0.4832, iou: 0.3595, pixel_acc: 0.8036
>>>>  results_0.9_107_93_0.32_0.68_0.86_0.09_0.05_0.96/Cityscape
>>>> thres: 0.1, dice: 0.2201, iou: 0.1405, pixel_acc: 0.2246
>>>> thres: 0.2, dice: 0.2346, iou: 0.1522, pixel_acc: 0.3284
>>>> thres: 0.3, dice: 0.2448, iou: 0.1617, pixel_acc: 0.4195
>>>> thres:

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

  latents = latent.expand(batch_size,  model.unet.in_channels, height // 8, width // 8).to(model.device)


>>>>>>>>>> test time: 203.14s
>>>>>>>>>> test time: 89.95s
>>>>>>>>>> test time: 17.53s
>>>>>>>>>> test time: 30.42s
>>>>  results_0.9_98_108_0.33_0.67_0.44_0.30_0.25_0.80/VOC2012
>>>> thres: 0.1, dice: 0.3632, iou: 0.2577, pixel_acc: 0.3807
>>>> thres: 0.2, dice: 0.4105, iou: 0.2992, pixel_acc: 0.5050
>>>> thres: 0.3, dice: 0.4511, iou: 0.3360, pixel_acc: 0.5809
>>>> thres: 0.4, dice: 0.4850, iou: 0.3689, pixel_acc: 0.6367
>>>> thres: 0.5, dice: 0.5127, iou: 0.3951, pixel_acc: 0.6841
>>>> thres: 0.6, dice: 0.5336, iou: 0.4143, pixel_acc: 0.7263
>>>> thres: 0.7, dice: 0.5506, iou: 0.4293, pixel_acc: 0.7657
>>>> thres: 0.8, dice: 0.5556, iou: 0.4344, pixel_acc: 0.7952
>>>> thres: 0.9, dice: 0.5073, iou: 0.3849, pixel_acc: 0.8170
>>>>  results_0.9_98_108_0.33_0.67_0.44_0.30_0.25_0.80/Cityscape
>>>> thres: 0.1, dice: 0.2195, iou: 0.1401, pixel_acc: 0.2031
>>>> thres: 0.2, dice: 0.2317, iou: 0.1497, pixel_acc: 0.2840
>>>> thres: 0.3, dice: 0.2433, iou: 0.1595, pixel_acc: 0.3815
>>>> thres: