In [None]:
import os

os.getcwd()
os.chdir("../../")
os.getcwd()

In [None]:
import sys
import argparse
import os

sys.argv = ["view", "--config", "config/single_task_object_detection.yaml"]

parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True, help="Path to the config file")
args = parser.parse_args()

print(args.config)

In [None]:
from dataloader import VOC08Attr
from torchvision.transforms import transforms
from config_experiments import config
import torch
from torch.utils.data import DataLoader

import numpy as np

In [None]:
transform_train = transforms.Compose(
    [
        transforms.Resize(size=600, max_size=1000),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=config["transform"]["mean"], std=config["transform"]["std"]
        ),
    ]
)

In [None]:
train_data = VOC08Attr(train=True, transform=transform_train)
train_dataloader = DataLoader(
    train_data,
    batch_size=config["preprocessing"]["n_images"],
    collate_fn=train_data.collate_fn,
    shuffle=False,
)

In [None]:
import torch
import torchvision.transforms.functional as F
from torchvision.utils import draw_bounding_boxes
import matplotlib.pyplot as plt
import numpy as np


def show(imgs):
    plt.rcParams["savefig.bbox"] = "tight"
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)  # (H, W)
        axs[0, i].imshow(np.asarray(img))  # (W, H, 3)
        # axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) # uncomment to remove axis in plot


def show_bbox_with_transform(image, box, mean, std, labels=None, color="white"):
    image = np.array(image)
    for channel in range(3):
        image[channel] = image[channel] * std[channel] + mean[channel]
    image = np.clip(image, 0, 1)
    image = (image * 255).astype(np.uint8)
    image = torch.from_numpy(image)
    show(draw_bounding_boxes(image, box, colors=color, labels=labels, width=2))

In [None]:
for i, (batch_images, rois, classes, offsets, attrs, indices_batch) in enumerate(
    train_dataloader
):
    # show(batch_images[0])
    # show(batch_images[1])
    # print(rois.shape)
    print(rois[indices_batch.squeeze(-1) == 0])
    show_bbox_with_transform(
        batch_images[0],
        rois[indices_batch.squeeze(-1) == 0],
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225],
        labels=None,
        color="white",
    )
    print(rois[indices_batch.squeeze(-1) == 1])
    show_bbox_with_transform(
        batch_images[1],
        rois[indices_batch.squeeze(-1) == 1],
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225],
        labels=None,
        color="black",
    )

    break

In [None]:
def get_normalize_values_target_class(data_loader):
    offsets_by_class = {i: [] for i in range(1, config["global"]["num_classes"] + 1)}
    for i, (image, train_roi, train_cls, train_offset, _, indices_batch) in enumerate(
        data_loader
    ):
        for cls, offset in zip(train_cls, train_offset):
            if cls.item() in offsets_by_class and cls.item() != 0:
                offsets_by_class[cls.item()].append(offset)

    mean_std_by_class = {}
    for cls, offsets in offsets_by_class.items():
        offsets_tensor = torch.stack(offsets)
        mean = torch.mean(offsets_tensor, dim=0)
        std = torch.std(offsets_tensor, dim=0)
        mean_std_by_class[cls] = {
            "mean": mean.tolist(),
            "std": std.tolist(),
        }

    return mean_std_by_class

In [None]:
mean_std_by_class = get_normalize_values_target_class(train_dataloader)

In [None]:
import json

with open(
    os.getcwd()
    + "/src/single_task_object_detection/"
    + "target_mean_std_by_class.yaml",
    "w",
) as f:

    json.dump(mean_std_by_class, f)

In [None]:
import yaml

with open(
    os.getcwd()
    + "/src/single_task_object_detection/"
    + "target_mean_std_by_class.yaml",
    "r",
) as f:
    mean_std_by_class = yaml.safe_load(f)

    print(mean_std_by_class)