# import packages

In [1]:
from pathlib import Path

import numpy as np
import pandas as pd
import torchvision

from PIL import Image, ImageFilter

# Set random seed for reproducability

In [2]:
SEED = 100

np.random.seed(SEED)
rng = np.random.default_rng(SEED)

# Set variables and preprocess pipeline

In [3]:
datadir = Path('../data')
raw_datadir = datadir.joinpath('CCSN_v2')
train_datadir = datadir.joinpath('processed', 'train')
test_datadir = datadir.joinpath('processed', 'test')
train_ratio = 0.9  # test_ratio = 1 - train_ratio

# for preprocess
pipeline = [
    lambda img: Image.Image.resize(img, (256, 256)),  # shape (256, 256)
    lambda img: Image.Image.crop(img, (14, 14, 241, 241)),  # shape (227, 227)
    {'img_filter': [ImageFilter.EDGE_ENHANCE, ImageFilter.UnsharpMask]}
]
suffix = ['resize', 'edge', 'unsharp']

# Get raw data path and split to train/test data

In [4]:
dataset = torchvision.datasets.ImageFolder(raw_datadir)
# make dir
for img_cls in dataset.classes:
    if not train_datadir.joinpath(img_cls).is_dir():
        train_datadir.joinpath(img_cls).mkdir(parents=True)
    if not test_datadir.joinpath(img_cls).is_dir():
        test_datadir.joinpath(img_cls).mkdir(parents=True)

# get raw data path
idx_to_class = {value: key for key, value in dataset.class_to_idx.items()}
df = pd.DataFrame(dataset.samples, columns=['img_path', 'img_idx'])
df['img_cls'] = df.img_idx.apply(lambda x: idx_to_class[x]).values

# split train/test
num_img_inclass = len(suffix)
num_processed_img = df.shape[0] * num_img_inclass
num_train_img = int(num_processed_img * train_ratio)
num_test_img = num_processed_img - num_train_img
is_train = rng.permutation(
    np.r_[np.ones(num_train_img), np.zeros(num_test_img)]
).reshape(-1, num_img_inclass)
print(f'Number of images for training: {num_train_img}')
print(f'Number of images for testing: {num_test_img}')

Number of images for training: 6866
Number of images for testing: 763


# Preprocess and save

In [5]:
for i, (raw_img_path, img_cls) in enumerate(df[['img_path', 'img_cls']].values):
    raw_img_path = Path(raw_img_path)
    img = Image.open(raw_img_path)
    imgs = [img for i in range(num_img_inclass)]
    for proc in pipeline:
        if callable(proc):
            imgs = [proc(img) for img in imgs]
        elif isinstance(proc, dict) and 'img_filter' in proc:
            for j, _filter in enumerate(proc['img_filter']):
                imgs[j+1] = imgs[j+1].filter(_filter)
        else:
            pass
    for j, img in enumerate(imgs):
        if is_train[i, j]:
            processed_img_path = train_datadir.joinpath(
                img_cls, 
                f'{raw_img_path.resolve().stem}-{suffix[j]}{raw_img_path.suffix}'
            )
            img.save(processed_img_path)
        else:
            processed_img_path = test_datadir.joinpath(
                img_cls, 
                f'{raw_img_path.resolve().stem}-{suffix[j]}{raw_img_path.suffix}'
            )
            img.save(processed_img_path)