In [None]:
!pip install diffusers["torch"] transformers accelerate
!pip install git+https://github.com/huggingface/diffusers
!pip install einops
!pip install compel

In [None]:
!unzip kvasir-seg.zip

In [None]:
import shutil, os, random
import numpy as np
from PIL import Image


def sample_set_kavasir(root_dir, target_dir, sample_size=50):
  if not os.path.isdir(root_dir):
    raise ValueError('invalid root_dir: {}'.format(root_dir))
  if os.path.isdir(target_dir):
    shutil.rmtree(target_dir)
  os.mkdir(target_dir)
  os.mkdir(os.path.join(target_dir, 'images'))
  os.mkdir(os.path.join(target_dir, 'segmentations'))

  image_dir = os.path.join(root_dir, 'images')
  mask_dir = os.path.join(root_dir, 'masks')
  image_dir_2 = os.path.join(target_dir, 'images')
  mask_dir_2 = os.path.join(target_dir, 'segmentations')
  file_list = os.listdir(image_dir)
  sample_list = random.sample(file_list, sample_size)
  for f in sample_list:
    source = os.path.join(image_dir, f)
    target = os.path.join(image_dir_2, f)
    shutil.copy(source, target)

    source = os.path.join(mask_dir, f)
    target = os.path.join(mask_dir_2, f)
    shutil.copy(source, target)


root_dir = 'Kvasir-SEG' # size 1000
target_dir = 'Kvasir-SEG-sample'
sample_size = 100 # samll: 30, large: 300

sample_set_kavasir(root_dir, target_dir, sample_size)
print('>>> sample size', len(os.listdir(os.path.join(target_dir, 'images'))))

>>> sample size 100


In [None]:
# import locale
# locale.getpreferredencoding = lambda: "UTF-8"
# !zip -r Kvasir-SEG-sample.zip Kvasir-SEG-sample
# from google.colab import files
# files.download("Kvasir-SEG-sample.zip")

import shutil
import os
shutil.rmtree('Kvasir-SEG')
os.rename('Kvasir-SEG-sample', 'Kvasir-SEG')

In [None]:
!jar xvf buildings_vaihingen.zip

In [None]:
import shutil, os, random
import numpy as np
from PIL import Image

def transform(i):
  res = str(i)
  if len(res) < 3:
    prefix = "0" if len(res) == 2 else "00"
    res = prefix + res

  return res

def sample_set_vaihingen(root_dir, target_dir, sample_size=50):
  if not os.path.isdir(root_dir):
    raise ValueError('invalid root_dir: {}'.format(root_dir))
  if os.path.isdir(target_dir):
    shutil.rmtree(target_dir)
  os.mkdir(target_dir)
  os.mkdir(os.path.join(target_dir, 'images'))
  os.mkdir(os.path.join(target_dir, 'segmentations'))

  index_range = range(1, 169)
  sample_index = random.sample(index_range, sample_size)
  for idx in sample_index:
    i = transform(idx)
    img_name = "building_{}.tif".format(i)
    mask_name = "all_buildings_mask_{}.tif".format(i)
    img_path = os.path.join(root_dir, img_name)
    mask_path = os.path.join(root_dir, mask_name)

    img_path_2 = os.path.join(target_dir, 'images', '{}.tif'.format(i))
    mask_path_2 = os.path.join(target_dir, 'segmentations', '{}.tif'.format(i))
    shutil.copy(img_path, img_path_2)
    shutil.copy(mask_path, mask_path_2)

root_dir = 'buildings'
target_dir = 'Vaihingen-sample' # size 168
sample_size = 100 # small: 20, large: 150

sample_set_vaihingen(root_dir, target_dir, sample_size)
print('>>> sample size', len(os.listdir(os.path.join(target_dir, 'images'))))

>>> sample size 100


In [None]:
import shutil
import os
shutil.rmtree('buildings')
os.rename('Vaihingen-sample', 'Vaihingen')

In [18]:
import shutil, os, random
import numpy as np
from PIL import Image
from torchvision.datasets import VOCSegmentation

voc_data = VOCSegmentation('./', download=True)
print('>>> all size', len(voc_data))
VOC_label_map = {
  1:'aeroplane',
  2:'bicycle',
  3:'bird',
  4:'boat',
  5:'bottle',
  6:'bus',
  7:'car',
  8:'cat',
  9:'chair',
  10:'cow',
  11:'diningtable',
  12:'dog',
  13:'horse',
  14:'motorbike',
  15:'person',
  16:'pottedplant',
  17:'sheep',
  18:'sofa',
  19:'train',
  20:'tvmonitor'
}

def sample_set_voc(voc_data, target_dir, sample_size_per_class=5):
  if os.path.isdir(target_dir):
    shutil.rmtree(target_dir)
  os.mkdir(target_dir)
  os.mkdir(os.path.join(target_dir, 'images'))
  os.mkdir(os.path.join(target_dir, 'segmentations'))

  freq_map = {idx: sample_size_per_class for idx in VOC_label_map.keys()}
  count = 0

  # Single-class samples
  print(">>> singl-class sampels")
  for idx in range(len(voc_data)):
    if len(freq_map) == 0:
      break
    img, seg = voc_data[idx]
    seg_arr = np.asarray(seg)
    label = np.unique(seg_arr)
    if len(label) != 3:
      continue
    label = label[1]
    if not label in freq_map.keys():
      continue
    freq = freq_map[label]
    if freq == 1:
      del freq_map[label]
    else:
      freq_map[label] -= 1
    print('save 1 {} >>>'.format(idx))
    img.save(os.path.join(target_dir, 'images', '{}.png'.format(idx)))
    seg.save(os.path.join(target_dir, 'segmentations', '{}.png'.format(idx)))

  if len(freq_map) != 0:
    print(freq_map)
  selected_list = os.listdir(os.path.join(target_dir, 'images'))
  print(">>> multi-class sampels")
  for idx in range(len(voc_data)):
    if len(freq_map) == 0:
      break
    file_name = '{}.png'.format(idx)
    if file_name in selected_list:
      continue
    img, seg = voc_data[idx]
    seg_arr = np.asarray(seg)
    label = np.unique(seg_arr)
    save = False
    for l in label:
      if l in freq_map.keys():
        save = True
        freq = freq_map[l]
        if freq == 1:
          del freq_map[l]
        else:
          freq_map[l] -= 1
    if save:
      print('save 2 {} >>>'.format(idx))
      img.save(os.path.join(target_dir, 'images', file_name))
      seg.save(os.path.join(target_dir, 'segmentations', file_name))

  sample_size = len(os.listdir(os.path.join(target_dir, 'images')))
  target_sample_size = len(VOC_label_map.keys()) * sample_size_per_class
  print(sample_size, target_sample_size)
  selected_list = os.listdir(os.path.join(target_dir, 'images'))
  if sample_size < target_sample_size:
    print(">>> the num per class constraint is unachievable")
    count = target_sample_size - sample_size
    for idx in range(len(voc_data)):
      if count == 0:
        return
      file_name = '{}.png'.format(idx)
      if file_name in selected_list:
        continue
      print('save 3 {} >>>'.format(idx))
      img, seg = voc_data[idx]
      img.save(os.path.join(target_dir, 'images', file_name))
      seg.save(os.path.join(target_dir, 'segmentations', file_name))
      count -= 1


target_dir = 'VOC-sample'
sample_set_voc(voc_data, target_dir, sample_size_per_class=25)# size 1464
print('>>> sample size', len(os.listdir(os.path.join(target_dir, 'images'))))

Using downloaded and verified file: ./VOCtrainval_11-May-2012.tar
Extracting ./VOCtrainval_11-May-2012.tar to ./
>>> all size 1464
>>> singl-class sampels
save 1 1 >>>
save 1 3 >>>
save 1 4 >>>
save 1 6 >>>
save 1 7 >>>
save 1 9 >>>
save 1 10 >>>
save 1 11 >>>
save 1 17 >>>
save 1 18 >>>
save 1 20 >>>
save 1 22 >>>
save 1 23 >>>
save 1 25 >>>
save 1 26 >>>
save 1 28 >>>
save 1 30 >>>
save 1 33 >>>
save 1 36 >>>
save 1 38 >>>
save 1 39 >>>
save 1 42 >>>
save 1 43 >>>
save 1 44 >>>
save 1 46 >>>
save 1 47 >>>
save 1 49 >>>
save 1 50 >>>
save 1 52 >>>
save 1 54 >>>
save 1 56 >>>
save 1 57 >>>
save 1 60 >>>
save 1 61 >>>
save 1 63 >>>
save 1 65 >>>
save 1 66 >>>
save 1 67 >>>
save 1 68 >>>
save 1 70 >>>
save 1 73 >>>
save 1 78 >>>
save 1 82 >>>
save 1 85 >>>
save 1 87 >>>
save 1 88 >>>
save 1 89 >>>
save 1 92 >>>
save 1 93 >>>
save 1 94 >>>
save 1 96 >>>
save 1 98 >>>
save 1 99 >>>
save 1 101 >>>
save 1 103 >>>
save 1 104 >>>
save 1 105 >>>
save 1 108 >>>
save 1 111 >>>
save 1 112 >>>
save

In [19]:
import shutil
import os
os.rename('VOC-sample', 'VOC2012')

In [None]:
# !mkdir dataset
# !wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username=$myusername&password=$mypassword&submit=Login' https://www.cityscapes-dataset.com/login/
# !wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1 -P dataset
# !wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3 -P dataset

# !mkdir -p CAP_augmentation/data/cityscapes
# !unzip -q -o dataset/gtFine_trainvaltest.zip -d CAP_augmentation/data/cityscapes
# !unzip -q -o dataset/leftImg8bit_trainvaltest.zip -d CAP_augmentation/data/cityscapes

# !wget https://github.com/shubham0204/Dataset_Archives/blob/master/cityscape_images.zip?raw=true -O cityscape_images.zip
# !unzip cityscape_images.zip

!wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username=yj373@cornell.edu&password=Yj373950911!&submit=Login' https://www.cityscapes-dataset.com/login/
!wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1
!wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3

--2024-05-18 03:33:46--  https://www.cityscapes-dataset.com/login/
Resolving www.cityscapes-dataset.com (www.cityscapes-dataset.com)... 139.19.217.8
Connecting to www.cityscapes-dataset.com (www.cityscapes-dataset.com)|139.19.217.8|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://www.cityscapes-dataset.com/downloads/ [following]
--2024-05-18 03:33:47--  https://www.cityscapes-dataset.com/downloads/
Reusing existing connection to www.cityscapes-dataset.com:443.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [text/html]
Saving to: ‘index.html’

index.html              [  <=>               ]  57.34K   189KB/s    in 0.3s    

2024-05-18 03:33:48 (189 KB/s) - ‘index.html’ saved [58715]

--2024-05-18 03:33:48--  https://www.cityscapes-dataset.com/file-handling/?packageID=1
Resolving www.cityscapes-dataset.com (www.cityscapes-dataset.com)... 139.19.217.8
Connecting to www.cityscapes-dataset.com (www.cityscapes-dataset.com)|139.19.217.8

In [None]:
!mkdir cityscape
!unzip -q -o gtFine_trainvaltest.zip -d cityscape
!unzip -q -o leftImg8bit_trainvaltest.zip -d cityscape

In [None]:
import os
import shutil, random
import numpy as np
from PIL import Image
from torchvision.datasets import Cityscapes

Cityscape_label_color_map = {
    (128, 64, 128): 'road', # flat
    (220, 20, 60): 'person', # human
    (70, 70, 70): 'building', # construction
    (250, 70, 30): 'traffic light', # object
    (107,142, 35): 'vegetation', # nature
    (0, 0, 142): 'car', # vehicle
    (0, 60, 100): 'bus', # vehicle
    (0, 80, 100): 'train', # vehicle
    (0, 0, 230): 'motorcycle', # vehicle
    (119, 11, 32): 'bicycle', #vehicle
}
Cityscape_label_map = {
    1: 'road', # flat
    2: 'person', # human
    3: 'building', # construction
    4: 'traffic light', # object
    5: 'vegetation', # nature
    6: 'car', # vehicle
    7: 'bus', # vehicle
    8: 'train', # vehicle
    9: 'motorcycle', # vehicle
    10: 'bicycle', #vehicle
}
Cityscape_label_map_reverse = {
    'road': 1, # flat
    'person': 2, # human
    'building': 3, # construction
    'traffic light': 4, # object
    'vegetation': 5, # nature
    'car': 6, # vehicle
    'bus': 7, # vehicle
    'train': 8, # vehicle
    'motorcycle': 9, # vehicle
    'bicycle': 10, #vehicle
}


def sample_set_cityscape(cityscape_data, target_dir, sample_size=50):
  if os.path.isdir(target_dir):
    shutil.rmtree(target_dir)
  os.mkdir(target_dir)
  os.mkdir(os.path.join(target_dir, 'images'))
  os.mkdir(os.path.join(target_dir, 'original_segmentations'))

  sample_indices = random.sample(range(len(cityscape_data)), sample_size)
  for idx in sample_indices:
    img, col = cityscape_data[idx]
    img.save(os.path.join(target_dir, 'images', '{}.png'.format(idx)))
    col.save(os.path.join(target_dir, 'original_segmentations', '{}.png'.format(idx)))

cityscape_data = Cityscapes('./cityscape', split='train', mode='fine',
                     target_type='color')
target_dir = 'Cityscape'
# sample_set_cityscape(cityscape_data, target_dir, sample_size_per_class=5)
sample_set_cityscape(cityscape_data, target_dir, sample_size=10)
print('>>> sample size', len(os.listdir(os.path.join(target_dir, 'images'))))

>>> sample size 10


In [None]:
import os
import nltk
import numpy as np
import shutil
import torch
from PIL import Image

Cityscape_rgb_map = {
    (128, 64,128): 'road', # flat
    (220, 20, 60): 'person', # human
    (70, 70, 70): 'building', # construction
    (250, 70, 30): 'traffic light', # object
    (107,142, 35): 'vegetation', # nature
    (0, 0, 142): 'car', # vehicle
    (0, 60, 100): 'bus', # vehicle
    (0, 80, 100): 'train', # vehicle
    (0, 0, 230): 'motorcycle', # vehicle
    (119, 11, 32): 'bicycle', #vehicle
}
Cityscape_int_map = {
    (128, 64,128): 1, # flat
    (220, 20, 60): 2, # human
    (70, 70, 70): 3, # construction
    (250, 70, 30): 4, # object
    (107,142, 35): 5, # nature
    (0, 0, 142): 6, # vehicle
    (0, 60, 100): 7, # vehicle
    (0, 80, 100): 8, # vehicle
    (0, 0, 230): 9, # vehicle
    (119, 11, 32): 10, #vehicle
}
seg_voc = Image.open('VOC2012/segmentations/1.png')
trans_seg_dir = 'Cityscape/segmentations'

if os.path.isdir(trans_seg_dir):
  shutil.rmtree(trans_seg_dir)
os.mkdir(trans_seg_dir)
seg_dir = 'Cityscape/original_segmentations'

for file in os.listdir(seg_dir):
  if not file.endswith('.png'):
    continue
  seg = Image.open(os.path.join(seg_dir, file)).convert('RGB')
  # display(seg)
  seg_arr = np.asarray(seg)
  h, w = seg_arr.shape[0], seg_arr.shape[1]
  seg_arr_ = np.zeros((h, w), dtype=np.uint8)
  for i in range(h):
    for j in range(w):
      value = tuple(seg_arr[i, j, :])
      if value in Cityscape_int_map.keys():
        seg_arr_[i, j] = Cityscape_int_map[value]
  # print('>>> ', file)
  trans = Image.fromarray(seg_arr_, mode='P')
  trans.putpalette(seg_voc.palette)
  trans.save(os.path.join(trans_seg_dir, file))
  # display(trans)
  # break

In [20]:
import os
import shutil
import numpy as np

from PIL import Image

Cityscape_label_map = {
    1: 'road', # flat
    2: 'person', # human
    3: 'building', # construction
    4: 'traffic light', # object
    5: 'vegetation', # nature
    6: 'car', # vehicle
    7: 'bus', # vehicle
    8: 'train', # vehicle
    9: 'motorcycle', # vehicle
    10: 'bicycle', #vehicle
}

VOC_label_map = {
  1:'aeroplane',
  2:'bicycle',
  3:'bird',
  4:'boat',
  5:'bottle',
  6:'bus',
  7:'car',
  8:'cat',
  9:'chair',
  10:'cow',
  11:'diningtable',
  12:'dog',
  13:'horse',
  14:'motorbike',
  15:'person',
  16:'pottedplant',
  17:'sheep',
  18:'sofa',
  19:'train',
  20:'tvmonitor'
}

Vaihingen_label_map = {
  1: 'building'
}

Kvasir_label_map = {
  1: 'tumor'
}

def generate_cls_seg(segmentation_dir, label_map, dataset="VOC2012"):
  cls_arr_dir = segmentation_dir.replace('segmentations', 'class_array')
  if os.path.isdir(cls_arr_dir):
    shutil.rmtree(cls_arr_dir)
  os.mkdir(cls_arr_dir)
  print('>>> ', segmentation_dir)
  for seg_file in os.listdir(segmentation_dir):
      seg_path = os.path.join(segmentation_dir, seg_file)

      if dataset == "Kvasir-SEG":
        seg = Image.open(seg_path).convert('L')
        seg_arr = np.asarray(seg)
        seg_arr = np.where(seg_arr > 200, 1, 0)
      else:
        seg = Image.open(seg_path)
        seg_arr = np.asarray(seg)
      # print(seg_arr.shape)
      seg_classes = np.unique(seg_arr)

      for cls in seg_classes:
        if cls in label_map.keys():
          seg_cls_arr = np.where(seg_arr == cls, 1, 0)
          cls_name = label_map[cls]
          with open(os.path.join(cls_arr_dir, '{}_{}.npy'.format(seg_file.split('.')[0], cls_name)), 'wb') as f:
            np.save(f, seg_cls_arr)

generate_cls_seg("VOC2012/segmentations", VOC_label_map, dataset="VOC2012")
# generate_cls_seg("Cityscape/segmentations", Cityscape_label_map, dataset="Cityscape")
# generate_cls_seg("Vaihingen/segmentations", Vaihingen_label_map, dataset="Vaihingen")
# generate_cls_seg("Kvasir-SEG/segmentations", Kvasir_label_map, dataset="Kvasir-SEG")


>>>  VOC2012/segmentations


In [21]:
import os
import nltk
import numpy as np
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration, BertModel, BertTokenizer, VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from torchvision import transforms
from scipy.spatial.distance import cosine
from nltk.stem import WordNetLemmatizer

device = "cuda:0"
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)

vit_gpt2_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)
vit_gpt2_feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
vit_gpt2_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')

model_name = 'bert-base-uncased'
bert_tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertModel.from_pretrained(model_name)


preprocessor_config.json:   0%|          | 0.00/445 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/527 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.60k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.88G [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.61k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/982M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/228 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/241 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/120 [00:00<?, ?B/s]

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [None]:
import os
import numpy as np
import json
from PIL import Image
VOC_label_map = {
  1:'aeroplane',
  2:'bicycle',
  3:'bird',
  4:'boat',
  5:'bottle',
  6:'bus',
  7:'car',
  8:'cat',
  9:'chair',
  10:'cow',
  11:'diningtable',
  12:'dog',
  13:'horse',
  14:'motorbike',
  15:'person',
  16:'pottedplant',
  17:'sheep',
  18:'sofa',
  19:'train',
  20:'tvmonitor'
}

Vaihingen_label_map = {
  1: 'building'
}

# Kvasir_label_map = {
#   1: 'polyp'
# }
Kvasir_label_map2 = {
  1: 'tumor'
}

Cityscape_label_map = {
  1: 'road', # flat
  2: 'person', # human
  3: 'building', # construction
  4: 'traffic light', # object
  5: 'vegetation', # nature
  6: 'car', # vehicle
  7: 'bus', # vehicle
  8: 'train', # vehicle
  9: 'motorcycle', # vehicle
  10: 'bicycle', #vehicle
}


def get_emb_by_idx(sentence, bert_tokenizer, bert_model, noun_dict):
  tokens = bert_tokenizer.tokenize(sentence)
  token_id = bert_tokenizer.convert_tokens_to_ids(tokens)
  input_ids = torch.tensor([token_id])
  with torch.no_grad():
    outputs = bert_model(input_ids)
    word_embeddings = outputs.last_hidden_state
  embs = []
  # for idx in indices:
  #   emb = word_embeddings.squeeze()[idx, :]
  #   embs.append(emb)
  for n in noun_dict.keys():
    emb_ = None
    for idx in noun_dict[n]:
      if emb_ is None:
        emb_ = word_embeddings.squeeze()[idx, :]
      else:
        emb_ += word_embeddings.squeeze()[idx, :]
    emb_ /= len(noun_dict[n])
    embs.append(emb_)
  return embs


def get_aug_cls(prompt, seg, label_map, thres=0.8):
  words = nltk.word_tokenize(prompt)
  tagged_words = nltk.tag.pos_tag(words)
  lemmatizer = WordNetLemmatizer()
  # nouns = []
  # noun_positions = []
  noun_dict = {}
  for i, (word, pos) in enumerate(tagged_words):
    if pos.startswith('N'):
      word = lemmatizer.lemmatize(word)
      if word in noun_dict.keys():
        noun_dict[word].append(i)
      else:
        noun_dict[word] = [i]

  print(noun_dict)
  seg_arr = np.asarray(seg).astype(np.uint8)
  cls_values = np.unique(seg_arr)
  embs = get_emb_by_idx(prompt, bert_tokenizer, bert_model, noun_dict)
  aug_labels = {}
  for cls in cls_values:
    if cls not in label_map.keys():
      continue
    cls_name = label_map[cls]
    print('>>> class name: {}'.format(cls_name))
    cls_aug = []
    for i, (noun, indices) in enumerate(noun_dict.items()):
      if cls_name == noun:
        continue
      prompt = prompt.replace(noun, cls_name)
      print(prompt)
      emb = get_emb_by_idx(prompt, bert_tokenizer, bert_model, {noun: indices})
      similarity = 1.0 - cosine(embs[i], emb[0])
      print('simlarity ({}, {}): {:.4f}, idx: {}'.format(cls_name, noun, similarity, indices))
      if similarity > thres:
        cls_aug.append(noun)
      prompt = prompt.replace(cls_name, noun)
    print("augemented class: ", cls_aug)
    aug_labels[cls_name] = cls_aug

  return aug_labels

# datasets = ["VOC2012", "Kvasir-SEG", "Vaihingen", "Cityscape"]
# label_maps = [VOC_label_map, Kvasir_label_map2, Vaihingen_label_map, Cityscape_label_map]
# datasets = ["Kvasir-SEG"]
# label_maps = [Kvasir_label_map2]
datasets = ["VOC2012"]
label_maps = [VOC_label_map]
thres = 0.85
for ds, label_map in zip(datasets, label_maps):
  image_dir = os.path.join(ds, "images")
  seg_dir = os.path.join(ds, "segmentations")
  ds_aug_labels = {}
  for file in os.listdir(image_dir):
    if file.startswith('.'):
      continue
    img_path = os.path.join(image_dir, file)
    seg_path = os.path.join(seg_dir, file)
    img = Image.open(img_path)
    seg = Image.open(seg_path)

    inputs = blip_processor(img, return_tensors="pt").to(device) # processor: Blip processor
    out = blip_model.generate(**inputs)
    prompt1 = blip_processor.decode(out[0], skip_special_tokens=True)
    print('** BLIP prompt: ' + prompt1)
    aug = get_aug_cls(prompt1, seg, label_map, thres=thres)
    ds_aug_labels[img_path] = aug

  with open(os.path.join(ds, "aug_label_blip_bert_{}.json".format(thres)), "w") as outfile:
    json.dump(ds_aug_labels, outfile)


** BLIP prompt: arafed scooter parked in front of a door in a stone building
{'scooter': [1], 'front': [4], 'door': [7], 'stone': [10], 'building': [11]}
>>> class name: motorbike
arafed motorbike parked in front of a door in a stone building
simlarity (motorbike, scooter): 0.8514, idx: [1]
arafed scooter parked in motorbike of a door in a stone building
simlarity (motorbike, front): 0.9439, idx: [4]
arafed scooter parked in front of a motorbike in a stone building
simlarity (motorbike, door): 0.8806, idx: [7]
arafed scooter parked in front of a door in a motorbike building
simlarity (motorbike, stone): 0.9253, idx: [10]
arafed scooter parked in front of a door in a stone motorbike
simlarity (motorbike, building): 0.7750, idx: [11]
augemented class:  ['scooter', 'front', 'door', 'stone']
** BLIP prompt: arafed cargo ship in a canal with a bridge and a train
{'ship': [2], 'canal': [5], 'bridge': [8], 'train': [11]}
>>> class name: boat
arafed cargo boat in a canal with a bridge and a tr

In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"
!zip -r VOC2012.zip VOC2012
!zip -r Cityscape.zip Cityscape
!zip -r Vaihingen.zip Vaihingen
!zip -r Kvasir-SEG.zip Kvasir-SEG
from google.colab import files
files.download("VOC2012.zip")
files.download("Cityscape.zip")
files.download("Vaihingen.zip")
files.download("Kvasir-SEG.zip")