In [18]:
import logging
import os
import shutil
from glob import glob
from typing import Any, List, Tuple

import cv2
import dask.bag
from distributed import Client
from tqdm import tqdm

logging.basicConfig(level=logging.INFO)


def make_patches(
    img_path_idx_tuple: Tuple[str, int],
    target_path: str,
    stage: str,
    hr_image_size: int = 128,
    lr_image_size: int = 32,
    stride_hr: int = 4,
    stride_lr: int = 1,
    scaling_factor: int = 4,
) -> None:
    img_path, idx = img_path_idx_tuple
    img = cv2.imread(img_path)

    def generate(image: Any, image_size: int, stride: int, is_lr: bool):
        h, w, channels = image.shape
        num_row = h // stride
        num_col = w // stride
        image_index = 0

        for i in range(num_row):
            if (i + 1) * image_size > h:
                break
            for j in range(num_col):
                if (j + 1) * image_size > w:
                    break

                path_to_img = os.path.join(
                    target_path,
                    stage,
                    "lr" if is_lr else "hr",
                    f'{str(idx).rjust(10, "0")}_{image_index}.png',
                )

                cv2.imwrite(
                    path_to_img,
                    image[
                        i * image_size : (i + 1) * image_size,
                        j * image_size : (j + 1) * image_size,
                    ],
                )
                image_index = image_index + 1

    # HR
    generate(img, hr_image_size, stride_hr, False)

    # LR
    width = int(img.shape[1] / scaling_factor)
    height = int(img.shape[0] / scaling_factor)
    img = cv2.resize(img, (width, height), interpolation=cv2.INTER_CUBIC)
    generate(img, lr_image_size, stride_lr, True)


def make_center_crop(
    img_path_idx_tuple: Tuple[str, int],
    target_path: str,
    stage: str,
    hr_image_size: int = 128,
    scaling_factor: int = 4,
) -> None:
    img_path, idx = img_path_idx_tuple
    img = cv2.imread(img_path)

    width, height = img.shape[1], img.shape[0]

    # process crop width and height for max available dimension
    crop_width = hr_image_size if hr_image_size < img.shape[1] else img.shape[1]
    crop_height = hr_image_size if hr_image_size < img.shape[0] else img.shape[0]
    mid_x, mid_y = int(width / 2), int(height / 2)
    cw2, ch2 = int(crop_width / 2), int(crop_height / 2)
    img_hr = img[mid_y - ch2 : mid_y + ch2, mid_x - cw2 : mid_x + cw2]

    img_lr = cv2.resize(
        img_hr,
        (int(hr_image_size / scaling_factor), int(hr_image_size / scaling_factor)),
        interpolation=cv2.INTER_CUBIC,
    )

    path_to_img_hr = os.path.join(
        target_path, stage, "hr", f'{str(idx).rjust(6, "0")}.png'
    )
    path_to_img_lr = os.path.join(
        target_path, stage, "lr", f'{str(idx).rjust(6, "0")}.png'
    )

    cv2.imwrite(path_to_img_hr, img_hr)
    cv2.imwrite(path_to_img_lr, img_lr)


def get_images(data_path: str) -> Tuple[List[str], List[str], List[str]]:
    glob_images = [
        glob(p, recursive=True)
        for p in [
            os.path.join(data_path, "**", ext)
            for ext in [
                ".jpeg",
                "*.jpg",
                "*.png",
                ".bmp",
                ".JPEG",
                ".JPG",
                ".PNG",
                ".BMP",
            ]
        ]
    ]
    images = []
    for img_list in glob_images:
        images.extend(img_list)

    logging.info(f"Total of {len(images)} found under the {data_path}")

    train_hr_images = []
    val_hr_images = []
    test_hr_images = []

    for img_path in tqdm(images):
        if "/val/" in img_path:
            val_hr_images.append(img_path)
        elif "/test/" in img_path:
            test_hr_images.append(img_path)
        else:
            train_hr_images.append(img_path)

    logging.info(
        f"Train/Validation/Test split sizes: {len(train_hr_images)}/{len(val_hr_images)}/{len(test_hr_images)}"
    )

    return sorted(train_hr_images), sorted(val_hr_images), sorted(test_hr_images)

In [21]:
data_dir = "/media/xultaeculcis/2TB/datasets/sr/pre-training/classic/original"
target_dir = "/media/xultaeculcis/2TB/datasets/sr/pre-training/classic"
scaling_factor = 4
hr_image_size = 128
lr_image_size = 32
stride_hr = 100
stride_lr = 25

In [32]:
os.makedirs(os.path.join(target_dir, "train", "lr"), exist_ok=True)
os.makedirs(os.path.join(target_dir, "train", "hr"), exist_ok=True)
os.makedirs(os.path.join(target_dir, "val", "lr"), exist_ok=True)
os.makedirs(os.path.join(target_dir, "val", "hr"), exist_ok=True)
os.makedirs(os.path.join(target_dir, "test", "lr"), exist_ok=True)
os.makedirs(os.path.join(target_dir, "test", "hr"), exist_ok=True)

train, val, test = get_images(data_dir)

INFO:root:Total of 17600 found under the /media/xultaeculcis/2TB/datasets/sr/pre-training/classic/original
100%|██████████| 17600/17600 [00:00<00:00, 1829893.91it/s]
INFO:root:Train/Validation/Test split sizes: 15914/924/762


In [27]:
c = Client(n_workers=8, threads_per_worker=1)

In [29]:
indices = [i for i in range(len(val))]
items = zip(sorted(val), indices)
_ = (
    dask.bag.from_sequence(items)
    .map(
        make_center_crop,
        target_path=target_dir,
        stage="val",
        hr_image_size=hr_image_size,
        scaling_factor=scaling_factor,
    )
    .compute()
)
logging.info("Done with val dataset")

indices = [i for i in range(len(train))]
items = zip(sorted(train), indices)
_ = (
    dask.bag.from_sequence(items, npartitions=1000)
    .map(
        make_patches,
        target_path=target_dir,
        stage="train",
        hr_image_size=hr_image_size,
        lr_image_size=lr_image_size,
        stride_hr=stride_hr,
        stride_lr=stride_lr,
        scaling_factor=scaling_factor,
    )
    .compute()
)
logging.info("Done with train dataset")

INFO:root:Done with val dataset
INFO:root:Done with train dataset


In [30]:
# close only after the computation has finished
c.close()

In [None]:
# In case of Fakap
shutil.rmtree(os.path.join(target_dir, "train"))
shutil.rmtree(os.path.join(target_dir, "val"))

In [17]:
# For really big dataset use this instead
c = Client(n_workers=8, threads_per_worker=2)

images = glob(os.path.join(target_dir, "train/**/*.png"), recursive=True)

_ = dask.bag.from_sequence(images, npartitions=10000).map(os.remove).compute()


# close only after the delete operation has finished
c.close()

tornado.application - ERROR - Uncaught exception GET /status/ws (::1)
HTTPServerRequest(protocol='http', host='localhost:8787', method='GET', uri='/status/ws', version='HTTP/1.1', remote_ip='::1')
Traceback (most recent call last):
  File "/home/xultaeculcis/anaconda3/envs/sr/lib/python3.8/site-packages/tornado/websocket.py", line 954, in _accept_connection
    open_result = handler.open(*handler.open_args, **handler.open_kwargs)
  File "/home/xultaeculcis/anaconda3/envs/sr/lib/python3.8/site-packages/tornado/web.py", line 3173, in wrapper
    return method(self, *args, **kwargs)
  File "/home/xultaeculcis/anaconda3/envs/sr/lib/python3.8/site-packages/bokeh/server/views/ws.py", line 137, in open
    raise ProtocolError("Token is expired.")
bokeh.protocol.exceptions.ProtocolError: Token is expired.


In [31]:
images = glob(os.path.join(target_dir, "train/**/*.png"), recursive=True)