In [None]:
import sys
sys.path.append('..')
import time

import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

from datasets.anorak import ANORAK
from models.histo_linear_decoder import LinearDecoder
from training.tiler import GridPadTiler

In [None]:
def move_to_device(x, device):
    if torch.is_tensor(x):
        return x.to(device, non_blocking=True)
    if isinstance(x, dict):
        return {k: move_to_device(v, device) for k, v in x.items()}
    if isinstance(x, (list, tuple)):
        return type(x)(move_to_device(v, device) for v in x)
    return x


def batch_size_of(batch):
    x = batch[0] if isinstance(batch, (list, tuple)) else batch
    return int(getattr(x, "shape", [1])[0]) if hasattr(x, "shape") else 1


In [None]:


device = torch.device("cuda:0")
max_batches = 2
pl_dm = ANORAK(
    "/home/valentin/workspaces/benchmark-vfm-ss/data/ANORAK",
    devices=1,
    num_workers=0,
    batch_size=1,
    img_size=(448, 448),
)

In [None]:
pl_dm.setup()


In [None]:

loader = pl_dm.val_dataloader()
n, samples = 0, 0
t0 = time.perf_counter()
tiler = GridPadTiler(448, 224, weighted_blend=False)


In [None]:
images = []
images_stitched = []
for n, batch in tqdm(enumerate(loader, start=1), total=max_batches):
    if device:
        batch = move_to_device(batch, device)
    samples += batch_size_of(batch)
    imgs, target = batch
    crops, origins, img_sizes = tiler.window(imgs)
    images.extend([img.cpu().numpy() for img in imgs])
    imgs_stitched = tiler.stitch(crops, origins, img_sizes)
    images_stitched.extend([img.cpu().numpy().astype(np.uint8) for img in imgs_stitched])

    if n >= max_batches:
        break
t1 = time.perf_counter()
print(f"Processed {samples} samples in {t1 - t0:.2f} seconds")




In [None]:
images_diff = []
for i, img in enumerate(images):
    image_diff = img.astype(np.uint8) - images_stitched[i].astype(np.uint8)
    images_diff.append(image_diff)
    if image_diff.max() > 0:
        print(f"Image {i} max difference: {image_diff.max()}")


In [None]:
plt.imshow(images[1].transpose(1, 2, 0)-images_stitched[1].transpose(1, 2, 0))

In [None]:
plt.imshow(images_stitched[1].transpose(1, 2, 0))