In [None]:
from dataclasses import dataclass, field
from typing import Optional
import sys
sys.path.append('/scratch/2023-fall-sp-le/langseg')

import torch
from torch import nn
from torch.nn import functional as F
torch.backends.cuda.matmul.allow_tf32 = True
from transformers import CLIPProcessor, CLIPModel
from diffusers import StableDiffusionPipeline

import cv2
import numpy as np
from PIL import Image
from  matplotlib import pyplot as plt
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.transforms import InterpolationMode
from methods.prompt_engineering import extract_class_embeddings, extract_clip_text_embeddings
BICUBIC = InterpolationMode.BICUBIC

from datasets.coco_stuff import coco_stuff_categories
from datasets.cityscapes import cat_to_label_id as cityscapes_cats
from datasets.utils import get_dataset
from utils.metrics import RunningScore
from utils.plotting import *

from methods.diffusion_patch import *
from methods.diffusion import *
from methods.diffusion_seg import *
from methods.diffusion_utils import *
from methods.text_embeddings import *
from methods.diffseg import get_semantics, get_pred_mask

%load_ext autoreload
%autoreload 2

In [None]:
# @dataclass
# class SegmentationConfig:
#     dir_dataset: str = field(
#         default="/sinergia/ozaydin/segment/STEGO-master/data/cocostuff", metadata={"help": "dir dataset"}
#     )
#     dataset_name: str = field(
#         default="voc2012", metadata={"help": "for get_dataset"}
#     )
#     split: str = field(
#         default="train", metadata={"help": "which split to use"}
#     )
#     resolution: int = field(
#         default=512, metadata={"help": "resolution of the images, e.g, 512, 768, 1024"}
#     )
#     mask_res: int = field(
#         default=320, metadata={"help": "resolution of the masks, e.g, 64, 320, 512"}
#     )
#     dense_clip_arch: str = field(
#         default="RN50x16", metadata={"help": "not used in cocostuff"}
#     )

# args = SegmentationConfig()

@dataclass
class SegmentationConfig:
    dir_dataset: str = field(
        default="/scratch/2023-fall-sp-le/data/VOCdevkit/VOC2010", metadata={"help": "dir dataset"}
    )
    dataset_name: str = field(
        default="pascal_context", metadata={"help": "for get_dataset"}
    )
    split: str = field(
        default="val", metadata={"help": "which split to use"}
    )
    resolution: int = field(
        default=512, metadata={"help": "resolution of the images, e.g, 512, 768, 1024"}
    )
    mask_res: int = field(
        default=320, metadata={"help": "resolution of the masks, e.g, 64, 320, 512"}
    )
    dense_clip_arch: str = field(
        default="RN50x16", metadata={"help": "not used in cocostuff"}
    )

args = SegmentationConfig()

In [None]:
dataset, categories, palette = get_dataset(
    dir_dataset=args.dir_dataset,
    dataset_name=args.dataset_name,
    split=args.split,
    resolution=args.resolution,
    mask_res=args.mask_res,
    dense_clip_arch=args.dense_clip_arch
)

if args.dataset_name == "coco_stuff":
    label_id_to_cat = coco_stuff_categories
    cat_to_label_id = {v: i for i, v in enumerate(label_id_to_cat)}
elif args.dataset_name == "cityscapes":
    cat_to_label_id = cityscapes_cats
    label_id_to_cat = {i: c for c, i in cat_to_label_id.items()}
elif args.dataset_name == "voc2012" or args.dataset_name == "pascal_context":
    label_id_to_cat = categories
    cat_to_label_id = {v: i for i, v in enumerate(label_id_to_cat)}

In [None]:
idxx = 4
val_img = dataset[idxx]["img"].cpu().numpy()
val_gt = dataset[idxx]["gt"].cpu().numpy()
lab_ids = sorted(list(np.unique(val_gt)))
lab_ids = np.array(lab_ids[1:])
val_pil_img = render_results(val_img, val_gt, palette)
_ = get_legends(lab_ids, palette, label_id_to_cat, is_voc2012 = args.dataset_name == "voc2012")

val_gt = dataset[idxx]["gt"].unsqueeze(0)
val_labels = sorted(np.unique(val_gt))
val_labels = [l for l in val_labels if l != -1] # don't process label -1 (ignored unlabelled pixels)
val_labels = [label_id_to_cat[c] for c in val_labels]
val_labels = [l for l in val_labels if l != "background"] # don't feed "background" as text input

In [None]:
import torch
from methods import gem
import requests
from PIL import Image

# model_name = 'ViT-B/16'  # 'ViT-B-16-quickgelu'
# pretrained = 'openai'  # 'metaclip_400m'
model_name = 'ViT-B/16-quickgelu'
# model_name = 'ViT-L/14-quickgelu'
pretrained = 'metaclip_400m'
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# init model and image transform
gem_model = gem.create_gem_model(model_name=model_name,
                                 pretrained=pretrained, 
                                 device=device)

In [None]:
# load image and text
preprocess = gem.get_gem_img_transform(img_size=(448, 448))
image = preprocess(Image.open(dataset[idxx]["p_img"])).unsqueeze(0).to(device)
class_names = label_id_to_cat#[1:]
with torch.no_grad():
    logits = gem_model(image, class_names, normalize=False, return_ori=False)  # [1, num_class, W, H]

In [None]:
normed_logits = gem_model.min_max(logits)
gem.visualize(image, class_names, normed_logits)  # (optional visualization)

In [None]:
def run_gem_voc():
    from tqdm import tqdm
    preprocess = gem.get_gem_img_transform(img_size=(448, 448))
    running_score = RunningScore(len(label_id_to_cat))
    pbar = tqdm(range(len(dataset)))
    # qs = []
    for idx in pbar:
        val_img = dataset[idx]["img"].permute(1,2,0)[None,...].numpy()
        val_gt = dataset[idx]["gt"].unsqueeze(0)
        val_labels = sorted(np.unique(val_gt))
        val_labels = [l for l in val_labels if l != -1] # don't process label -1 (ignored unlabelled pixels)
        val_labels = [label_id_to_cat[c] for c in val_labels]
        val_labels = [l for l in val_labels if l != "background"] # don't feed "background" as text input
        # if len(val_labels) == 1: continue

        # load image and text
        image = preprocess(Image.open(dataset[idx]["p_img"])).unsqueeze(0).to(device)
        class_names = label_id_to_cat#[1:]

        with torch.no_grad():
            logits = gem_model(image, class_names, output_size=320, normalize=False)  # [1, num_class, W, H]

        pred = logits.argmax(dim=1)
        
        # probs = (logits * 10).softmax(dim=1) # without softmax -> person class VERY bad
        # max_probs = probs.max(dim=1)[0]
        # bg_probs = 1 - max_probs - 0.8
        # bg_probs = torch.where(bg_probs > 0, bg_probs, 0)
        # # bg_probs = probs.max() - probs.max(dim=1)[0]
        # pred = torch.cat([bg_probs[None], probs], dim=1)
        # pred = pred.argmax(dim=1)

        # probs = (logits * 100).softmax(dim=1)
        # ppls = (-probs * probs.log()).sum(dim=1).exp() # perplexity, aka effective no. outcomes
        # # median = entropies.quantile(0.5).item()
        # # mean = entropies.mean().item()
        # # thresh = 1.5 * median/mean
        # # thresh = np.maximum(thresh, 1.1)
        # thresh = 1.5
        # bg_mask = ppls > thresh
        # # bg_mask = val_gt == 0
        # pred = logits.argmax(dim=1) + 1
        # pred[bg_mask] = 0

        # concept_indices = [[i] for i in range(21)]
        # nouns = label_id_to_cat
        # probs = (logits*10).softmax(dim=1)
        # bg_probs = torch.full_like(probs[:, 0], -100)
        # pred = torch.cat([bg_probs[None], probs], dim=1)
        # x_weight = pred.permute(0, 2, 3, 1).squeeze(0)
        # pred_mask = val_gt.squeeze(0)
        # label_to_mask = get_semantics(pred_mask, x_weight, concept_indices, nouns, voting="mean", background=False)
        # pred_mask = get_pred_mask(pred_mask, label_to_mask, cat_to_label_id)
        # pred = torch.from_numpy(pred_mask[None])

        # concept_indices = [[i] for i in range(20)]
        # nouns = label_id_to_cat[1:]
        # probs = (logits*100).softmax(dim=1)
        # x_weight = probs.permute(0, 2, 3, 1).squeeze(0)
        # pred_mask = val_gt.squeeze(0)
        # label_to_mask = get_semantics(pred_mask, x_weight, concept_indices, nouns, voting="mean", background=False)
        # pred_mask = get_pred_mask(pred_mask, label_to_mask, cat_to_label_id)
        # pred = torch.from_numpy(pred_mask)[None]
        # pred[val_gt == 0] = 0
        
        # probs = (logits * 100).softmax(dim=1)
        # pred = probs.argmax(dim=1) + 1
        # max_prob = probs.max(dim=1)[0]
        # bg_ratio = 0.9 #1-logits.softmax(dim=1).max()/probs.max()
        # bg_thresh = bg_ratio*probs.max()
        # pred[max_prob < bg_thresh] = 0

        # eliminating false positives
        # pred_classes = torch.unique(pred)
        # for i, c in enumerate(label_id_to_cat):
        #     if i in pred_classes and c not in ["background"]+val_labels:
        #         pred[pred == i] = 0

        running_score.update(val_gt.cpu().numpy(), pred.cpu().numpy())
        metrics, cls_iou = running_score.get_scores()
        miou = metrics["Mean IoU"]

        pbar.set_description(
            f"mIoU {miou:.3f}"
        )
    return running_score

running_score = run_gem_voc()

In [None]:
metrics, cls_iou = running_score.get_scores() # original results
print({k: f"{v*100:.1f}" for k, v in metrics.items()})
print({label_id_to_cat[i]: f"{v*100:.1f}" for i, v in cls_iou.items()})