# 0. Set up

In [None]:
!nvidia-smi

In [None]:
#import 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.optim import lr_scheduler
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
import pandas as pd
import rawpy
from tqdm import tqdm as pbar
import copy
from livelossplot import PlotLosses
import matplotlib.pyplot as plt
import seaborn
import cv2
seaborn.set()
import scipy
import albumentations as A
import cv2
import numpy as np
from PIL import Image

In [None]:
data_path = 'dataset'
train_path = '/Sony_train_list.txt'
test_path = '/Sony_test_list.txt'
val_path = '/Sony_val_list.txt'
# np.random.seed(0)
# torch.manual_seed(0)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

# 1. Preprocess raw data from camera sensor

![](figures/3a.png)

Pack raw Bayer sensor data into 4 channels (R-G-B-G). By doing this also reduces resolution by factor of 2

## 1.1 Pack raw is used for input 

In [None]:
def pack_raw(raw):
    """
    Input: object returned from rawpy.imread()
    Output: numpy array in shape (1424, 2128, 4)
    """
    
    im = raw.raw_image_visible.astype(np.float32) # shape of (2848, 4256)
    im = np.maximum(im - 512, 0) / (16383 - 512) #subtract the black level
    im = np.expand_dims(im, axis=2) # shape of (2848, 4256, 1)

    img_shape = im.shape # (H, W, 1)
    H = img_shape[0]
    W = img_shape[1]
    
    # Pack into 4 channels
    red = im[0:H:2,0:W:2,:]
    green_1 = im[0:H:2,1:W:2,:]
    blue = im[1:H:2,1:W:2,:]
    green_2 = im[1:H:2,0:W:2,:]
    
    # Final shape: (1424, 2128, 4)
    out = np.concatenate((red, green_1, blue, green_2), axis=2)
    return out

In [None]:
# x_img = rawpy.imread(data_path + '/Sony/short/00001_00_0.04s.ARW')
# x_img = pack_raw(x_img)
# x_img.shape

## 1.2 Post process is used for ground true

In [None]:
def post_process(raw):
    """
    Input: object returned from rawpy.imgread()
    Output: numpy array in shape (2848, 4256, 3)
    """
    max_output = 65535.0
    im = raw.postprocess(use_camera_wb=True, no_auto_bright=True, output_bps=16)
    im = np.float32(im / max_output)
    im = cv2.resize(im, (2128 , 1424), interpolation = cv2.INTER_AREA)
    return im

In [None]:
# y_img = rawpy.imread(data_path + '/Sony/long/00001_00_10s.ARW')
# y_img = post_process(y_img)
# y_img.shape

## 1.3 Batch process all data

**Files' name explanation**

The file lists are provided. In each row, there are a short-exposed image path, the corresponding long-exposed image path, camera ISO and F number. 
Note that multiple short-exposed images may correspond to the same long-exposed image.

The file name contains the image information. For example, in "10019_00_0.033s.RAF":
- the first digit "1" means it is from the test set ("0" for training set and "2" for validation set)
- 0019" is the image ID
- the following "00" is the number in the sequence/burst
- "0.033s" is the exposure time 1/30 seconds.

There are some misalignment with the ground-truth for image 10034, 10045, 10172. I've removed those images for quantitative results, but they still can be used for qualitative evaluations.

In [None]:
def read_file_list(file_list):
    data = pd.read_csv(data_path + file_list, sep=" ", header = None, names = ['X', 'Y', 'ISO', 'F-stop'])
    return data

In [None]:
train_list = read_file_list('/Sony_train_list.txt')
train_list.head()

In [None]:
train_list['Y'].value_counts().hist()

In [None]:
def batch_process_raw(data, hide_progree=False):
    """
    Input: Pandas dataframe returned from read_file_list
    Output: a dictionary of 
            X : amplified numpy array
            Y : numpy array
            X_Y_map: numpy array of indexes of corresponding pair of X and Y
    """
    
    # Multiple Xs can have the same Y    
    m_x = len(data)
    m_y = data['Y'].nunique()
    
    X = np.zeros((m_x, 1424, 2128, 4), dtype=np.float32)
    Y = np.zeros((m_y, 1424, 2128, 3), dtype=np.float32)
   
    # Mapping of of X to Y
    X_map = []
    Y_map = []
    
    for i in pbar(range(m_x), disable=hide_progree):
        x_path = data.iloc[i][0][1:] # remove the "." in the name
        y_path = data.iloc[i][1][1:] # remove the "." in the name
        
        # Shutter speed is in the file name
        x_shutter_speed = x_path.split('_')[-1].split('s.')[0]
        y_shutter_speed = y_path.split('_')[-1].split('s.')[0]
        amp_ratio = float(y_shutter_speed)/float(x_shutter_speed)
        
        X[i] = pack_raw(rawpy.imread(data_path + x_path)) * amp_ratio
    
    for i in pbar(range(m_y), disable=hide_progree):
        current_y = data['Y'].unique()[i]
        
        y_path = current_y[1:]
        Y[i] = post_process(rawpy.imread(data_path + y_path))
        
        # Maping of X to Y
        X_map_temp = data['Y'][data['Y']==current_y].index.tolist()
        Y_map_temp = [i]*len(X_map_temp)
        X_map += X_map_temp
        Y_map += Y_map_temp
    
    X_Y_map = np.array((X_map, Y_map), dtype=np.int32).T
    dataset = {'X':X, 'Y':Y, 'X_Y_map':X_Y_map}
    
    return dataset

In [None]:
train_dataset = batch_process_raw(train_list.head(10), True)
print("Shape of X_train:", train_dataset['X'].shape)
print("Shape of Y_train:", train_dataset['Y'].shape)
print("Shape of X_Y_map_train:", train_dataset['X_Y_map'].shape)

In [None]:
train_dataset['X_Y_map']

# 3. Make batches of image patches for training, validation and testing

In [None]:
def numpy_to_torch(image):
    """
    Input: numpy array (H x W x C)
    Output: torch tensory (C x H x W)
    """
    image = image.transpose((2, 0, 1))
    torch_tensor = torch.from_numpy(image)
    return torch_tensor

In [None]:
#dataset class
class SeeInTheDarkDataset(Dataset):
    def __init__(self, dataset = None, transform = None):
        
        self.dataset = dataset
        self.transform = transform
        
    def __len__(self):
        return self.dataset['X_Y_map'].shape[0]
    
    
    def __getitem__(self, i):
        x_index, y_index = self.dataset['X_Y_map'][i][0], self.dataset['X_Y_map'][i][1]
        
        X, Y = self.dataset['X'][x_index], self.dataset['Y'][y_index]

        
        if self.transform:
            transformed = self.transform(image=X, mask=Y)
            X = transformed['image']
            Y = transformed['mask']
            X = transforms.ToTensor()(X)
            Y = transforms.ToTensor()(Y)
            
            X = torch.clamp(X, min=0.0, max=1.0)
            Y = torch.clamp(Y, min=0.0, max=1.0)  
        return X, Y 
    

# 4. Model architecture

![](figures/atten_Unet_GAN.png)

### Generator 

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

def init_weights(net, init_type='normal', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)

class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )


    def forward(self,x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

class Recurrent_block(nn.Module):
    def __init__(self,ch_out,t=2):
        super(Recurrent_block,self).__init__()
        self.t = t
        self.ch_out = ch_out
        self.conv = nn.Sequential(
            nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        for i in range(self.t):

            if i==0:
                x1 = self.conv(x)
            
            x1 = self.conv(x+x1)
        return x1
        
class RRCNN_block(nn.Module):
    def __init__(self,ch_in,ch_out,t=2):
        super(RRCNN_block,self).__init__()
        self.RCNN = nn.Sequential(
            Recurrent_block(ch_out,t=t),
            Recurrent_block(ch_out,t=t)
        )
        self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)

    def forward(self,x):
        x = self.Conv_1x1(x)
        x1 = self.RCNN(x)
        return x+x1


class single_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(single_conv,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.conv(x)
        return x

class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi


class U_Net(nn.Module):
    def __init__(self,img_ch=3,output_ch=1):
        super(U_Net,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
        self.Conv2 = conv_block(ch_in=64,ch_out=128)
        self.Conv3 = conv_block(ch_in=128,ch_out=256)
        self.Conv4 = conv_block(ch_in=256,ch_out=512)
        self.Conv5 = conv_block(ch_in=512,ch_out=1024)

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)


    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4,d5),dim=1)
        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return d1


class R2U_Net(nn.Module):
    def __init__(self,img_ch=3,output_ch=1,t=2):
        super(R2U_Net,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.Upsample = nn.Upsample(scale_factor=2)

        self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)

        self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
        
        self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
        
        self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
        
        self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
        

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
        
        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)


    def forward(self,x):
        # encoding path
        x1 = self.RRCNN1(x)

        x2 = self.Maxpool(x1)
        x2 = self.RRCNN2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.RRCNN3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.RRCNN4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.RRCNN5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4,d5),dim=1)
        d5 = self.Up_RRCNN5(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_RRCNN4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_RRCNN3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_RRCNN2(d2)

        d1 = self.Conv_1x1(d2)

        return d1



class AttU_Net(nn.Module):
    def __init__(self,img_ch=4,output_ch=3):
        super(AttU_Net,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
        self.Conv2 = conv_block(ch_in=64,ch_out=128)
        self.Conv3 = conv_block(ch_in=128,ch_out=256)
        self.Conv4 = conv_block(ch_in=256,ch_out=512)
        self.Conv5 = conv_block(ch_in=512,ch_out=1024)

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)


    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5,x=x4)
        d5 = torch.cat((x4,d5),dim=1)        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4,x=x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3,x=x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2,x=x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return d1


class R2AttU_Net(nn.Module):
    def __init__(self,img_ch=3,output_ch=1,t=2):
        super(R2AttU_Net,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.Upsample = nn.Upsample(scale_factor=2)

        self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)

        self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
        
        self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
        
        self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
        
        self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
        

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
        self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
        
        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
        self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
        self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
        self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)


    def forward(self,x):
        # encoding path
        x1 = self.RRCNN1(x)

        x2 = self.Maxpool(x1)
        x2 = self.RRCNN2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.RRCNN3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.RRCNN4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.RRCNN5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5,x=x4)
        d5 = torch.cat((x4,d5),dim=1)
        d5 = self.Up_RRCNN5(d5)
        
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4,x=x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_RRCNN4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3,x=x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_RRCNN3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2,x=x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_RRCNN2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

### Discriminator

In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(CNNBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, x_channels=4, y_channels = 3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                x_channels + y_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2),
            )
            in_channels = feature

        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            ),
        )

        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x


def Discriminator_test():
    x = torch.randn((1, 4, 256, 256))
    y = torch.randn((1, 3, 256, 256))
    model = Discriminator(x_channels=4, y_channels = 3)
    preds = model(x, y)
#     print(model)
    print(preds.shape)
    
Discriminator_test()


# 5. Traing and testing code

In [None]:
def calculate_psnr(target, output):
    """
    Calculate Peak Signal To Noise Ratio
    Input: torch tensor of shape (m, C, H, W)
    Output: average of PSTR for that batch
    """
    
    m, C, H, W = target.shape
    sum_psnr = 0 
    
    for i in range(m):
        output[i] = torch.clamp(output[i], min=0.0, max=1.0)
        mse = torch.sum((target[i] - output[i])**2)/(C*H*W)
        psnr =  -10*torch.log10(mse)
        sum_psnr += psnr
        
    return sum_psnr/m

In [None]:
# Train on cuda if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using', device, 'to train')

In [None]:
#data augmentation
my_transforms = A.Compose([
    A.RandomCrop(width=512, height=512),
    A.HorizontalFlip(p=0.2),
    A.VerticalFlip(p=0.2)
])

In [None]:
# train dataset
train_list = read_file_list(train_path)
train_data = batch_process_raw(train_list)

In [None]:
train_dataset = SeeInTheDarkDataset(dataset = train_data, transform =my_transforms)
train_loader = DataLoader(dataset = train_dataset, batch_size = 16, shuffle = True)

In [None]:
# Validation dataset
val_list = read_file_list(val_path)
val_data = batch_process_raw(val_list)

In [None]:
val_dataset = SeeInTheDarkDataset(dataset = val_data, transform =my_transforms)
val_loader = DataLoader(dataset = val_dataset, batch_size = 16, shuffle = True)

In [None]:
X,y = next(iter(val_loader))
print(X.shape, y.shape)

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 2e-4
NUM_WORKERS = 2
CHANNELS_IMG = 3
L1_LAMBDA = 100
LAMBDA_GP = 10
NUM_EPOCHS = 500

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image
from livelossplot import PlotLosses

torch.backends.cudnn.benchmark = True


def train_fn(
    disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler,
):
    gen.train()
    disc.train()
    loop = tqdm(loader, leave=True)
    total = 0 
    total_psnr = 0 
    for idx, (x, y) in enumerate(loop):
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, y) * L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()
        
        psnr_batch = calculate_psnr(y_fake.detach(), y.detach()).item()
        total += x.shape[0]
        total_psnr += psnr_batch*x.shape[0]
        
        if idx % 1 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
                D_psnr = psnr_batch
            )
            
    mean_psnr = total_psnr/total
    return mean_psnr
        


In [None]:
def test_model(gen, loader):
    loop = tqdm(loader, leave=True)
    total = 0 
    total_psnr = 0 
    gen.eval()
    with torch.no_grad():
        for idx, (x, y) in enumerate(loop):
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            with torch.cuda.amp.autocast():
                y_fake = gen(x)
            psnr_batch = calculate_psnr(y_fake.detach(), y.detach()).item()
            total += x.shape[0]
            total_psnr += psnr_batch*x.shape[0]
            
    mean_psnr = total_psnr/total
    return mean_psnr
        

In [None]:
from livelossplot import PlotLosses
file_name = 'AttUnetGanSSIM'
disc = Discriminator().to(DEVICE)
gen = AttU_Net().to(DEVICE)
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999),)
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
BCE = nn.BCEWithLogitsLoss()
L1_LOSS = nn.L1Loss()
check_point = 10
ms_ssim_module = MS_SSIM(data_range=1, size_average=True, channel=3)

g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
liveloss = PlotLosses()

for epoch in range(NUM_EPOCHS):
    
    best_model_weights = copy.deepcopy(gen.state_dict())
    
    plot_logs = {}
    train_psnr = train_fn(
        disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
    )
    val_psnr = test_model(gen, val_loader)
    
    plot_logs['PSNR'] = train_psnr
    plot_logs['val_PSNR'] = val_psnr
    
    liveloss.update(plot_logs)
    liveloss.draw()
        
    print(f'[epoch {epoch}/{NUM_EPOCHS}, TRAIN PSNR={train_psnr}, VAL PSNR={val_psnr}]')
    
     # Check point
    if epoch%check_point==0:
        torch.save(best_model_weights, f'trained_model/{file_name}_best_model.pt')


In [None]:
val_data['X'].max()

In [None]:
random_data = torch.rand(32, 4, 256, 256)
gen = Generator()
gen(random_data).shape

## Test

In [None]:
# Test dataset
test_list = read_file_list('/Sony_test_list.txt')
test_dataset = batch_process_raw(test_list)


In [None]:
test_list

In [None]:
score = test_model(gen, test_dataset, hide_progress=False)
print('Peak Signal Noise Ratio on test dataset {:.2f}'.format(score))

## Test custom image

In [None]:
display_an_example(gen, test_list, test_dataset, 65)

In [None]:
display_custom_image(gen, 'custom_images/image_1.arw', 8)