In [None]:
import os
os.chdir(os.path.abspath('..'))
%load_ext autoreload
%autoreload 2
from dotenv import load_dotenv
load_dotenv()
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
import matplotlib.pyplot as plt

SIZE_FACTOR = 2

def plot_crops_rgb(sample):
    n_glob_crops = len(sample['global_crops'])
    n_loc_crops = len(sample['local_crops'])
    assert n_loc_crops % n_glob_crops == 0
    n_loc_per_glob = n_loc_crops // n_glob_crops

    print(f'global crops: n={n_glob_crops}, shape={tuple(sample["global_crops"][0]["imgs"].shape)}')
    print(f'local crops: n={n_loc_crops}, shape={tuple(sample["local_crops"][0]["imgs"].shape)}')

    fig, axs = plt.subplots(n_glob_crops, 1+n_loc_per_glob, figsize=((1+n_loc_per_glob)*SIZE_FACTOR, n_glob_crops*SIZE_FACTOR))

    def _plot(i,j, crop):
        rgb_img = crop['imgs'].permute(1, 2, 0)
        rgb_img = (rgb_img - rgb_img.min()) / (rgb_img.max() - rgb_img.min())
        axs[i,j].imshow(rgb_img)
        axs[i,j].axis('off')

    loc = 0
    for i, glob_crop in enumerate(sample['global_crops']):
        _plot(i,0, glob_crop)
        axs[i,0].set_title(f'global crop {i}')

        for j in range(n_loc_per_glob):
            loc_crop = sample['local_crops'][loc+j]
            _plot(i,1+j, loc_crop)
            axs[i,1+j].set_title(f'local crop {loc+j}')
        loc += n_loc_per_glob
    plt.show()

def plot_crops_chns(sample):
    def _plot(sample,key):
        print(key)
        crops = sample[key]

        ncrops = len(crops)
        shape = crops[0]['imgs'].shape
        print(f'ncrops={ncrops}, shape={shape}')

        fig, axs = plt.subplots(ncrops, shape[0], figsize=(shape[0]*SIZE_FACTOR, ncrops*SIZE_FACTOR))
        for i, crop in enumerate(crops):
            for j, chn in enumerate(crop['imgs']):
                axs[i, j].imshow(chn)
                axs[i, j].axis('off')
        plt.show()

    _plot(sample, 'global_crops')
    _plot(sample, 'local_crops')

def plot_crops(sample, is_rgb=False):
    if is_rgb:
        plot_crops_rgb(sample)
    else:
        plot_crops_chns(sample)

In [None]:
""" make datasets """
from omegaconf import OmegaConf
from dinov2.eval.setup import setup_logger
from dinov2.data.loaders import make_dataset
from dinov2.data.augmentations import make_augmentation

global logger
logger = setup_logger('dinov2', to_sysout=True, simple_prefix=True)

# make fmow
cfg = OmegaConf.create({
    'id': 'FmowDataset',
    'root': '/data/panopticon/datasets/',
    'split': 'fmow/metadata_v2/fmow_iwm_onid_train_val.parquet',
    'return_rgb': True
})
ds_fmow = make_dataset(cfg, seed=42)
# cfg.keep_sensors = ['rgb']
# ds_fmow_rgb = make_dataset(cfg, seed=42)

# make mmearth

# ds_mmearth = make_dataset(cfg, seed=42)


In [None]:
# make dino augm
dino_augm_cfg = [OmegaConf.create({
    'id': 'ChnSpatialAugmentationV2',
    'global_crops_scale': [0.32, 1.0],
    'local_crops_number': 4,
    'global_crops_number': 2,
    'local_crops_scale': [0.05, 0.32],
    'global_crops_size': 224,
    'local_crops_size': 98,
    'global_crops_spectral_size': [3,3],
    'local_crops_spectral_size': [3,3],
    'global_modes_probs': [1.0, 0., 0.],
    'local_modes_probs': [1.0, 0.],
    'color_jitter_args': {
        'p': 1.0, # also need to set this
        'brightness': 0.4,
        'contrast': 0.4,
        'saturation': 0.2,
        'hue': 0.1
    }
})]
dino_augm_new = make_augmentation(dino_augm_cfg)

In [None]:
import random

ds = ds_fmow
ds.transform = dino_augm_new
for _ in range(10):
    idx = random.randint(0, len(ds)-1)
    sample = ds[idx]
    SIZE_FACTOR = 4
    plot_crops(sample, is_rgb=True)