In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
from model import Net
from skimage.color import rgb2lab, lab2rgb
from torch.autograd import Variable

# scale to 256 * 256 then perform random crop to 224 * 224
transform_train = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

# scale to 224 * 224
transform_val = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor()
])

class ColorizationDataset(datasets.ImageFolder):
    '''
    Custom Dataset for colorization
    '''
    def __getitem__(self, index):
        sample, target = super(ColorizationDataset, self).__getitem__(index)
        # color range L: [0,100], a&b: [-128, 127]
        lab = np.float32(rgb2lab(sample.permute(1, 2, 0)))
        l = lab[:, :, 0:1]
        # normalize a&b to [0, 1]
        ab = (lab[:, :, 1:] + 128) / 255
        return l, ab, target
    
class ColorizationDatasetCIFAR10(datasets.CIFAR10):
    '''
    Custom Dataset for colorization
    '''
    def __getitem__(self, index):
        sample, target = super(ColorizationDatasetCIFAR10, self).__getitem__(index)
        # color range L: [0,100], a&b: [-128, 127]
        lab = np.float32(rgb2lab(sample.permute(1, 2, 0)))
        l = lab[:, :, 0:1]
        # normalize a&b to [0, 1]
        ab = (lab[:, :, 1:] + 128) / 255
        return l, ab, target
    
data_root = '/home/shared/places365_standard'
train_dir = os.path.join(data_root, 'train')
train_set = ColorizationDataset(train_dir, transform=transform_train)
#train_set = ColorizationDatasetCIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4)

val_dir = os.path.join(data_root, 'val')
val_set = ColorizationDataset(val_dir, transform=transform_val)
#val_set = ColorizationDatasetCIFAR10(root='./data', train=False, download=True, transform=transform_val)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=128, shuffle=False, num_workers=4)

def classificationLoss(outputs, labels):
    '''
    Calculate the cross-entropy loss for the classification network.
    The classification loss affects the low-level features network,
    global features network, and the classification network.
    '''
    return nn.CrossEntropyLoss()(outputs, labels)


def colorizationLoss(coloroutputs, colorlabels, classoutputs, classlabels):
    '''
    Calculate the loss for the colorization network.
    The colorization loss affects the entire network.
    '''
    # Choose an alpha so that the MSE loss and the cross-entropy loss
    # are similar in magnitude.
    alpha = 1.0 / 300 
    # Calculate the MSE loss for the colorization network
    colorLoss = nn.MSELoss()(coloroutputs, colorlabels)
    # Calculate the classification loss
    classLoss = classificationLoss(classoutputs, classlabels)
    # Calculate std loss
    stdLoss = 0.0001/coloroutputs[0,:,:,0].std() + 0.0001/coloroutputs[0,:,:,1].std()
    return colorLoss + alpha * classLoss + stdLoss


def train(train_loader, model, optimizer, use_gpu):
    '''
    Train both colorization and classification networks. 
    '''
    model.train()
    count = 0
    loss_sum = 0.0
    running_loss_sum = 0.0
    best_running_loss = float("inf")
    running_losses = []
    for i, (inputs, colors, labels) in enumerate(train_loader):
        if use_gpu:
            inputs, colors, labels = Variable(inputs.cuda()), Variable(colors.cuda()), Variable(labels.cuda())
        else:
            inputs = inputs.float()
            colors = colors.float()
        # predict and compute loss
        color_preds, label_preds = model(inputs)
        loss = colorizationLoss(color_preds, colors, label_preds, labels)
        
        # record loss
        loss_sum += loss.item()
        count += inputs.size(0)
        running_loss_sum += loss.item()
        
        # compute gradient and do Adadelta step for both colorization and classification networks.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # print statistics and save checkpoint
        step_size = 200
        if i % step_size == (step_size - 1):    # print every step_size mini-batches
            running_loss = running_loss_sum / step_size
            running_losses.append(running_loss)
            print('[%d, %5d] running loss: %.5f' % (epoch + 1, i + 1, running_loss))
            running_checkpoint = {
                'epoch': epoch + 1,
                'iteration': i,
                'running_losses': running_losses,
                'state_dict': model.state_dict(),
                'optimizer' : optimizer.state_dict()
            }
            # save the running checkpoint
            if running_loss < best_running_loss:
                best_running_loss = running_loss
                # save the best running checkpoint
                print('Saving the new best running checkpoint')
                save_checkpoint(running_checkpoint, 'best_running_checkpoint_improv.pth.tar')
            print('Saving the new running checkpoint')
            save_checkpoint(running_checkpoint, 'running_checkpoint_improv.pth.tar')
            running_loss_sum = 0.0
            
    return loss_sum / count
        
def validate(val_loader, model, use_gpu):
    '''
    Evaluate both colorization and classification networks.
    '''
    model.eval()
    count = 0
    loss_sum = 0.0
    with torch.no_grad():
        for i, (inputs, colors, labels) in enumerate(val_loader):
            if use_gpu:
                inputs, colors, labels = Variable(inputs.cuda()), Variable(colors.cuda()), Variable(labels.cuda())
            else:
                inputs = inputs.float()
                colors = colors.float()
            # predict and compute losses
            color_preds, label_preds = model(inputs)
            loss = colorizationLoss(color_preds, colors, label_preds, labels)
            
            # record loss
            loss_sum += loss.item()
            count += inputs.size(0)
    
    return loss_sum / count


def save_checkpoint(state, filename='checkpoint_improv.pth.tar'):
    '''
    Save checkpoint for the given state. 
    '''
    torch.save(state, filename)

def load_checkpoint(filename='checkpoint_improv.pth.tar'):
    '''
    Load checkpoint from file. 
    '''
    if os.path.isfile(filename):
        return torch.load(filename)
    else:
        return None
    
# Checking GPU availability
use_gpu = torch.cuda.is_available()

# Initialize model and optimizer
model = Net()
if use_gpu:
    model = model.cuda()
    print('Use GPU')
else:
    print('Use CPU')
# Use ADADELTA optimizer    
optimizer = optim.Adadelta(model.parameters(), lr=1)
start_epoch = 0
total_epochs = 20
best_loss = float("inf")
train_losses = []
val_losses = []
# Load from checkpoint if exists
checkpoint = load_checkpoint()
if checkpoint:
    start_epoch = checkpoint['epoch']
    train_losses = checkpoint['train_losses']
    val_losses = checkpoint['val_losses']
    best_loss = min(val_losses)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    print('Loaded checkpoint from epoch %d' % start_epoch)
for epoch in range(start_epoch, total_epochs):
    print('Start training for epoch %d' % (epoch + 1))
    # train for one epoch
    train_loss = train(train_loader, model, optimizer, use_gpu)
    train_losses.append(train_loss)
    print('Training loss for epoch %d: %.5f' % (epoch + 1, train_loss))
    
    # evaluate on validation set
    print('Start validation for epoch %d' % (epoch + 1))
    val_loss = validate(val_loader, model, use_gpu)
    val_losses.append(val_loss)
    print('Validation loss for epoch %d: %.5f' % (epoch + 1, val_loss))
    
    new_checkpoint = {
        'epoch': epoch + 1,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'state_dict': model.state_dict(),
        'optimizer' : optimizer.state_dict()
    }
    if val_loss < best_loss:
        best_loss = val_loss
        # save the best checkpoint
        print('Saving the new best checkpoint')
        save_checkpoint(new_checkpoint, 'best_checkpoint_improv.pth.tar')
    # save checkpoint
    print('Saving the checkpoint for epoch %d' % (epoch + 1))
    save_checkpoint(new_checkpoint)
    save_checkpoint(new_checkpoint, 'checkpoint%d_improv.pth.tar' % (epoch + 1))


Use GPU
Loaded checkpoint from epoch 11
Start training for epoch 12
[12,   200] running loss: 0.01549
Saving the new best running checkpoint
Saving the new running checkpoint
[12,   400] running loss: 0.01555
Saving the new running checkpoint
[12,   600] running loss: 0.01548
Saving the new best running checkpoint
Saving the new running checkpoint
[12,   800] running loss: 0.01547
Saving the new best running checkpoint
Saving the new running checkpoint
[12,  1000] running loss: 0.01546
Saving the new best running checkpoint
Saving the new running checkpoint
[12,  1200] running loss: 0.01547
Saving the new running checkpoint
[12,  1400] running loss: 0.01555
Saving the new running checkpoint
[12,  1600] running loss: 0.01558
Saving the new running checkpoint
[12,  1800] running loss: 0.01544
Saving the new best running checkpoint
Saving the new running checkpoint
[12,  2000] running loss: 0.01554
Saving the new running checkpoint
[12,  2200] running loss: 0.01537
Saving the new best run

[13,  8200] running loss: 0.01530
Saving the new running checkpoint
[13,  8400] running loss: 0.01524
Saving the new running checkpoint
[13,  8600] running loss: 0.01523
Saving the new running checkpoint
[13,  8800] running loss: 0.01529
Saving the new running checkpoint
[13,  9000] running loss: 0.01527
Saving the new running checkpoint
[13,  9200] running loss: 0.01519
Saving the new running checkpoint
[13,  9400] running loss: 0.01520
Saving the new running checkpoint
[13,  9600] running loss: 0.01528
Saving the new running checkpoint
[13,  9800] running loss: 0.01540
Saving the new running checkpoint
[13, 10000] running loss: 0.01522
Saving the new running checkpoint
[13, 10200] running loss: 0.01533
Saving the new running checkpoint
[13, 10400] running loss: 0.01537
Saving the new running checkpoint
[13, 10600] running loss: 0.01532
Saving the new running checkpoint
[13, 10800] running loss: 0.01525
Saving the new running checkpoint
[13, 11000] running loss: 0.01530
Saving the new

[15,  2600] running loss: 0.01502
Saving the new running checkpoint
[15,  2800] running loss: 0.01497
Saving the new best running checkpoint
Saving the new running checkpoint
[15,  3000] running loss: 0.01508
Saving the new running checkpoint
[15,  3200] running loss: 0.01507
Saving the new running checkpoint
[15,  3400] running loss: 0.01504
Saving the new running checkpoint
[15,  3600] running loss: 0.01501
Saving the new running checkpoint
[15,  3800] running loss: 0.01499
Saving the new running checkpoint
[15,  4000] running loss: 0.01515
Saving the new running checkpoint
[15,  4200] running loss: 0.01500
Saving the new running checkpoint
[15,  4400] running loss: 0.01496
Saving the new best running checkpoint
Saving the new running checkpoint
[15,  4600] running loss: 0.01503
Saving the new running checkpoint
[15,  4800] running loss: 0.01499
Saving the new running checkpoint
[15,  5000] running loss: 0.01500
Saving the new running checkpoint
[15,  5200] running loss: 0.01504
Savi

[16, 11200] running loss: 0.01492
Saving the new running checkpoint
[16, 11400] running loss: 0.01496
Saving the new running checkpoint
[16, 11600] running loss: 0.01498
Saving the new running checkpoint
[16, 11800] running loss: 0.01495
Saving the new running checkpoint
[16, 12000] running loss: 0.01489
Saving the new running checkpoint
[16, 12200] running loss: 0.01496
Saving the new running checkpoint
[16, 12400] running loss: 0.01482
Saving the new running checkpoint
[16, 12600] running loss: 0.01489
Saving the new running checkpoint
[16, 12800] running loss: 0.01486
Saving the new running checkpoint
[16, 13000] running loss: 0.01490
Saving the new running checkpoint
[16, 13200] running loss: 0.01491
Saving the new running checkpoint
[16, 13400] running loss: 0.01492
Saving the new running checkpoint
[16, 13600] running loss: 0.01484
Saving the new running checkpoint
[16, 13800] running loss: 0.01499
Saving the new running checkpoint
[16, 14000] running loss: 0.01496
Saving the new

[18,  5200] running loss: 0.01463
Saving the new best running checkpoint
Saving the new running checkpoint
[18,  5400] running loss: 0.01473
Saving the new running checkpoint
[18,  5600] running loss: 0.01473
Saving the new running checkpoint
[18,  5800] running loss: 0.01463
Saving the new running checkpoint
[18,  6000] running loss: 0.01470
Saving the new running checkpoint
[18,  6200] running loss: 0.01468
Saving the new running checkpoint
[18,  6400] running loss: 0.01475
Saving the new running checkpoint
[18,  6600] running loss: 0.01471
Saving the new running checkpoint
[18,  6800] running loss: 0.01473
Saving the new running checkpoint
[18,  7000] running loss: 0.01466
Saving the new running checkpoint
[18,  7200] running loss: 0.01468
Saving the new running checkpoint
[18,  7400] running loss: 0.01469
Saving the new running checkpoint
[18,  7600] running loss: 0.01478
Saving the new running checkpoint
[18,  7800] running loss: 0.01478
Saving the new running checkpoint
[18,  800

[19, 13800] running loss: 0.01464
Saving the new running checkpoint
[19, 14000] running loss: 0.01458
Saving the new running checkpoint
Training loss for epoch 19: 0.00011
Start validation for epoch 19
Validation loss for epoch 19: 0.00012
Saving the new best checkpoint
Saving the checkpoint for epoch 19
Start training for epoch 20
[20,   200] running loss: 0.01447
Saving the new best running checkpoint
Saving the new running checkpoint
[20,   400] running loss: 0.01459
Saving the new running checkpoint
[20,   600] running loss: 0.01454
Saving the new running checkpoint
[20,   800] running loss: 0.01445
Saving the new best running checkpoint
Saving the new running checkpoint
[20,  1000] running loss: 0.01450
Saving the new running checkpoint
[20,  1200] running loss: 0.01443
Saving the new best running checkpoint
Saving the new running checkpoint
[20,  1400] running loss: 0.01454
Saving the new running checkpoint
[20,  1600] running loss: 0.01452
Saving the new running checkpoint
[20, 