In [1]:
from dataclasses import dataclass, field
from PIL import Image
import os, math, warnings
import sys
sys.path.append('/scratch/2023-fall-sp-le/langseg')

import torch
torch.backends.cuda.matmul.allow_tf32 = True
from diffusers import StableDiffusionPipeline
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from datasets.coco_stuff import coco_stuff_categories
from datasets.cityscapes import cat_to_label_id as cityscapes_cats
from datasets.utils import get_dataset
from utils.metrics import RunningScore
from methods.diffusion_patch import *
from methods.diffusion import *
from methods.diffusion_utils import *
from methods.text_embeddings import *
from methods.diffusion_seg import *
from methods.grabcut import *
from methods.pipeline_patch import patch_sd_call, patch_sdxl_call, patch_sd_prepare_latents
from methods.multilabel_classifiers import CLIPMultilabelClassifier, BLIPMultilabelClassifier
#pip install --upgrade diffusers transformers nltk accelerate torch_kmeans igraph peft compel torchvision ftfy open_clip_torch einops
%load_ext autoreload
%autoreload 2

In [2]:
@dataclass
class SegmentationConfig:
    dir_dataset: str = field(
        default="/sinergia/ozaydin/segment/STEGO-master/data/cityscapes", metadata={"help": "dir dataset"}
    )
    dataset_name: str = field(
        default="cityscapes", metadata={"help": "for get_dataset"}
    )
    split: str = field(
        default="val", metadata={"help": "which split to use"}
    )
    resolution: int = field(
        default=512, metadata={"help": "resolution of the images, e.g, 512, 768, 1024"}
    )
    mask_res: int = field(
        default=320, metadata={"help": "resolution of the masks, e.g, 64, 320, 512"}
    )
    dense_clip_arch: str = field(
        default="RN50x16", metadata={"help": "not used in cocostuff"}
    )

args = SegmentationConfig()

In [3]:
dataset, categories, palette = get_dataset(
    dir_dataset=args.dir_dataset,
    dataset_name=args.dataset_name,
    split=args.split,
    resolution=args.resolution,
    mask_res=args.mask_res,
    dense_clip_arch=args.dense_clip_arch
)
label_id_to_cat = categories
cat_to_label_id = {v: i for i, v in enumerate(label_id_to_cat)}

In [4]:
# model_id = "stabilityai/stable-diffusion-2-1"
# model_id = "CompVis/stable-diffusion-v1-4"
model_id = "runwayml/stable-diffusion-v1-5"
# model_id = "stabilityai/stable-diffusion-xl-base-1.0"

pipe = StableDiffusionPipeline.from_pretrained(
    model_id, torch_dtype=torch.float16,
    use_safetensors=True, variant="fp16",
    # device_map="auto"
)

# unet_id = "mhdang/dpo-sd1.5-text2image-v1"
# unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16)
# pipe.unet = unet

# load_model_weights(pipe, './TexForce/lora_weights/sd15_refl/', 'unet+lora')
# load_model_weights(pipe, './TexForce/lora_weights/sd15_texforce/', 'text+lora')
# pipe.forward = patch_sd_call(pipe)
pipe = pipe.to("cuda")
# pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.4, b2=1.6)
configure_ldm(pipe)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.


In [5]:
handles = {}
attention_layers_to_use = ATTENTION_LAYERS
attention_store = AttentionStore(low_resource=False, no_uncond=True, layer_keys=attention_layers_to_use)
attention_store.num_att_layers = len(attention_layers_to_use)
handles = register_attention_hooks(pipe.unet, attention_store, attention_layers_to_use, handles)

In [6]:
def run_exp():
    running_score_0 = RunningScore(len(coco_stuff_categories))
    running_score_1 = RunningScore(len(coco_stuff_categories))
    running_score_2 = RunningScore(len(coco_stuff_categories))
    remapping = torch.tensor(list(coco_stuff_171_to_27.keys())).cuda(), torch.tensor(list(coco_stuff_171_to_27.values())).cuda()
    pbar = tqdm(range(len(dataset)))
    for idx in pbar:
        val_img = dataset[idx]["img"].permute(1,2,0)[None,...].numpy()
        val_gt = dataset[idx]["gt"].unsqueeze(0)
        val_labels = sorted(np.unique(val_gt))
        val_labels = [l for l in val_labels if l != -1] # don't process label -1 (ignored unlabelled pixels)
        val_labels = [label_id_to_cat[c] for c in val_labels]
        val_labels = [l for l in val_labels if l != "background"] # don't feed "background" as text input

        # image, y_true = get_image_and_labels(idx)
        # y_pred, _ = clip_classifier(image, choice=(8,8), clf_thresh=0.5)
        # y_pred = y_pred.cpu().numpy()
        # try:
        #     val_labels = get_pred_label_names(y_pred, label_id_to_cat[1:])
        # except AssertionError:
        #     image_indices_failed_clf.append(idx)
        #     continue

        # text_embeds, concept_ind, _ = get_text_embeddings(pipe.tokenizer, pipe.text_encoder, val_labels, label_id_to_cat)
        # text_embeds, concept_ind, _ = get_contextualized_text_embeddings(pipe.tokenizer, pipe.text_encoder, val_labels)
        text_embeds, concept_ind, concept_indices, _ = get_txt_embeddings(
            pipe.tokenizer, pipe.text_encoder, val_labels, label_id_to_cat, cat_to_label_id,
            use_compel=False
        )

        val_img = val_img.repeat(len(text_embeds), axis=0)
        attention_store.reset()

        # latents = image2latent(pipe.vae, val_img, normalize=False)
        # latents, _, _ = get_noisy_latents(pipe.scheduler, latents)
        # pipe.forward(prompt_embeds=text_embeds, latents=latents, guidance_scale=0)

        training_step(pipe, text_embeds, val_img, attention_store, no_uncond=True, normalize=False, low_resource=False)

        ca, sa = get_attention_maps(
            attention_store.get_average_attention(),
            batch_size=1,
            label_indices=concept_indices,
            output_size=64,
            average_layers=True,
            apply_softmax=True,
            softmax_dim=-1,
            simple_average=False
        )
        agg_map = get_agg_map(ca, sa, walk_len=1, beta=1, minmax_norm=False)

        pred0 = get_random_walk_mask(
            agg_map, cat_to_label_id,
            concept_ind, val_labels, args.mask_res
        ).long()
        
        pred1 = get_specclust_mask(
            agg_map, sa, cat_to_label_id,
            concept_ind, val_labels,
            output_size=args.mask_res
        ).long()

        pred2 = diffseg(sa, out_res=args.mask_res, refine=True, kl_thresh=0.8)
        agg_map = upscale_attn(agg_map, 320, is_cross=True)[..., concept_ind]
        pred2 = label_clusters(pred2, agg_map, reshape_logits=False)
        for c in pred2.unique().tolist():        
            pred2[pred2 == c] = cat_to_label_id[val_labels[c]]

        running_score_0.update(val_gt.cpu().numpy(), pred0.cpu().numpy())
        metrics_0, cls_iou_0 = running_score_0.get_scores()
        miou_0 = metrics_0["Mean IoU"]

        running_score_1.update(val_gt.cpu().numpy(), pred1.cpu().numpy())
        metrics_1, cls_iou_1 = running_score_1.get_scores()
        miou_1 = metrics_1["Mean IoU"]

        running_score_2.update(val_gt.cpu().numpy(), pred2.cpu().numpy())
        metrics_2, cls_iou_2 = running_score_2.get_scores()
        miou_2 = metrics_2["Mean IoU"]


    #     pred2 = DiffusionGraphCut(
    #         agg_map.cpu(),
    #         sa.cpu(),
    #         concept_ind, val_labels, cat_to_label_id
    #     )(args.mask_res)[None]
    #     pred2 = pred2.cpu().numpy()
    #     running_score_2.update(val_gt.cpu().numpy(), pred2)
    #     metrics_2, cls_iou_2 = running_score_2.get_scores()
    #     miou_2 = metrics_2["Mean IoU"]

        pbar.set_description(
            f"mIoU_0 {miou_0:.3f} | "
            f"mIoU_1 {miou_1:.3f} |"
            f"mIoU_2 {miou_2:.3f}"
        )
    return running_score_0, running_score_1, running_score_2

with np.errstate(divide='ignore', invalid='ignore'):
    rs = run_exp()

  acc_cls = np.diag(hist) / hist.sum(axis=1)
  iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
mIoU_0 0.112 | : 100%|██████████| 500/500 [03:04<00:00,  2.72it/s]


In [7]:
for i in range(len(rs)):
    metrics, cls_iou = rs[i].get_scores() # original results
    print({k: f"{v*100:.1f}" for k, v in metrics.items()})
    print({label_id_to_cat[i]: f"{v*100:.1f}" for i, v in cls_iou.items()})
    print("#" * 100)

{'Pixel Acc': '31.7', 'Mean Acc': '41.9', 'FreqW Acc': '24.8', 'Mean IoU': '11.2'}
{'road': '9.1', 'sidewalk': '10.6', 'parking lot': '2.3', 'rail track': '15.3', 'building': '35.0', 'wall': '6.2', 'fence': '11.6', 'guard rail': '4.8', 'bridge': '7.7', 'tunnel': 'nan', 'pole': '5.0', 'polegroup': '0.0', 'traffic light': '0.6', 'traffic sign': '1.5', 'vegetation': '55.4', 'terrain': '0.6', 'sky': '19.3', 'person': '5.1', 'rider': '0.2', 'car': '39.7', 'truck': '17.6', 'bus': '15.3', 'caravan': '2.2', 'trailer': '8.4', 'train': '11.6', 'motorcycle': '0.7', 'bicycle': '6.6'}


val results

{'Pixel Acc': '31.7', 'Mean Acc': '41.9', 'FreqW Acc': '24.8', 'Mean IoU': '11.2'}

{'road': '9.1', 'sidewalk': '10.6', 'parking lot': '2.3', 'rail track': '15.3', 'building': '35.0', 'wall': '6.2', 'fence': '11.6', 'guard rail': '4.8', 'bridge': '7.7', 'tunnel': 'nan', 'pole': '5.0', 'polegroup': '0.0', 'traffic light': '0.6', 'traffic sign': '1.5', 'vegetation': '55.4', 'terrain': '0.6', 'sky': '19.3', 'person': '5.1', 'rider': '0.2', 'car': '39.7', 'truck': '17.6', 'bus': '15.3', 'caravan': '2.2', 'trailer': '8.4', 'train': '11.6', 'motorcycle': '0.7', 'bicycle': '6.6'}