In [None]:
# for correct pytorch version
# import pkg_resources
# pkg_resources.require("torch==1.11.0")
# pkg_resources.require("torchmetrics==0.11.0")

# main net
import torch
import torch.nn as nn

import numpy as np
# ====================================== #

# for dataset
from os import listdir, sep

from torchvision import transforms
# from torchvision.io import read_image
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

# loading raw image
from skimage import io

# demosaicing raw image
from colour_demosaicing import demosaicing_CFA_Bayer_bilinear
# ====================================== #

# for loss and optimizer
import torch.optim as optim
from kornia.color import rgb_to_lab
from torchmetrics import MultiScaleStructuralSimilarityIndexMeasure as MSSSIM
# ====================================== #

# quality of life
from tqdm import tqdm
# ====================================== #

In [None]:
class DeepispLL(nn.Module):
    def __init__(self, kernel=(3,3), stride=1, padding=1):
        super(DeepispLL, self).__init__()

        # self.size = n * m
        self.padding = padding
        self.kernel = kernel
        self.stride = stride

    def forward(self, x):
        rh = nn.Conv2d(61, 61, kernel_size=self.kernel, stride=self.stride, padding=self.padding)(x[:,:61,:,:])
        rh = nn.ReLU()(rh)

        lh = nn.Conv2d(3, 3, kernel_size=self.kernel, stride=self.stride, padding=self.padding)(x[:,61:,:,:])
        lh = nn.Tanh()(lh)

        # need to so some sum
        # lh += x

        return torch.cat((rh, lh), 1)


class DeepispHL(nn.Module):
    def __init__(self, kernel=(3,3), stride=2, padding=1):
        super(DeepispHL, self).__init__()

        self.padding = padding
        self.kernel = kernel
        self.stride = stride

    def forward(self, x):
        x = nn.Conv2d(64, 64, kernel_size=self.kernel, stride=self.stride, padding=self.padding)(x)
        x = nn.ReLU()(x)
        x = nn.MaxPool2d(2, 2)(x)

        return x


class GlobalPool2d(nn.Module):
    def __init__(self):
        super(GlobalPool2d, self).__init__()

    def forward(self, x):
        b, c, w, h = tuple(x.shape)
        return nn.AvgPool2d(kernel_size=(h,w))(x).reshape((b, c))


def triu(rgb):
    res = torch.tensor(np.empty(10), dtype=torch.float)
    r, g, b = rgb[0, 0], rgb[0, 1], rgb[0, 2]
    res[0] = r*r
    res[1] = r*g
    res[2] = r*b
    res[3] = r
    res[4] = g*g
    res[5] = g*b
    res[6] = g
    res[7] = b*b
    res[8] = b
    res[9] = 1

    return res.reshape((10))


def Tform(I, W):
    b, c, h, w = I.shape
    res = torch.tensor(np.empty(I.shape))
    W = W.reshape((3, 10))
    for x in range(h):
        for y in range(w):
            res[:, :, x, y] = W @ triu(I[:, :, x, y])
    return res


class DeepISP(nn.Module):
    def __init__(self, n_ll, n_hl, stride=1, padding=1):
        super(DeepISP, self).__init__()
        self.stride = stride
        self.padding = padding
        
        self.lowlevel = nn.Sequential()
        self.highlevel = nn.Sequential()

        self.lowlevel.append(nn.Conv2d(3, 64, kernel_size=(3,3), stride=self.stride, padding=self.padding))
        self.highlevel.append(nn.Conv2d(61, 64, kernel_size=(3,3), stride=self.stride, padding=self.padding))

        for i in range(n_ll):
            self.lowlevel.append(DeepispLL(stride=self.stride, padding=self.padding))

        for i in range(n_hl):
            self.highlevel.append(DeepispHL(stride=self.stride, padding=self.padding))

        # append global pooling on high level to get 1x1x64 shape
        # current shape = (N/4^n_hl)*(M/4^n_hl)*64
        # self.highlevel.append(nn.MaxPool2(...))
        self.highlevel.append(GlobalPool2d())

        self.highlevel.append(nn.Linear(64, 30))
        
        # do some T(W, L)
        self.T = Tform
    
    def forward(self, x):
        I = self.lowlevel(x)
        W = self.highlevel(I[:,:61,:,:])
        return self.T(I[:,61:,:,:], W)

In [None]:
class S7Dataset(Dataset):
    def __init__(self, directory, mode, target, factor, crop_size):
        self.directory = directory

        self.raw_transform = demosaicing_CFA_Bayer_bilinear
        self.crop_size = crop_size

        self.dng = '.dng'
        self.jpg = '.jpg'

        self.l = len(listdir(self.directory))

        if mode == 'train':
            self.len = 0, int(self.l * factor)
        if mode == 'test':
            self.len = int(self.l * factor), self.l

        if target == 'm':
            self.target = 'medium_exposure'
        elif target == 's':
            self.target = 'short_exposure'
            self.jpg = '1.jpg'

    def __len__(self):

        return self.len[1] - self.len[0]

    def __getitem__(self, idx):
        l = listdir(self.directory)

        i_img = io.imread(sep.join([self.directory, l[idx + self.len[0]], f'{self.target}{self.dng}']))
        o_img = io.imread(sep.join([self.directory, l[idx + self.len[0]], f'{self.target}{self.jpg}']))

        i_img = self.raw_transform(i_img) / 1024
        
        old_shape = i_img.shape
        new_shape = old_shape[2], self.crop_size, self.crop_size
        
        x = np.random.randint(0, old_shape[0] - self.crop_size)
        y = np.random.randint(0, old_shape[1] - self.crop_size)
        
        i_img = torch.tensor(i_img[x:x+self.crop_size, y:y+self.crop_size, :])
        o_img = torch.tensor(o_img[x:x+self.crop_size, y:y+self.crop_size, :])
                
        i_img = i_img.reshape(new_shape)
        o_img = o_img.reshape(new_shape)
        
        # maybe do data normalization
        # img = norm(img)

        return i_img.float(), o_img.float()


def get_data(data_path, batch_size, target='m', factor=0.7, crop_size=256):
    train_data = S7Dataset(
        directory=data_path,
        mode='train',
        target=target,
        factor=factor,
        crop_size=crop_size
    )

    test_data = S7Dataset(
        directory=data_path,
        mode='test',
        target=target,
        factor=factor,
        crop_size=crop_size
    )

    train_loader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True
    )
    test_loader = DataLoader(
        test_data,
        batch_size=batch_size,
        shuffle=False
    )

    return train_loader, test_loader

In [None]:
class deepISPloss():
    def __init__(self, alpha=0.5):
        self.alpha = alpha
        self.MSSSIM = MSSSIM()
    
    def __call__(self, x, target):
        lab_x = rgb_to_lab(x).float()
        lab_tar = rgb_to_lab(target).float()
        b, c, h, w = lab_x.shape

        res = (1 - self.alpha) * torch.mean(torch.abs(lab_x - lab_tar))
        # take only first channel to MS-SSIM
        res +=     self.alpha  * (self.MSSSIM(lab_x[:, :1, :, :], lab_tar[:, :1, :, :]))

        return res

In [None]:
data_path = '/home/jupyter/mnt/datasets/S7Dataset/S7-ISP-Dataset'
train, test = get_data(data_path, batch_size=1)

print(f'train batch number {len(train)}')
print(f'test  batch number {len(test)}')

In [None]:
e = 1
lr = 0.01
momentum = 0.9

make_checkpoints = True
checkpoint_path = '/home/jupyter/work/resources/deepISP-implementation/checkp'

epochs = [i for i in range(e)]

model = DeepISP(0, 0).float()
criterion = deepISPloss()
optimizer = optim.SGD(DeepISP.parameters(model), lr, momentum)

test_loss = 0

print('Starting trainig...')

for epoch in epochs:
    train_iter = tqdm(train, ncols=100, desc='Epoch: {}, training'.format(epoch))
    for (x, target) in train_iter:
        optimizer.zero_grad()
        y = model(x.float())
        loss = criterion(y, target)

#         loss.backward()
        optimizer.step()

    test_iter = tqdm(test, ncols=128, desc='Epoch: {}, testing '.format(epoch))
    for idx, (x, target) in enumerate(test_iter):
        y = model(x)
        loss = criterion(y, target)
        test_loss += loss
        test_iter.set_postfix(str=f'loss: {test_loss / (idx + 1)}')
    test_loss /= len(test_iter)

print('Training done!')

In [None]:
torch.save({
    'epoch': 0,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': test_loss,
}, checkpoint_path + '/model_e{}_loss{}'.format(epoch, test_loss))