In [60]:
import os, shutil, time
# from IPython.display import Image, display
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data as D
import torch.optim as optim
import torchvision.models as models
from torchvision import datasets, transforms

import glob
import os.path as osp

%matplotlib inline

In [67]:
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(dev)

cpu


## Dataset

In [124]:
class PlacesImages(D.Dataset):
    def __init__(self, root, transform):
        self.filenames = []
        self.root = root
        self.transform = transform
        
        for fn in glob.glob(osp.join(self.root, '*.jpg')):
            self.filenames.append(fn)
        
        self.len = len(self.filenames)
    
    def __getitem__(self, index):
        img = Image.open(self.filenames[index])
        if img.mode != "RGB":
            img = img.convert("RGB")

        img = self.transform(img)
        img_original = np.asarray(img)

        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2,0,1))).float()
        img_original = rgb2gray(img_original)
        img_original = torch.from_numpy(img_original).unsqueeze(0).float()
        return img_original, img_ab
    
    def __len__(self):
        return self.len
    

In [125]:
train_transforms = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip()])
train_images = PlacesImages('images/sub_train/class', train_transforms)
train_loader = D.DataLoader(train_images, batch_size=4, shuffle=False)



In [126]:
dataiter = iter(train_loader)
feats, labels = dataiter.next()

In [127]:
print(feats.shape)
print(labels.shape)

torch.Size([4, 1, 224, 224])
torch.Size([4, 2, 224, 224])


## Models

In [128]:
class BasicNet(nn.Module):
    def __init__(self):
        super(BasicNet, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 2, kernel_size=3, stride=2, padding=1),
            nn.Upsample(scale_factor=4)
        )
    
    def forward(self, x):
        return self.net(x)

In [129]:
model = BasicNet()
model.to(dev)

BasicNet(
  (net): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU()
    (5): Conv2d(32, 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (6): Upsample(scale_factor=4.0, mode=nearest)
  )
)

In [130]:
output = model(feats)
output = output.view(-1, 2, 224, 224)

In [131]:
print(output.shape)

torch.Size([4, 2, 224, 224])


In [132]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2)

In [133]:
for epoch in range(2):
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = inputs.to(dev), labels.to(dev)
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        if i % 10 == 0:
            print(i, loss.item())

0 0.497944176197052
10 0.048871491104364395
20 0.018883993849158287
30 0.005149592645466328
40 0.008907676674425602
50 0.009279435500502586
60 0.010624323971569538
70 0.006091078743338585
80 0.007649230770766735
90 0.0022815605625510216
100 0.009819429367780685
0 0.002118743257597089
10 0.0032317563891410828
20 0.0030198984313756227
30 0.0025561300572007895
40 0.005335967987775803
50 0.013648899272084236
60 0.003894361900165677
70 0.0024330897722393274
80 0.006499660201370716
90 0.0023360189516097307
100 0.0036955063696950674
