In [None]:
!pip install diffusers["torch"] transformers accelerate
!pip install git+https://github.com/huggingface/diffusers
!pip install einops
!pip install bayesian-optimization

In [None]:
!unzip VOC2012.zip
!unzip Vaihingen.zip
!unzip Kvasir-SEG.zip
!unzip Cityscape.zip

In [3]:
from typing import Optional, Union, Tuple, List, Callable, Dict
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler, DiffusionPipeline
import numpy as np
import abc

LOW_RESOURCE = False
NUM_DIFFUSION_STEPS = 50
GUIDANCE_SCALE = 7.5
MAX_NUM_WORDS = 77

# code for store attention
class AttentionControl(abc.ABC):

    def step_callback(self, x_t):
        return x_t

    def between_steps(self):
        return

    @property
    def num_uncond_att_layers(self):
        return self.num_att_layers if LOW_RESOURCE else 0

    @abc.abstractmethod
    def forward (self, attn, is_cross: bool, place_in_unet: str):
        raise NotImplementedError

    def __call__(self, attn, is_cross: bool, place_in_unet: str):
        if self.cur_att_layer >= self.num_uncond_att_layers:
            if LOW_RESOURCE:
                attn = self.forward(attn, is_cross, place_in_unet)
            else:
                h = attn.shape[0]
                attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
        self.cur_att_layer += 1
        if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
            self.cur_att_layer = 0
            self.cur_step += 1
            self.between_steps()
        return attn

    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0

    def __init__(self):
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0

class EmptyControl(AttentionControl):

    def forward (self, attn, is_cross: bool, place_in_unet: str):
        return attn


class AttentionStore(AttentionControl):

    @staticmethod
    def get_empty_store():
        return {"down_cross": [], "mid_cross": [], "up_cross": [],
                "down_self": [],  "mid_self": [],  "up_self": []}

    @torch.no_grad()
    def forward(self, attn, is_cross: bool, place_in_unet: str):
        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
        if attn.shape[1] <= 64 ** 2:  # avoid memory overhead
            self.step_store[key].append(attn)
        return attn

    def between_steps(self):
        if len(self.attention_store) == 0:
            self.attention_store = self.step_store
        else:
            for key in self.attention_store:
                for i in range(len(self.attention_store[key])):
                    self.attention_store[key][i] += self.step_store[key][i]
        self.step_store = self.get_empty_store()

    @torch.no_grad()
    def get_average_attention(self):
        average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
        return average_attention

    def reset(self):
        super(AttentionStore, self).reset()
        self.step_store = self.get_empty_store()
        self.attention_store = {}

    def __init__(self):
        super(AttentionStore, self).__init__()
        self.step_store = self.get_empty_store()
        self.attention_store = {}

In [4]:
from PIL import Image
import cv2

## Visualization code utils
def view_images(images, num_rows=1, offset_ratio=0.02):
    if type(images) is list:
        num_empty = len(images) % num_rows
    elif images.ndim == 4:
        num_empty = images.shape[0] % num_rows
    else:
        images = [images]
        num_empty = 0

    empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
    images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
    num_items = len(images)

    h, w, c = images[0].shape
    offset = int(h * offset_ratio)
    num_cols = num_items // num_rows
    image_ = np.ones((h * num_rows + offset * (num_rows - 1),
                      w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
    for i in range(num_rows):
        for j in range(num_cols):
            image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
                i * num_cols + j]

    pil_img = Image.fromarray(image_)
    display(pil_img)


def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
    h, w, c = image.shape
    offset = int(h * .2)
    img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
    font = cv2.FONT_HERSHEY_SIMPLEX
    img[:h] = image
    textsize = cv2.getTextSize(text, font, 1, 2)[0]
    text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
    cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
    return img


In [5]:
# code for aggregaring attention
@torch.no_grad()
def aggregate_all_attention(prompts, attention_store: AttentionStore, from_where: List[str], is_cross: bool, select: int):
    attention_maps = attention_store.get_average_attention()
    att_8 = []
    att_16 = []
    att_32 = []
    att_64 = []
    for location in from_where:
        for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
            if item.shape[1] == 8*8:
                cross_maps = item.reshape(len(prompts), -1, 8, 8, item.shape[-1])[select]
                att_8.append(cross_maps)
            if item.shape[1] == 16*16:
                cross_maps = item.reshape(len(prompts), -1, 16, 16, item.shape[-1])[select]
                att_16.append(cross_maps)
            if item.shape[1] == 32*32:
                cross_maps = item.reshape(len(prompts), -1, 32, 32, item.shape[-1])[select]
                att_32.append(cross_maps)
            if item.shape[1] == 64*64:
                cross_maps = item.reshape(len(prompts), -1, 64, 64, item.shape[-1])[select]
                # print(cross_maps)
                att_64.append(cross_maps)
    atts = []
    # print(len(att_8), len(att_16), len(att_32), len(att_64)) # 1, 5, 5, 5, both LDM and SD
    for att in [att_8, att_16, att_32, att_64]:
        att = torch.cat(att, dim=0)
        att = att.sum(0) / att.shape[0]
        atts.append(att.cpu())
    del attention_maps
    torch.cuda.empty_cache()
    return atts


@torch.no_grad()
def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
    out = []
    attention_maps = attention_store.get_average_attention()
    num_pixels = res ** 2
    for location in from_where:
        for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
            if item.shape[1] == num_pixels:
                cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
                out.append(cross_maps)
    out = torch.cat(out, dim=0)
    out = out.sum(0) / out.shape[0]
    return out.cpu()


# visualize cross att
def show_cross_attention(prompts, tokenizer, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
    tokens = tokenizer.encode(prompts[select])
    decoder = tokenizer.decode
    attention_maps = aggregate_attention(prompts, attention_store, res, from_where, True, select)
    images = []
    j = 0
    for i in range(len(tokens)):
        image = attention_maps[:, :, i]
        image = 255 * image / image.max()
        image = image.unsqueeze(-1).expand(*image.shape, 3)
        # image = image.numpy().astype(np.uint8)
        image = image.float().numpy().astype(np.uint8)
        image = np.array(Image.fromarray(image).resize((256, 256)))
        if decoder(int(tokens[j])) == "++":
            j += 1
        image = text_under_image(image, decoder(int(tokens[j])))
        images.append(image)
        j+=1
        if j >= len(tokens):
            break
    view_images(np.stack(images, axis=0))


# visualize self att
def show_self_attention_comp(prompts, attention_store: AttentionStore, res: int, from_where: List[str],
                        max_com=10, select: int = 0):
    attention_maps = aggregate_attention(prompts, attention_store, res, from_where, False, select).float().numpy().reshape((res ** 2, res ** 2))
    u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
    images = []
    for i in range(max_com):
        image = vh[i].reshape(res, res)
        image = image - image.min()
        image = 255 * image / image.max()
        image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
        image = Image.fromarray(image).resize((256, 256))
        image = np.array(image)
        images.append(image)
    view_images(np.concatenate(images, axis=1))

In [6]:
def encode_imgs(imgs, vae):
    # imgs: [B, 3, H, W]
    imgs = 2 * imgs - 1
    posterior = vae.encode(imgs).latent_dist.mean
    # print('posterior in encode_imgs', posterior)
    latents = posterior * 0.18215
    return latents


# fix random seed
def same_seeds(seed):
    torch.manual_seed(seed)  # 固定随机种子（CPU）
    if torch.cuda.is_available():  # 固定随机种子（GPU)
        torch.cuda.manual_seed(seed)  # 为当前GPU设置
        torch.cuda.manual_seed_all(seed)  # 为所有GPU设置
    np.random.seed(seed)  # 保证后续使用random函数时，产生固定的随机数
    torch.backends.cudnn.benchmark = False  # GPU、网络结构固定，可设置为True
    torch.backends.cudnn.deterministic = True  # 固定网络结构


# cam visual_code
def show_cam_on_image(img, mask):
    mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(img.size[1],img.size[0]), mode='bilinear', align_corners=False).squeeze().squeeze()
    img = np.float32(img) / 255.
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + img
    cam = cam / np.max(cam)
    cam = np.uint8(255 * cam)
    return cam

In [7]:
## ptp utils function
from einops import rearrange
from compel import Compel

def init_latent(latent, model, height, width, generator, batch_size):
    if latent is None:
        latent = torch.randn(
            (1, model.unet.config.in_channels, height // 8, width // 8),
            generator=generator,
        )
    latents = latent.expand(batch_size,  model.unet.config.in_channels, height // 8, width // 8).to(model.device)
    return latent, latents


@torch.no_grad()
def register_attention_control(model, controller):
    def ca_forward(self, place_in_unet):
        to_out = self.to_out
        if type(to_out) is torch.nn.modules.container.ModuleList:
            to_out = self.to_out[0]
        else:
            to_out = self.to_out

        def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
            x = hidden_states
            context = encoder_hidden_states
            mask = attention_mask
            batch_size = len(x)
            h = self.heads
            q = self.to_q(x)
            is_cross = context is not None
            context = context if is_cross else x
            k = self.to_k(context)
            v = self.to_v(context)
            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
            sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale

            if mask is not None:
                mask = mask.reshape(batch_size, -1)
                max_neg_value = -torch.finfo(sim.dtype).max
                mask = mask[:, None, :].repeat(h, 1, 1)
                sim.masked_fill_(~mask, max_neg_value)

            # attention, what we cannot get enough of
            attn = sim.softmax(dim=-1)
            ## controller foward function saving the attention map in self.step_store
            attn = controller(attn, is_cross, place_in_unet)

            out = torch.einsum("b i j, b j d -> b i d", attn, v)
            out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
            return to_out(out)

        return forward

    class DummyController:

        def __call__(self, *args):
            return args[0]

        def __init__(self):
            self.num_att_layers = 0

    if controller is None:
        controller = DummyController()

    def register_recr(net_, count, place_in_unet):
        if net_.__class__.__name__ == 'Attention':
            net_.forward = ca_forward(net_, place_in_unet)
            return count + 1
        elif hasattr(net_, 'children'):
            for net__ in net_.children():
                count = register_recr(net__, count, place_in_unet)
        return count

    cross_att_count = 0
    sub_nets = model.unet.named_children()
    for net in sub_nets:
        if "down" in net[0]:
            cross_att_count += register_recr(net[1], 0, "down")
        elif "up" in net[0]:
            cross_att_count += register_recr(net[1], 0, "up")
        elif "mid" in net[0]:
            cross_att_count += register_recr(net[1], 0, "mid")

    controller.num_att_layers = cross_att_count


@torch.no_grad()
def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False, height=None, width=None):
    if low_resource:
        noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
        noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
    else:
        latents_input = torch.cat([latents] * 2)
        context = torch.cat(context)
        added_cond_kwargs = {}
        noise_pred = model.unet(latents_input, t, encoder_hidden_states=context, added_cond_kwargs=added_cond_kwargs)["sample"]

    return latents


## text to image custom pipeline
@torch.no_grad()
def text2image_ldm_stable(
    model,
    prompt: List[str],
    controller,
    num_inference_steps: int = 50,
    guidance_scale: float = 7.5,
    generator: Optional[torch.Generator] = None,
    latent: Optional[torch.FloatTensor] = None,
    low_resource: bool = False,
    noise_sample_num=1,
    height: int = None,
    width: int = None
):
    register_attention_control(model, controller)
    height = 512 if height is None else height
    width = 512 if width is None else width
    batch_size = len(prompt)

    text_input = model.tokenizer(
        prompt,
        padding="max_length",
        max_length=model.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    model.unet.to(model.device)
    if hasattr(model, 'text_encoder'):
      text_input_ids = text_input.input_ids
      prompt_embeds = model.text_encoder(text_input_ids.to(model.device), attention_mask=None)
      text_embeddings = prompt_embeds[0]

      uncond_input = model.tokenizer(
        [""] * batch_size,
        padding="max_length",
        max_length=model.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
      )
      negative_prompt_embeds = model.text_encoder(
          uncond_input.input_ids.to(model.device),
          attention_mask=None,
      )
      uncond_embeddings = negative_prompt_embeds[0]
    else:
      text_input = model.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
      text_embeddings = model.bert(text_input.input_ids.to(model.device))[0]
      uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
      uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0]

    context = [uncond_embeddings, text_embeddings]
    latent, latents = init_latent(latent, model, height, width, generator, batch_size)
    latents = latent.squeeze(1).to(model.device)

    ## Sets the discrete timesteps used for the diffusion chain (to be run before inference).
    model.scheduler.set_timesteps(num_inference_steps)
    latents = diffusion_step(model, controller, latents, context, num_inference_steps, guidance_scale, low_resource, height, width)

    return None, None


In [8]:
import torch.nn.functional as F

# main function
# Stable diffusion
def generate_att(t, ldm_stable, input_latent, noise, prompts, controller, pos, device, is_self=True, is_multi_self=False, is_cross_norm=True, weight=[0.3,0.5,0.1,0.1], height=None, width=None, verbose=False):
    ## pos: position of the target class word int he prompt
    controller.reset()
    g_cpu = torch.Generator(4307)
    # print(t)
    # print(t.dtype)
    t = int(t)
    latents_noisy = ldm_stable.scheduler.add_noise(input_latent, noise, torch.tensor(t, device=device))
    images, x_t = text2image_ldm_stable(ldm_stable, prompts, controller, latent=latents_noisy, num_inference_steps=t, guidance_scale=GUIDANCE_SCALE, generator=g_cpu, low_resource=LOW_RESOURCE, height=height, width=width)
    layers = ("mid", "up", "down")
#     cross attention:  torch.Size([8, 8, 77])
#     cross attention:  torch.Size([16, 16, 77])
#     cross attention:  torch.Size([32, 32, 77])
#     cross attention:  torch.Size([64, 64, 77])
    cross_attention_maps = aggregate_all_attention(prompts, controller, layers, True, 0)
    # print(cross_attention_maps)
#     self attention:  torch.Size([8, 8, 64])
#     self attention:  torch.Size([16, 16, 256])
#     self attention:  torch.Size([32, 32, 1024])
#     self attention:  torch.Size([64, 64, 4096])
    self_attention_maps = aggregate_all_attention(prompts, controller, ("up", "mid", "down"), False, 0)

    imgs = []
    ## res: resolution
    for idx,res in enumerate([8, 16, 32, 64]):
        out_att = cross_attention_maps[idx].permute(2,0,1).float()
        if is_cross_norm:
            att_max = torch.amax(out_att,dim=(1,2),keepdim=True)
            att_min = torch.amin(out_att,dim=(1,2),keepdim=True)
            out_att = (out_att-att_min)/(att_max-att_min)
        if is_multi_self:
            self_att = self_attention_maps[idx].view(res*res,res*res).float()
            self_att = self_att/self_att.max()
            out_att = torch.matmul(self_att.unsqueeze(0),out_att.view(-1,res*res,1)).view(-1,res,res)
        if res != 64:
            out_att = F.interpolate(out_att.unsqueeze(0), size=(64,64), mode='bilinear', align_corners=False).squeeze()
        ## 8*8: 0.3, 16*16: 0.5, 32*32: 0.1, 64*64: 0.1
        imgs.append(out_att * weight[idx])

    # aggregated cross attention map
    cross_att_map = torch.stack(imgs).sum(0)[pos].mean(0).view(64*64, 1)
    # refine cross attention map with self attention map
    if is_self and not is_multi_self:
        self_att = self_attention_maps[3].view(64*64,64*64).float()
        self_att = self_att/self_att.max()
        for i in range(1):
            cross_att_map = torch.matmul(self_att, cross_att_map)
    # res here is the highest resulution iterated in previous for loop, 64
    att_map = cross_att_map.view(res,res)
    att_map = F.interpolate(att_map.unsqueeze(0).unsqueeze(0), size=(512,512), mode='bilinear', align_corners=False).squeeze().squeeze()
    att_map = (att_map-att_map.min())/(att_map.max()-att_map.min())
    att_map = F.sigmoid(8 * (att_map - 0.4))
    att_map = (att_map-att_map.min())/(att_map.max()-att_map.min())

    del cross_att_map
    torch.cuda.empty_cache()

    if verbose:
      att_map_map = Image.fromarray((att_map.cpu().detach().numpy()*255).astype(np.uint8),mode="L")

      display(att_map_map)
      tokenizer = ldm_stable.tokenizer
      print("8x8 cross att map")
      show_cross_attention(prompts, tokenizer, controller, res=8, from_where=layers)
      print("16x16 cross att map")
      show_cross_attention(prompts, tokenizer, controller, res=16, from_where=layers)
      print("32x32 cross att map")
      show_cross_attention(prompts, tokenizer, controller, res=32, from_where=layers)
      print("64x64 cross att map")
      show_cross_attention(prompts, tokenizer, controller, res=64, from_where=layers)

    return att_map

In [9]:
from torchvision import transforms

def stable_diffusion_inference(img_path, cls_name, device, blip_device, processor, model, ldm_stable, verbose=False, weight=[0.3,0.5,0.1,0.1], t=100):
  ## img_path: path to the target image
  ## cls name: taget class in the prompt
  ## device: device of stable diffusion model
  ## blip device: device of BLIP model
  ## processor: BLIP processot
  ## model: BLIP model
  ## vae: vae of the stable diffusion model
  with torch.no_grad():
    same_seeds(3407)

    input_img = Image.open(img_path).convert("RGB")

    trans = []
    trans.append(transforms.ToTensor())
    trans = transforms.Compose(trans)

    img_tensor = (trans(input_img).unsqueeze(0)).to(device)

    rgb_512 = F.interpolate(img_tensor, (512, 512), mode='bilinear', align_corners=False).bfloat16()

    if hasattr(ldm_stable, 'vae'):
      vae = ldm_stable.vae
    else:
      vae = ldm_stable.vqvae
    input_latent = encode_imgs(rgb_512, vae)
    noise = torch.randn_like(input_latent).to(device)
    raw_image = input_img

    text = f"a photograph of {cls_name}"
    inputs = processor(raw_image, text, return_tensors="pt").to(blip_device) # processor: Blip processor

    # use blip and "++" emphasizing semantic information of target categories
    out = model.generate(**inputs)
    texts = processor.decode(out[0], skip_special_tokens=True)
    texts = text +"++"+ texts[len(text):] # ", highly realistic, artsy, trending, colorful"

    # weight is the weight of different layer's cross attn
    # pos is the position of target class word in the sentence, in "a photograph of plane" (plane)'s position is 4
    # t is the denoising step, usually set between 50 to 150
    prompts = [texts]
    # print("**** blip_prompt: "+texts+"****")
    pos = [4]   # pos of targer class word
    controller = AttentionStore()

    height = 512
    width = 512
    mask = generate_att(t, ldm_stable, input_latent, noise, prompts, controller, pos, device, is_self=True, is_multi_self=False, is_cross_norm=True, weight=weight, height=height, width=width, verbose=verbose)
    mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(raw_image.size[1],raw_image.size[0]), mode='bilinear', align_corners=False).squeeze().squeeze()
    # print(mask.shape, raw_image.size)
    if verbose:
      cam = show_cam_on_image(raw_image,mask)
      print("visual_cam")
      display(Image.fromarray(cam[:,:,::-1]))
    del img_tensor
    del noise
    del inputs
    torch.cuda.empty_cache()
    return mask

In [10]:
import json
import time, os, shutil

def domain_test(processor, model, ldm_stable, blip_device, device, images_dir, result_dir, label_map, augmented_label_file, augmented_label=False, thres_list=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], weight=[0.3,0.5,0.1,0.1], t = 100):
  start = time.time()
  if os.path.isdir(result_dir):
    shutil.rmtree(result_dir)
  os.mkdir(result_dir)
  os.mkdir(os.path.join(result_dir, 'mask'))
  for thres in thres_list:
    os.mkdir(os.path.join(result_dir, '{}'.format(thres)))
  cls_arr_dir = images_dir.replace("images", "class_array")
  augmented_label_path = images_dir.replace("images", augmented_label_file)

  with open(augmented_label_path, 'r') as f:
    label_data = json.load(f)

  for img_file in os.listdir(images_dir):
    if not(img_file.endswith('.png') or img_file.endswith('.tif') or img_file.endswith('.jpg')):
      continue
    img_path = os.path.join(images_dir, img_file)
    # print(">>> ", img_path)
    seg_classes = label_data[img_path]

    for cls_name in seg_classes.keys():
      mask = stable_diffusion_inference(img_path, cls_name, device, blip_device, processor, model, ldm_stable, verbose=False, weight=weight, t=t)
      with open(os.path.join(result_dir, 'mask', '{}_{}.npy'.format(img_file.split('.')[0], cls_name)), 'wb') as f:
        np.save(f, mask)
      for mask_threshold in thres_list:
        mask_binary = np.where(mask > mask_threshold, 255, 0)
        mask_binary_img = Image.fromarray(mask_binary.astype(np.uint8))
        mask_binary_img.save(os.path.join(result_dir, '{}'.format(mask_threshold), '{}_{}.png'.format(img_file.split('.')[0], cls_name)))

      if augmented_label:
        for aug_cls_name in seg_classes[cls_name]:
          mask = stable_diffusion_inference(img_path, aug_cls_name, device, blip_device, processor, model, ldm_stable, verbose=False, weight=weight, t=t)
          with open(os.path.join(result_dir, 'mask', '{}_{}.npy'.format(img_file.split('.')[0], aug_cls_name)), 'wb') as f:
            np.save(f, mask)
          for mask_threshold in thres_list:
            mask_binary = np.where(mask > mask_threshold, 255, 0)
            mask_binary_img = Image.fromarray(mask_binary.astype(np.uint8))
            mask_binary_img.save(os.path.join(result_dir, '{}'.format(mask_threshold), '{}_{}.png'.format(img_file.split('.')[0], aug_cls_name)))

  print(">>>>>>>>>> test time: {:.2f}s".format(time.time() - start))


In [11]:
from PIL import Image
import numpy as np
import os, time
from sklearn.metrics import f1_score, roc_curve, auc
from numba import cuda

def analysis(results_dir_list, segmentations_dir_list, augmented_label_file, thres_list=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]):
  start = time.time()

  iou_res_ = {}
  pixel_acc_res_ = {}
  f1_res_ = {}
  for thres in thres_list:
    iou_res_[thres] = []
    pixel_acc_res_[thres] = []
    f1_res_[thres] = []

  for results_dir, segmentations_dir in zip(results_dir_list, segmentations_dir_list):
    cls_arr_dir = segmentations_dir.replace("segmentations", "class_array")
    images_dir = segmentations_dir.replace("segmentations", "images")
    augmented_label_path = segmentations_dir.replace("segmentations", augmented_label_file)
    print('>>>> ', results_dir)
    with open(augmented_label_path, 'r') as f:
      label_data = json.load(f)

    for thres in thres_list:
      predict_root_dir = os.path.join(results_dir, '{}'.format(thres))
      if not os.path.isdir(predict_root_dir):
        continue
      iou_res = []
      pixel_acc_res = []
      f1_res = []

      for seg_file in os.listdir(segmentations_dir):
        if not(seg_file.endswith('.png') or seg_file.endswith('.tif') or seg_file.endswith('.jpg')):
          continue
        img_path = os.path.join(images_dir, seg_file)
        seg_classes = label_data[img_path]

        for cls_name in seg_classes.keys():
          all_classes = [cls_name] + seg_classes[cls_name]
          seg_cls_arr = np.load(os.path.join(cls_arr_dir, '{}_{}.npy'.format(seg_file.split('.')[0], cls_name)))
          iou = -1
          pixel_acc = -1
          f1 = -1
          for cls in all_classes:
            predict_path = os.path.join(predict_root_dir, '{}_{}.png'.format(seg_file.split('.')[0], cls))
            if os.path.isfile(predict_path):
              predict_img = Image.open(predict_path)
              predict_cls_arr = np.asarray(predict_img) / 255

              if predict_cls_arr.shape != seg_cls_arr.shape:
                print('>>>invalid prediction', predict_path, seg_cls_arr.shape, predict_cls_arr.shape)
                continue
              intersection = np.sum(predict_cls_arr * seg_cls_arr).astype(np.float32)
              union = np.sum(np.logical_or(predict_cls_arr, seg_cls_arr)).astype(np.float32)
              correct = np.sum(predict_cls_arr == seg_cls_arr).astype(np.float32)

              iou_ = intersection / union
              pixel_acc_ = correct / (seg_cls_arr.shape[0] * seg_cls_arr.shape[1])
              f1_ = f1_score(seg_cls_arr.flatten(), predict_cls_arr.flatten())

              if f1_ > f1:
                f1 = f1_
                pixel_acc = pixel_acc_
                iou = iou_
          iou_res.append(iou)
          pixel_acc_res.append(pixel_acc)
          f1_res.append(f1)

      iou_res_[thres] += iou_res
      pixel_acc_res_[thres] += pixel_acc_res
      f1_res_[thres] += f1_res

      f1_mean = np.array(f1_res).mean()
      iou_mean = np.array(iou_res).mean()
      pixel_acc_mean = np.array(pixel_acc_res).mean()
      print('>>>> thres: {}, dice: {:.4f}, iou: {:.4f}, pixel_acc: {:.4f}'.format(thres, f1_mean, iou_mean, pixel_acc_mean))

  iou_res_summary = []
  pixel_res_summary = []
  f1_res_summary = []
  for thres in thres_list:
    iou_res_summary.append(np.array(iou_res_[thres]).mean())
    pixel_res_summary.append(np.array(pixel_acc_res_[thres]).mean())
    f1_res_summary.append(np.array(f1_res_[thres]).mean())

  iou_res_summary = np.asarray(iou_res_summary)
  iou_auc = auc(np.asarray(thres_list), iou_res_summary)
  iou_optim = iou_res_summary.max()
  iou_auc_over_optim = iou_auc / iou_optim

  pixel_res_summary = np.asarray(pixel_res_summary)
  pixel_auc = auc(np.asarray(thres_list), pixel_res_summary)
  pixel_optim = pixel_res_summary.max()
  pixel_auc_over_optim = pixel_auc / pixel_optim

  f1_res_summary = np.asarray(f1_res_summary)
  f1_auc = auc(np.asarray(thres_list), f1_res_summary)
  f1_optim = f1_res_summary.max()
  f1_auc_over_optim = f1_auc / f1_optim

  print('>>> dice AUC: {:.4f}, dic optimum: {:.4f}, dice AUC/optim: {:.4f}'.format(f1_auc, f1_optim, f1_auc_over_optim))
  print('>>> iou AUC: {:.4f}, iou optimum: {:.4f}, iou AUC/optim: {:.4f}'.format(iou_auc, iou_optim, iou_auc_over_optim))
  print('>>> pixel_acc AUC: {:.4f}, pixel_acc optimum: {:.4f}, pixel_acc AUC/optim: {:.4f}'.format(pixel_auc, pixel_optim, pixel_auc_over_optim))
  print('>>> analysis time: {:.2f}s'.format(time.time() - start))
  return f1_auc, f1_optim, iou_auc, iou_optim, pixel_auc, pixel_optim

In [None]:
from bayes_opt import BayesianOptimization
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
# import ray

# @ray.remote(num_gpus=1)
def black_box_function(t = 100, map_weight1=0.3, map_weight2=0.5, map_weight3=0.1, map_weight4=0.1):
  VOC_label_map = {
    1:'aeroplane',
    2:'bicycle',
    3:'bird',
    4:'boat',
    5:'bottle',
    6:'bus',
    7:'car',
    8:'cat',
    9:'chair',
    10:'cow',
    11:'diningtable',
    12:'dog',
    13:'horse',
    14:'motorbike',
    15:'person',
    16:'pottedplant',
    17:'sheep',
    18:'sofa',
    19:'train',
    20:'tvmonitor'
  }

  Cityscape_label_map = {
    1: 'road', # flat
    2: 'person', # human
    3: 'building', # construction
    4: 'traffic light', # object
    5: 'vegetation', # nature
    6: 'car', # vehicle
    7: 'bus', # vehicle
    8: 'train', # vehicle
    9: 'motorcycle', # vehicle
    10: 'bicycle', #vehicle
  }

  Vaihingen_label_map = {
    1: 'building'
  }

  Kvasir_label_map = {
    1: 'tumor'
  }

  device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

  # model_key = "runwayml/stable-diffusion-v1-5"

  # model_key = "stabilityai/stable-diffusion-2-1-base"
  # model_key = "stabilityai/stable-diffusion-2-base"
  model_key = "stabilityai/stable-diffusion-2-1"
  # model_key = "stabilityai/stable-diffusion-2"

  # model_key = "CompVis/stable-diffusion-v1-1"
  # model_key = "CompVis/stable-diffusion-v1-2"
  # model_key = "CompVis/stable-diffusion-v1-3"
  # model_key = "CompVis/stable-diffusion-v1-4"

  # model_key = "CompVis/ldm-text2im-large-256"

  # model_key = "stabilityai/sdxl-turbo"

  ldm_stable = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=torch.bfloat16).to(device)
  # ldm_stable = DiffusionPipeline.from_pretrained(model_key, torch_dtype=torch.bfloat16).to(device)

  ldm_stable.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler",
                                                     beta_start=0.00085,beta_end=0.012,
                                                     steps_offset=1)

  blip_device = "cuda:0"
  # blip device
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(blip_device)

  datasets = ["VOC2012", "Cityscape", "Vaihingen", "Kvasir-SEG"]
  label_maps = [VOC_label_map, Cityscape_label_map, Vaihingen_label_map, Kvasir_label_map]
  map_weight_sum = map_weight1 + map_weight2 + map_weight3 + map_weight4
  map_weight1_ = map_weight1 / map_weight_sum
  map_weight2_ = map_weight2 / map_weight_sum
  map_weight3_ = map_weight3 / map_weight_sum
  map_weight4_ = map_weight4 / map_weight_sum
  weight = [map_weight1_, map_weight2_, map_weight3_, map_weight4_]
  root_dir = "results_0.9_{}_{:.2f}_{:.2f}_{:.2f}_{:.2f}".format(int(t), map_weight1_, map_weight2_, map_weight3_, map_weight4_)
  thres_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
  # thres_list = [0.2, 0.4, 0.6, 0.8]
  if os.path.isdir(root_dir):
    shutil.rmtree(root_dir)
  os.mkdir(root_dir)
  augmented_label_file = 'aug_label_blip_bert_0.9.json'
  for ds, label_map in zip(datasets, label_maps):
   images_dir = os.path.join(ds, "images")
   result_dir = os.path.join(root_dir, ds)
   domain_test(processor, model, ldm_stable, blip_device, device, images_dir, result_dir, label_map, augmented_label_file, augmented_label=True, thres_list=thres_list, weight=weight, t=t)

  results_dir_list = [os.path.join(root_dir, ds) for ds in datasets]
  segmentations_dir_list = [os.path.join(ds, "segmentations") for ds in datasets]
  f1_auc, f1_optim, iou_auc, iou_optim, pixel_auc, pixel_optim = analysis(results_dir_list, segmentations_dir_list, augmented_label_file, thres_list=thres_list)

  del ldm_stable
  del model
  del processor
  torch.cuda.empty_cache()

  return f1_auc


def parameter_tuning():
  pbounds = {'t': (90, 110), 'map_weight1': (0.0, 1.0), 'map_weight2': (0.0, 1.0), 'map_weight3': (0.0, 1.0), 'map_weight4': (0.0, 1.0)}
  optimizer = BayesianOptimization(
    f=black_box_function,
    pbounds=pbounds,
    verbose=2, # verbose = 1 prints only when a maximum is observed, verbose = 0 is silent
    random_state=1,
  )
  optimizer.maximize(
    init_points=0,
    n_iter= 9,
  )

parameter_tuning()

|   iter    |  target   | map_we... | map_we... | map_we... | map_we... |     t     |
-------------------------------------------------------------------------------------


model_index.json:   0%|          | 0.00/537 [00:00<?, ?B/s]

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/1.36G [00:00<?, ?B/s]

(…)ature_extractor/preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

tokenizer/special_tokens_map.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

text_encoder/config.json:   0%|          | 0.00/633 [00:00<?, ?B/s]

tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

scheduler/scheduler_config.json:   0%|          | 0.00/345 [00:00<?, ?B/s]

tokenizer/tokenizer_config.json:   0%|          | 0.00/824 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.46G [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

vae/config.json:   0%|          | 0.00/611 [00:00<?, ?B/s]

unet/config.json:   0%|          | 0.00/939 [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]



>>>>>>>>>> test time: 123.23s
>>>>>>>>>> test time: 62.07s
>>>>>>>>>> test time: 10.59s
>>>>>>>>>> test time: 18.82s
>>>>  results_0.9_92_0.29_0.50_0.00_0.21/VOC2012
>>>> thres: 0.1, dice: 0.4163, iou: 0.3020, pixel_acc: 0.5125
>>>> thres: 0.2, dice: 0.5079, iou: 0.3879, pixel_acc: 0.6684
>>>> thres: 0.3, dice: 0.5765, iou: 0.4574, pixel_acc: 0.7568
>>>> thres: 0.4, dice: 0.6268, iou: 0.5111, pixel_acc: 0.8113
>>>> thres: 0.5, dice: 0.6558, iou: 0.5406, pixel_acc: 0.8429
>>>> thres: 0.6, dice: 0.6659, iou: 0.5489, pixel_acc: 0.8619
>>>> thres: 0.7, dice: 0.6558, iou: 0.5344, pixel_acc: 0.8694
>>>> thres: 0.8, dice: 0.6202, iou: 0.4934, pixel_acc: 0.8670
>>>> thres: 0.9, dice: 0.5102, iou: 0.3836, pixel_acc: 0.8455
>>>>  results_0.9_92_0.29_0.50_0.00_0.21/Cityscape
>>>> thres: 0.1, dice: 0.2217, iou: 0.1422, pixel_acc: 0.2427
>>>> thres: 0.2, dice: 0.2310, iou: 0.1496, pixel_acc: 0.3545
>>>> thres: 0.3, dice: 0.2431, iou: 0.1624, pixel_acc: 0.4584
>>>> thres: 0.4, dice: 0.2509, iou: 0.1

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]



>>>>>>>>>> test time: 122.67s
>>>>>>>>>> test time: 62.74s
>>>>>>>>>> test time: 10.55s
>>>>>>>>>> test time: 18.97s
>>>>  results_0.9_105_0.32_0.30_0.12_0.26/VOC2012
>>>> thres: 0.1, dice: 0.3911, iou: 0.2798, pixel_acc: 0.4559
>>>> thres: 0.2, dice: 0.4652, iou: 0.3485, pixel_acc: 0.6044
>>>> thres: 0.3, dice: 0.5262, iou: 0.4071, pixel_acc: 0.6955
>>>> thres: 0.4, dice: 0.5723, iou: 0.4525, pixel_acc: 0.7565
>>>> thres: 0.5, dice: 0.6076, iou: 0.4876, pixel_acc: 0.8026
>>>> thres: 0.6, dice: 0.6274, iou: 0.5066, pixel_acc: 0.8320
>>>> thres: 0.7, dice: 0.6289, iou: 0.5045, pixel_acc: 0.8508
>>>> thres: 0.8, dice: 0.6007, iou: 0.4723, pixel_acc: 0.8563
>>>> thres: 0.9, dice: 0.4921, iou: 0.3655, pixel_acc: 0.8407
>>>>  results_0.9_105_0.32_0.30_0.12_0.26/Cityscape
>>>> thres: 0.1, dice: 0.2238, iou: 0.1438, pixel_acc: 0.2386
>>>> thres: 0.2, dice: 0.2325, iou: 0.1501, pixel_acc: 0.3391
>>>> thres: 0.3, dice: 0.2369, iou: 0.1544, pixel_acc: 0.4380
>>>> thres: 0.4, dice: 0.2440, iou: 0

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]



>>>>>>>>>> test time: 124.96s
>>>>>>>>>> test time: 63.48s
>>>>>>>>>> test time: 10.64s
>>>>>>>>>> test time: 19.14s
>>>>  results_0.9_99_0.14_0.22_0.28_0.36/VOC2012
>>>> thres: 0.1, dice: 0.4224, iou: 0.3080, pixel_acc: 0.5248
>>>> thres: 0.2, dice: 0.5092, iou: 0.3922, pixel_acc: 0.6655
>>>> thres: 0.3, dice: 0.5687, iou: 0.4498, pixel_acc: 0.7432
>>>> thres: 0.4, dice: 0.6028, iou: 0.4839, pixel_acc: 0.7875
>>>> thres: 0.5, dice: 0.6168, iou: 0.4959, pixel_acc: 0.8142
>>>> thres: 0.6, dice: 0.6148, iou: 0.4918, pixel_acc: 0.8311
>>>> thres: 0.7, dice: 0.5994, iou: 0.4750, pixel_acc: 0.8398
>>>> thres: 0.8, dice: 0.5589, iou: 0.4351, pixel_acc: 0.8438
>>>> thres: 0.9, dice: 0.4434, iou: 0.3262, pixel_acc: 0.8321
>>>>  results_0.9_99_0.14_0.22_0.28_0.36/Cityscape
>>>> thres: 0.1, dice: 0.2312, iou: 0.1501, pixel_acc: 0.2648
>>>> thres: 0.2, dice: 0.2458, iou: 0.1621, pixel_acc: 0.3813
>>>> thres: 0.3, dice: 0.2537, iou: 0.1701, pixel_acc: 0.4673
>>>> thres: 0.4, dice: 0.2576, iou: 0.1

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]



>>>>>>>>>> test time: 123.19s
>>>>>>>>>> test time: 61.36s
>>>>>>>>>> test time: 10.17s
>>>>>>>>>> test time: 18.10s
>>>>  results_0.9_91_0.19_0.47_0.16_0.19/VOC2012
>>>> thres: 0.1, dice: 0.4532, iou: 0.3345, pixel_acc: 0.5852
>>>> thres: 0.2, dice: 0.5561, iou: 0.4365, pixel_acc: 0.7257
>>>> thres: 0.3, dice: 0.6204, iou: 0.5058, pixel_acc: 0.7943
>>>> thres: 0.4, dice: 0.6558, iou: 0.5436, pixel_acc: 0.8324
>>>> thres: 0.5, dice: 0.6724, iou: 0.5600, pixel_acc: 0.8543
>>>> thres: 0.6, dice: 0.6752, iou: 0.5607, pixel_acc: 0.8678
>>>> thres: 0.7, dice: 0.6615, iou: 0.5417, pixel_acc: 0.8725
>>>> thres: 0.8, dice: 0.6223, iou: 0.4957, pixel_acc: 0.8681
>>>> thres: 0.9, dice: 0.5141, iou: 0.3880, pixel_acc: 0.8475
>>>>  results_0.9_91_0.19_0.47_0.16_0.19/Cityscape
>>>> thres: 0.1, dice: 0.2246, iou: 0.1447, pixel_acc: 0.2621
>>>> thres: 0.2, dice: 0.2424, iou: 0.1603, pixel_acc: 0.3946
>>>> thres: 0.3, dice: 0.2534, iou: 0.1723, pixel_acc: 0.5061
>>>> thres: 0.4, dice: 0.2609, iou: 0.1

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]



>>>>>>>>>> test time: 124.78s
>>>>>>>>>> test time: 63.88s
>>>>>>>>>> test time: 10.72s
>>>>>>>>>> test time: 19.28s
>>>>  results_0.9_90_0.29_0.21_0.17_0.33/VOC2012
>>>> thres: 0.1, dice: 0.3844, iou: 0.2735, pixel_acc: 0.4395
>>>> thres: 0.2, dice: 0.4530, iou: 0.3370, pixel_acc: 0.5815
>>>> thres: 0.3, dice: 0.5081, iou: 0.3881, pixel_acc: 0.6732
>>>> thres: 0.4, dice: 0.5494, iou: 0.4280, pixel_acc: 0.7335
>>>> thres: 0.5, dice: 0.5788, iou: 0.4564, pixel_acc: 0.7786
>>>> thres: 0.6, dice: 0.5940, iou: 0.4685, pixel_acc: 0.8099
>>>> thres: 0.7, dice: 0.5896, iou: 0.4609, pixel_acc: 0.8315
>>>> thres: 0.8, dice: 0.5579, iou: 0.4289, pixel_acc: 0.8422
>>>> thres: 0.9, dice: 0.4373, iou: 0.3171, pixel_acc: 0.8301
>>>>  results_0.9_90_0.29_0.21_0.17_0.33/Cityscape
>>>> thres: 0.1, dice: 0.2245, iou: 0.1444, pixel_acc: 0.2423
>>>> thres: 0.2, dice: 0.2310, iou: 0.1489, pixel_acc: 0.3456
>>>> thres: 0.3, dice: 0.2368, iou: 0.1538, pixel_acc: 0.4474
>>>> thres: 0.4, dice: 0.2372, iou: 0.1

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]



>>>>>>>>>> test time: 122.54s
>>>>>>>>>> test time: 61.37s
>>>>>>>>>> test time: 10.58s
>>>>>>>>>> test time: 18.87s
>>>>  results_0.9_92_0.00_1.00_0.00_0.00/VOC2012
>>>> thres: 0.1, dice: 0.5526, iou: 0.4302, pixel_acc: 0.7184
>>>> thres: 0.2, dice: 0.6332, iou: 0.5185, pixel_acc: 0.8085
>>>> thres: 0.3, dice: 0.6713, iou: 0.5610, pixel_acc: 0.8527
>>>> thres: 0.4, dice: 0.6903, iou: 0.5787, pixel_acc: 0.8739
>>>> thres: 0.5, dice: 0.6916, iou: 0.5787, pixel_acc: 0.8842
>>>> thres: 0.6, dice: 0.6843, iou: 0.5689, pixel_acc: 0.8898
>>>> thres: 0.7, dice: 0.6605, iou: 0.5402, pixel_acc: 0.8861
>>>> thres: 0.8, dice: 0.6053, iou: 0.4807, pixel_acc: 0.8735
>>>> thres: 0.9, dice: 0.4841, iou: 0.3591, pixel_acc: 0.8489
>>>>  results_0.9_92_0.00_1.00_0.00_0.00/Cityscape
>>>> thres: 0.1, dice: 0.2356, iou: 0.1554, pixel_acc: 0.3482
>>>> thres: 0.2, dice: 0.2493, iou: 0.1688, pixel_acc: 0.4629
>>>> thres: 0.3, dice: 0.2633, iou: 0.1825, pixel_acc: 0.5377
>>>> thres: 0.4, dice: 0.2789, iou: 0.2

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]



>>>>>>>>>> test time: 122.63s
>>>>>>>>>> test time: 61.90s
>>>>>>>>>> test time: 10.62s
>>>>>>>>>> test time: 18.87s
>>>>  results_0.9_92_0.00_0.56_0.44_0.00/VOC2012
>>>> thres: 0.1, dice: 0.5410, iou: 0.4214, pixel_acc: 0.6977
>>>> thres: 0.2, dice: 0.6256, iou: 0.5127, pixel_acc: 0.7928
>>>> thres: 0.3, dice: 0.6611, iou: 0.5522, pixel_acc: 0.8344
>>>> thres: 0.4, dice: 0.6791, iou: 0.5714, pixel_acc: 0.8602
>>>> thres: 0.5, dice: 0.6856, iou: 0.5755, pixel_acc: 0.8740
>>>> thres: 0.6, dice: 0.6776, iou: 0.5627, pixel_acc: 0.8784
>>>> thres: 0.7, dice: 0.6548, iou: 0.5342, pixel_acc: 0.8771
>>>> thres: 0.8, dice: 0.6033, iou: 0.4773, pixel_acc: 0.8663
>>>> thres: 0.9, dice: 0.4840, iou: 0.3602, pixel_acc: 0.8447
>>>>  results_0.9_92_0.00_0.56_0.44_0.00/Cityscape
>>>> thres: 0.1, dice: 0.2393, iou: 0.1576, pixel_acc: 0.3272
>>>> thres: 0.2, dice: 0.2594, iou: 0.1754, pixel_acc: 0.4526
>>>> thres: 0.3, dice: 0.2668, iou: 0.1846, pixel_acc: 0.5335
>>>> thres: 0.4, dice: 0.2799, iou: 0.1

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]



>>>>>>>>>> test time: 122.69s
>>>>>>>>>> test time: 62.34s
>>>>>>>>>> test time: 10.62s
>>>>>>>>>> test time: 18.93s
>>>>  results_0.9_110_0.00_0.00_1.00_0.00/VOC2012
>>>> thres: 0.1, dice: 0.5478, iou: 0.4321, pixel_acc: 0.6987
>>>> thres: 0.2, dice: 0.5881, iou: 0.4746, pixel_acc: 0.7722
>>>> thres: 0.3, dice: 0.5959, iou: 0.4795, pixel_acc: 0.8077
>>>> thres: 0.4, dice: 0.5887, iou: 0.4694, pixel_acc: 0.8254
>>>> thres: 0.5, dice: 0.5645, iou: 0.4427, pixel_acc: 0.8336
>>>> thres: 0.6, dice: 0.5254, iou: 0.4040, pixel_acc: 0.8347
>>>> thres: 0.7, dice: 0.4713, iou: 0.3530, pixel_acc: 0.8309
>>>> thres: 0.8, dice: 0.3894, iou: 0.2774, pixel_acc: 0.8211
>>>> thres: 0.9, dice: 0.2570, iou: 0.1692, pixel_acc: 0.8053
>>>>  results_0.9_110_0.00_0.00_1.00_0.00/Cityscape
>>>> thres: 0.1, dice: 0.2487, iou: 0.1653, pixel_acc: 0.3835
>>>> thres: 0.2, dice: 0.2691, iou: 0.1822, pixel_acc: 0.5255
>>>> thres: 0.3, dice: 0.2829, iou: 0.1935, pixel_acc: 0.6140
>>>> thres: 0.4, dice: 0.2801, iou: 0

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]



>>>>>>>>>> test time: 122.92s
>>>>>>>>>> test time: 62.19s
>>>>>>>>>> test time: 10.70s
>>>>>>>>>> test time: 19.09s
>>>>  results_0.9_95_0.00_0.50_0.50_0.00/VOC2012


In [2]:
# !unzip VOC2012_small.zip
import shutil
import os
for dir in os.listdir('./'):
  if dir.startswith('results'):
    shutil.rmtree(dir)