# Transforms visualization

In [None]:
from lib_dl.analysis.notebook import setup_notebook
setup_notebook("../../")
               
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

from lib_dl.analysis.visualize.images import show_dataset_samples
from transforms_2d import (
    create_trans2d_dataset,
    Trans2DConfig,
    TRANSFORMS,
)
from dataset_analysis.t2d_dataset.plotting_utils import plot_transforms
# from experiments.data_analysis.obj2d_dataset.benchmark_util import load_data

In [None]:
MAIN_PLOTS_DIR = Path("../../neurips_2023/figures/dataset")
APPENDIX_PLOTS_DIR = Path("../../neurips_2023/figures/dataset/ap_all_samples")
# FIGSIZE = (16, 4.5)
MAIN_FIGSIZE = (3.5, 1)
MAIN_FIGSIZE = (10, 2)
APPENDIX_FIGSIZE = (4.5, 1)

## Main paper

In [None]:
n_samples = 8
config = Trans2DConfig(
    sampling_seed=483,
    transforms_sampling_seed=893,
    img_size=32,
    n_training_samples=n_samples,
    n_val_samples=n_samples,
    n_test_samples=n_samples,
    batch_size=n_samples,
)
transforms = [
    # ("translate", [1], 2934),
    ("rotate", [28], 394),
    ("hue", [5], 5832),
    ("blur", [34], 22),
    ("translate", [16], 532),
]

plot_transforms(
    config,
    transforms,
    MAIN_PLOTS_DIR,
    MAIN_FIGSIZE,
)

## Appendix, all transformations

In [None]:
n_samples = 6
config = Trans2DConfig(
    sampling_seed=992,
    transforms_sampling_seed=2532,
    img_size=32,
    n_training_samples=n_samples,
    n_val_samples=n_samples,
    n_test_samples=n_samples,
    batch_size=n_samples,
)
seed_offset = 5
def obj_fn(o):
    remap = {
        3: 46, # h_flip
        8: 36, # grayscale
        9: 37, # posterize
        11: 38, # sharpen
        12: 39, # blur
        14: 24, # pixelate
        16: 26, # erasing,
        17: 56, # contrast
    }
    if o in remap:
        return remap[o]
    else:
        return o + 8
transforms = list(zip(
    TRANSFORMS,
    [[obj_fn(o)] for o in range(len(TRANSFORMS))],
    [seed_offset for s in range(len(TRANSFORMS))],
))

plot_transforms(
    config,
    transforms,
    APPENDIX_PLOTS_DIR,
    APPENDIX_FIGSIZE,
)

## Checks

### Multiple objects per class

Use multiple objects per class, instead of the single object used for pre-training.

In [None]:
n_samples = 8
config = Obj2DConfig(
    sampling_seed=483,
    transforms_sampling_seed=893,
    img_size=32,
    n_training_samples=n_samples,
    n_val_samples=n_samples,
    n_test_samples=n_samples,
    batch_size=n_samples,
)
transforms = [
    # ("translate", [1], 2934),
    ("rotate", [28], 394),
    ("hue", [5], 5832),
    ("blur", [34], 22),
    ("blur", [16], 22),
]

plot_transforms(
    config,
    transforms,
    plots_dir=None,
    figsize=(7, 1),
    use_single_object=False,
)

### Only transformation differences

In [None]:
n_samples = 16

def get_dataset(transforms_seed: int):
    return create_transforms_datasets(
        CTDataConfig(
            dataset="obj2d",
            config_seed=5839,
            sampling_seed=483,
            transforms_sampling_seed=transforms_seed,
            n_classes=10,
            img_size=32,
            n_training_samples=n_samples,
            n_val_samples=n_samples,
            n_test_samples=n_samples,
            batch_size=n_samples,
            # transforms=["move", "scale"],
        ),
        normalize=False,
    )

data_1 = get_dataset(593)
data_2 = get_dataset(4020)
transform = "rw_translate"

print("transforms:", data_1.data.keys())
fig_1 = show_dataset_samples(data_1.data[transform], n_samples=n_samples, data_type="test")
fig_2 = show_dataset_samples(data_2.data[transform], n_samples=n_samples, data_type="test")

In [None]:
# Cifar with augmentations

In [None]:
from utils.invariance_measurement import create_rw_datasets

In [None]:
datasets_1 = create_rw_datasets("cifar100", 17, normalize=False)
datasets_2 = create_rw_datasets("cifar100", 20, normalize=False)

transforms = ["translate", "hue", "rotate"]
n_samples = 8
for transform in transforms:
    print("transform:", transform)
    show_dataset_samples(datasets_1[transform], n_samples=n_samples, data_type="test")
    show_dataset_samples(datasets_2[transform], n_samples=n_samples, data_type="test")