In [1]:
!pip install diffusers["torch"] transformers accelerate
!pip install git+https://github.com/huggingface/diffusers
!pip install einops
!pip install bayesian-optimization

Collecting diffusers[torch]
  Downloading diffusers-0.28.1-py3-none-any.whl (2.2 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/2.2 MB[0m [31m30.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m37.2 MB/s[0m eta [36m0:00:00[0m
Collecting accelerate
  Downloading accelerate-0.30.1-py3-none-any.whl (302 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.6/302.6 kB[0m [31m34.2 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.4->diffusers[torch])
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.4->diffusers[torch])
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Coll

In [1]:
import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [3]:
# !scp drive/MyDrive/dataset/VOC2012.zip ./
# !scp drive/MyDrive/dataset/Cityscape.zip ./
# !scp drive/MyDrive/dataset/Kvasir-SEG.zip ./
# !scp drive/MyDrive/dataset/Vaihingen.zip ./

!unzip VOC2012.zip
!unzip Cityscape.zip
!unzip Kvasir-SEG.zip
!unzip Vaihingen.zip

Archive:  VOC2012.zip
   creating: VOC2012/
  inflating: VOC2012/aug_label_blip_bert_0.9.json  
  inflating: VOC2012/aug_label_blip_bert_0.85.json  
   creating: VOC2012/images/
  inflating: VOC2012/images/877.png  
  inflating: VOC2012/images/483.png  
  inflating: VOC2012/images/182.png  
  inflating: VOC2012/images/458.png  
  inflating: VOC2012/images/168.png  
  inflating: VOC2012/images/286.png  
  inflating: VOC2012/images/294.png  
  inflating: VOC2012/images/554.png  
  inflating: VOC2012/images/54.png   
  inflating: VOC2012/images/456.png  
  inflating: VOC2012/images/57.png   
  inflating: VOC2012/images/319.png  
  inflating: VOC2012/images/779.png  
  inflating: VOC2012/images/214.png  
  inflating: VOC2012/images/82.png   
  inflating: VOC2012/images/434.png  
  inflating: VOC2012/images/561.png  
  inflating: VOC2012/images/141.png  
  inflating: VOC2012/images/78.png   
  inflating: VOC2012/images/17.png   
  inflating: VOC2012/images/5.png    
  inflating: VOC2012/ima

In [2]:
from typing import Optional, Union, Tuple, List, Callable, Dict
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler, DiffusionPipeline
import numpy as np
import abc

LOW_RESOURCE = False
NUM_DIFFUSION_STEPS = 50
GUIDANCE_SCALE = 7.5
MAX_NUM_WORDS = 77

# code for store attention
class AttentionControl(abc.ABC):

    def step_callback(self, x_t):
        return x_t

    def between_steps(self):
        return

    @property
    def num_uncond_att_layers(self):
        return self.num_att_layers if LOW_RESOURCE else 0

    @abc.abstractmethod
    def forward (self, attn, is_cross: bool, place_in_unet: str):
        raise NotImplementedError

    def __call__(self, attn, is_cross: bool, place_in_unet: str):
        if self.cur_att_layer >= self.num_uncond_att_layers:
            if LOW_RESOURCE:
                attn = self.forward(attn, is_cross, place_in_unet)
            else:
                h = attn.shape[0]
                attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
        self.cur_att_layer += 1
        if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
            self.cur_att_layer = 0
            self.cur_step += 1
            self.between_steps()
        return attn

    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0

    def __init__(self):
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0

class EmptyControl(AttentionControl):

    def forward (self, attn, is_cross: bool, place_in_unet: str):
        return attn


class AttentionStore(AttentionControl):

    @staticmethod
    def get_empty_store():
        return {"down_cross": [], "mid_cross": [], "up_cross": [],
                "down_self": [],  "mid_self": [],  "up_self": []}

    @torch.no_grad()
    def forward(self, attn, is_cross: bool, place_in_unet: str):
        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
        if attn.shape[1] <= 64 ** 2:  # avoid memory overhead
            self.step_store[key].append(attn)
        return attn

    def between_steps(self):
        if len(self.attention_store) == 0:
            self.attention_store = self.step_store
        else:
            for key in self.attention_store:
                for i in range(len(self.attention_store[key])):
                    self.attention_store[key][i] += self.step_store[key][i]
        self.step_store = self.get_empty_store()

    @torch.no_grad()
    def get_average_attention(self):
        average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
        return average_attention

    def reset(self):
        super(AttentionStore, self).reset()
        self.step_store = self.get_empty_store()
        self.attention_store = {}

    def __init__(self):
        super(AttentionStore, self).__init__()
        self.step_store = self.get_empty_store()
        self.attention_store = {}

  deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)


In [3]:
from PIL import Image
import cv2

## Visualization code utils
def view_images(images, palette, num_rows=1, offset_ratio=0.02):
    if type(images) is list:
        num_empty = len(images) % num_rows
    elif images.ndim == 4:
        num_empty = images.shape[0] % num_rows
    else:
        images = [images]
        num_empty = 0

    empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
    images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
    num_items = len(images)

    h, w, c = images[0].shape
    offset = int(h * offset_ratio)
    num_cols = num_items // num_rows
    image_ = np.ones((h * num_rows + offset * (num_rows - 1),
                      w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
    for i in range(num_rows):
        for j in range(num_cols):
            image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
                i * num_cols + j]

    pil_img = Image.fromarray(image_)
    palette_image = pil_img.convert('P')
    palette_image.putpalette(palette)
    display(palette_image)


def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
    h, w, c = image.shape
    offset = int(h * .2)
    img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
    font = cv2.FONT_HERSHEY_SIMPLEX
    img[:h] = image
    textsize = cv2.getTextSize(text, font, 1, 2)[0]
    text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
    cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
    return img


In [4]:
# code for aggregaring attention
@torch.no_grad()
def aggregate_all_attention(prompts, attention_store: AttentionStore, from_where: List[str], is_cross: bool, select: int):
    attention_maps = attention_store.get_average_attention()
    att_8 = []
    att_16 = []
    att_32 = []
    att_64 = []
    for location in from_where:
        for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
            if item.shape[1] == 8*8:
                cross_maps = item.reshape(len(prompts), -1, 8, 8, item.shape[-1])[select]
                att_8.append(cross_maps)
            if item.shape[1] == 16*16:
                cross_maps = item.reshape(len(prompts), -1, 16, 16, item.shape[-1])[select]
                att_16.append(cross_maps)
            if item.shape[1] == 32*32:
                cross_maps = item.reshape(len(prompts), -1, 32, 32, item.shape[-1])[select]
                att_32.append(cross_maps)
            if item.shape[1] == 64*64:
                cross_maps = item.reshape(len(prompts), -1, 64, 64, item.shape[-1])[select]
                att_64.append(cross_maps)
    atts = []
    # print(len(att_8), len(att_16), len(att_32), len(att_64)) # 1, 5, 5, 5, both LDM and SD
    for att in [att_8, att_16, att_32, att_64]:
        att = torch.cat(att, dim=0)
        att = att.sum(0) / att.shape[0]
        atts.append(att.cpu())
    del attention_maps
    torch.cuda.empty_cache()
    return atts


@torch.no_grad()
def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
    out = []
    attention_maps = attention_store.get_average_attention()
    num_pixels = res ** 2
    for location in from_where:
        for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
            if item.shape[1] == num_pixels:
                cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
                out.append(cross_maps)
    out = torch.cat(out, dim=0)
    out = out.sum(0) / out.shape[0]
    return out.cpu()


# visualize cross att
def show_cross_attention(prompts, tokenizer, attention_store: AttentionStore, palette, res: int, from_where: List[str], select: int = 0, cls_name=''):
    tokens = tokenizer.encode(prompts[select])
    decoder = tokenizer.decode
    attention_maps = aggregate_attention(prompts, attention_store, res, from_where, True, select)
    images = []
    j = 0
    for i in range(len(tokens)):
        image = attention_maps[:, :, i]
        image = 255 * image / image.max()
        image = image.unsqueeze(-1).expand(*image.shape, 3)
        # image = image.numpy().astype(np.uint8)
        image = image.float().numpy().astype(np.uint8)
        image = np.array(Image.fromarray(image).resize((256, 256)))
        # pil_img = Image.fromarray(image)
        # palette_image = pil_img.convert('P')
        # palette_image.putpalette(palette)
        # palette_image.save('{}_{}_{}_{}.pdf'.format(decoder(int(tokens[j])), j, res, cls_name))
        if decoder(int(tokens[j])) == "++":
            j += 1
        image = text_under_image(image, decoder(int(tokens[j])))
        images.append(image)
        j+=1
        if j >= len(tokens):
            break
    view_images(np.stack(images, axis=0), palette)


# visualize self att
def show_self_attention_comp(prompts, attention_store: AttentionStore, res: int, from_where: List[str],
                        max_com=10, select: int = 0):
    attention_maps = aggregate_attention(prompts, attention_store, res, from_where, False, select).float().numpy().reshape((res ** 2, res ** 2))
    u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
    images = []
    for i in range(max_com):
        image = vh[i].reshape(res, res)
        image = image - image.min()
        image = 255 * image / image.max()
        image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
        image = Image.fromarray(image).resize((256, 256))
        image = np.array(image)
        images.append(image)
    view_images(np.concatenate(images, axis=1))

In [5]:
def encode_imgs(imgs, vae):
    # imgs: [B, 3, H, W]
    imgs = 2 * imgs - 1
    posterior = vae.encode(imgs).latent_dist.mean
    latents = posterior * 0.18215
    return latents


# fix random seed
def same_seeds(seed):
    torch.manual_seed(seed)  # 固定随机种子（CPU）
    if torch.cuda.is_available():  # 固定随机种子（GPU)
        torch.cuda.manual_seed(seed)  # 为当前GPU设置
        torch.cuda.manual_seed_all(seed)  # 为所有GPU设置
    np.random.seed(seed)  # 保证后续使用random函数时，产生固定的随机数
    torch.backends.cudnn.benchmark = False  # GPU、网络结构固定，可设置为True
    torch.backends.cudnn.deterministic = True  # 固定网络结构


# cam visual_code
def show_cam_on_image(img, mask):
    mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(img.size[1],img.size[0]), mode='bilinear', align_corners=False).squeeze().squeeze()
    img = np.float32(img) / 255.
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + img
    cam = cam / np.max(cam)
    cam = np.uint8(255 * cam)
    return cam

In [6]:
## ptp utils function
from einops import rearrange

def init_latent(latent, model, height, width, generator, batch_size):
    if latent is None:
        latent = torch.randn(
            (1, model.unet.config.in_channels, height // 8, width // 8),
            generator=generator,
        )
    latents = latent.expand(batch_size,  model.unet.config.in_channels, height // 8, width // 8).to(model.device)
    return latent, latents


@torch.no_grad()
def register_attention_control(model, controller):
    def ca_forward(self, place_in_unet):
        to_out = self.to_out
        if type(to_out) is torch.nn.modules.container.ModuleList:
            to_out = self.to_out[0]
        else:
            to_out = self.to_out

        def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
            x = hidden_states
            context = encoder_hidden_states
            mask = attention_mask
            batch_size = len(x)
            h = self.heads
            q = self.to_q(x)
            is_cross = context is not None
            context = context if is_cross else x
            k = self.to_k(context)
            v = self.to_v(context)
            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
            sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale

            if mask is not None:
                mask = mask.reshape(batch_size, -1)
                max_neg_value = -torch.finfo(sim.dtype).max
                mask = mask[:, None, :].repeat(h, 1, 1)
                sim.masked_fill_(~mask, max_neg_value)

            # attention, what we cannot get enough of
            attn = sim.softmax(dim=-1)
            ## controller foward function saving the attention map in self.step_store
            attn = controller(attn, is_cross, place_in_unet)

            out = torch.einsum("b i j, b j d -> b i d", attn, v)
            out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
            return to_out(out)

        return forward

    class DummyController:

        def __call__(self, *args):
            return args[0]

        def __init__(self):
            self.num_att_layers = 0

    if controller is None:
        controller = DummyController()

    def register_recr(net_, count, place_in_unet):
        if net_.__class__.__name__ == 'Attention':
            net_.forward = ca_forward(net_, place_in_unet)
            return count + 1
        elif hasattr(net_, 'children'):
            for net__ in net_.children():
                count = register_recr(net__, count, place_in_unet)
        return count

    cross_att_count = 0
    sub_nets = model.unet.named_children()
    for net in sub_nets:
        if "down" in net[0]:
            cross_att_count += register_recr(net[1], 0, "down")
        elif "up" in net[0]:
            cross_att_count += register_recr(net[1], 0, "up")
        elif "mid" in net[0]:
            cross_att_count += register_recr(net[1], 0, "mid")

    controller.num_att_layers = cross_att_count


@torch.no_grad()
def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False, height=None, width=None):
    if low_resource:
        noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
        noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
    else:
        latents_input = torch.cat([latents] * 2)
        context = torch.cat(context)
        added_cond_kwargs = {}
        noise_pred = model.unet(latents_input, t, encoder_hidden_states=context, added_cond_kwargs=added_cond_kwargs)["sample"]

    return latents


## text to image custom pipeline
@torch.no_grad()
def text2image_ldm_stable(
    model,
    prompt: List[str],
    controller,
    num_inference_steps: int = 50,
    guidance_scale: float = 7.5,
    generator: Optional[torch.Generator] = None,
    latent: Optional[torch.FloatTensor] = None,
    low_resource: bool = False,
    noise_sample_num=1,
    height: int = None,
    width: int = None
):
    register_attention_control(model, controller)
    height = 512 if height is None else height
    width = 512 if width is None else width
    batch_size = len(prompt)

    text_input = model.tokenizer(
        prompt,
        padding="max_length",
        max_length=model.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    model.unet.to(model.device)
    if hasattr(model, 'text_encoder'):
      text_input_ids = text_input.input_ids
      prompt_embeds = model.text_encoder(text_input_ids.to(model.device), attention_mask=None)
      text_embeddings = prompt_embeds[0]

      uncond_input = model.tokenizer(
        [""] * batch_size,
        padding="max_length",
        max_length=model.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
      )
      negative_prompt_embeds = model.text_encoder(
          uncond_input.input_ids.to(model.device),
          attention_mask=None,
      )
      uncond_embeddings = negative_prompt_embeds[0]
    else:
      text_input = model.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
      text_embeddings = model.bert(text_input.input_ids.to(model.device))[0]
      uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
      uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0]

    context = [uncond_embeddings, text_embeddings]
    latent, latents = init_latent(latent, model, height, width, generator, batch_size)
    latents = latent.squeeze(1).to(model.device)

    ## Sets the discrete timesteps used for the diffusion chain (to be run before inference).
    model.scheduler.set_timesteps(num_inference_steps)
    latents = diffusion_step(model, controller, latents, context, num_inference_steps, guidance_scale, low_resource, height, width)

    return None, None


In [7]:
import torch.nn.functional as F
import matplotlib.pyplot as plt
# main function
# Stable diffusion

def create_palette(cmap_name, num_colors=256):
    cmap = plt.get_cmap(cmap_name)
    palette = []
    for i in range(num_colors):
        r, g, b, _ = cmap(i / num_colors)
        palette.extend((int(r * 255), int(g * 255), int(b * 255)))
    return palette

def generate_att(t, ldm_stable, input_latent, noise, prompts, controller, pos_positions, device,
                 is_self=True, is_multi_self=False, is_cross_norm=True, weight=[0.3,0.5,0.1,0.1],
                 height=None, width=None, verbose=False, alpha=8, beta=0.4, neg_positions=[], neg_weight=1.0, cls_name=''):
    ## pos: position of the target class word int he prompt
    controller.reset()
    g_cpu = torch.Generator(4307)
    # print(t)
    # print(t.dtype)
    t = int(t)
    latents_noisy = ldm_stable.scheduler.add_noise(input_latent, noise, torch.tensor(t, device=device))
    images, x_t = text2image_ldm_stable(ldm_stable, prompts, controller, latent=latents_noisy, num_inference_steps=t, guidance_scale=GUIDANCE_SCALE, generator=g_cpu, low_resource=LOW_RESOURCE, height=height, width=width)
    layers = ("mid", "up", "down")
#     cross attention:  torch.Size([8, 8, 77])
#     cross attention:  torch.Size([16, 16, 77])
#     cross attention:  torch.Size([32, 32, 77])
#     cross attention:  torch.Size([64, 64, 77])
    cross_attention_maps = aggregate_all_attention(prompts, controller, layers, True, 0)
    # print(cross_attention_maps)
#     self attention:  torch.Size([8, 8, 64])
#     self attention:  torch.Size([16, 16, 256])
#     self attention:  torch.Size([32, 32, 1024])
#     self attention:  torch.Size([64, 64, 4096])
    self_attention_maps = aggregate_all_attention(prompts, controller, ("up", "mid", "down"), False, 0)

    imgs = []
    ## res: resolution
    for idx,res in enumerate([8, 16, 32, 64]):
        out_att = cross_attention_maps[idx].permute(2,0,1).float()
        if is_cross_norm:
            att_max = torch.amax(out_att,dim=(1,2),keepdim=True)
            att_min = torch.amin(out_att,dim=(1,2),keepdim=True)
            out_att = (out_att-att_min)/(att_max-att_min)
        if is_multi_self:
            self_att = self_attention_maps[idx].view(res*res,res*res).float()
            self_att = self_att/self_att.max()
            out_att = torch.matmul(self_att.unsqueeze(0),out_att.view(-1,res*res,1)).view(-1,res,res)
        if res != 64:
            out_att = F.interpolate(out_att.unsqueeze(0), size=(64,64), mode='bilinear', align_corners=False).squeeze()
        ## 8*8: 0.3, 16*16: 0.5, 32*32: 0.1, 64*64: 0.1
        imgs.append(out_att * weight[idx])

    # aggregated cross attention map
    cross_att_map = torch.stack(imgs).sum(0)[pos_positions[0]].mean(0).view(64*64, 1)
    for pos in pos_positions[1:]:
      cross_att_map += torch.stack(imgs).sum(0)[pos].mean(0).view(64*64, 1)
    if len(pos_positions) > 1:
      cross_att_map /= len(pos_positions)

    if len(neg_positions) > 0:
      cross_att_map_neg = torch.zeros_like(cross_att_map)
      for pos in neg_positions:
        cross_att_map_neg += torch.stack(imgs).sum(0)[pos].mean(0).view(64*64, 1)
      cross_att_map_neg /= len(neg_positions)
      cross_att_map -= neg_weight * cross_att_map_neg

    # refine cross attention map with self attention map
    if is_self and not is_multi_self:
        self_att = self_attention_maps[3].view(64*64,64*64).float()
        self_att = self_att/self_att.max()
        for i in range(1):
            cross_att_map = torch.matmul(self_att, cross_att_map)
    # res here is the highest resulution iterated in previous for loop, 64
    att_map = cross_att_map.view(res,res)
    att_map = F.interpolate(att_map.unsqueeze(0).unsqueeze(0), size=(512,512), mode='bilinear', align_corners=False).squeeze().squeeze()
    att_map = (att_map - att_map.min())/(att_map.max() - att_map.min())
    # print('before scaling by alpha and beta')
    # print(att_map)
    # print(att_map.mean())
    att_map = F.sigmoid(alpha * (att_map - beta))
    # att_map = F.sigmoid(alpha * (att_map - att_map.mean()))
    # print('after scaling by alpha and beta')
    # print(att_map)
    # print(att_map.mean())
    att_map = (att_map-att_map.min())/(att_map.max()-att_map.min())

    del cross_att_map
    torch.cuda.empty_cache()

    if verbose:
      att_map_map = Image.fromarray((att_map.cpu().detach().numpy()*255).astype(np.uint8), mode="L")
      palette_image = att_map_map.convert('P')
      # palette = []
      # for i in range(256):
      #   # Interpolate between dark blue (0, 0, 139) and yellow (255, 255, 0)
      #   r = int(i * 255 / 255)
      #   g = int(i * 255 / 255)
      #   b = int(139 + (255 - 139) * i / 255)
      #   palette.extend((r, g, b))
      palette = create_palette('viridis')
      # Apply the custom palette
      # palette_image.putpalette(palette)

      # display(att_map_map)
      # display(palette_image)
      # palette_image.save('final_map_{}.pdf'.format(cls_name))
      tokenizer = ldm_stable.tokenizer
      print("8x8 cross att map")
      show_cross_attention(prompts, tokenizer, controller, palette, res=8, from_where=layers, cls_name=cls_name)
      print("16x16 cross att map")
      show_cross_attention(prompts, tokenizer, controller, palette, res=16, from_where=layers, cls_name=cls_name)
      print("32x32 cross att map")
      show_cross_attention(prompts, tokenizer, controller, palette, res=32, from_where=layers, cls_name=cls_name)
      print("64x64 cross att map")
      show_cross_attention(prompts, tokenizer, controller, palette, res=64, from_where=layers, cls_name=cls_name)

    return att_map

In [8]:
from torchvision import transforms

def stable_diffusion_inference(img_path, cls_name, device, blip_device, processor, model, ldm_stable, verbose=False, weight=[0.3,0.5,0.1,0.1], t=100, alpha=8, beta=0.4, seed=3407, negative_token=False):
  ## img_path: path to the target image
  ## cls name: taget class in the prompt
  ## device: device of stable diffusion model
  ## blip device: device of BLIP model
  ## processor: BLIP processot
  ## model: BLIP model
  ## vae: vae of the stable diffusion model
  with torch.no_grad():
    same_seeds(seed)

    input_img = Image.open(img_path).convert("RGB")

    trans = []
    trans.append(transforms.ToTensor())
    trans = transforms.Compose(trans)

    img_tensor = (trans(input_img).unsqueeze(0)).to(device)

    rgb_512 = F.interpolate(img_tensor, (512, 512), mode='bilinear', align_corners=False).bfloat16()

    if hasattr(ldm_stable, 'vae'):
      vae = ldm_stable.vae
    else:
      vae = ldm_stable.vqvae
    input_latent = encode_imgs(rgb_512, vae)
    noise = torch.randn_like(input_latent).to(device)
    raw_image = input_img

    text = f"a photograph of {cls_name}"
    inputs = processor(raw_image, text, return_tensors="pt").to(blip_device) # processor: Blip processor

    # use blip and "++" emphasizing semantic information of target categories
    out = model.generate(**inputs)
    texts = processor.decode(out[0], skip_special_tokens=True)
    texts = text +"++"+ texts[len(text):] # ", highly realistic, artsy, trending, colorful"

    # weight is the weight of different layer's cross attn
    # pos is the position of target class word in the sentence, in "a photograph of plane" (plane)'s position is 4
    # t is the denoising step, usually set between 50 to 150
    prompts = [texts]
    # print("**** blip_prompt: "+texts+"****")
    token_ids = ldm_stable.tokenizer.encode(texts)
    tokens = [ldm_stable.tokenizer.decode(int(_)) for _ in token_ids]
    tagged_tokens = nltk.tag.pos_tag(tokens)
    pos_positions = []   # pos of targer class word
    neg_positions = []
    pos_start = False
    neg_start_pos = -1
    for i, (word, tag) in enumerate(tagged_tokens):
      if word == 'of':
        pos_start = True
        continue
      if word == '++':
        neg_start_pos = i + 1
        break
      if pos_start:
        pos_positions.append([i])
    if negative_token:
      for i, (word, tag) in enumerate(tagged_tokens[neg_start_pos:-1]):
        if tag.startswith('N'):
          if word != cls_name:
            # print(word)
            neg_positions.append([i + neg_start_pos])

    # print(pos_positions)
    # print(neg_positions)
    controller = AttentionStore()

    height = 512
    width = 512
    mask = generate_att(t, ldm_stable, input_latent, noise, prompts, controller, pos_positions, device,
                        is_self=True, is_multi_self=False, is_cross_norm=True, weight=weight, height=height, width=width,
                        verbose=verbose, alpha=alpha, beta=beta, neg_positions=neg_positions, cls_name=cls_name)
    mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(raw_image.size[1],raw_image.size[0]), mode='bilinear', align_corners=False).squeeze().squeeze()
    # print(mask.shape, raw_image.size)
    if verbose:
      cam = show_cam_on_image(raw_image, mask)
      print("visual_cam")
      pil_img = Image.fromarray(cam[:,:,::-1])
      display(pil_img)
      pil_img.save('cam.pdf')
    del img_tensor
    del noise
    del inputs
    torch.cuda.empty_cache()
    return mask

In [9]:
import json
import time, os, shutil

def domain_test(processor, model, ldm_stable, blip_device, device, images_dir, result_dir, label_map, augmented_label_file, augmented_label=False, thres_list=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], weight=[0.3,0.5,0.1,0.1], t = 100, alpha=8, beta=0.4, seed=3407, negative_token=False):
  start = time.time()
  if os.path.isdir(result_dir):
    shutil.rmtree(result_dir)
  os.mkdir(result_dir)
  os.mkdir(os.path.join(result_dir, 'mask'))
  for thres in thres_list:
    os.mkdir(os.path.join(result_dir, '{}'.format(thres)))
  cls_arr_dir = images_dir.replace("images", "class_array")
  augmented_label_path = images_dir.replace("images", augmented_label_file)

  with open(augmented_label_path, 'r') as f:
    label_data = json.load(f)

  size = 0
  print('>>> seed: {}'.format(seed))
  for img_file in os.listdir(images_dir):
    # if img_file != '710.png':
    #   continue
    if not(img_file.endswith('.png') or img_file.endswith('.tif') or img_file.endswith('.jpg')):
      continue
    img_path = os.path.join(images_dir, img_file)
    img = Image.open(img_path)
    # img.save('710.pdf')
    size += 1
    # print(">>> ", img_path)
    seg_classes = label_data[img_path]

    for cls_name in seg_classes.keys():
      mask = stable_diffusion_inference(img_path, cls_name, device, blip_device, processor, model, ldm_stable, verbose=False, weight=weight, t=t, alpha=alpha, beta=beta, seed=seed, negative_token=negative_token)
      with open(os.path.join(result_dir, 'mask', '{}_{}.npy'.format(img_file.split('.')[0], cls_name)), 'wb') as f:
        np.save(f, mask)
      for mask_threshold in thres_list:
        mask_binary = np.where(mask > mask_threshold, 255, 0)
        mask_binary_img = Image.fromarray(mask_binary.astype(np.uint8))
        mask_binary_img.save(os.path.join(result_dir, '{}'.format(mask_threshold), '{}_{}.png'.format(img_file.split('.')[0], cls_name)))

      if augmented_label:
        for aug_cls_name in seg_classes[cls_name]:
          mask = stable_diffusion_inference(img_path, aug_cls_name, device, blip_device, processor, model, ldm_stable, verbose=False, weight=weight, t=t, alpha=alpha, beta=beta, seed=seed, negative_token=negative_token)
          with open(os.path.join(result_dir, 'mask', '{}_{}.npy'.format(img_file.split('.')[0], aug_cls_name)), 'wb') as f:
            np.save(f, mask)
          for mask_threshold in thres_list:
            mask_binary = np.where(mask > mask_threshold, 255, 0)
            mask_binary_img = Image.fromarray(mask_binary.astype(np.uint8))
            mask_binary_img.save(os.path.join(result_dir, '{}'.format(mask_threshold), '{}_{}.png'.format(img_file.split('.')[0], aug_cls_name)))
    # gt_path = os.path.join(images_dir.replace('images', 'segmentations'), img_file)
    # gt = Image.open(gt_path)
    # display(gt)
  ds_name = images_dir.split('/')[-2]
  print(">>>>>>>>>> dataset: {}, size: {}, test time: {:.2f}s".format(ds_name, size, time.time() - start))


In [10]:
from PIL import Image
import numpy as np
import os, time
from sklearn.metrics import f1_score, roc_curve, auc
from numba import cuda

def analysis(results_dir_list, segmentations_dir_list, augmented_label_file, thres_list=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]):
  start = time.time()

  iou_res_ = {}
  pixel_acc_res_ = {}
  f1_res_ = {}
  for thres in thres_list:
    iou_res_[thres] = []
    pixel_acc_res_[thres] = []
    f1_res_[thres] = []

  for results_dir, segmentations_dir in zip(results_dir_list, segmentations_dir_list):
    cls_arr_dir = segmentations_dir.replace("segmentations", "class_array")
    images_dir = segmentations_dir.replace("segmentations", "images")
    augmented_label_path = segmentations_dir.replace("segmentations", augmented_label_file)
    print('>>>> ', results_dir)
    with open(augmented_label_path, 'r') as f:
      label_data = json.load(f)

    iou_domain = []
    pixel_acc_domain = []
    f1_domain = []

    for thres in thres_list:
      predict_root_dir = os.path.join(results_dir, '{}'.format(thres))
      if not os.path.isdir(predict_root_dir):
        continue
      iou_res = []
      pixel_acc_res = []
      f1_res = []

      for seg_file in os.listdir(segmentations_dir):
        if not(seg_file.endswith('.png') or seg_file.endswith('.tif') or seg_file.endswith('.jpg')):
          continue
        img_path = os.path.join(images_dir, seg_file)
        seg_classes = label_data[img_path]

        for cls_name in seg_classes.keys():
          all_classes = [cls_name] + seg_classes[cls_name]
          seg_cls_arr = np.load(os.path.join(cls_arr_dir, '{}_{}.npy'.format(seg_file.split('.')[0], cls_name)))
          iou = -1
          pixel_acc = -1
          f1 = -1
          for cls in all_classes:
            predict_path = os.path.join(predict_root_dir, '{}_{}.png'.format(seg_file.split('.')[0], cls))
            if os.path.isfile(predict_path):
              predict_img = Image.open(predict_path)
              predict_cls_arr = np.asarray(predict_img) / 255

              if predict_cls_arr.shape != seg_cls_arr.shape:
                print('>>>invalid prediction', predict_path, seg_cls_arr.shape, predict_cls_arr.shape)
                continue
              intersection = np.sum(predict_cls_arr * seg_cls_arr).astype(np.float32)
              union = np.sum(np.logical_or(predict_cls_arr, seg_cls_arr)).astype(np.float32)
              correct = np.sum(predict_cls_arr == seg_cls_arr).astype(np.float32)

              iou_ = intersection / union
              pixel_acc_ = correct / (seg_cls_arr.shape[0] * seg_cls_arr.shape[1])
              f1_ = f1_score(seg_cls_arr.flatten(), predict_cls_arr.flatten())

              if f1_ > f1:
                f1 = f1_
                pixel_acc = pixel_acc_
                iou = iou_

          iou_res.append(iou)
          pixel_acc_res.append(pixel_acc)
          f1_res.append(f1)

      iou_res_[thres] += iou_res
      pixel_acc_res_[thres] += pixel_acc_res
      f1_res_[thres] += f1_res

      f1_mean = np.array(f1_res).mean()
      iou_mean = np.array(iou_res).mean()
      pixel_acc_mean = np.array(pixel_acc_res).mean()
      iou_domain.append(iou_mean)
      pixel_acc_domain.append(pixel_acc_mean)
      f1_domain.append(f1_mean)
      print('>>>> thres: {}, dice: {:.4f}, iou: {:.4f}, pixel_acc: {:.4f}'.format(thres, f1_mean, iou_mean, pixel_acc_mean))

    iou_domain_auc = auc(np.asarray(thres_list), np.asarray(iou_domain))
    pixel_acc_domain_auc = auc(np.asarray(thres_list), np.asarray(pixel_acc_domain))
    f1_domain_auc = auc(np.asarray(thres_list), np.asarray(f1_domain))
    print('>>> dice AUC: {:.4f}, iou AUC: {:.4f}, pixel_acc AUC: {:.4f}'.format(f1_domain_auc, iou_domain_auc, pixel_acc_domain_auc))

  iou_res_summary = []
  pixel_res_summary = []
  f1_res_summary = []
  for thres in thres_list:
    iou_res_summary.append(np.array(iou_res_[thres]).mean())
    pixel_res_summary.append(np.array(pixel_acc_res_[thres]).mean())
    f1_res_summary.append(np.array(f1_res_[thres]).mean())

  iou_res_summary = np.asarray(iou_res_summary)
  iou_auc = auc(np.asarray(thres_list), iou_res_summary)
  iou_optim = iou_res_summary.max()
  iou_auc_over_optim = iou_auc / iou_optim

  pixel_res_summary = np.asarray(pixel_res_summary)
  pixel_auc = auc(np.asarray(thres_list), pixel_res_summary)
  pixel_optim = pixel_res_summary.max()
  pixel_auc_over_optim = pixel_auc / pixel_optim

  f1_res_summary = np.asarray(f1_res_summary)
  f1_auc = auc(np.asarray(thres_list), f1_res_summary)
  f1_optim = f1_res_summary.max()
  f1_auc_over_optim = f1_auc / f1_optim

  print('>>> dice AUC: {:.4f}, dic optimum: {:.4f}, dice AUC/optim: {:.4f}'.format(f1_auc, f1_optim, f1_auc_over_optim))
  print('>>> iou AUC: {:.4f}, iou optimum: {:.4f}, iou AUC/optim: {:.4f}'.format(iou_auc, iou_optim, iou_auc_over_optim))
  print('>>> pixel_acc AUC: {:.4f}, pixel_acc optimum: {:.4f}, pixel_acc AUC/optim: {:.4f}'.format(pixel_auc, pixel_optim, pixel_auc_over_optim))
  print('>>> analysis time: {:.2f}s'.format(time.time() - start))
  return f1_auc, f1_optim, iou_auc, iou_optim, pixel_auc, pixel_optim

In [None]:
from bayes_opt import BayesianOptimization
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration


# model_key = "stabilityai/stable-diffusion-2-1-base"
# model_key = "stabilityai/stable-diffusion-2-base"
# model_key = "stabilityai/stable-diffusion-2-1"
# model_key = "stabilityai/stable-diffusion-2"

# model_key = "CompVis/stable-diffusion-v1-1"
# model_key = "CompVis/stable-diffusion-v1-2"
# model_key = "CompVis/stable-diffusion-v1-3"
# model_key = "CompVis/stable-diffusion-v1-4"
# model_key = "runwayml/stable-diffusion-v1-5"

def black_box_function(t=100, map_weight1=0.3, map_weight2=0.5, map_weight3=0.1, map_weight4=0.1, alpha=8, beta=0.4, seed=3408, negative_token=False):
  VOC_label_map = {
    1:'aeroplane',
    2:'bicycle',
    3:'bird',
    4:'boat',
    5:'bottle',
    6:'bus',
    7:'car',
    8:'cat',
    9:'chair',
    10:'cow',
    11:'diningtable',
    12:'dog',
    13:'horse',
    14:'motorbike',
    15:'person',
    16:'pottedplant',
    17:'sheep',
    18:'sofa',
    19:'train',
    20:'tvmonitor'
  }

  Cityscape_label_map = {
    1: 'road', # flat
    2: 'person', # human
    3: 'building', # construction
    4: 'traffic light', # object
    5: 'vegetation', # nature
    6: 'car', # vehicle
    7: 'bus', # vehicle
    8: 'train', # vehicle
    9: 'motorcycle', # vehicle
    10: 'bicycle', #vehicle
  }

  Vaihingen_label_map = {
    1: 'building'
  }

  Kvasir_label_map = {
    1: 'tumor'
  }

  device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

  if model_key == "CompVis/ldm-text2im-large-256":
    ldm_stable = DiffusionPipeline.from_pretrained(model_key, torch_dtype=torch.bfloat16).to(device)
  else:
    ldm_stable = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=torch.bfloat16).to(device)

  ldm_stable.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler",
                                                     beta_start=0.00085,beta_end=0.012,
                                                     steps_offset=1)

  blip_device = "cuda:0"
  # blip device
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(blip_device)
  print(model_key)
  datasets = ["VOC2012", "Cityscape", "Vaihingen", "Kvasir-SEG"]
  label_maps = [VOC_label_map, Cityscape_label_map, Vaihingen_label_map, Kvasir_label_map]
  # datasets = ["VOC2012"]
  # label_maps = [VOC_label_map]
  map_weight_sum = map_weight1 + map_weight2 + map_weight3 + map_weight4
  map_weight1_ = map_weight1 / map_weight_sum
  map_weight2_ = map_weight2 / map_weight_sum
  map_weight3_ = map_weight3 / map_weight_sum
  map_weight4_ = map_weight4 / map_weight_sum
  weight = [map_weight1_, map_weight2_, map_weight3_, map_weight4_]
  root_dir = "results_0.9_{}_{:.2f}_{:.2f}_{:.2f}_{:.2f}_{:.2f}_{:.2f}_{}".format(int(t), map_weight1_, map_weight2_, map_weight3_, map_weight4_, alpha, beta, negative_token)
  thres_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
  # thres_list = [0.2, 0.4, 0.6, 0.8]
  if os.path.isdir(root_dir):
    shutil.rmtree(root_dir)
  os.mkdir(root_dir)
  augmented_label_file = 'aug_label_blip_bert_0.9.json'
  for ds, label_map in zip(datasets, label_maps):
    images_dir = os.path.join(ds, "images")
    result_dir = os.path.join(root_dir, ds)
    domain_test(processor, model, ldm_stable, blip_device, device, images_dir, result_dir, label_map, augmented_label_file, augmented_label=True, thres_list=thres_list, weight=weight, t=t, alpha=alpha, beta=beta, seed=seed, negative_token=negative_token)

  results_dir_list = [os.path.join(root_dir, ds) for ds in datasets]
  segmentations_dir_list = [os.path.join(ds, "segmentations") for ds in datasets]
  f1_auc, f1_optim, iou_auc, iou_optim, pixel_auc, pixel_optim = analysis(results_dir_list, segmentations_dir_list, augmented_label_file, thres_list=thres_list)

  del ldm_stable
  del model
  del processor
  torch.cuda.empty_cache()

  return f1_auc


def parameter_tuning(n_iter=5):
  pbounds = {'t': (90, 110), 'map_weight1': (0.001, 0.999), 'map_weight2': (0.001, 0.999), 'map_weight3': (0.001, 0.999), 'map_weight4': (0.001, 0.999), 'alpha': (1, 10), 'beta': (0, 0.8)}
  optimizer = BayesianOptimization(
    f=black_box_function,
    pbounds=pbounds,
    verbose=2, # verbose = 1 prints only when a maximum is observed, verbose = 0 is silent
  )
  initial_points = [
      {'t': 92, 'map_weight1': 0.001, 'map_weight2': 0.999, 'map_weight3': 0.001, 'map_weight4': 0.001, 'alpha': 7.5, 'beta': 0.4},
      {'t': 92, 'map_weight1': 0.001, 'map_weight2': 0.999, 'map_weight3': 0.001, 'map_weight4': 0.001, 'alpha': 7.5, 'beta': 0.3},
  ]

  for p in initial_points:
    optimizer.probe(p, lazy=True)
  optimizer.maximize(
    init_points=0,
    n_iter=n_iter,
  )

# parameter_tuning(n_iter=25)
import random

seed = 3871
# model_key = "CompVis/stable-diffusion-v1-4"
# black_box_function(t=100, map_weight1=0.3, map_weight2=0.5, map_weight3=0.1, map_weight4=0.1, alpha=8, beta=0.4, seed=seed)
# model_key = "runwayml/stable-diffusion-v1-5"
# black_box_function(t=100, map_weight1=0.3, map_weight2=0.5, map_weight3=0.1, map_weight4=0.1, alpha=16, beta=0.5, seed=seed, negative_token=True)
# black_box_function(t=100, map_weight1=0.3, map_weight2=0.5, map_weight3=0.1, map_weight4=0.1, alpha=8, beta=0.4, seed=seed, negative_token=False)


# model_key = "stabilityai/stable-diffusion-2-1"
# black_box_function(t=100, map_weight1=0.0, map_weight2=1.0, map_weight3=0.0, map_weight4=0.0, alpha=16, beta=0.5, seed=seed, negative_token=True)
# model_key = "stabilityai/stable-diffusion-2"
# black_box_function(t=100, map_weight1=0.0, map_weight2=1.0, map_weight3=0.0, map_weight4=0.0, alpha=16, beta=0.5, seed=seed, negative_token=True)
model_key = "CompVis/stable-diffusion-v1-1"
black_box_function(t=100, map_weight1=0.3, map_weight2=0.5, map_weight3=0.1, map_weight4=0.1, alpha=16, beta=0.5, seed=seed, negative_token=True)
model_key = "CompVis/stable-diffusion-v1-2"
black_box_function(t=100, map_weight1=0.3, map_weight2=0.5, map_weight3=0.1, map_weight4=0.1, alpha=16, beta=0.5, seed=seed, negative_token=True)
model_key = "CompVis/stable-diffusion-v1-3"
black_box_function(t=100, map_weight1=0.3, map_weight2=0.5, map_weight3=0.1, map_weight4=0.1, alpha=16, beta=0.5, seed=seed, negative_token=True)
model_key = "CompVis/stable-diffusion-v1-4"
black_box_function(t=100, map_weight1=0.3, map_weight2=0.5, map_weight3=0.1, map_weight4=0.1, alpha=16, beta=0.5, seed=seed, negative_token=True)
model_key = "runwayml/stable-diffusion-v1-5"
black_box_function(t=100, map_weight1=0.3, map_weight2=0.5, map_weight3=0.1, map_weight4=0.1, alpha=16, beta=0.5, seed=seed, negative_token=True)

# for _ in range(2):
#   seed = random.randint(0, 5000)
#   model_key = "CompVis/stable-diffusion-v1-1"
#   black_box_function(t=100, map_weight1=0.3, map_weight2=0.5, map_weight3=0.1, map_weight4=0.1, alpha=8, beta=0.4, seed=seed)
#   model_key = "CompVis/stable-diffusion-v1-2"
#   black_box_function(t=100, map_weight1=0.3, map_weight2=0.5, map_weight3=0.1, map_weight4=0.1, alpha=8, beta=0.4, seed=seed)
#   model_key = "CompVis/stable-diffusion-v1-3"
#   black_box_function(t=100, map_weight1=0.3, map_weight2=0.5, map_weight3=0.1, map_weight4=0.1, alpha=8, beta=0.4, seed=seed)
#   model_key = "CompVis/stable-diffusion-v1-4"
#   black_box_function(t=100, map_weight1=0.3, map_weight2=0.5, map_weight3=0.1, map_weight4=0.1, alpha=8, beta=0.4, seed=seed)
#   model_key = "runwayml/stable-diffusion-v1-5"
#   black_box_function(t=100, map_weight1=0.3, map_weight2=0.5, map_weight3=0.1, map_weight4=0.1, alpha=8, beta=0.4, seed=seed)
#   model_key = "stabilityai/stable-diffusion-2"
#   black_box_function(t=100, map_weight1=0.0, map_weight2=1.0, map_weight3=0.0, map_weight4=0.0, alpha=8, beta=0.4, seed=seed)
#   model_key = "stabilityai/stable-diffusion-2-1"
#   black_box_function(t=100, map_weight1=0.0, map_weight2=1.0, map_weight3=0.0, map_weight4=0.0, alpha=8, beta=0.4, seed=seed)




model_index.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

unet/diffusion_pytorch_model.safetensors not found


Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

(…)ature_extractor/preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

tokenizer/special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

scheduler/scheduler_config.json:   0%|          | 0.00/288 [00:00<?, ?B/s]

safety_checker/config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

text_encoder/config.json:   0%|          | 0.00/731 [00:00<?, ?B/s]

tokenizer/tokenizer_config.json:   0%|          | 0.00/942 [00:00<?, ?B/s]

unet/config.json:   0%|          | 0.00/926 [00:00<?, ?B/s]

vae/config.json:   0%|          | 0.00/704 [00:00<?, ?B/s]

tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/492M [00:00<?, ?B/s]

diffusion_pytorch_model.bin:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

diffusion_pytorch_model.bin:   0%|          | 0.00/335M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

CompVis/stable-diffusion-v1-1
>>> seed: 3871




>>>>>>>>>> dataset: VOC2012, size: 500, test time: 472.23s
>>> seed: 3871
>>>>>>>>>> dataset: Cityscape, size: 50, test time: 366.62s
>>> seed: 3871
>>>>>>>>>> dataset: Vaihingen, size: 150, test time: 135.59s
>>> seed: 3871
>>>>>>>>>> dataset: Kvasir-SEG, size: 300, test time: 302.34s
>>>>  results_0.9_100_0.30_0.50_0.10_0.10_16.00_0.50_True/VOC2012
>>>> thres: 0.1, dice: 0.5788, iou: 0.4719, pixel_acc: 0.7368
>>>> thres: 0.2, dice: 0.6239, iou: 0.5192, pixel_acc: 0.7909
>>>> thres: 0.3, dice: 0.6465, iou: 0.5424, pixel_acc: 0.8172
>>>> thres: 0.4, dice: 0.6592, iou: 0.5553, pixel_acc: 0.8356
>>>> thres: 0.5, dice: 0.6680, iou: 0.5636, pixel_acc: 0.8496
>>>> thres: 0.6, dice: 0.6740, iou: 0.5678, pixel_acc: 0.8605
>>>> thres: 0.7, dice: 0.6754, iou: 0.5670, pixel_acc: 0.8686
>>>> thres: 0.8, dice: 0.6712, iou: 0.5589, pixel_acc: 0.8756
>>>> thres: 0.9, dice: 0.6503, iou: 0.5304, pixel_acc: 0.8784
>>> dice AUC: 0.5233, iou AUC: 0.4375, pixel_acc AUC: 0.6706
>>>>  results_0.9_100_0.30_0

In [None]:
import shutil
import os
for dir in os.listdir('./'):
  if dir.startswith('results'):
    shutil.rmtree(dir)

In [None]:
!mkdir 'pdf'
import os
for f in os.listdir('./'):
  if f.endswith('.pdf'):
    shutil.copy(f, os.path.join('pdf', f))
import locale
locale.getpreferredencoding = lambda: "UTF-8"
!zip -r pdf.zip pdf
from google.colab import files
files.download("pdf.zip")

  adding: pdf/ (stored 0%)
  adding: pdf/final_map_couch.pdf (deflated 87%)
  adding: pdf/photograph_2_8_sofa.pdf (deflated 88%)
  adding: pdf/table_7_64_couch.pdf (deflated 87%)
  adding: pdf/710.pdf (deflated 2%)
  adding: pdf/in_8_32_couch.pdf (deflated 88%)
  adding: pdf/in_8_8_sofa.pdf (deflated 89%)
  adding: pdf/living_10_8_couch.pdf (deflated 88%)
  adding: pdf/room_11_8_sofa.pdf (deflated 88%)
  adding: pdf/++_5_16_couch.pdf (deflated 89%)
  adding: pdf/in_8_64_sofa.pdf (deflated 88%)
  adding: pdf/columns_13_32_sofa.pdf (deflated 89%)
  adding: pdf/in_8_16_sofa.pdf (deflated 89%)
  adding: pdf/table_7_32_sofa.pdf (deflated 88%)
  adding: pdf/a_1_64_sofa.pdf (deflated 88%)
  adding: pdf/of_3_64_sofa.pdf (deflated 88%)
  adding: pdf/<|startoftext|>_0_32_couch.pdf (deflated 92%)
  adding: pdf/living_10_32_couch.pdf (deflated 88%)
  adding: pdf/room_11_16_sofa.pdf (deflated 88%)
  adding: pdf/++_5_16_sofa.pdf (deflated 88%)
  adding: pdf/<|endoftext|>_14_64_sofa.pdf (deflated 88%

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>