In [1]:
import glob
import numpy as np
from PIL import Image
import torch 
import torch.nn as nn
from torch.nn.functional import relu
from torch.utils.data import Dataset, DataLoader
from skimage.color import rgb2lab, lab2rgb
from torchvision import transforms

In [2]:
%env CUDA_VISIBLE_DEVICES=""

env: CUDA_VISIBLE_DEVICES=""


In [3]:
set_dir = glob.glob('./../../dataset/*.jpg')
train_set = set_dir[:5703]
val_set = set_dir[5703:]

In [4]:
PIC_SIZE = 256
class ColorizationDataset(Dataset):
    def __init__(self, paths):
        self.transforms = transforms.Resize((PIC_SIZE, PIC_SIZE),  Image.BICUBIC)
        
        self.pic_size = PIC_SIZE
        self.paths = paths
        
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transforms(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype('float32')
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. -1.
        ab = img_lab[[1, 2], ...] / 110.
               
        return {'L': L, 'ab':ab}
    
    def __len__(self):
        return len(self.paths)

In [5]:
def make_dataloaders(batch_size=16, n_workers=2, pin_memory=True, **kwargs):
    dataset = ColorizationDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,pin_memory=pin_memory)
    
    return dataloader

train_dl = make_dataloaders(paths=train_set)
len(train_dl)

0

In [6]:
class Conv_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        
    def forward(self, inputs):
        x = relu(self.conv1(inputs))
        x = relu(self.conv2(x))
        
        return x

In [7]:
class Encoder_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.convolution = Conv_block(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        
    def forward (self, inputs):
        x = self.convolution(inputs)
        p = self.pool(x)
        
        return x, p

In [8]:
class Decoder_block(nn.Module):
    def __init__ (self, in_channels, out_channels):
        super().__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
        self.dconv1 = nn.Conv2d(in_channels+out_channels, out_channels, kernel_size=3, padding=1)
        self.dconv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        
    def forward(self, inputs, skip):
        x = self.upconv(inputs)
        x = torch.cat([x, skip], dim=1)
        x = relu(self.dconv1(x))
        x = relu(self.dconv2(x))
        
        return x

In [9]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        #Encoder part
        self.econv1 = Encoder_block(1, 64)
        self.econv2 = Encoder_block(64, 128)
        self.econv3 = Encoder_block(128, 256)
        self.econv4 = Encoder_block(256, 512)
        
        #Bootleneck
        self.b = Conv_block(512, 1024)
        
        #Decoder part
        self.dconv1 = Decoder_block(1024, 512)
        self.dconv2 = Decoder_block(512,256)
        self.dconv3 = Decoder_block(256, 128)
        self.dconv4 = Decoder_block(128, 64)
        
        #Output layer
        self.outconv = nn.Conv2d(64, 2, kernel_size = 1)
        
    def forward (self, inputs):
        #Encoder part
        x1, p1 = self.econv1(inputs)
        x2, p2 = self.econv2(p1)       
        x3, p3 = self.econv3(p2)        
        x4, p4 = self.econv4(p3)
        
        #Bootleneck
        b = self.b(p4)
        
        #Decoder part
        d1 = self.dconv1(b, x4)
        d2 = self.dconv2(d1, x3)
        d3 = self.dconv3(d2, x2)
        d4 = self.dconv4(d3, x1)
        
        #Output layer
        out = self.outconv(d4)
        
        return out
    
    def predict(self, input_data):
        self.eval()
        
        with torch.no_grad():
            pred = self.forward(input_data)
            
        prediction = torch.cat ((input_data, pred), dim=1)
        prediction[:, 0, :, :] += 1.
        prediction[:, 0, :, :] *= 50.
        prediction[:, 1, :, :] *= 110.
        prediction[:, 2, :, :] *= 110.
        pred_arr = prediction.numpy()
        
        from skimage import color
        rgb_image = np.transpose(pred_arr, (0, 2, 3, 1))
        rgb_image = color.lab2rgb(rgb_image[0])
        rgb_image = np.clip(rgb_image, 0, 1) * 255
        rgb_image = rgb_image.astype(np.uint8)
        image_pil = Image.fromarray(rgb_image)
            
        self.train()
        return image_pil

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

def train_model(model, train_dl, epochs, learning_rate=0.001):
    criterion = nn.MSELoss()  # Funkcja straty Mean Squared Error
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        
        for batch_idx, data in tqdm(enumerate(train_dl), total=len(train_dl)):
            inputs, targets = data['L'], data['ab']
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        epoch_loss = running_loss / len(train_dl)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss:.4f}")

model = UNet().to('cpu')
train_model(model, train_dl, epochs=100, learning_rate=0.0002)

0it [00:00, ?it/s]


ZeroDivisionError: float division by zero

In [11]:
model = UNet()
model.load_state_dict(torch.load('./trained_model.pth', map_location=torch.device('cpu')))

FileNotFoundError: [Errno 2] No such file or directory: './trained_model.pth'

In [12]:
from torchsummary import summary
summary(model, (1, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]             640
            Conv2d-2         [-1, 64, 256, 256]          36,928
        Conv_block-3         [-1, 64, 256, 256]               0
         MaxPool2d-4         [-1, 64, 128, 128]               0
     Encoder_block-5  [[-1, 64, 256, 256], [-1, 64, 128, 128]]               0
            Conv2d-6        [-1, 128, 128, 128]          73,856
            Conv2d-7        [-1, 128, 128, 128]         147,584
        Conv_block-8        [-1, 128, 128, 128]               0
         MaxPool2d-9          [-1, 128, 64, 64]               0
    Encoder_block-10  [[-1, 128, 128, 128], [-1, 128, 64, 64]]               0
           Conv2d-11          [-1, 256, 64, 64]         295,168
           Conv2d-12          [-1, 256, 64, 64]         590,080
       Conv_block-13          [-1, 256, 64, 64]               0
        M

In [156]:
def ssim_check(validiation_set):
    from skimage.metrics import structural_similarity as ssim
    from tqdm.notebook import tqdm
    ssim_list = list()
    for elements in tqdm(val_set):
        pic = glob.glob(elements)
        img_test = make_dataloaders(paths=pic, batch_size=1)
        channel = next(iter(img_test))

        Ls, abs_ = channel['L'], channel['ab']
        
        pred = model.predict(Ls)
    
        image_ref = Image.open(elements)
        image_ref = image_ref.resize((256, 256), Image.BICUBIC)
        image_ref = np.array(image_ref)
        
        image_pred = np.array(pred)
        
        ssim_val = ssim(image_ref, image_pred, win_size=3, multichannel=True)
        ssim_list.append(ssim_val)       
        
    return ssim_list

In [13]:
def predict(model: UNet, image):
    '''
    Image color prediction
    '''

    img_normalized = image / 50.0 - 1
    img_tensor = torch.tensor(img_normalized).float().unsqueeze(0).unsqueeze(0)
    
    image_pred = model.predict(img_tensor)
    
    return image_pred
