Creating your own pix2pix dataset
=================================



## Installation requirements

To run this notebook as well as the [07_visualizing_pix2pix_results.ipynb](07_visualizing_pix2pix_results.ipynb) notebook you might need to install some new Python packages. To do so, open a terminal and first make sure your environment is active:

```bash
conda activate dmlap
conda install -c conda-forge pycairo opencv scikit-image
```

In [26]:
import os
import glob
import random
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

import cv2
from skimage import io
from skimage import filters
from skimage import feature
from tensorflow.keras.preprocessing import image

from IPython.display import clear_output

## Setting up 
Set your directories and the dataset specifics

-   `TARGET_DIR` defines where your **target** images are located.
-   `SOURCE_DIR` defines where your **source** images are located, if you already have these. Otherwise, set this to an empty string `''`.
-   `DATASET_DIR` defines where your pix2pix dataset will be saved.
-   `IS_INPUT_PIX_TO_PIX` set this to `True` if the input dataset already consists of an source and target pairs. This will be the case if you want to modify an existing pix2pix dataset. In this case we need to extract only the target.
-   `INPUT_OUTPUT_TARGET_INDEX` if we are manipulating a dataset that is already a pix2pix dataset, this defines whether the target image is to the left (`0`) or to the right (`1`).

Note you will have to put exactly the path to your image directories here, this code does not recursively search for images. Also note that the most common use case for this system will be with you providing an dataset of targets (desired outputs) that you will process to create the corresponding inputs (e.g. with edge detection or finding face landmarks). In that case you should not worry about the `SOURCE_DIR` directory below.

Here, by default we will load the "Face 2 comics" dataset. Download the dataset from [Kaggle](https://www.kaggle.com/datasets/defileroff/comic-faces-paired-synthetic), unzip, and place the `face2comics_v1.0.0_by_Sxela` directory in your dataset directory. This is already a "pix2pix-friendly" dataset consisting, however, of pairs of images that are separated. We will use the images to create an "Edges to comics" dataset, where we apply edge detection to a subset of the source images and leave the corresponding comic version unchanged.

In [37]:
TARGET_DIR = 'datasets/face2comics_v1.0.0_by_Sxela/comics/'
SOURCE_DIR = None  # Only used if we already have source image examples, e.g. 'datasets/face2comics_v1.0.0_by_Sxela/face/'

DATASET_DIR = 'datasets/bw2comics'
os.makedirs(DATASET_DIR, exist_ok=True)

IS_INPUT_PIX_TO_PIX = False
INPUT_OUTPUT_TARGET_INDEX = 1 # 0: [target, source], 1: [source, target]

## Load the images to process

Now let's load our target images, and optionally our source images if we have set the `SOURCE_DIR` directory.

In [38]:
def load_image(path):
    size = (256, 256)
    if IS_INPUT_PIX_TO_PIX: # In case we are already loading a pix2pix image
        size = (256, 512)
    img = image.load_img(path, target_size=size)
    img = image.img_to_array(img)
    # If we are loading a pix2pix dataset just extract the target
    if IS_INPUT_PIX_TO_PIX:
        if INPUT_OUTPUT_TARGET_INDEX == 0:
            img = img[:,:size[0],:]
        else:
            img = img[:,size[0]:,:]
    return img.astype(np.uint8)

def load_images_in_path(path, shuffle=False, limit=0):
    fnames = glob.glob(os.path.join(path, "*"))
    print(f"Found {len(fnames)} files in '{path}'")    
    if limit > 0:
        fnames = fnames[:limit]
        print(f"Limiting number of files to {limit}")
    for f in fnames:
        yield load_image(f) # See this: https://realpython.com/introduction-to-python-generators/

In [None]:
if SOURCE_DIR:
    source_loader = iter(load_images_in_path(SOURCE_DIR)) 
    plt.imshow(next(source_loader))
    plt.show()

target_loader = iter(load_images_in_path(TARGET_DIR)) # create an iterator
plt.imshow(next(target_loader))
plt.show()

## Define our transformation



The code below has a number of transformations already setup for you. These are:

-  `apply_bw_cv2` Turns the picture to black and white (note: you need artificially to maintain the number of channels to three for the architecture to work).
-   `apply_canny_cv2` Applies Canny edge detection by using OpenCV. You can set two parameters (thresholds between 0 and 255) that will determine the result of the edge detection: `thresh1` and `thresh2`. Experiment with these values to adjust the results to your liking. Additional details can be seen [here](https://docs.opencv.org/4.x/dd/d1a/group__imgproc__feature.html#ga04723e007ed888ddf11d9ba04e2232de).
-   `apply_canny_skimage` Applies Canny edge detection by using [scikit-image](https://scikit-image.org). You can set one parameter, `sigma` that determines the number of edges. In general, a higher number will produce less edges. See [this](https://scikit-image.org/docs/stable/auto_examples/edges/plot_canny.html) for additional details.
-   `apply_face_landmarks` Would face landmarks in an image by using [mlxtend](http://rasbt.github.io/mlxtend/) and uses the Canvas API to draw these as polygons. Note that this function will fail if the face detector cannot find a face in the image. **Removed**, as the recent mlxtend update stopped supporting facial landmarks. A nice task could be to reimplement this using ChatGPT.
-   `load_source` simply gets the next source image from the directory specified in `SOURCE_DIR`.

Set the `image_transformation` in the code below to the function that describes the transformation you want to apply.
If you feel confident, you can extend this to other image transformations by duplicating one of the functions and adapting it to your needs.

In the example below we will use the `apply_canny_cv` function. This means that we will load the source (face image) from the input dataset, and apply edge detection to that image for constructing our dataset.

In [None]:
def apply_bw_cv2(img, thresh1=160, thresh2=250):
    grey_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    return cv2.merge([grey_img, grey_img, grey_img]) # Force three channels for shape compat, thanks ChatGPT!

def apply_canny_cv2(img, thresh1=160, thresh2=250, invert=False):
    grey_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    edges = cv2.Canny(grey_img, thresh1, thresh2)
    if invert:
        edges = cv2.bitwise_not(edges)
    return cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)

def apply_canny_skimage(img, sigma=1.5, invert=False):
    grey_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    edges = (feature.canny(grey_img, sigma=sigma)*255).astype(np.uint8)
    if invert:
        edges = cv2.bitwise_not(edges)
    return cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)

# IDEA: rewrite this function using the Mediapipe API instead of mlxtend
#       (ChatGPT I'm sure will be delighted to help you...)
def apply_face_landmarks(img, stroke_weight=2):
    raise NotImplementedError(
        """
        This function used mlxtend.image.extract_face_landmarks, which has been removed from the library.
        It would be worthwhile, if you are interested, to look at how to use the equivalent functionality
        in MediaPipe to get the logic below to work again. See the following notebook:
        https://github.com/googlesamples/mediapipe/blob/main/examples/face_landmarker/python/%5BMediaPipe_Python_Tasks%5D_Face_Landmarker.ipynb
        """
    )
    
    # # from py5canvas import canvas # if you installed canvas
    # import canvas # if you have the canvas.py in the current directory
    # from mlxtend.image import extract_face_landmarks

    # def landmark_polylines(landmarks):
    #     # https://pyimagesearch.com/2017/04/03/facial-landmarks-dlib-opencv-python/
    #     landmarks = np.array(landmarks).astype(np.float32)
    #     indices = [
    #         list(range(0, 17)),         # Jawline
    #         list(range(17, 22)),        # Left eyebrow
    #         list(range(22, 27)),        # Right eyebrow
    #         list(range(27, 31)),        # Nose bridge
    #         list(range(31, 36)),        # Lower nose
    #         list(range(36, 42)) + [36], # Left eye
    #         list(range(42, 48)) + [42], # Right eye
    #         list(range(48, 60)) + [48], # Outer lip
    #         list(range(60, 68)) + [60]  # Inner lip
    #     ]
    #     return [landmarks[I] for I in indices]

    # c = canvas.Canvas(256, 256)
    # c.background(0)
    # landmarks = extract_face_landmarks(img)
    # if landmarks is None:
    #     return None
    # c.stroke_weight(stroke_weight)
    # c.no_fill()
    # c.stroke(255)
    # paths = landmark_polylines(landmarks)
    # for path in paths:
    #     c.polyline(path)
    # return c.get_image()

# IDEA: It might be possible to use other Mediapipe functionalities, like:
#       - segmentation: https://developers.google.com/mediapipe/solutions/vision/image_segmenter
#       - pose landmarks: https://developers.google.com/mediapipe/solutions/vision/pose_landmarker
#       to write other transformation functions... (For both of those, you then need to find datasets!)

# IDEA: Use Canvas (or openCV) to remove parts of the image (draw a rectangle/circle somewhere)
#       so that the net learns to complete an image with a hole in it (inpainting)

# As it is, this version loads an image from the source_image directory and applies the Canny edge detection
# algorithm to it. Set transform=None if you just want to load that image without processing
def load_source(img, img_source_iterator):
    return next(img_source_iterator)

# Set this to the tranformation you want to apply. If you are only working with a single folder of images that you
# want to process, set image_transformation to one of the filtering operations above,
# e.g. 
image_transformation = apply_bw_cv2
# image_transformation = apply_canny_cv2
# image_transformation = apply_canny_skimage

# # If you are working with existing sources, you can use
# # a Python partial to assign a fixed argument to load_source
# # and use it exactly like the other image transformations
# # (See: https://docs.python.org/3/library/functools.html#functools.partial)
# from functools import partial
# image_transformation = partial(load_source, img_source_iterator=iter(load_images_in_path(SOURCE_DIR)))
  
img = next(load_images_in_path(TARGET_DIR))
plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(img)
plt.subplot(1, 2, 2)
plt.imshow(image_transformation(img))
plt.show()

## Create the dataset!

Here we loop through all the target images, generate the source image and stitch these together into a single image. The input image directory might contain more than the desired number of images. If we want to process a lower number, set the `num_images` variable to a non-zero value.

In [None]:
OUTPUT_TARGET_INDEX = 1 # 0: [target, source], 1: [source, target]

NUM_IMAGES = 0 # if you have a lot of images and want to parse only a fraction
    
img_loader = iter(load_images_in_path(TARGET_DIR, limit=NUM_IMAGES))

def combine_images(source, target):
    if OUTPUT_TARGET_INDEX == 1:
        combined = np.hstack([source, target])
    else:
        combined = np.hstack([target, source])
    return combined

if SOURCE_DIR:
    # CASE 1: we have source images, we want to combine them
    # create a source image iterator
    source_loader = iter(load_images_in_path(SOURCE_DIR, limit=num_images))   
    # loop over both source and target and save the combine images
    for i, (source, target) in enumerate(zip(source_loader, img_loader)):
        clear_output(wait=True)
        print(f"Processing image {i}")        
        # IDEA: you could also apply additional processing to either your 
        # source or target here
        combined = combine_images(source, target)
        io.imsave(os.path.join(DATASET_DIR, f"{i+1}.png"), combined) 
else:
    # CASE 2: we only have targets, we create the sources and combine them
    for i, target in enumerate(img_loader):
        clear_output(wait=True)
        print(f"Processing image {i}")
        
        source = image_transformation(target)
        if source is None:
            print(f"Failed to transform image {i+1}")
            continue

        combined = combine_images(source, target)
        io.imsave(os.path.join(DATASET_DIR, f"{i+1}.png"), combined)