In [320]:
import os, shutil, time
# from IPython.display import Image, display
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data as D
import torch.optim as optim
import torchvision.models as models
from torchvision import datasets, transforms

import glob
import os.path as osp

%matplotlib inline

*Den här är helt komplett och fungerar*

In [337]:
model_name = 'vgg-mse-'
test = False
custom_loss = False

In [322]:
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(dev)

cpu


## Prepare dataset

In [None]:
# Split into train and validation
os.makedirs('images/train/class/', exist_ok=True)
os.makedirs('images/val/class/', exist_ok=True)
for i, file in enumerate(os.listdir('testSet_resize')):
    if i < 1000:
        os.rename('testSet_resize/' + file, 'images/val/class/' + file) # 1k images
    else:
        os.rename('testSet_resize/'+ file, 'images/train/class/' + file) # 40k images

## Dataset

In [323]:
class PlacesImages(D.Dataset):
    def __init__(self, root, transform):
        self.filenames = []
        self.root = root
        self.transform = transform
        
        for fn in glob.glob(osp.join(self.root, '*.jpg')):
            self.filenames.append(fn)
        
        self.len = len(self.filenames)
    
    def __getitem__(self, index):
        img = Image.open(self.filenames[index])
        if img.mode != "RGB":
            img = img.convert("RGB")

        img = self.transform(img)
        img_original = np.asarray(img)

        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2,0,1))).float()
        img_original = rgb2gray(img_original)
        img_original = torch.from_numpy(img_original).unsqueeze(0).float()
        return img_original, img_ab
    
    def __len__(self):
        return self.len
    

In [324]:
if test:
    batch_size = 4
    train_path = 'images/sub_train/class'
    val_path = 'images/sub_val/class'
else:
    batch_size = 64
    train_path = 'images/train/class'
    val_path = 'images/val/class'   
    
train_transforms = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip()])
train_images = PlacesImages(train_path, train_transforms)
train_loader = D.DataLoader(train_images, batch_size=batch_size, shuffle=True, num_workers=4)

val_transforms = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)])
val_images = PlacesImages(val_path, val_transforms)
val_loader = D.DataLoader(val_images, batch_size=batch_size, shuffle=False, num_workers=4)


## Models

In [325]:
class BasicNet(nn.Module):
    def __init__(self):
        super(BasicNet, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 2, kernel_size=3, stride=2, padding=1),
            nn.Upsample(scale_factor=4)
        )
    
    def forward(self, x):
        return self.net(x)

In [326]:
class VGGNet(nn.Module):
    def __init__(self):
        super(VGGNet, self).__init__()
        self.net1 = nn.Sequential(
            nn.MaxPool2d(kernel_size=4, stride=4),
            
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),  
            nn.BatchNorm2d(128),
            
            nn.Conv2d(128, 256, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size = 3, stride = 2, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            
            nn.Conv2d(256, 512, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size = 3, stride = 2, padding = 1),
            nn.ReLU(),
            
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 2, padding = 2, dilation = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 2, padding = 2, dilation = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 2, padding = 2, dilation = 1),
            nn.ReLU(),      
        )
        self.net2 = nn.Sequential(
            nn.Linear(4608, 6272),
            nn.ReLU(),
            nn.Linear(6272, 6272),
        )
        self.upsample = nn.Upsample(scale_factor=4)
    
    def forward(self, x):
        x = self.net1(x)
        x = self.net2(x.view(x.size(0), -1))
        x = self.upsample(x.view(x.size(0), 2, 56, 56))
        return x

In [327]:
model = VGGNet()
model.to(dev)

VGGNet(
  (net1): Sequential(
    (0): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (9): ReLU()
    (10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU()
    (13): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (14): ReLU()
    (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (16): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1

## Helper functions

In [328]:
alpha = 0.005
def customLoss():
    def criterion(trained, target):
        MSE = torch.mean((trained - target) ** 2)
        return (MSE + alpha / trained.std()), MSE
    return criterion


if custom_loss:
    criterion = customLoss()
    print('custom_loss')
else:
    criterion = nn.MSELoss()


optimizer = optim.Adam(model.parameters(), lr=1e-2)


custom_loss


In [329]:
#dataiter = iter(train_loader)
#feats, labels = dataiter.next()
#print(feats.shape)
#print(labels.shape)
#outputs = model(feats)
#print(outputs.shape)

In [330]:
class CSVWriter(object):
    def __init__(self):
        self.csv_file = open(model_name + 'stats.csv', 'a')
        self.csv_file.write('epoch, train_mse, val_mse, train_loss, val_loss, train_std, val_std \n')
        self.csv_file.close()
        
    def write(self, epoch, train_mse, val_mse, train_loss, val_loss, train_std, val_std):
        self.csv_file = open(model_name + 'stats.csv', 'a')
        self.csv_file.write('{}, {}, {}, {}, {}, {}, {} \n'.format(epoch, train_mse, val_mse, train_loss, val_loss, train_std, val_std))
        self.csv_file.close()

class AverageMeter(object):
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def calc_std(output_ab):
    ab = output_ab.detach().cpu().numpy()
    return np.std(ab)

def to_rgb(gray_input, ab_input, save_path=None, save_name=None, save_gray=True):
    plt.clf()
    color_img = torch.cat((gray_input, ab_input), 0).numpy()
    color_img = color_img.transpose((1, 2, 0))
    color_img[:, :, 0:1] = color_img[:, :, 0:1] * 100
    color_img[:, :, 1:3] = color_img[:, :, 1:3] * 255 - 128
    color_img = lab2rgb(color_img.astype(np.float64))
    gray_input = gray_input.squeeze().numpy()
    
    plt.imsave(arr=color_img, fname="{}{}".format(save_path['colorized'], save_name))
    if save_gray:
        plt.imsave(arr=gray_input, fname="{}{}".format(save_path['grayscale'], save_name), cmap="gray")

In [331]:
def train(train_loader, model, criterion, optimizer, epoch):
    print("Starting epoch {}".format(epoch))
    model.train()
    batch_time, losses, mse, std_dev = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
    
    end = time.time()
    for i, (input_gray, input_ab) in enumerate(train_loader):
        input_gray, input_ab = input_gray.to(dev), input_ab.to(dev)
        
        output_ab = model(input_gray)
        if custom_loss:
            loss, MSE_val = criterion(output_ab, input_ab)
            losses.update(loss.item(), input_gray.size(0))
            mse.update(MSE_val.item(), input_gray.size(0))
        else:
            loss = criterion(output_ab, input_ab)
            losses.update(loss.item(), input_gray.size(0))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end) # record time for forward + backward
        end = time.time()
        
        std_dev.update(calc_std(output_ab)) # record std_dev
        
        if i % 25 == 0:
            print('Epoch [{0}][{1}/{2}]\tTime {3:.3f}\tLoss {4:.4f}'.format(epoch, i, len(train_loader), batch_time.avg, losses.avg))
    
    return losses.avg, std_dev.avg, mse.avg


In [332]:
def validate(val_loader, model, criterion, epoch, save_all=False):
    print("Start validation")
    model.eval()
    batch_time, losses, mse, std_dev = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
    
    end = time.time()
    already_saved_images = False
    for i, (input_gray, input_ab) in enumerate(val_loader):
        input_gray, input_ab = input_gray.to(dev), input_ab.to(dev)
        
        output_ab = model(input_gray)
        if custom_loss:
            loss, MSE_val = criterion(output_ab, input_ab)
            losses.update(loss.item(), input_gray.size(0))
            mse.update(MSE_val.item(), input_gray.size(0))
        else:
            loss = criterion(output_ab, input_ab)
            losses.update(loss.item(), input_gray.size(0))
        
        std_dev.update(calc_std(output_ab)) # record std_dev
        
        if not already_saved_images and not save_all:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)):
                save_path = {'grayscale': model_name + 'outputs/gray/', 'colorized': model_name + 'outputs/color/'}
                save_name = 'img-{}-epoch-{}.jpg'.format(i * val_loader.batch_size + j, epoch)
                to_rgb(input_gray[j].cpu(), ab_input=output_ab[j].detach().cpu(),
                       save_path=save_path, save_name=save_name, save_gray=not(epoch > 0))
        
        if save_all:
            print('len', len(output_ab))
            print(output_ab.size())
            for j in range(len(output_ab)):
                save_path = {'grayscale': model_name + 'validation/gray/', 'colorized': model_name + 'validation/color/'}
                save_name = 'img-{}.jpg'.format(i * val_loader.batch_size + j)
                to_rgb(input_gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)
        
        batch_time.update(time.time() - end)
        end = time.time()
    
    return losses.avg, std_dev.avg, mse.avg

## Starting training

In [333]:
os.makedirs(model_name + 'outputs/color', exist_ok=True)
os.makedirs(model_name + 'outputs/gray', exist_ok=True)
os.makedirs(model_name + 'checkpoints', exist_ok=True)
best_losses = 1e10
epochs = 4

In [334]:
csv_f = CSVWriter()

for epoch in range(epochs):
    train_loss, train_std, train_mse = train(train_loader, model, criterion, optimizer, epoch)
    
    with torch.no_grad():
        val_loss, val_std, val_mse = validate(val_loader, model, criterion, epoch)
    
    csv_f.write(epoch, train_mse, val_mse, train_loss, val_loss, train_std, val_std)
    
    if val_loss < best_losses:
        best_losses = val_loss
        torch.save(model.state_dict(), model_name + 'checkpoints/model-epoch-{}-losses-{:.3f}.pth'.format(epoch+1, val_loss))

torch.save(model.state_dict(), model_name + 'checkpoints/model-epoch-{}-losses-{:.3f}.pth'.format(epochs, val_loss))
print('Training finished')

Starting epoch 0
Epoch [0][0/101]	Time 1.496	Loss 0.5623
Epoch [0][25/101]	Time 0.738	Loss 16073180.0236
Epoch [0][50/101]	Time 0.720	Loss 8194185.5381
Epoch [0][75/101]	Time 0.797	Loss 5498729.9110
Epoch [0][100/101]	Time 0.823	Loss 4168613.2249
Start validation


  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)


Starting epoch 1
Epoch [1][0/101]	Time 3.251	Loss 0.0543
Epoch [1][25/101]	Time 0.963	Loss 0.0357
Epoch [1][50/101]	Time 0.917	Loss 0.0332
Epoch [1][75/101]	Time 0.920	Loss 0.0326
Epoch [1][100/101]	Time 0.915	Loss 0.7874
Start validation
Starting epoch 2
Epoch [2][0/101]	Time 2.824	Loss 0.0716
Epoch [2][25/101]	Time 1.013	Loss 58.8018
Epoch [2][50/101]	Time 0.960	Loss 30.0170
Epoch [2][75/101]	Time 0.971	Loss 20.1541
Epoch [2][100/101]	Time 0.955	Loss 27.3896
Start validation


  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)


Starting epoch 3
Epoch [3][0/101]	Time 2.646	Loss 6.2794
Epoch [3][25/101]	Time 0.979	Loss 0.8560
Epoch [3][50/101]	Time 0.937	Loss 0.4835
Epoch [3][75/101]	Time 0.918	Loss 17.2436
Epoch [3][100/101]	Time 0.913	Loss 13.0881
Start validation


  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)


Training finished


<Figure size 432x288 with 0 Axes>

## Validation

In [335]:
#pretrained = torch.load('vgg-mse-checkpoints/model-epoch-3-losses-0.037.pth')
#model.load_state_dict(pretrained)
model.eval()

VGGNet(
  (net1): Sequential(
    (0): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (9): ReLU()
    (10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU()
    (13): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (14): ReLU()
    (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (16): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1

In [336]:
os.makedirs(model_name + 'validation/color', exist_ok=True)
os.makedirs(model_name + 'validation/gray', exist_ok=True)

with torch.no_grad():
    validate(val_loader, model, criterion, 0, True)

Start validation
len 4
torch.Size([4, 2, 224, 224])
len 4
torch.Size([4, 2, 224, 224])


  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)


len 3
torch.Size([3, 2, 224, 224])


  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)


<Figure size 432x288 with 0 Axes>