In [None]:
import h5py
import numpy as np
import torchvision.transforms.v2 as transforms
from scipy.signal import fftconvolve
import torch
from tqdm.auto import tqdm

import shutil

from vacation.data import download_dataset

from sklearn.preprocessing import MinMaxScaler

import matplotlib.pyplot as plt

rng = np.random.default_rng(42)

In [None]:
download_dataset("../../.data/", overwrite=True)

In [None]:
def numpy_to_tensor(image):
    return torch.from_numpy(image).permute(2, 0, 1).float().to("cuda")


def tensor_to_numpy(im_tensor):
    im_tensor = torch.clamp(im_tensor, 0, 1)
    return im_tensor.permute(1, 2, 0).cpu().numpy()


transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.GaussianNoise(mean=0.0, sigma=0.1, clip=True),
        transforms.ColorJitter(brightness=0.4, contrast=0, saturation=0.5, hue=0),
    ]
)


def augmented_class(path: str, class_index: int, target_count: int = 2600):

    with h5py.File(path, "r") as hf:

        labels = np.array(hf["ans"])
        images = hf["images"]

        original_indices = np.where(labels == class_index)[0]
        needed_count = target_count - len(original_indices)

        print(
            f"[Class {class_index}] Current: {len(original_indices)}, Adding: {needed_count}"
        )

        augmented_images = np.zeros((needed_count, *images[0].shape))
        augmented_labels = np.ones(labels.shape, dtype=np.uint8) * class_index

        for i in tqdm(np.arange(needed_count), desc=f"Augmenting class {class_index}"):
            idx = rng.choice(original_indices)
            augmented_images[i] = tensor_to_numpy(
                transform(numpy_to_tensor(images[idx]))
            )

    return augmented_images, augmented_labels


def extend_dataset(
    path: str,
    target_path: str,
    images: np.ndarray,
    labels: np.ndarray,
    overwrite: bool = True,
):

    target_path = Path(target_path)

    if not target_path.is_file():
        shutil.copy(Path(path), target_path)
    elif not overwrite:
        raise FileExistsError(
            "This file already exists. Set 'overwrite = True' to overwrite the file!"
        )

    with h5py.File(target_path, mode="a") as hf:

        images_h5 = hf["images"]
        labels_h5 = hf["ans"]

        original_size = images_h5.shape[0]

        images_h5.resize(original_size + images.shape[0], axis=0)
        labels_h5.resize(original_size + labels.shape[0], axis=0)

        images_h5[original_size:] = images
        labels_h5[original_size:] = labels

In [None]:
augmented_im_3, augmented_label_3 = augmented_class(
    "../../.data/Galaxy10_DECals.h5", 5, target_count=2100
)

In [None]:
data, labels = extend_dataset(
    path="../../.data/Galaxy10_DECals.h5",
    target_path="../../.data/Galaxy10_DECals_augmented.h5",
    images=augmented_im_3,
    labels=augmented_label_3,
)

In [None]:
plt.imshow(augmented_im_3[np.random.randint(0, len(augmented_label_3))][..., 2])
plt.colorbar()