In [None]:
import sys
sys.path.append('..')

import torch
import matplotlib.pyplot as plt
from torchvision.transforms import v2 as T
import numpy as np

from datasets.anorak import ANORAK
from training.tiler import GridPadTiler

In [None]:
# Set up data module to load validation data
data_module = ANORAK(
    root="/home/valentin/workspaces/benchmark-vfm-ss/data/ANORAK_10x",
    devices=1,
    num_workers=0,
    batch_size=128,
    img_size=(448, 448),
    num_classes=7,
    num_metrics=1
)

# Setup the data module
data_module.setup()

In [None]:
# Get a sample from validation set
dataloader = data_module.val_dataloader()

In [None]:
batch_idx = 3
for i, sample_batch in enumerate(dataloader):
    if i == batch_idx:
        break


In [None]:
images, targets = sample_batch

In [None]:
images[0].shape

In [None]:
tiler = GridPadTiler(448, 224, weighted_blend=True)

In [None]:
crops, origins, img_sizes = tiler.window(images)
crops = crops.float()
images_stitches = tiler.stitch(crops, origins, img_sizes)

In [None]:
batch_idx = 11

In [None]:
images_stitches[batch_idx].shape

In [None]:
targets[batch_idx]["masks"].shape

In [None]:
plt.imshow(images[batch_idx].permute(1, 2, 0).cpu().numpy())

In [None]:
plt.imshow(targets[batch_idx]["masks"][0,...].cpu().numpy())

In [None]:
plt.imshow(images_stitches[batch_idx].permute(1, 2, 0).cpu().numpy().astype(np.uint8))

In [None]:
targets[0]["labels"]

In [None]:
n = len(images)
cols = 4  # number of images per row
rows = (n + cols - 1) // cols  # ceil division

fig, axes = plt.subplots(rows, cols, figsize=(3*cols, 3*rows))

for i, ax in enumerate(axes.flat):
    if i < n:
        image = images[i].permute(1, 2, 0).cpu().numpy()
        ax.imshow(image)
        ax.axis("off")  # hide axis
    else:
        ax.remove()  # remove empty subplot

plt.tight_layout()
plt.show()

In [None]:
train_dataloader = data_module.train_dataloader()

In [None]:
batch_idx = 0
for i, sample_batch in enumerate(train_dataloader):
    if i == batch_idx:
        break

In [None]:
images, targets = sample_batch

In [None]:
def to_per_pixel_targets_semantic(
    targets: list[dict],
    ignore_idx,
):
    per_pixel_targets = []
    for target in targets:
        per_pixel_target = torch.full(
            target["masks"].shape[-2:],
            ignore_idx,
            dtype=target["labels"].dtype,
            device=target["labels"].device,
        )

        for i, mask in enumerate(target["masks"]):
            per_pixel_target[mask] = target["labels"][i]

        per_pixel_targets.append(per_pixel_target)

    return per_pixel_targets

In [None]:
masks = to_per_pixel_targets_semantic(targets, 0)

In [None]:
masks[0].shape

In [None]:
n = len(images)
cols = 4  # number of images per row
rows = (n + cols - 1) // cols  # ceil division

fig, axes = plt.subplots(rows, cols, figsize=(3*cols, 3*rows))

for i, ax in enumerate(axes.flat):
    if i < n:
        image = images[i].permute(1, 2, 0).cpu().numpy()
        ax.imshow(image)
        ax.axis("off")  # hide axis
    else:
        ax.remove()  # remove empty subplot

plt.tight_layout()
plt.show()

In [None]:
n = len(masks)
cols = 4  # number of images per row
rows = (n + cols - 1) // cols  # ceil division

fig, axes = plt.subplots(rows, cols, figsize=(3*cols, 3*rows))

for i, ax in enumerate(axes.flat):
    if i < n:
        image = masks[i].cpu().numpy()
        ax.imshow(image)
        ax.axis("off")  # hide axis
    else:
        ax.remove()  # remove empty subplot

plt.tight_layout()
plt.show()

In [None]:
targets[1]["masks"]