# Data selection & augmentation

TODO:
- speed up...

# Imports

In [1]:
import random
import re
import shutil
from glob import glob
from pathlib import Path
from typing import Tuple
from warnings import warn

import numpy as np
import torch
from PIL import Image

# from torch.utils.data import Subset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm.auto import tqdm

from src.utils_Img2Img import print_grid

# Data selection

Performs the following:
- selects *the same amount* (number of images) of data, at random, from each *super*-class of the original dataset, where a super-class is a first-level class, ie:
```
- latrunculin_B_high_conc     <- super-class
      - latrunculin B 30      <- not a super-class
            - file1
            - file2
      - latrunculin B 10      <- not a super-class either
            - file
```
- splits the data into train and test sets

# Original dataset

Load the original dataset

In [2]:
augmentations = transforms.Compose(
    [
        transforms.Resize(128, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),  # map to [-1, 1] for SiLU
    ]
)

data_root_path = "/projects/imagesets/Golgi/128x128/"

dataset = ImageFolder(
    root=data_root_path,
    transform=lambda x: augmentations(x.convert("RGB")),
    target_transform=lambda y: torch.tensor(y).long(),
)

In [3]:
dataset

Dataset ImageFolder
    Number of datapoints: 175020
    Root location: /projects/imagesets/Golgi/128x128/
    StandardTransform
Transform: <function <lambda> at 0x7f12a788d870>
Target transform: <function <lambda> at 0x7f12a788d990>

Check original dataset stats

In [4]:
unique, counts = np.unique(dataset.targets, return_counts=True)
print("Class names and indices:", [
      (cl, dataset.class_to_idx[cl]) for cl in dataset.classes])
print("Counts:", dict(zip(unique, counts)))

common_nb_samples = np.min(counts)
print("Common number of samples to select per class:", common_nb_samples)

Class names and indices: [('0', 0), ('1', 1)]
Counts: {0: 89446, 1: 85574}
Common number of samples to select per class: 85574


Define train/test split sizes

In [5]:
split_approx_props = {"train": 0.5, "test": 0.5}

In [6]:
assert sum(split_approx_props.values()) == 1.0

split_sizes = dict.fromkeys(split_approx_props.keys(), 0)

cumsum = 0
for idx, (split_name, prop) in enumerate(split_approx_props.items()):
    if idx == 1: # fill the test split with the remaining samples
        assert split_name == "test"
        split_size = common_nb_samples - cumsum
    else:
        split_size = int(prop * common_nb_samples)
    cumsum += split_size
    split_sizes[split_name] = split_size
assert cumsum == common_nb_samples

print("split_sizes:", split_sizes)

split_sizes: {'train': 42787, 'test': 42787}


Helper function

In [7]:
def get_full_names(selected_files: list[str], src_dir: Path) -> list[Tuple[Path, str]]:
    """Returns the (source, target) file path pairs,
    with possibly the intermediary directory name prefixed to it. eg:

    - DMSO (= src_dir)
        - file1
        - file2

    would give [(file1, file1), (file2, file2)], but:

    - latrunculin_B_high_conc (= src_dir)
        - latrunculin B 30
            - file1
            - file2
      - latrunculin B 10
            - file1

    would give [
        (latrunculin B 30/file1, latrunculin B 30_file1), 
        (latrunculin B 30/file2, latrunculin B 30_file2), 
        (latrunculin B 10/file1, latrunculin B 10_file1)
    ]                    ^                       ^     
                         |                       |
    (notice the         '/'          vs         '_')
    """
    selected_files_full_names = []

    for file in selected_files:
        # find the last part(s) (2 max)
        parts = Path(file).parts
        min_len = min(len(parts), len(src_dir.parts))
        idx = 0
        while idx < min_len and parts[idx] == src_dir.parts[idx]:
            idx += 1
        last_parts = parts[idx:]
        assert len(last_parts) <= 2
        # form the (Path, str) pair
        to_append = (Path(*last_parts), '_'.join(last_parts))
        selected_files_full_names.append(to_append)

    return selected_files_full_names

Target directory

In [8]:
aug_data_root_path = "/projects/deepdevpath/Thomas/data/Golgi"  # train/test split done in the copy script

# Data selection

In [9]:
# do_copy=False means dry-run (no data will be written)
do_copy = True

Select & copy data into splits

In [11]:
desc = "Copying files" if do_copy else "*Not* copying files"

pbar = tqdm(
    total=common_nb_samples * len(dataset.classes),
    desc=desc,
)

# Create general destination directory
dst_dir = Path(aug_data_root_path)
print(f"Destination directory: {dst_dir}")
if dst_dir.exists():
    if do_copy:
        raise RuntimeError(f"{dst_dir} already exists!")
    else:
        warn(f"{dst_dir} already exists! But no copy will be made...")
if do_copy:
    print(f"Creating {dst_dir}")
    dst_dir.mkdir(parents=True, exist_ok=False)
else:
    print(f"Would create {dst_dir}")

# fill the class splits
for class_name in dataset.classes:
    print(f"\n===========================> {class_name}")
    # Set the source and destination directories
    src_dir = Path(data_root_path, class_name)
    print(f"Source directory:      {src_dir}")

    # Get a list of all PNG files in the source directory
    # A single class might have multiple subdirectories
    # so we need some glob magic
    pathname = src_dir.as_posix() + "/**/*.png"
    png_file_paths = list(glob(pathname, recursive=True))
    print(f"Found {len(png_file_paths)} PNG files")

    # Check if there are enough PNG files to select from
    if len(png_file_paths) < common_nb_samples:
        raise ValueError(f"Not enough PNG files in {src_dir} to select from.")

    # Select a random subset of `common_nb_samples` images files
    selected_files = random.sample(png_file_paths, common_nb_samples)

    # Split this selection into train/test splits
    train_files = selected_files[: split_sizes["train"]]
    test_files = selected_files[split_sizes["train"] :]
    assert len(train_files) + len(test_files) == common_nb_samples
    files_dict = {"train": train_files, "test": test_files}

    for split, files in files_dict.items():
        # Copy the selected files to the destination directory
        # first get their name (+ possibly intermediary class)
        selected_files_full_names = get_full_names(files, src_dir)

        split_class_dir = Path(dst_dir, split, class_name)
        if not split_class_dir.exists():
            if do_copy:
                print(f"Creating {split_class_dir}")
                split_class_dir.mkdir(parents=True, exist_ok=False)
            else:
                print(f"Would create {split_class_dir}")

        # then copy
        for srcfilename, trgtfilename in selected_files_full_names:
            src_path = Path(src_dir, srcfilename)
            dst_path = Path(dst_dir, split, class_name, trgtfilename)
            pbar.set_postfix_str(f"{srcfilename} -> {split}/{trgtfilename}")
            if do_copy:
                if dst_path.exists():
                    raise RuntimeError(f"{dst_path} already exists!")
                shutil.copy(src_path, dst_path)
            pbar.update()
        if not do_copy:
            print(
                f"    Would have copied {len(selected_files_full_names)} files to {split_class_dir}"
            )

pbar.close()

C:   0%|          | 0/171148 [00:00<?, ?it/s]

Destination directory: /projects/deepdevpath/Thomas/data/Golgi
Creating /projects/deepdevpath/Thomas/data/Golgi

Source directory:      /projects/imagesets/Golgi/128x128/0
Found 89446 PNG files
Creating /projects/deepdevpath/Thomas/data/Golgi/train/0
Creating /projects/deepdevpath/Thomas/data/Golgi/test/0

Source directory:      /projects/imagesets/Golgi/128x128/1
Found 85574 PNG files
Creating /projects/deepdevpath/Thomas/data/Golgi/train/1
Creating /projects/deepdevpath/Thomas/data/Golgi/test/1


# Data augmentation

The symmetry group of a square $\mathrm{Dih}_4$ is of order 8, so given the assumed semantic invariance of our 2D squared images (no up and down or right and left, no bottom and above), we can achieve up to 8x times data augmentation.

In [None]:
OPS = (
    "rot1",
    "rot2",
    "flip2rot2",
    "flip1rot2",
    "rot3",
    "flip2rot3",
    "flip1rot3",
)

In [None]:
def perform_data_aug_op(image: np.ndarray, op_code: str):
    """
    Performs the data augmentation operation given by op_code on the image.

    Arguments
    =========
    - image: array, shape (3, res, res)

    - op_code: str

    Returns
    =======
    - modified_image: array, same shape
    """

    assert op_code in OPS, "op_code not in OPS_LIST"

    image = image.copy()

    if op_code == "id":
        modified_image = image
    elif re.match(r"rot\d", op_code):
        res = re.match(r"rot\d", op_code).group()
        r = int(res[-1])
        modified_image = np.rot90(image, r, (1, 2))
    elif re.match(r"flip\drot\d", op_code):
        res = re.match(r"flip\drot\d", op_code).group()
        r, f = int(res[-1]), int(res[-5])
        modified_image = np.flip(np.rot90(image, r, (1, 2)), f)
    else:
        raise ValueError("No match!?")

    return modified_image

Check the augmentations

In [None]:
for class_name in dataset.classes:
    print(f"\n===========================> {class_name}")
    for split in ["train", "test"]:
        # Set the source and destination directories
        dir = Path(aug_data_root_path, split, class_name)
        print(f"Directory: {dir}")

        # Get a list of all PNG files in the directory
        png_file_paths = list(dir.glob('**/*.png'))
        assert len(png_file_paths) == split_sizes[split]
        print(f"Found {len(png_file_paths)} PNG files")

        # Visualize data aug
        img_list = []
        for file in png_file_paths:
            image = np.asarray(Image.open(file))
            image = image.transpose(2, 0, 1)
            img_list.append(Image.open(file))
            for op_code in OPS:
                modified_image = perform_data_aug_op(image, op_code)
                img_list.append(Image.fromarray(modified_image.transpose(1, 2, 0)))
            break

        print_grid(img_list)

Augment the images *in-place*

In [None]:
pbar = tqdm(total=common_nb_samples * len(dataset.classes) * len(OPS))

for class_name in dataset.classes:
    print(f"\n===========================> {class_name}")
    for split in ["train", "test"]:
        # Set the source and destination directories
        dir = Path(aug_data_root_path, split, class_name)
        print(f"Directory: {dir}")

        # Get a list of all PNG files in the directory
        png_file_paths = list(dir.glob('**/*.png'))
        assert len(png_file_paths) == split_sizes[split]
        print(f"Found {len(png_file_paths)} PNG files")

        # Perform data aug
        for file in png_file_paths:
            array = np.asarray(Image.open(file)).transpose(2, 0, 1)

            for op_code in OPS:
                modified_array = perform_data_aug_op(array, op_code)
                modif_filename = file.stem + "_" + op_code + ".png"
                modif_filepath = Path(file.parent, modif_filename)
                assert not modif_filepath.exists()
                img = Image.fromarray(modified_array.transpose(1, 2, 0))
                img.save(modif_filepath)
                pbar.update()

pbar.close()

Quick dirty check (adapt to correct folder names!)

In [12]:
!cd $aug_data_root_path && ls -l train/0 | wc -l && ls -l train/1 | wc -l && ls -l test/0 | wc -l && ls -l test/1 | wc -l

42788
42788
42788
42788
