In [None]:
import os, sys
import shutil
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
project_path = '/content/gdrive/MyDrive/master/true_2.5D_sketches_generator/2.5D_sketches_256'
sys.path.append(project_path)

Mounted at /content/gdrive


In [None]:
import tensorflow as tf
from tensorflow.keras import layers

import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt

In [None]:
# reference: https://keras.io/examples/vision/depth_estimation/
class DataLoader(tf.keras.utils.Sequence):
  def __init__(self, data, batch_size=6, dim=(256, 256), n_channels=3, shuffle=True):
    """
    Initialization
    """
    self.data = data
    self.indices = self.data.index.tolist()
    self.dim = dim
    self.n_channels = n_channels
    self.batch_size = batch_size
    self.shuffle = shuffle
    self.min_depth = 0.1
    self.on_epoch_end()

  def __len__(self):
    return int(np.ceil(len(self.data) / self.batch_size))

  def __getitem__(self, index):
    # modify batch size of last batch
    if (index + 1) * self.batch_size > len(self.indices):
        self.batch_size = len(self.indices) - index * self.batch_size
    # Generate one batch of data
    # Generate indices of the batch
    index = self.indices[index * self.batch_size : (index + 1) * self.batch_size]
    # Find list of IDs
    batch = [self.indices[k] for k in index]
    x, y = self.load_batch(batch)

    return x, y


  def on_epoch_end(self):
    """
    Updates indexes after each epoch
    """
    self.index = np.arange(len(self.indices))
    if self.shuffle == True:
        np.random.shuffle(self.index)

  def load(self, image_path, sketch_path):
    """
    Load image and 2.5D sketch pair.
    """
    img = cv2.imread(image_path)
    sketch = cv2.imread(sketch_path)

    return img, sketch

  def load_batch(self, batch):
    """
    Load one batch of data.
    """
    x = np.empty((self.batch_size, *self.dim, self.n_channels))
    y = np.empty((self.batch_size, *self.dim, self.n_channels))

    for i, batch_id in enumerate(batch):
      x[i,], y[i,] = self.load(
        self.data["image"][batch_id],
        self.data["normal"][batch_id],
      )

    return x, y