In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import timm
import torch
import argparse
import matplotlib.pyplot as plt
import numpy as np
import math
import torch.nn.functional as F
from PIL import Image
from mpl_toolkits.axes_grid1 import make_axes_locatable
from core.configs import cfg
from core.train_learners import Test
from core.datasets.build import transform

In [None]:
def parse_args(args_str: str = None):
    parser = argparse.ArgumentParser(description="Active Domain Adaptive Semantic Segmentation Training")
    parser.add_argument("-cfg",
                        "--config-file",
                        default="",
                        metavar="FILE",
                        help="path to config file",
                        type=str)
    parser.add_argument("--proctitle",
                        type=str,
                        default="HALO",
                        help="allow a process to change its title", )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER
    )

    args_list = args_str.split() if args_str else None
    args = parser.parse_args(args_list)

    if args.opts is not None and args.opts != []:
        args.opts[-1] = args.opts[-1].strip('\r\n')

    cfg.set_new_allowed(True)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    return args

In [None]:
args = parse_args("-cfg configs/gtav/test.yaml")
learner = Test(cfg)

In [None]:
w, h = cfg.INPUT.INPUT_SIZE_TEST
trans = transform.Compose([
    transform.Resize((h, w), resize_label=False),
    transform.ToTensor(),
    transform.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=cfg.INPUT.TO_BGR255)
])

In [None]:
img_path = "datasets/cityscapes/leftImg8bit/train/aachen/aachen_000133_000019_leftImg8bit.png"
img = Image.open(img_path).convert('RGB')
plt.imshow(img)
plt.show()
img = trans(img, torch.randn(1, h, w))[0]
print(img.shape)

In [None]:
out = learner.forward(img.unsqueeze(0))

In [None]:
output, decoder_out = out
print("output.shape: ", output.shape)
print("decoder_out.shape: ", decoder_out.shape)
size = output.shape[-2:]
print("size: ", size)

In [None]:
from core.active.build import select_pixels_to_label
from core.active.floating_region import FloatingRegionScore

per_region_pixels = (2 * cfg.ACTIVE.RADIUS_K + 1) ** 2
active_radius = cfg.ACTIVE.RADIUS_K
mask_radius = cfg.ACTIVE.MASK_RADIUS_K
active_ratio = cfg.ACTIVE.RATIO / len(cfg.ACTIVE.SELECT_ITER)
uncertainty_type = cfg.ACTIVE.UNCERTAINTY
purity_type = cfg.ACTIVE.PURITY
K = cfg.ACTIVE.K
num_pixel_cur = size[0] * size[1]

decoder_out = F.interpolate(decoder_out, size=size, mode='bilinear', align_corners=True)

floating_region_score = FloatingRegionScore(
    in_channels=19, size=2*active_radius+1, purity_type=purity_type, K=K)

score, _, _ = floating_region_score(
    output, decoder_out=decoder_out, normalize=cfg.ACTIVE.NORMALIZE,
    unc_type=uncertainty_type, pur_type=purity_type)

# active_regions = math.ceil(num_pixel_cur * active_ratio / per_region_pixels)
# score, active, selected, active_mask = select_pixels_to_label(
#     score, active_regions, active_radius, mask_radius,
#     active, selected, active_mask, ground_truth
#     )


In [None]:
score.shape

In [None]:
# Calculate the number of elements to be set to True (top 5%)
num_elements = score.numel()
num_true_elements = int(num_elements * 0.05)

# Flatten the tensor and sort it in descending order
flattened = score.flatten()
sorted_tensor, _ = torch.sort(flattened, descending=True)

# Get the threshold value for the top 5%
threshold = sorted_tensor[num_true_elements]

# Create a new tensor with the same shape as the original tensor
active_mask = torch.zeros_like(score, dtype=torch.bool)

# Set elements to True if they are greater than or equal to the threshold
active_mask[score >= threshold] = True

In [None]:
def select_pixels_to_label(score, active_regions, active_radius, mask_radius,
                           active, selected, active_mask, ground_truth):
    for pixel in range(active_regions):
        values, indices_h = torch.max(score, dim=0)
        max_value, indices_w = torch.max(values, dim=0)
        if max_value == -float('inf'):
            break
        w = indices_w.item()
        h = indices_h[w].item()

        active_start_w = w - active_radius if w - active_radius >= 0 else 0
        active_start_h = h - active_radius if h - active_radius >= 0 else 0
        active_end_w = w + active_radius + 1
        active_end_h = h + active_radius + 1

        mask_start_w = w - mask_radius if w - mask_radius >= 0 else 0
        mask_start_h = h - mask_radius if h - mask_radius >= 0 else 0
        mask_end_w = w + mask_radius + 1
        mask_end_h = h + mask_radius + 1

        # mask out
        score[mask_start_h:mask_end_h,
              mask_start_w:mask_end_w] = -float('inf')
        active[mask_start_h:mask_end_h,
               mask_start_w:mask_end_w] = True
        selected[active_start_h:active_end_h,
                 active_start_w:active_end_w] = True
        # active sampling
        active_mask[active_start_h:active_end_h, active_start_w:active_end_w] = \
            ground_truth[active_start_h:active_end_h,
                         active_start_w:active_end_w]

    return score, active, selected, active_mask

In [None]:
values, indices_h = torch.max(score, dim=0)
max_value, indices_w = torch.max(values, dim=0)
w = indices_w.item()
h = indices_h[w].item()

active_start_w = w - active_radius if w - active_radius >= 0 else 0
active_start_h = h - active_radius if h - active_radius >= 0 else 0
active_end_w = w + active_radius + 1
active_end_h = h + active_radius + 1

mask_start_w = w - mask_radius if w - mask_radius >= 0 else 0
mask_start_h = h - mask_radius if h - mask_radius >= 0 else 0
mask_end_w = w + mask_radius + 1
mask_end_h = h + mask_radius + 1

In [None]:
max_value

In [None]:
CITYSCAPES_MEAN = torch.Tensor(
    [123.675, 116.28, 103.53]).reshape(1, 1, 3).numpy()
CITYSCAPES_STD = torch.Tensor([58.395, 57.12, 57.375]).reshape(1, 1, 3).numpy()

img_np = img.permute(1, 2, 0).detach().numpy()
img_np = (img_np * CITYSCAPES_STD + CITYSCAPES_MEAN).astype(np.uint8)
score_np = score.detach().numpy()
active_mask_np = active_mask.detach().numpy()

In [None]:
import io

cmap1='gray'
cmap2='viridis'
alpha=0.7

fig, axes = plt.subplots(3, 1, constrained_layout=True, figsize=(12, 12), dpi=300)

title = "HALO pixel selection"
plt.suptitle(title, fontsize=16)

for ax in axes:
    ax.axis('off')

axes[0].set_title('Original image')
axes[0].imshow(img_np)

axes[1].set_title('Pixels score')
axes[1].imshow(img_np, cmap=cmap1)
# im_score = axes[1].imshow(score_np,  cmap=cmap2, alpha=alpha)
axes[1].imshow(score_np, cmap=cmap2, alpha=alpha)
# divider = make_axes_locatable(axes[1])
# cax = divider.append_axes("right", size="20%", pad=0.05)
# plt.colorbar(im_score, cax=cax, location='right')

axes[2].set_title('Selected Pixels (top 5%)')
axes[2].imshow(img_np, cmap=cmap1)
axes[2].imshow(active_mask_np, cmap='autumn', alpha=0.3)

img_buf = io.BytesIO()
plt.savefig(img_buf, format='png', bbox_inches='tight', pad_inches=0, dpi=300)
plt.close()

im = Image.open(img_buf).convert('RGB')
plt.axis('off')
plt.imshow(im)

# im_np = np.array(im)
# plt.imshow(im_np)