In [None]:
%cd ..

from omegaconf import OmegaConf
from dotenv import load_dotenv
import torch
import os
import matplotlib.pyplot as plt

from dinov2.configs import dinov2_default_config
from dinov2.train.setup import setup_dataloader

In [None]:
load_dotenv()
config_file = os.path.join(os.getenv("PROJECTPATH"), "configs/vitl_test.yaml")
default_cfg = OmegaConf.create(dinov2_default_config)
cfg = OmegaConf.load(config_file)
cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli())
cfg.train.num_workers = 1
cfg.train.batch_size_per_gpu = 1
# cfg.datasets[0].channels = 1

In [None]:
dataloader = setup_dataloader(cfg, torch.half, use_full_image=False)
get_iter = iter(dataloader)

In [None]:
example_data = next(get_iter)
example_data.keys()

In [None]:
global_crops = example_data["collated_global_crops"]
n_global_crops = global_crops.shape[0]

vmin = torch.min(global_crops)
vmax = torch.max(global_crops)

In [None]:
local_crops = example_data["collated_local_crops"]
n_local_crops = local_crops.shape[0]

vmin = min(vmin, torch.min(local_crops))
vmax = max(vmax, torch.max(local_crops))

In [None]:
fig, axes = plt.subplots(1, n_global_crops)
axes = axes.flatten()

for i in range(n_global_crops):
    im = axes[i].imshow(global_crops[i][0], cmap="gray", vmin=vmin, vmax=vmax)
    axes[i].axis("off")
    axes[i].set_title(f"crop {i + 1}")

fig.colorbar(im, ax=axes, fraction=0.02, pad=0.1)

plt.show()

In [None]:
fig, axes = plt.subplots(2, -(n_local_crops // -2))
axes = axes.flatten()

for i in range(n_local_crops):
    im = axes[i].imshow(local_crops[i][0], cmap="gray", vmin=vmin, vmax=vmax)
    axes[i].set_xticklabels([])
    axes[i].set_yticklabels([])
    axes[i].axis("off")

plt.subplots_adjust(wspace=0.1, hspace=0)
fig.colorbar(im, ax=axes, fraction=0.03, pad=0.1)

plt.show()

In [None]:
slices, rows, columns = global_crops[0].shape
row = rows // 2

fig, axes = plt.subplots(1, n_global_crops)
axes = axes.flatten()

for i in range(n_global_crops):
    im = axes[i].imshow(global_crops[i][:, row, :], cmap="gray", vmin=vmin, vmax=vmax)
    axes[i].axis("off")
    axes[i].set_title(f"crop {i + 1}")

fig.colorbar(im, ax=axes, fraction=0.02, pad=0.1)

plt.show()