# **Train & Test Transform**
Resacle and Random Crop the image for generalization. <br/>
For testing transform, Fix crop position to get consistent input.

In [0]:
train_transform = transforms.Compose([
    transforms.Scale(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip()
])

test_transform = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224)
])

# **Attribute Dict**
Use arg dict as a compact input to training/testing.

In [0]:
from torchvision import datasets, transforms
from skimage.color import rgb2lab
from skimage.transform import resize
from skimage import color
from PIL import Image

import torch.utils.data as data
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import sklearn.neighbors as nbrs

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self
        



# **Color Rebalancing**

In [0]:
class ColorRebal(object):
    def __init__(self, gamma):
        self.gamma = gamma
        self.prior_probs = np.load('./resources/probs.npy')
        self.uni_probs = np.ones(self.prior_probs.shape)
        self.uni_probs = self.uni_probs / np.sum(self.uni_probs)

        self.prior_factor = ((1 - self.gamma) * self.prior_probs + self.gamma * self.uni_probs) ** (-1)
        self.prior_factor = self.prior_factor / np.sum(self.prior_probs * self.prior_factor)  # re-normalize
        np.save('./resources/rebal_probs.npy', self.prior_factor)

    def forward(self, max_encode):
        corr_factor = self.prior_factor[max_encode]
        return np.expand_dims(corr_factor, axis=1)


# **Encode Layer**
Encode image from lab color space to corresponding color class.

In [0]:
class Encode():
    def __init__(self, NN, sigma, batch_size, km_filepath):
        self.cc = np.load(km_filepath) # pts.npy (199,2)
        self.num_colors = self.cc.shape[0] # 199
        self.NN = int(NN)
        self.sigma = sigma
        self.nbrs = nbrs.NearestNeighbors(n_neighbors=NN, algorithm='ball_tree').fit(self.cc)
        self.encode_vec = np.zeros((56 * 56 * batch_size, self.num_colors))

    def encode(self, pts_origin):
        flat_pts = flatten(pts_origin)
        self.encode_vec[...] = 0 # (125440, 199)
        (d, indices) = self.nbrs.kneighbors(flat_pts)

        # print(dists.shape) # dist to 32 nearest bins
        # print(inds[0])  # 32 nearest bins

        weights = np.exp(-d ** 2 / (2 * self.sigma ** 2))
        weights = weights / np.expand_dims(np.sum(weights, axis=1), axis=-1) # softmax of gaussian (125440, 32)
        pts_ind = np.expand_dims(np.arange(0, flat_pts.shape[0], dtype='int'), axis=-1)
        self.encode_vec[pts_ind, indices] = weights

        encode_origin_shape = restore(self.encode_vec, pts_origin)
        return encode_origin_shape

class EncodeMax(object):
    def __init__(self, NN, sigma, batch_size):
        self.NN = NN
        self.sigma = sigma
        self.nnenc = Encode(self.NN, self.sigma, batch_size, km_filepath='./resources/pts.npy')

    def forward(self, x):
        encode = self.nnenc.encode(x) # (40, 199, 56, 56)
        max_encode = np.argmax(encode,axis=1).astype(np.int32) # (40, 56, 56)
        return encode, max_encode

# **GrayScale Mask**
Screen the black and white images by void their contribution

In [0]:
class GrayScaleMask(object):
    def forward(self, bottom):
        bottom = bottom.numpy()
        s1 = (np.abs(bottom) > 5).astype('float')
        s2 = np.sum(s1, axis=1)
        s3 = np.sum(s2, axis=1)
        s4 = np.sum(s3, axis=1)
        out = (s4 > 0)[:, np.newaxis, np.newaxis, np.newaxis].astype('float')
        return out