The following cell sets up all the constants needed for the model to run. 

In [None]:
import numpy as np

BW_THRESHOLD = 20 # threshold for detecting colour; used to trim black edges & reject images

IMG_ROWS = 256 # height of input image
IMG_COLS = 256 # width of input image
NUM_CLASSES = 512 # number of bins to divide up the RGB colour space

BATCH_SIZE = 24 # size of mini-batch for training
EPOCHS = 2 # number of epochs to train
SAVE_INTERVAL = 500 # number of mini-batches before checkpointing model weights

NUM_NEIGHB = 5 # number of nearest neighbours for soft-encoding the probability of a bin
SIGMA_NEIGHB = 5 # standard deviation for soft-encoding
TO_REBAL = True # Boolean representing whether to use class rebalancing
#FACTORS = np.load('factors.npy') # factors for class rebalancing

EPSILON = 1e-4 # small value to prevent log(0)
TEMPERATURE = 0.8 # parameter for reconstructing colour from a probility distribution

# folders for images
LINE_PATH = 'lines' # greyscale outlines from processing, used to produce input for the network
COLOUR_PATH = 'colours' # colour image from processing, used to produce expected output for the network

# arrays for dividing up the RGB colour space into bins
CLASS_MAP_B = np.asarray(([32*i+16 for i in range(8)]*64))
CLASS_MAP_G = np.asarray(([32*int(i/8)+16 for i in range(64)]*8))
CLASS_MAP_R = np.asarray(([32*int(i/64)+16 for i in range(512)]))
CLASS_MAP = np.vstack((CLASS_MAP_B, CLASS_MAP_G, CLASS_MAP_R)).T


The following cell processes images from the dataset and writes output into 'LINE_PATH' and 'COLOUR_PATH'. This results in original colour illustrations being converted into 256x256 images suitable for the network.
Line drawings and corresponding colour images are written to 'LINE_PATH' and 'COLOUR_PATH' respectively.

In [None]:
import cv2
import os
import numpy as np

if not os.path.exists(LINE_PATH):
    os.makedirs(LINE_PATH)
if not os.path.exists(COLOUR_PATH):
    os.makedirs(COLOUR_PATH)

def trim(img):
    '''Takes an image and trims away black borders.'''
    # nothing
    if img.shape[0] == 0:
        return np.zeros((0, 0, 3))
    # trim above
    if np.mean(img[0]) < BW_THRESHOLD:
        return trim(img[1:])
    # trim below
    elif np.mean(img[-1]) < BW_THRESHOLD:
        return trim(img[:-1])
    # trim left
    elif np.mean(img[:, 0]) < BW_THRESHOLD:
        return trim(img[:, 1:])
    # trim right
    elif np.mean(img[:, -1]) < BW_THRESHOLD:
        return trim(img[:, :-1])
    return img

def process(img):
    '''Takes an image and produces the correctly formatted line and colour versions.'''
    # shrink dimensions
    ratio = size / min(img.shape[0], img.shape[1])
    colour = cv2.resize(img, (0, 0), fx=ratio, fy=ratio, interpolation=cv2.INTER_AREA)
    
    # convert to greyscale
    grey = cv2.cvtColor(colour, cv2.COLOR_BGR2GRAY)
    
    # erode, dilate & subtract
    kernel = np.ones((2, 2), np.uint8)
    erosion = cv2.erode(grey, kernel, iterations=1)
    dilation = cv2.dilate(grey, kernel, iterations=1)
    diff = cv2.absdiff(erosion, dilation)
    diff = cv2.multiply(diff, 2)
    line = cv2.bitwise_not(diff)
    
    return line, colour

dir_name = '/kaggle/input/tagged-anime-illustrations/danbooru-images/danbooru-images'
file_paths = []
for (dir_path, _, file_names) in os.walk(dir_name):
    file_paths += [os.path.join(dir_path, file) for file in file_names]

# process every jpg image
size = min(IMG_ROWS, IMG_COLS)
for index, file_path in enumerate(file_paths):
    img = cv2.imread(file_path, cv2.IMREAD_COLOR)
    
    # trim black borders
    img = trim(img)
    
    # check if image is too small
    h, w, c = img.shape
    h = float(h)
    w = float(w)
    if h < size or w < size:
        # print(file_path, ': small')
        continue
    
    # create cropped square versions (tiles)
    num_slides, hskip, wskip, min_dim = 1, 0, 0, h
    if h > w:
        num_slides = np.ceil(h/w)
        hskip = w - (w*num_slides-h) / (num_slides-1)
        min_dim = w
    elif h < w:
        num_slides = np.ceil(w/h)
        wskip = h - (h*num_slides-w) / (num_slides-1)
        min_dim = h
        
    imgs = []
    for n in range(int(num_slides)):
        imgs.append(img[int(hskip*n):int(hskip*n+min_dim), \
                        int(wskip*n):int(wskip*n+min_dim)])
        
    for crop_num, img in enumerate(imgs):
        img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
        # check if image is too dull using mean saturation
        mean_sat = cv2.mean(img_hsv)
        if mean_sat[1] < BW_THRESHOLD:
            # print(file_path, ': dull')
            continue
        
        # image is suitable for processing
        line, colour = process(img)
        
        # write processed images to file
        out_name = str(index) + '-' + str(crop_num) + '.jpg'
        cv2.imwrite(os.path.join(LINE_PATH, out_name), line)
        cv2.imwrite(os.path.join(COLOUR_PATH, out_name), colour)

The following cell performs the final tasks of handling the image data. Specifically, the cell:
* Defines a function 'soft_encode', which takes in a colour image and produces a probability distribution for each pixel. Each pixel's probability distribution represents the chance that the pixel's colour belongs to a particular class defined by the CLASS_MAP. This represents the expected output from the network. Optionally, class rebalancing is applied to the probability distribution.
* Specifies a custom DataGenerator for producing mini-batches of data; needed to reduce RAM requirements.
* Splits up the data set into training and validation sets by assigning the file names into 'train_names.txt' or 'valid_names.txt' respectively.

In [None]:
import os
import cv2
import numpy as np
import random
import sklearn.neighbors as nn
from keras.utils import Sequence

def soft_encode(bgr):
    ''' Takes in a colour image and produces a 'soft-encoded' probability 
    distribution of colour classes for each pixel. Also rebalances each pixel 
    based on how common its colour class is.'''
    # arrange image array (assumes BGR because of cv2 peculiarities)
    h, w = bgr.shape[:2]
    pixels = h * w
    bgr = bgr.reshape((pixels, 3))
    
    # for each pixel, find the 5 nearest neighbours with respect to CLASS_MAP
    neighb_finder = nn.NearestNeighbors(NUM_NEIGHB, algorithm='ball_tree').fit(CLASS_MAP)
    dist_neighb, index_neighb = neighb_finder.kneighbors(bgr)
    
    # smooth the distances using a gaussian kernel
    weights = np.exp(-dist_neighb ** 2 / (2 * SIGMA_NEIGHB ** 2))
    weights = weights / np.sum(weights, axis=1, keepdims=True)
    
    # multiply the weights by a factor for class rebalancing
    if TO_REBAL:
        for i in range(pixels):
            one_hot = index_neighb[i][0]
            weights[i] *= FACTORS[one_hot]
        
    # format the probability distribution
    encode = np.zeros((pixels, NUM_CLASSES))
    index_pts = np.arange(pixels).reshape(pixels, 1)
    encode[index_pts, index_neighb] = weights
    encode = encode.reshape(h, w, NUM_CLASSES)
    
    return encode

class DataGenerator(Sequence):
    ''' Produces mini-batches for training & evaluation, since the dataset 
    takes too much memory to load at once'''
    def __init__(self, names_file):
        with open(names_file, 'r') as f:
            self.names = f.read().splitlines()
        np.random.shuffle(self.names)

    def __len__(self):
        return int(np.ceil(len(self.names) / float(BATCH_SIZE)))

    def __getitem__(self, index):
        start = index * BATCH_SIZE

        out_rows, out_cols = IMG_ROWS // 4, IMG_COLS // 4

        length = min(BATCH_SIZE, (len(self.names) - start))
        X = np.empty((length, IMG_ROWS, IMG_COLS, 1), dtype=np.float32)
        y = np.empty((length, out_rows, out_cols, NUM_CLASSES), dtype=np.float32)

        for i in range(length):
            name = self.names[start]
            line_filename = os.path.join(LINE_PATH, name)
            colour_filename = os.path.join(COLOUR_PATH, name)
            
            line = cv2.imread(line_filename, cv2.IMREAD_GRAYSCALE)
            line = (line-0) / 255 # substract mean
            
            bgr = cv2.imread(colour_filename, cv2.IMREAD_COLOR)
            bgr = cv2.resize(bgr, (out_rows, out_cols), cv2.INTER_AREA)
            encode = soft_encode(bgr)
            
            # 50% chance to randomly flip the image
            if np.random.random_sample() > 0.5:
                line = np.fliplr(line)
                encode = np.fliplr(encode)

            X[i, :, :, 0] = line
            y[i] = encode
            
            start += 1

        return X, y

    def on_epoch_end(self):
        np.random.shuffle(self.names)
    
def split_data():
    ''' Splits the data set into training & validation sets.'''
    names = [f for f in os.listdir(LINE_PATH) if f.lower().endswith('.jpg')]

    num_samples = len(names) # 351204
    num_train_samples = int(num_samples * 0.992) # 348394
    num_valid_samples = num_samples - num_train_samples # 2810
    print('num samples: ' + str(num_samples))
    print('num train samples: ' + str(num_train_samples))
    print('num valid samples: ' + str(num_valid_samples))
    
    valid_names = random.sample(names, num_valid_samples)
    train_names = [n for n in names if n not in valid_names]
    np.random.shuffle(train_names)
    np.random.shuffle(valid_names)

    with open('train_names.txt', 'w') as file:
        file.write('\n'.join(train_names))

    with open('valid_names.txt', 'w') as file:
        file.write('\n'.join(valid_names))
        
split_data()


The following cell defines the network architecture. The network takes a 256x256x1 input array which represents a 256x256 grayscale line drawing. After several convolutional blocks, the network outputs a 64x64x512 array, which represents a 64x64 square of pixels, each containing a probability distribution over the 512 classes that the colour space has been divided into. This is later upsampled to produce a colour prediction.

In [None]:
from keras.layers import Input, Conv2D, BatchNormalization, UpSampling2D
from keras.models import Model

def build_model():
    inputs = Input(shape=(IMG_ROWS, IMG_COLS, 1))
    
    x = Conv2D(64, (3, 3), activation='relu', padding='same', name='conv1_1',
               kernel_initializer="he_normal")(inputs)
    x = Conv2D(64, (3, 3), activation='relu', padding='same', name='conv1_2',
               kernel_initializer="he_normal", strides=(2, 2))(x)
    x = BatchNormalization()(x)

    x = Conv2D(128, (3, 3), activation='relu', padding='same', name='conv2_1',
               kernel_initializer="he_normal")(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same', name='conv2_2',
               kernel_initializer="he_normal", strides=(2, 2))(x)
    x = BatchNormalization()(x)

    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='conv3_1',
               kernel_initializer="he_normal")(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='conv3_2',
               kernel_initializer="he_normal")(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='conv3_3',
               kernel_initializer="he_normal", strides=(2, 2))(x)
    x = BatchNormalization()(x)

    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='conv4_1',
               kernel_initializer="he_normal")(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='conv4_2',
               kernel_initializer="he_normal")(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='conv4_3',
               kernel_initializer="he_normal")(x)
    x = BatchNormalization()(x)

    x = Conv2D(512, (3, 3), activation='relu', padding='same', dilation_rate=2, name='conv5_1',
               kernel_initializer="he_normal")(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', dilation_rate=2, name='conv5_2',
               kernel_initializer="he_normal")(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', dilation_rate=2, name='conv5_3',
               kernel_initializer="he_normal")(x)
    x = BatchNormalization()(x)

    x = Conv2D(512, (3, 3), activation='relu', padding='same', dilation_rate=2, name='conv6_1',
               kernel_initializer="he_normal")(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', dilation_rate=2, name='conv6_2',
               kernel_initializer="he_normal")(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', dilation_rate=2, name='conv6_3',
               kernel_initializer="he_normal")(x)
    x = BatchNormalization()(x)

    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='conv7_1',
               kernel_initializer="he_normal")(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='conv7_2',
               kernel_initializer="he_normal")(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='conv7_3',
               kernel_initializer="he_normal")(x)
    x = BatchNormalization()(x)

    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='conv8_1',
               kernel_initializer="he_normal")(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='conv8_2',
               kernel_initializer="he_normal")(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='conv8_3',
               kernel_initializer="he_normal")(x)

    outputs = Conv2D(NUM_CLASSES, (1, 1), activation='softmax', padding='same', name='pred')(x)

    model = Model(inputs=inputs, outputs=outputs, name="ClassificationModel")
    return model


The following cell builds, compiles and fits (i.e. trains) a network. The custom callbacks are used to save model weights after every nth batch (as specified by SAVE_INTERVAL), and after every training epoch.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import keras

class loss_tracker(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []
        self.cumul_batches = 0

    def on_batch_end(self, batch, logs={}):
        loss = logs.get('loss')
        self.losses.append(loss)
        # save weights every nth batch
        self.cumul_batches += 1
        if self.cumul_batches % SAVE_INTERVAL == 0:
            self.model.save_weights('model_batch{}.hd5'.format(self.cumul_batches))
        
class save_epoch(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        self.model.save_weights('model_epoch{}.hd5'.format(epoch))

# build and compile a new model
new_model = build_model()
new_model.compile(optimizer=keras.optimizers.Adam(), loss='categorical_crossentropy')

# set up callbacks
loss_tracker = loss_tracker()
save_epoch = save_epoch()
callbacks = [loss_tracker, save_epoch]

print(new_model.summary())

# set up generators for training and validation data
train_gen = DataGenerator('train_names.txt')
valid_gen = DataGenerator('valid_names.txt')

# train a network model
new_model.fit_generator(generator=train_gen,
                        validation_data=valid_gen,
                        epochs=EPOCHS,
                        verbose=1,
                        callbacks=callbacks
                        )

keras.backend.clear_session()

The following cell evaluates each saved checkpoint, and determines the one with the lowest validation loss.

In [None]:
# evaluate each set of saved weights to determine validation loss
losses = loss_tracker.losses.copy()
length = len(losses)
val_losses = []
results = []
saved_weights = [('model_batch'+str(i)+'.hd5') for i in range(SAVE_INTERVAL, length, SAVE_INTERVAL)]
for weights in saved_weights:
    test = build_model()
    test.compile(optimizer=keras.optimizers.Adam(), loss='categorical_crossentropy')
    test.load_weights(weights)
    val_loss = test.evaluate_generator(valid_gen, verbose=1)
    val_losses.append(val_loss)
    results.append(weights, val_loss)
    print(weights, ':', val_loss)
    keras.backend.clear_session()
sorted_results = sorted(results, key=lambda x: x[1])
val_loss_best = sorted_results[0][1]
best_model = sorted_results[0][0]


The following cell plots a graph showing the training loss throughout training and the validation loss of each saved checkpoint.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# generate loss graph
x = np.arange(length)
x_interval = np.arange(SAVE_INTERVAL, length, SAVE_INTERVAL)
x_best = val_losses.index(val_loss_best) * SAVE_INTERVAL + SAVE_INTERVAL

fig, ax = plt.subplots(figsize=(12, 9))

p1 = ax.plot(x, losses, linewidth=0.2, zorder=1)
p2 = ax.scatter(x_interval, val_losses, color='r', marker='o', zorder=2)
p3 = ax.scatter(x_best, val_loss_best, color='lime',  marker='o', zorder=3)

y_min = np.floor(min(losses))
y_max = np.ceil(max(losses))
ax.set_ylim([y_min, y_max])
ax.set_xlabel('batch number')
ax.set_ylabel('loss')
ax.set_title('Model loss')
ax.legend([p1[0], p2, p3], ['training loss', 'validation loss', 'lowest validation loss'])

plt.savefig('graph.png')
plt.show()

The following cell:
* Loads a trained model
* Reads in line drawings from 'test_X'
* Uses the model to make colour predictions and writes them into 'test_y'

In [None]:
import os
import cv2
import numpy as np
from keras.models import load_model

def get_prediction(model, test):
    holder = np.empty((1, IMG_ROWS, IMG_COLS, 1))
    holder[0, :, :, 0] = test / 255
    
    z = model.predict(holder)[0]
    z = np.reshape(z, (64**2, NUM_CLASSES))

    probs = np.exp(np.log(z + EPSILON) / TEMPERATURE)
    probs = probs / np.sum(probs, axis=1, keepdims=True)

    out_img = np.stack((np.sum(CLASS_MAP_B * probs, axis=1), 
                        np.sum(CLASS_MAP_G * probs, axis=1), 
                        np.sum(CLASS_MAP_R * probs, axis=1)), axis=1)

    out_img = np.reshape(out_img, (64, 64, 3))
    out_img = cv2.resize(out_img, (IMG_ROWS, IMG_COLS), interpolation=cv2.INTER_LINEAR)
    
    return out_img

new_model = load_model(best_model) # THIS CAN BE CHANGED AS NECESSARY
X = 'test_X'
y = 'test_y'

if not os.path.exists(y):
    os.makedirs(y)

for filename in os.listdir(X):
    test = cv2.imread(os.path.join(X, filename), cv2.IMREAD_GRAYSCALE)
    lines = cv2.imread(os.path.join(X, filename), cv2.IMREAD_COLOR)
    
    predicted_colours = get_prediction(new_model, test)
    
    # overlay lines on top of predicted colours
    result = predicted_colours * (lines / 255)
    
    cv2.imwrite(os.path.join(y, filename), result)
