In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
from skimage.color import rgb2lab, lab2rgb

# scale to 256 * 256 then perform random crop to 224 * 224
transform_train = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

# scale to 224 * 224
transform_val = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor()
])

class ColorizationDataset(datasets.ImageFolder):
    '''
    Custom Dataset for colorization
    '''
    
    def __getitem__(self, index):
        sample, target = super(ColorizationDataset, self).__getitem__(index)
        # color range L: [0,100], a&b: [-128, 127]
        lab = rgb2lab(sample.permute(1, 2, 0))
        l = lab[:, :, 0:1]
        # normalize a&b to [0, 1]
        ab = (lab[:, :, 1:] + 128) / 255
        return l, ab, target

data_root = '/Users/yuli/Downloads/places365_standard'
train_dir = os.path.join(data_root, 'train')
train_set = ColorizationDataset(train_dir, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=16, shuffle=True, num_workers=4)

val_dir = os.path.join(data_root, 'val')
val_set = ColorizationDataset(val_dir, transform=transform_val)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=16, shuffle=False, num_workers=4)

    
dataiter = iter(train_loader)
lchannels, abchannels, labels = dataiter.next()


def classificationLoss(outputs, labels):
    '''
    Calculate the cross-entropy loss for the classification network.
    The classification loss affects the low-level features network,
    global features network, and the classification network.
    '''
    return nn.CrossEntropyLoss()(outputs, labels)


def colorizationLoss(coloroutputs, colorlabels, classoutputs, classlabels):
    '''
    Calculate the loss for the colorization network.
    The colorization loss affects the entire network.
    '''
    # Choose an alpha so that the MSE loss and the cross-entropy loss
    # are similar in magnitude.
    alpha = 1.0 / 300 
    # Calculate the MSE loss for the colorization network
    colorLoss = nn.MSELoss()(coloroutputs, colorlabels)
    # Calculate the classification loss
    classLoss = classificationLoss(classoutputs, classlabels)
    return colorLoss - alpha * classLoss


def train(train_loader, color_model, class_model, color_optimizer, class_optimizer):
    '''
    Train both colorization and classification models. 
    '''
    color_model.train()
    class_model.train()
    count = 0
    loss_sum = 0.0
    for i, (inputs, colors, labels) in enumerate(train_loader):
        # predict and compute loss
        color_preds = color_model(inputs)
        label_preds = class_model(inputs)
        loss = colorizationLoss(color_preds, colors, label_preds, labels)
        
        # record loss
        loss_sum += loss.item()
        count += inputs.size(0)
        
        # compute gradient and do Adadelta step for both colorization and classification networks.
        color_optimizer.zero_grad()
        class_optimizer.zero_grad()
        loss.backward()
        color_optimizer.step()
        class_optimizer.step()
    
    return loss_sum / count
        
def validate(val_loader, color_model, class_model):
    '''
    Evaluate both colorization and classification models.
    '''
    color_model.eval()
    class_model.eval()
    count = 0
    loss_sum = 0.0
    with torch.no_grad():
        for i, (inputs, colors, labels) in enumerate(val_loader):
            # predict and compute losses
            color_preds = color_model(inputs)
            label_preds = class_model(inputs)
            loss = colorizationLoss(color_preds, colors, label_preds, labels)
            
            # record loss
            loss_sum += loss.item()
            count += inputs.size(0)
    
    return loss_sum / count


def save_checkpoint(state, filename='checkpoint.pth.tar'):
    '''
    Save checkpoint for the given state. 
    '''
    torch.save(state, filename)

def load_checkpoint(filename='checkpoint.pth.tar'):
    '''
    Load checkpoint from file. 
    '''
    if os.path.isfile(filename):
        return torch.load(filename)
    else:
        return None

# Initialize models and optimizers
# color_model = 
# class_model =
color_optimizer = optim.Adadelta(color_model.parameters(), lr=1)
class_optimizer = optim.Adadelta(class_model.parameters(), lr=1)
start_epoch = 0
total_epochs = 10
best_loss = float("inf")
losses = []
# Load from checkpoint if exists
checkpoint = load_checkpoint()
if checkpoint:
    start_epoch = checkpoint['epoch']
    losses = checkpoint['losses']
    color_model.load_state_dict(checkpoint['color_state_dict'])
    class_model.load_state_dict(checkpoint['class_state_dict'])
    color_optimizer.load_state_dict(checkpoint['color_optimizer'])
    class_optimizer.load_state_dict(checkpoint['class_optimizer'])
for epoch in range(start_epoch, total_epochs):
    if args.distributed:
        train_sampler.set_epoch(epoch)
    # train for one epoch
    train_loss = train(train_loader, color_model, class_model, color_optimizer, class_optimizer)

    # evaluate on validation set
    val_loss = validate(val_loader, color_model, class_model)
    losses.append(val_loss)
    new_checkpoint = {
        'epoch': epoch + 1,
        'losses': losses
        'color_state_dict': color_model.state_dict(),
        'class_state_dict': class_model.state_dict(),
        'color_optimizer' : color_optimizer.state_dict(),
        'class_optimizer' : class_optimizer.state_dict()
    }
    if val_loss < best_loss:
        best_loss = val_loss
        # save the best checkpoint
        save_checkpoint(new_checkpoint, 'best_checkpoint.pth.tar')
    # save checkpoint
    save_checkpoint(new_checkpoint)
