In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
from skimage.color import rgb2lab, lab2rgb

# scale to 256 * 256 then perform random crop to 224 * 224
transform_train = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

# scale to 224 * 224
transform_val = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor()
])

class ColorizationDataset(datasets.ImageFolder):
    '''
    Custom Dataset for colorization
    '''
    
    def __getitem__(self, index):
        sample, target = super(ColorizationDataset, self).__getitem__(index)
        # color range L: [0,100], a&b: [-128, 127]
        lab = rgb2lab(sample.permute(1, 2, 0))
        l = lab[:, :, 0:1]
        # normalize a&b to [0, 1]
        ab = (lab[:, :, 1:] + 128) / 255
        return l, ab, target

data_root = '/Users/yuli/Downloads/places365_standard'
traindir = os.path.join(data_root, 'train')
trainset = ColorizationDataset(traindir, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True, num_workers=4)

valdir = os.path.join(data_root, 'val')
valset = ColorizationDataset(valdir, transform=transform_val)
valloader = torch.utils.data.DataLoader(valset, batch_size=16, shuffle=False, num_workers=4)

    
dataiter = iter(trainloader)
lchannels, abchannels, labels = dataiter.next()


def colorizationLoss(coloroutputs, colorlabels, classoutputs, classlabels):
    '''
    Calculate the loss for the colorization network.
    The colorization loss affects the entire network.
    '''
    # Choose an alpha so that the MSE loss and the cross-entropy loss
    # are similar in magnitude.
    alpha = 1.0 / 300 
    # Calculate the MSE loss for the colorization network
    colorLoss = nn.MSELoss()(coloroutputs, colorlabels)
    # Calculate the classification loss
    classLoss = classificationLoss(classoutputs, classlabels)
    return colorLoss - alpha * classLoss

def classificationLoss(outputs, labels):
    '''
    Calculate the cross-entropy loss for the classification network.
    The classification loss affects the low-level features network,
    global features network, and the classification network.
    '''
    return nn.CrossEntropyLoss()(outputs, labels)
