For this notebook to work, fmow dataset needs to return imgs, chn_ids as list

In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
os.chdir(os.path.abspath('..'))

In [None]:
""" prepare datasets """
from omegaconf import OmegaConf
from dinov2.data.loaders import make_dataset   
from dinov2.eval.setup import setup_logger
from dinov2.data.augmentations import make_augmentation

global logger
logger = setup_logger('dinov2', to_sysout=True)

cfg = OmegaConf.create({
    'id': 'FmowDataset',
    'root': '/data/panopticon/datasets/',
    'split': 'fmow/metadata_v2/fmow_iwm_onid_train_val.parquet',
    'use_new_output': False,
})
ds_all = make_dataset(cfg)

cfg.keep_sensors = ['rgb']
ds_rgb = make_dataset(cfg)

# cfg.keep_sensors = ['s2']
# ds_s2 = make_dataset(cfg)

# cfg.keep_sensors = ['wv23']
# ds_wv = make_dataset(cfg)

cfg.split = 'fmow/metadata_v2/fmow_iwm_onid_train_val_presorted.parquet'
cfg.pop('keep_sensors')
ds_wvs2 = make_dataset(cfg)


I20241102 15:27:35 3076981 dinov2 loaders.py:68] Building dataset "FmowDataset" ...
I20241102 15:27:35 3076981 dinov2 augmentations.py:59] Augmentations in order: []
I20241102 15:27:36 3076981 dinov2 fmow.py:110] Using reference date: 2002-01-01
I20241102 15:27:38 3076981 dinov2 fmow.py:132] Dataset size: 95418, sensor counts: {'wv23': 89301, 's2': 72858, 'rgb': 95418}
I20241102 15:27:41 3076981 dinov2 fmow.py:167] Subsetted dataset to only include rows where every row has atleast one image >= 256 | #rows = 95418
I20241102 15:27:42 3076981 dinov2 fmow.py:132] Dataset size: 95418, sensor counts: {'wv23': 89301, 's2': 72858, 'rgb': 95418}
I20241102 15:27:42 3076981 dinov2 fmow.py:198] Subsetted dataset to only include rows where the number of available sensors >= 2 | #rows = 93190
I20241102 15:27:43 3076981 dinov2 fmow.py:132] Dataset size: 93190, sensor counts: {'wv23': 89301, 's2': 72858, 'rgb': 93190}
I20241102 15:27:43 3076981 dinov2 fmow.py:118] Normalizing images
I20241102 15:27:43

In [None]:
dino_augm_cfg = [OmegaConf.create({
    'id': 'ChnSpatialAugmentationV2',
    'global_crops_scale': [0.32, 1.0],
    'local_crops_number': 4,
    'global_crops_number': 2,
    'local_crops_scale': [0.05, 0.32],
    'global_crops_size': 224,
    'local_crops_size': 98,
    'global_crops_spectral_size': [6,13],
    'local_crops_spectral_size': [3,6],
    'global_modes_probs': [0.8, 0.1, 0.1],
    'local_modes_probs': [0.2, 0.8],
})]
dino_augm_new = make_augmentation(dino_augm_cfg)



dino_augm_cfg = [OmegaConf.create({
    'id': 'ChnSpatialAugmentation',
    'global_crops_scale': [0.32, 1.0],
    'global_crops_number': 2,
    'local_crops_number': 4,
    'local_crops_scale': [0.05, 0.32],
    'global_crops_size': [13, 224], # 224
    'local_crops_size': [6, 98], # 98
})]
dino_augm = make_augmentation(dino_augm_cfg)

I20241102 15:13:11 3076981 dinov2 augmentations.py:290] ###################################
I20241102 15:13:11 3076981 dinov2 augmentations.py:291] Using data augmentation parameters:
I20241102 15:13:11 3076981 dinov2 augmentations.py:292] id: ChnSpatialAugmentationV2
I20241102 15:13:11 3076981 dinov2 augmentations.py:293] local_crops_number: 4
I20241102 15:13:11 3076981 dinov2 augmentations.py:294] global_crops_scale: [0.32, 1.0]
I20241102 15:13:11 3076981 dinov2 augmentations.py:295] local_crops_scale: [0.05, 0.32]
I20241102 15:13:11 3076981 dinov2 augmentations.py:296] global_crops_size: 224
I20241102 15:13:11 3076981 dinov2 augmentations.py:297] local_crops_size: 98
I20241102 15:13:11 3076981 dinov2 augmentations.py:298] global_crops_spectral_size: [6, 13]
I20241102 15:13:11 3076981 dinov2 augmentations.py:299] local_crops_spectral_size: [3, 6]
I20241102 15:13:11 3076981 dinov2 augmentations.py:300] global_modes_probs: [0.8, 0.1, 0.1]
I20241102 15:13:11 3076981 dinov2 augmentations

## Correctness

In [None]:
def fwd(imgs, chn_ids, is_new):
    if is_new: 
        out = [dict(
            imgs = [imgs[i]],
            chn_ids = [chn_ids[i]],
        ) for i in range(len(imgs))]
        out = dino_augm_new(out)
    else:
        out = dict(
            imgs = [[i] for i in imgs], 
            chn_ids = chn_ids,
        )
        out = dino_augm(out)
    return out

In [None]:
""" check shapes """

ds = ds_wvs2
imgs, chn_ids = ds[0]
out = fwd(imgs, chn_ids, True)

print(len(out['global_crops']))
for data_obj in out['global_crops']:
    print(data_obj['imgs'].shape, data_obj['chn_ids'].shape)

print(len(out['local_crops']))
for data_obj in out['local_crops']:
    print(data_obj['imgs'].shape, data_obj['chn_ids'].shape)

2
torch.Size([13, 224, 224]) torch.Size([13])
torch.Size([13, 224, 224]) torch.Size([13])
4
torch.Size([6, 98, 98]) torch.Size([6])
torch.Size([6, 98, 98]) torch.Size([6])
torch.Size([6, 98, 98]) torch.Size([6])
torch.Size([6, 98, 98]) torch.Size([6])


## Benchmarking

In [44]:
import numpy as np
import time

In [None]:
""" different modes """

def benchmark(ds):
    start = time.time()
    imgs, chn_ids = ds[np.random.randint(0, len(ds))]
    data_time = time.time() - start

    start = time.time()
    out = fwd(imgs, chn_ids, is_new=True)
    augm_new_time = time.time() - start

    return data_time, augm_new_time



In [58]:
""" compare vs old augmentation on different subsets"""

import time
import numpy as np


def benchmark(ds):
    start = time.time()
    imgs, chn_ids = ds[np.random.randint(0, len(ds))]
    data_time = time.time() - start

    start = time.time()
    out = fwd(imgs, chn_ids, is_new=False)
    augm_time = time.time() - start

    start = time.time()
    out = fwd(imgs, chn_ids, is_new=True)
    augm_new_time = time.time() - start

    return data_time, augm_time, augm_new_time


nsamples = 100
all_ds = {
    'rgb ': ds_rgb,
    'wvs2': ds_wvs2,
    'all ': ds_all,
}

print('ds_name: data_time, augm_old_time, augm_new_time')
for ds_name, ds in all_ds.items():
    times = [benchmark(ds) for _ in range(nsamples)]
    times = np.array(times)
    print(f'{ds_name}: {times.mean(0)}, {times.std(0)}')


ds_name: data_time, augm_old_time, augm_new_time
rgb : [0.14657643 0.02567428 0.03003799], [0.14690989 0.01672777 0.055954  ]
wvs2: [0.12613507 0.02861156 0.01434147], [0.06535134 0.01285906 0.01131173]
all : [0.13758    0.02398742 0.01641087], [0.04931806 0.00412526 0.00613196]


In [61]:
""" benchmark loading more channels to cuda """
import torch
import time

device = 'cuda:0'

nchns = [6,12,18,24,30]
B = 60

def benchmark(nchn):
    x_dict = dict(
        imgs = torch.randn(B, nchn, 224, 224),
        chn_ids = torch.randint(0, 100, (B, nchn)),
    )
    start = time.time()
    for k in x_dict:
        x_dict[k] = x_dict[k].to(device)
    torch.cuda.synchronize()
    to_cuda_time = time.time() - start
    return to_cuda_time

nsamples = 200
warmup = 100
for nchn in nchns:
    times = [benchmark(nchn) for _ in range(nsamples)]
    times = np.array(times[warmup:])
    print(f'{nchn}: {times.mean():.3f}, {times.std():.5f}')

6: 0.016, 0.00010
12: 0.032, 0.00020
18: 0.047, 0.00017
24: 0.063, 0.00564
30: 0.078, 0.00050
