In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
import torch
import torch.nn as nn
import glob
from tqdm import tqdm
import math
import pickle

In [32]:
# if you have a model parameter's file, setting path
_CWAN_L_PATH = None
_CWAN_AB_PATH = None
# training hyperparameters
_LR = 1e-5
_WEIGHT_DECAY = 0.05
_BATCH_SIZE = 64
_START_EPOCH = 0 + 1
_EPOCH = 200
# dataset parameter
_ONE_FILE_SIZE = 2000

In [8]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('device => {}'.format(device))

device => cpu


## LAB

In [18]:
class LAB(nn.Module):
    def __init__(self):
        super().__init__()
        self.illuminants = \
    {"A": {'2': (1.098466069456375, 1, 0.3558228003436005),
           '10': (1.111420406956693, 1, 0.3519978321919493)},
     "D50": {'2': (0.9642119944211994, 1, 0.8251882845188288),
             '10': (0.9672062750333777, 1, 0.8142801513128616)},
     "D55": {'2': (0.956797052643698, 1, 0.9214805860173273),
             '10': (0.9579665682254781, 1, 0.9092525159847462)},
     "D65": {'2': (0.95047, 1., 1.08883),   # This was: `lab_ref_white`
             '10': (0.94809667673716, 1, 1.0730513595166162)},
     "D75": {'2': (0.9497220898840717, 1, 1.226393520724154),
             '10': (0.9441713925645873, 1, 1.2064272211720228)},
     "E": {'2': (1.0, 1.0, 1.0),
           '10': (1.0, 1.0, 1.0)}}
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    def _get_xyz_coords(self,illuminant,observer):
        """ Get the XYZ coordinates from illuminant and observer
        Parameters
        ==========
        illuminant : {"A","D50","D65","D75","E"}
        observer : {"2","10"}
        Returns
        ==========
        XYZ coordinate Tensor Float
        """
        try:
            return torch.tensor(self.illuminants[illuminant][observer]).float()
        except KeyError:
            raise ValueError("Unknown illuminat:'{}'/observer:'{}' combination".format(illuminant,observer))

    def _check_shape(self,tensor):
        if tensor.shape[0] != 3:
            raise ValueError("Input array must have (batch, 3,height,width)")

    def xyz2rgb(self,xyz_tensor,show_results=False):
        """XYZ to RGB color space conversion.
        Parameters
        ==========
        xyz_tensor : shape -> (3,height,width) Tensor
        show_results : whether to display the resulting rgb image
        Returns
        ==========
        rgb_tensor : shape -> (3,height,width) Tensor
        """
        xyz_tensor = xyz_tensor.permute(1,2,0)
        xyz_from_rgb = torch.tensor([[0.412453, 0.357580, 0.180423],
                                     [0.212671, 0.715160, 0.072169],
                                     [0.019334, 0.119193, 0.950227]]).to(self.device)
        rgb_from_xyz = torch.inverse(xyz_from_rgb)
        rgb = torch.matmul(xyz_tensor,torch.t(rgb_from_xyz))
        mask = rgb > 0.0031308
        rgb[mask] = 1.055 * torch.pow(rgb[mask],1/2.4) - 0.055
        rgb[~mask] *= 12.92
        rgb = torch.clamp(rgb,0,1)
        rgb = rgb.permute(2,0,1)
        if show_results:
            rgb_numpy = rgb.cpu().detach().numpy().transpose(1,2,0)
            plt.imshow(rgb_numpy)
            plt.show()
        return rgb

    def lab2xyz(self,lab_tensor,show_results=False,illuminant='D65',observer='2'):
        """LAB to XYZ color space conversion.
        Parameters
        ==========
        lab_tensor : shape -> (3,height,width) Tensor
        show_results : whether to display the resulting xyz image
        Returns
        ==========
        xyz_tensor : shape -> (3,height,width) Tensor
        """
        l,a,b = lab_tensor[0],lab_tensor[1],lab_tensor[2]
        y = (l+16.)/116.
        x = (a / 500.) + y
        z = y - (b / 200.)

        xyz = torch.stack([x,y,z],dim=0)
        mask = xyz > 0.2068966
        xyz[mask] = torch.pow(xyz[mask],3.)
        xyz[~mask] = (xyz[~mask] - 16. / 116.) / 7.787

        xyz_ref_white = self._get_xyz_coords(illuminant,observer).to(self.device)
        xyz = xyz.permute(1,2,0)
        xyz *= xyz_ref_white
        xyz = xyz.permute(2,0,1)
        if show_results:
            xyz_numpy = xyz.cpu().detach().numpy().transpose(1,2,0)
            plt.imshow(xyz_numpy)
            plt.show()
        return xyz

    def lab2rgb(self,lab_tensor,show_results_xyz=False,show_results_rgb=False):
        """LAB to RGB color space conversion.
        Parameters
        ==========
        lab_tensor : shape -> (3,height,width) Tensor
        show_results_xyz : whether to display the resulting xyz image
        show_results_rgb : whether to display the resulting rgb image

        Returns
        ==========

        rgb_tensor : shape -> (3,height,width) Tensor

        """
        results = []
        for i in range(lab_tensor.shape[0]):
            xyz = self.lab2xyz(lab_tensor[i],show_results_xyz)
            rgb = self.xyz2rgb(xyz,show_results_rgb)
            results.append(rgb)
        results = torch.cat(results).reshape(len(results),*results[0].shape)
        return results

    def rgb2xyz(self,rgb_tensor,show_results=False):
        """RGB to XYZ color space conversion.
        Parameters
        ==========
        rgb_tensor : shape -> (3,height,width) Tensor
        show_results : whether to display the resulting xyz image
        Returns
        ==========
        xyz_tensor : shape -> (3,height,width) Tensor
        what is xyz_tensor?
        -------------------
            -> https://www.dic-color.com/knowledge/xyz.html 
        """
        self._check_shape(rgb_tensor) #must have input shape {3,height,width}
        rgb_tensor = rgb_tensor.permute(1,2,0)
        mask = rgb_tensor > 0.04045
        rgb_tensor[mask] = torch.pow((rgb_tensor[mask] + 0.055)/1.055,2.4)
        rgb_tensor[~mask] /= 12.92
        xyz_from_rgb = torch.tensor([[0.412453, 0.357580, 0.180423],
                                     [0.212671, 0.715160, 0.072169],
                                     [0.019334, 0.119193, 0.950227]]).to(self.device)
        xyz = torch.matmul(rgb_tensor,torch.t(xyz_from_rgb))
        if show_results: # show matplotlib
            xyz_numpy = xyz.cpu().detach().numpy()
            plt.imshow(xyz_numpy)
            plt.show()

        xyz = xyz.permute(2,0,1)
        return xyz

    def xyz2lab(self,xyz_tensor,show_results=False,illuminant='D65',observer='2'):
        """XYZ to CIE-LAB color space conversion.
        Parameters
        ==========
        xyz_tensor : shape -> (3,height,width) Tensor
        show_results : whether to display the resulting lab image
        Returns
        ==========
        lab_tensor : shape -> (3,height,width) Tensor
        
        what is lab_tensor?
        -------------------
            -> http://rysys.co.jp/dpex/help_laboutput.html 

        """
        xyz_tensor = xyz_tensor.permute(1,2,0)

        xyz_ref_white = self._get_xyz_coords(illuminant,observer).to(self.device)
        xyz_tensor = xyz_tensor / xyz_ref_white

        mask = xyz_tensor > 0.008856
        xyz_tensor[mask] = torch.pow(xyz_tensor[mask],1/3)
        xyz_tensor[~mask] = 7.787 * xyz_tensor[~mask] + 16. / 116.
        x,y,z = xyz_tensor[...,0],xyz_tensor[...,1],xyz_tensor[...,2]
        L = (116. * y) - 16.
        a = 500. * (x - y)
        b = 200. * (y - z)
        lab = torch.cat([L.unsqueeze(-1),a.unsqueeze(-1),b.unsqueeze(-1)],dim=-1)
        if show_results:
            lab_numpy = lab.cpu().detach().numpy()
            plt.imshow(lab_numpy)
            plt.show()

        lab = lab.permute(2,0,1)
        return lab

    def forward(self,rgb_tensor,show_xyz_results=False,show_lab_results=False):
        results = []
        for i in range(rgb_tensor.shape[0]):
            xyz = self.rgb2xyz(rgb_tensor[i],show_xyz_results)
            lab = self.xyz2lab(xyz,show_lab_results)
            results.append(lab)
        results = torch.cat(results).reshape(len(results),*results[0].shape)
        return results


## Parts to use in CWAN_L and CWAN_AB

In [19]:
class MemoryBlock(nn.Module):
    def __init__(self,channels,num_resblock,num_memblock):
        super().__init__()
        self.recursive_unit = nn.ModuleList(
                [ResidualBlock(channels) for i in range(num_resblock)]
        )
        self.gate_unit = ReLUConv((num_resblock+num_memblock)*channels,channels,1,1,0)
    def forward(self,x,ys):
        xs = []
        residual = x
        for layer in self.recursive_unit:
            x = layer(x)
            xs.append(x)
        gate_output = self.gate_unit(torch.cat(xs+ys,1))
        ys.append(gate_output)
        return gate_output

class ResidualBlock(nn.Module):
    def __init__(self,channels,k=3,s=1,p=1):
        super().__init__()
        self.relu_conv1 = BNReLUConv(channels,channels,k,s,p)
        self.relu_conv2 = BNReLUConv(channels,channels,k,s,p)
    def forward(self,x):
        residual = x
        out = self.relu_conv1(x)
        out = self.relu_conv2(out)
        out = out + residual
        return out

class ReLUConv(nn.Sequential):
    def __init__(self,in_channels,channels,k=3,s=1,p=1,inplace=True):
        super().__init__()
        self.add_module('relu',nn.ReLU(inplace=inplace))
        self.add_module('conv',nn.Conv2d(in_channels,channels,k,s,p,bias=False))

class BNReLUConv(nn.Sequential):
    def __init__(self,in_channels,channels,k=3,s=1,p=1,inplace=True):
        super().__init__()
        self.add_module('bn',nn.BatchNorm2d(in_channels))
        self.add_module('relu',nn.ReLU(inplace=inplace))
        self.add_module('conv',nn.Conv2d(in_channels,channels,k,s,p,bias=False))

## CWAN_L

In [20]:
class CWAN_L(nn.Module):
    """ 'L' of LAB's model.
    motivation
    ==========
    focus on enhancing image lightness and denoising
    name
    ====
    k -> kernel_size
    n -> output channels size
    x -> repeat number
    parameters
    =========
    k3n32 -> memory blocks. this block utilize local short skip connections whitin the bloack to represent short-term memory,as well as long skip connections sourcing from previous blocks to represent long-term memory.
    returns
    =======
    enhanced lightness image (1xHxW)

    """
    def __init__(self):
        super().__init__()
        self.feature_extractor = ReLUConv(1,32)
        self.k3n1 = nn.Sequential(
                nn.Conv2d(32,1,(3,3),stride=1,padding=1)
        )
        self.memory_blocks = nn.ModuleList(
                [MemoryBlock(32,3,i+1) for i in range(3)]
        )

    def forward(self,l):
        residual = l
        out = self.feature_extractor(l)
        ys = [out]
        for memory_block in self.memory_blocks:
            out = memory_block(out,ys)
        out = self.k3n1(out)
        out = out + residual
        return out

## CWAN_AB

In [24]:
class CWAN_AB(nn.Module):
    """ 'AB' of LAB's model.
    motivation
    ==========
    color infomation drive the attention of CWAN_AB
    name
    ====
    k -> kernel_size
    n -> output channels size
    parameters
    ==========
    returns
    =======
    1.enhanced color images(2xHxW)
    2.color attention maps(2xHxW)
    3.sparse attention points(2xHxW)
    """
    def __init__(self):
        super().__init__()

        self.k3n32_1 = nn.Sequential(
                nn.Conv2d(2,32,(3,3),stride=1,padding=1),
                nn.ReLU()
        )
        self.k3n32_2 = nn.Sequential(
                nn.Conv2d(2,32,(3,3),stride=1,padding=1),
                nn.ReLU()
        )
        k3n64_k1n128_k3n64 = nn.Sequential(
                nn.Conv2d(32,64,(3,3),stride=1,padding=1),
                nn.ReLU(),
                nn.Conv2d(64,128,(1,1),stride=1,padding=0),
                nn.ReLU(),
                nn.Conv2d(128,64,(3,3),stride=1,padding=1),
                nn.ReLU()
        )
        self.k3n64_k1n128_k3n64_1 = k3n64_k1n128_k3n64
        self.k3n64_k1n128_k3n64_2 = k3n64_k1n128_k3n64
        self.k3n64_k1n128_k3n64_3 = nn.Sequential(
                nn.Conv2d(64,64,(3,3),stride=1,padding=1),
                nn.ReLU(),
                nn.Conv2d(64,128,(1,1),stride=1,padding=1),
                nn.ReLU(),
                nn.Conv2d(128,64,(3,3),stride=1,padding=1),
                nn.ReLU()
        )
        self.k3n2 = nn.Sequential(
                nn.Conv2d(64,2,(3,3),stride=1,padding=1)
        )
        self.k3n4 = nn.Sequential(
                nn.Conv2d(64,4,(3,3))
        )
    def forward(self,ab):
        residual = ab
        k3n32_1_output = self.k3n32_1(ab)
        k3n64_k1n128_k3n64_1_output = self.k3n64_k1n128_k3n64_1(k3n32_1_output)
        k3n2_output = self.k3n2(k3n64_k1n128_k3n64_1_output)
        attention_map = k3n2_output
        k3n32_2_output = self.k3n32_2(residual + k3n2_output)
        k3n64_k1n128_k3n64_2_output = self.k3n64_k1n128_k3n64_2(k3n32_2_output)
        k3n64_k1n128_k3n64_3_output = self.k3n64_k1n128_k3n64_3(k3n64_k1n128_k3n64_2_output)
        k3n4_output = self.k3n4(k3n64_k1n128_k3n64_3_output)
        attention_points = k3n4_output[:,2:]
        enhance_ab = residual + k3n4_output[:,:2]
        return enhance_ab,attention_map,attention_points

# CWAN Network

In [25]:
class CWAN(nn.Module):
    def __init__(self):
        super().__init__()
        self.cwan_l = CWAN_L()
        self.cwan_ab = CWAN_AB()
        self.lab_converter = LAB()
    def lab2rgb(self,lab):
        """ LAB tensor to RGB tensor
        Parameter
        =========
        lab : LAB image tensor
        Returns
        =========
        rgb : RGB image tensor
        """
        rgb = self.lab_converter.lab2rgb(lab)
        return rgb

    def l_test(self,tensor):
        lab = self.lab_converter(tensor)
        l = lab[:,:1]
        l_output = self.cwan_l(l)
        return l_output

    def ab_test(self,tensor):
        lab = self.lab_converter(tensor)
        ab = lab[:,1]
        ab_output,_,_ = self.cwan_ab(ab)
        return ab_output

    def forward(self,tensor):
        lab = self.lab_converter(tensor)
        l,ab = lab[:,:1],lab[:,1:]
        l_output = self.cwan_l(l)
        ab_output,attention_map,attention_points = self.cwan_ab(ab)
        generated_image = torch.rand(tensor.shape)
        generated_image[:,0] = l_output[:,0]
        generated_image[:,1] = ab_output[:,0]
        generated_image[:,2] = ab_output[:,1]
        return generated_image,attention_map,attention_points,l_output,ab_output

# Training

In [28]:
#training network
cwan = CWAN()
if _CWAN_L_PATH is not None:
  cwan.cwan_l.load_state_dict(torch.load(_CWAN_L_PATH))
if _CWAN_AB_PATH is not None:
  cwan.cwan_ab.load_state_dict(torch.load(_CWAN_AB_PATH))
cwan = cwan.train().to(device)
lab = LAB()
lab = lab.eval().to(device)

In [29]:
#training setting
optimizer = torch.optim.Adam([{'params':cwan.cwan_ab.parameters()}],lr=_LR,weight_decay=_WEIGHT_DECAY)
loss_func = nn.L1Loss()
loss_mse_func = nn.MSELoss()

In [30]:
long_dic = dict()
#loss data lists
loss_list = list()
loss_map_list = list()
loss_huber_list = list()
loss_mse_list = list()

In [None]:
for e in tqdm(range(_START_EPOCH,_EPOCH)):
  print("now {} epoch".format(e))
  print("+++++++++++++++++++++++++")