# 0. Set up

In [None]:
!nvidia-smi

In [None]:
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn as nn
import numpy as np
import pandas as pd
import rawpy
from tqdm import tqdm as pbar
import copy
from livelossplot import PlotLosses
import matplotlib.pyplot as plt
import seaborn
seaborn.set()
import scipy

In [None]:
!source  /scratch/yt2188/temp/env.sh 


In [None]:
data_path = 'dataset'
# np.random.seed(0)
# torch.manual_seed(0)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

# 1. Preprocess raw data from camera sensor

![](figures/3a.png)

Pack raw Bayer sensor data into 4 channels (R-G-B-G). By doing this also reduces resolution by factor of 2

## 1.1 Pack raw is used for input 

In [None]:
def pack_raw(raw):
    """
    Input: object returned from rawpy.imread()
    Output: numpy array in shape (1424, 2128, 4)
    """
    
    im = raw.raw_image_visible.astype(np.float32) # shape of (2848, 4256)
    im = np.maximum(im - 512, 0) / (16383 - 512) #subtract the black level
    im = np.expand_dims(im, axis=2) # shape of (2848, 4256, 1)

    img_shape = im.shape # (H, W, 1)
    H = img_shape[0]
    W = img_shape[1]
    
    # Pack into 4 channels
    red = im[0:H:2,0:W:2,:]
    green_1 = im[0:H:2,1:W:2,:]
    blue = im[1:H:2,1:W:2,:]
    green_2 = im[1:H:2,0:W:2,:]
    
    # Final shape: (1424, 2128, 4)
    out = np.concatenate((red, green_1, blue, green_2), axis=2)
    return out

In [None]:
# x_img = rawpy.imread(data_path + '/Sony/short/00001_00_0.04s.ARW')
# x_img = pack_raw(x_img)
# x_img.shape

## 1.2 Post process is used for ground true

In [None]:
def post_process(raw):
    """
    Input: object returned from rawpy.imgread()
    Output: numpy array in shape (2848, 4256, 3)
    """
    max_output = 65535.0
    im = raw.postprocess(use_camera_wb=True, no_auto_bright=True, output_bps=16)
    im = np.float32(im / max_output)
    return im

In [None]:
# y_img = rawpy.imread(data_path + '/Sony/long/00001_00_10s.ARW')
# y_img = post_process(y_img)
# y_img.shape

## 1.3 Batch process all data

**Files' name explanation**

The file lists are provided. In each row, there are a short-exposed image path, the corresponding long-exposed image path, camera ISO and F number. 
Note that multiple short-exposed images may correspond to the same long-exposed image.

The file name contains the image information. For example, in "10019_00_0.033s.RAF":
- the first digit "1" means it is from the test set ("0" for training set and "2" for validation set)
- 0019" is the image ID
- the following "00" is the number in the sequence/burst
- "0.033s" is the exposure time 1/30 seconds.

There are some misalignment with the ground-truth for image 10034, 10045, 10172. I've removed those images for quantitative results, but they still can be used for qualitative evaluations.

In [None]:
def read_file_list(file_list):
    data = pd.read_csv(data_path + file_list, sep=" ", header = None, names = ['X', 'Y', 'ISO', 'F-stop'])
    return data

In [None]:
train_list = read_file_list('/Sony_train_list.txt')
train_list.head()

In [None]:
def batch_process_raw(data, hide_progree=False):
    """
    Input: Pandas dataframe returned from read_file_list
    Output: a dictionary of 
            X : amplified numpy array
            Y : numpy array
            X_Y_map: numpy array of indexes of corresponding pair of X and Y
    """
    
    # Multiple Xs can have the same Y    
    m_x = len(data)
    m_y = data['Y'].nunique()
    
    X = np.zeros((m_x, 1424, 2128, 4), dtype=np.float32)
    Y = np.zeros((m_y, 2848, 4256, 3), dtype=np.float32)
   
    # Mapping of of X to Y
    X_map = []
    Y_map = []
    
    for i in pbar(range(m_x), disable=hide_progree):
        x_path = data.iloc[i][0][1:] # remove the "." in the name
        y_path = data.iloc[i][1][1:] # remove the "." in the name
        
        # Shutter speed is in the file name
        x_shutter_speed = x_path.split('_')[-1].split('s.')[0]
        y_shutter_speed = y_path.split('_')[-1].split('s.')[0]
        amp_ratio = float(y_shutter_speed)/float(x_shutter_speed)
        
        X[i] = pack_raw(rawpy.imread(data_path + x_path)) * amp_ratio
    
    for i in pbar(range(m_y), disable=hide_progree):
        current_y = data['Y'].unique()[i]
        
        y_path = current_y[1:]
        Y[i] = post_process(rawpy.imread(data_path + y_path))
        
        # Maping of X to Y
        X_map_temp = data['Y'][data['Y']==current_y].index.tolist()
        Y_map_temp = [i]*len(X_map_temp)
        X_map += X_map_temp
        Y_map += Y_map_temp
    
    X_Y_map = np.array((X_map, Y_map), dtype=np.int32).T
    dataset = {'X':X, 'Y':Y, 'X_Y_map':X_Y_map}
    
    return dataset

In [None]:
train_dataset = batch_process_raw(train_list.head(10), True)
print("Shape of X_train:", train_dataset['X'].shape)
print("Shape of Y_train:", train_dataset['Y'].shape)
print("Shape of X_Y_map_train:", train_dataset['X_Y_map'].shape)

# 2. Data augmentation
Random crop, flip, and tranpose data, then amplify the result

In [None]:
def numpy_to_torch(image):
    """
    Input: numpy array (H x W x C)
    Output: torch tensory (C x H x W)
    """
    image = image.transpose((2, 0, 1))
    torch_tensor = torch.from_numpy(image)
    return torch_tensor

In [None]:
def augment_data(x_input, y_output, ps):
    """
    Input: numpy arrays with shape (H x W x C), patch_size = integer
    Output: X: augmented torch tensor with shape (C x atch_size x patch_size)
            Y: augmented numpy arrays with shape (Cx 2*patch_size x 2*patch_size)
    """
    
    # Random crop
    H = x_input.shape[0]
    W = x_input.shape[1]
    xx = np.random.randint(0, W-ps)
    yy = np.random.randint(0, H-ps)
    x_patch = x_input[yy:yy+ps, xx:xx+ps,:]
    y_patch = y_output[yy*2:yy*2+ps*2, xx*2:xx*2+ps*2,:]

    # Random flip first axis
    if np.random.randint(2, size=1)[0] == 1:
        x_patch = np.flip(x_patch, axis=0)
        y_patch = np.flip(y_patch, axis=0)
    
    # Random flip second axis
    if np.random.randint(2, size=1)[0] == 1:
        x_patch = np.flip(x_patch, axis=1)
        y_patch = np.flip(y_patch, axis=1)
    
    # Random transpose
    if np.random.randint(2, size=1)[0] == 1:
        x_patch = np.transpose(x_patch, (1, 0, 2))
        y_patch = np.transpose(y_patch, (1, 0, 2))
    
    # Clip saturated value
    x_patch = np.clip(x_patch, a_min=0.0, a_max=1.0)
    y_patch = np.clip(y_patch, a_min=0.0, a_max=1.0)
        
    return numpy_to_torch(x_patch), numpy_to_torch(y_patch)

In [None]:
# x_aug, y_aug = augment_data(X_train[0], Y_train[0], 512)
# print("Shape of X_aug:", x_aug.shape)
# print("Shape of Y_aug:", y_aug.shape)

# 3. Make batches of image patches for training, validation and testing

In [None]:
def make_batch(dataset, image_indexes, patch_size):
    """
    Prepare a batch for training
    Input:  a dictionary of X, Y, X_Y_map (returned from batch_process_raw())
            image_indexces: a subset size m of random permuatation of X_Y_map 
    Output: X torch tensor of shape (m, 4, patch_size, patch_size)
            Y torch tensor of shape (m, 3, 2*patch_size, 2*patch_size)
    """
    
    X = dataset['X']
    Y = dataset['Y']
    
    m = len(image_indexes)
    
    X_patches = torch.zeros(m, 4, patch_size, patch_size, dtype=torch.float32, device=device)
    Y_patches = torch.zeros(m, 3, 2*patch_size, 2*patch_size, dtype=torch.float32, device=device)
    
    for i in range(m):
        x_index, y_index = image_indexes[i]
        X_patches[i], Y_patches[i] = augment_data(X[x_index], Y[y_index], patch_size)
    
    return X_patches, Y_patches

In [None]:
def make_batch_test(dataset, image_indexes):
    """
    Prepare a batch (full res) for testing
    """
    X = dataset['X']
    Y = dataset['Y']
    
    m = len(image_indexes)
    
    X_images = torch.zeros(m, 4, 1424, 2128, dtype=torch.float32, device=device)
    Y_images = torch.zeros(m, 3, 2848, 4256, dtype=torch.float32, device=device)
    
    for i in range(m):
        x_index, y_index = image_indexes[i]
        X_images[i] = numpy_to_torch(np.clip(X[x_index], a_min=0.0, a_max=1.0))
        Y_images[i] = numpy_to_torch(np.clip(Y[y_index], a_min=0.0, a_max=1.0))
    
    return X_images, Y_images

In [None]:
# batch_size = 2
# random_orders = np.random.permutation(train_dataset['X_Y_map'])
# splitted_random_orders = np.array_split(random_orders, range(batch_size, len(random_orders), batch_size))

# first_batch_indexes = splitted_random_orders
# last_batch_indexes = splitted_random_orders

# x_batch, y_batch = make_batch(train_dataset, first_batch_indexes, 512)
# print('Shape of first batch X and Y:', x_batch.shape, y_batch.shape)

# x_batch, y_batch = make_batch(train_dataset, last_batch_indexes, 512)
# print('Shape of last batch X and Y:', x_batch.shape, y_batch.shape)

# 4. Model architecture

In [None]:
class DoubleConv(nn.Module):
    #  Conv -> BN -> LReLU -> Conv -> BN -> LReLU
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.f = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(0.2, inplace=True),)
    def forward(self, x):
        x = self.f(x)
        return x


class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.f = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_ch, out_ch),)
    def forward(self, x):
        x = self.f(x)
        return x


class Up(nn.Module):
    # upsample and concat
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.upsample = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x1, x2):
        x1 = self.upsample(x1)
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class OutConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(OutConv, self).__init__()
        self.f = nn.Conv2d(in_ch, out_ch, 1)
    def forward(self, x):
        x = self.f(x)
        return x

class Unet(nn.Module):
    def __init__(self):
        super().__init__()
        self.inc = DoubleConv(4, 32)
        self.d1 = Down(32, 64)
        self.d2 = Down(64, 128)
        self.d3 = Down(128, 256)
        self.d4 = Down(256, 512)

        self.u1 = Up(512, 256)
        self.u2 = Up(256, 128)
        self.u3 = Up(128, 64)
        self.u4 = Up(64, 32)
        self.outc = OutConv(32, 12)
        self.pixel_shuffle = nn.PixelShuffle(2)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.d1(x1)
        x3 = self.d2(x2)
        x4 = self.d3(x3)
        x5 = self.d4(x4)
        x = self.u1(x5, x4)
        x = self.u2(x, x3)
        x = self.u3(x, x2)
        x = self.u4(x, x1)
        x = self.outc(x)
        x = self.pixel_shuffle(x)
        return x

# 5. Traing and testing code

In [None]:
def calculate_psnr(target, output):
    """
    Calculate Peak Signal To Noise Ratio
    Input: torch tensor of shape (m, C, H, W)
    Output: average of PSTR for that batch
    """
    
    m, C, H, W = target.shape
    sum_psnr = 0 
    
    for i in range(m):
        output[i] = torch.clamp(output[i], min=0.0, max=1.0)
        mse = torch.sum((target[i] - output[i])**2)/(C*H*W)
        psnr =  -10*torch.log10(mse)
        sum_psnr += psnr
        
    return sum_psnr/m

In [None]:
def train_model(model, train_dataset, val_dataset, optimizer, scheduler, check_point, batch_size, num_epochs):
    liveloss = PlotLosses()
    criterion = nn.L1Loss()
    m_train = len(train_dataset['X_Y_map'])
    best_psnr = 0.0
    best_model_weights = copy.deepcopy(model.state_dict())
    
    for epoch in pbar(range(num_epochs)):
        plot_logs = {}
        logs = []
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'validation']:
            psnr_epoch = 0
            
            if phase == 'train':
                model.train()

                # Shuffle training set
                random_orders = np.random.permutation(train_dataset['X_Y_map'])
                # Split the training set in into batches
                splitted_random_orders = np.array_split(random_orders, range(batch_size, len(random_orders), batch_size))
                
                # Iterate over data
                for a_batch_index in splitted_random_orders:
                    image, target = make_batch(train_dataset, a_batch_index, patch_size=512)
                    
                    # Zero gradient
                    optimizer.zero_grad()
                    
                    # Forward pass
                    y_hat = model(image)
                    
                    # Calculate loss
                    psnr_batch = calculate_psnr(target.detach(), y_hat.detach()).item()
                    loss = criterion(target, y_hat)
                    psnr_epoch += psnr_batch * image.size(0)
                    
                    # Backward pass
                    loss.backward()
                    optimizer.step()
                    
                # Update logs
                psnr_epoch = psnr_epoch / m_train
                plot_logs['PSNR'] = psnr_epoch
                logs.append(psnr_epoch)
                                    
            else:    
                val_psnr_epoch = test_model(model, val_dataset)
                
                # Update logs
                plot_logs['val_PSNR'] = val_psnr_epoch
                logs.append(val_psnr_epoch)
                
                # Save best model
                if val_psnr_epoch > best_psnr:
                    best_psnr = val_psnr_epoch
                    best_model_weights = copy.deepcopy(model.state_dict())
                    
                # Check point
                if epoch%check_point==0:
                    torch.save(best_model_weights, 'trained_model/best_model.pt')

        scheduler.step()
        
        # Update live plot every epoch
        liveloss.update(plot_logs)
        liveloss.draw()
        
        # Write to log file every epoch
        # Epoch - Best Val PSNR - Train  PSNR - Val PSNR
        f = open("trained_model/training_log.txt", "a")
        f.write("\n{:4d} \t{:.5f} \t{:.5f} \t{:.5f}".format(epoch, best_psnr, logs[0], logs[1]))
        f.close()

In [None]:
def test_model(model, dataset, hide_progress=True):
    model.eval()
    m_test = len(dataset['X_Y_map'])
    test_psnr = 0
    
    with torch.no_grad():
        # Iterate over data
        for i in pbar(dataset['X_Y_map'], disable=hide_progress):
            image, target = make_batch_test(dataset, np.expand_dims(i, 0))

            # Forward pass
            y_hat = model(image)

            # Calculate loss
            test_psnr_batch = calculate_psnr(target, y_hat).item()
            test_psnr += test_psnr_batch
            
    return test_psnr / m_test

In [None]:
def display_an_example(model, image_list, dataset, index):
    """
    Display a single example
    """
    model.eval()
    with torch.no_grad():
        image, ground_truth = make_batch_test(dataset, np.expand_dims(dataset['X_Y_map'][index], 0))
        y_hat = model(image)
        y_hat = torch.clamp(y_hat, min=0.0, max=1.0)
    
        # Convert from torch tensor to numpy
        y_hat = y_hat.squeeze().cpu().numpy().transpose((1, 2, 0))
        ground_truth = ground_truth.squeeze().cpu().numpy().transpose((1, 2, 0))
    
    x_path = image_list.iloc[index][0][1:] # remove the "." in the name   
    image_to_display = post_process(rawpy.imread(data_path + x_path))
    fig=plt.figure(figsize=(30, 10))
    
    fig.add_subplot(1, 3, 1)
    plt.imshow(image_to_display, vmin=0, vmax=1)
    plt.title('Original image')
    plt.axis('off')
    plt.grid(b=None)
    
    fig.add_subplot(1, 3, 2)
    plt.imshow(y_hat, vmin=0, vmax=1)
    plt.title('Denoised by model')
    plt.axis('off')
    plt.grid(b=None)
    
    fig.add_subplot(1, 3, 3)
    plt.imshow(ground_truth, vmin=0, vmax=1)
    plt.title('Ground Truth')
    plt.axis('off')
    plt.grid(b=None)

    plt.show()

In [None]:
def display_custom_image(model, image_path, amp_ratio, render=False):
    model.eval()
        
    orig_image = post_process(rawpy.imread(image_path))
    
    fig=plt.figure(figsize=(20, 10))
    fig.add_subplot(1, 2, 1)
    plt.imshow(orig_image, vmin=0, vmax=1)
    plt.title('Original image')
    plt.axis('off')
    plt.grid(b=None)
    
    image = pack_raw(rawpy.imread(image_path)) * amp_ratio
    image = numpy_to_torch(np.clip(image, a_min=0.0, a_max=1.0)).unsqueeze(0)
    image = image.to(device)
    with torch.no_grad():
        y_hat = model(image)
        y_hat = torch.clamp(y_hat, min=0.0, max=1.0)
    image = y_hat.squeeze().cpu().numpy().transpose((1, 2, 0))
        
    fig.add_subplot(1, 2, 2)
    plt.imshow(image, vmin=0, vmax=1)
    plt.title('Denoised by model')
    plt.axis('off')
    plt.grid(b=None)
    
    if render:
        scipy.misc.toimage(image * 255, high=255, low=0, cmin=0, cmax=255).save('custom_images/processed.png')

    plt.show()

# 6. Put everything together

## Train

In [None]:
# Train on cuda if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using', device, 'to train')

In [None]:
# Train dataset
train_list = read_file_list('/Sony_train_list.txt')
train_dataset = batch_process_raw(train_list)


In [None]:
# Validation dataset
val_list = read_file_list('/Sony_val_list.txt')
val_dataset = batch_process_raw(val_list)

In [None]:
# Inialize and load model
my_model = Unet()
# my_model.load_state_dict(torch.load('trained_model/best_model.pt',map_location='cuda:0'))
my_model = my_model.to(device)

In [None]:
# # Initialize optimizer
optimizer = optim.Adam(my_model.parameters(), lr=1e-5, amsgrad=True)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[1000], gamma=0.1)

# # Train model
train_model(my_model, train_dataset, val_dataset, optimizer, scheduler, check_point=10, batch_size=32, num_epochs=150)

## Test

In [None]:
# Test dataset
test_list = read_file_list('/Sony_test_list.txt')[:100]
test_dataset = batch_process_raw(test_list)

# Inialize and load model
my_model = Unet()
my_model.load_state_dict(torch.load('trained_model/best_model.pt',map_location='cuda'))
my_model = my_model.to(device)

In [None]:
test_list

In [None]:
score = test_model(my_model, test_dataset, hide_progress=False)
print('Peak Signal Noise Ratio on test dataset {:.2f}'.format(score))

## Test custom image

In [None]:
display_an_example(my_model, test_list, test_dataset, 65)

In [None]:
display_custom_image(my_model, 'custom_images/image_1.arw', 8)