In [1]:
## mount google drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import sys
import os
sys.path.append(os.path.abspath('/content/drive/My Drive/image_filtering/'))

Mounted at /content/drive


In [2]:
from __future__ import print_function, division
import torchvision
import torch
from skimage import io, transform
import numpy as np
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import cv2
from math import log10, pi
import time
from google.colab import files

import utils
from datasetsRUAS import RedFlashDatasetRUAS
from vgg import Vgg16

from torchvision.utils import save_image

In [3]:
class MFFNet(torch.nn.Module):
    def __init__(self):
        super(MFFNet, self).__init__()
        
        self.conv1 = ConvLayer(4, 32, kernel_size=9, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
        self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
        self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
        # Residual layers
        self.res1 = ResidualBlock(128)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)
        self.res4 = ResidualBlock(128)
        self.res5 = ResidualBlock(128)
        self.res6 = ResidualBlock(128)
        self.res7 = ResidualBlock(128)
        self.res8 = ResidualBlock(128)
        self.res9 = ResidualBlock(128)
        self.res10 = ResidualBlock(128)
        self.res11 = ResidualBlock(128)
        self.res12 = ResidualBlock(128)
        self.res13 = ResidualBlock(128)
        self.res14 = ResidualBlock(128)
        self.res15 = ResidualBlock(128)
        self.res16 = ResidualBlock(128)
        
        self.deconv1 = UpsampleConvLayer(128*2, 64, kernel_size=3, stride=1, upsample=2)
        self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
        self.deconv2 = UpsampleConvLayer(64*2, 32, kernel_size=3, stride=1, upsample=2)
        self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
        self.deconv3 = ConvLayer(32*2, 3, kernel_size=9, stride=1)

        self.relu = torch.nn.ReLU()
    
    def forward(self, X):
        o1 = self.relu(self.conv1(X))
        o2 = self.relu(self.conv2(o1))
        o3 = self.relu(self.conv3(o2))

        y = self.res1(o3)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.res6(y)
        y = self.res7(y)
        y = self.res8(y)
        y = self.res9(y)
        y = self.res10(y)
        y = self.res11(y)
        y = self.res12(y)
        y = self.res13(y)
        y = self.res14(y)
        y = self.res15(y)
        y = self.res16(y)
        
        in1 = torch.cat( (y, o3), 1 )
        y = self.relu(self.deconv1(in1))
        in2 = torch.cat( (y, o2), 1 )
        y = self.relu(self.deconv2(in2))
        in3 = torch.cat( (y, o1), 1 )
        y = self.deconv3(in3)
        
        return y

class ConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
    
    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out

class ResidualBlock(torch.nn.Module):
    
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.relu = torch.nn.ReLU()
    
    def forward(self, x):
        residual = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        out = out + residual
        return out


class UpsampleConvLayer(torch.nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
    
    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out

In [4]:
# redflash_dataset = RedFlashDataset('/content/drive/My Drive/image_filtering/Deep Learning Final Project/images', True)
redflash_dataset = RedFlashDatasetRUAS('/content/drive/My Drive/image_filtering/Deep Learning Final Project/images', True)
train_size = int(0.7 * len(redflash_dataset.fileID))
val_size = len(redflash_dataset.fileID) - train_size
print(train_size, val_size)
train_dataset, val_dataset = torch.utils.data.random_split(redflash_dataset, [train_size, val_size])
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=16,
                                           shuffle=True,
                                           num_workers=0)

val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=16,
                                           shuffle=True,
                                           num_workers=0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
imageFilter = MFFNet()
model_name = 'MFF-net'
imageFilter.load_state_dict(torch.load(f'/content/drive/My Drive/image_filtering/MFF-net-sr.ckpt'))
imageFilter = imageFilter.to(device).float()

# Initializing VGG16 model for perceptual loss
VGG = Vgg16(requires_grad=False)
VGG = VGG.to(device)

1014 435


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

In [5]:
num_epochs = 600

criterion_img = nn.MSELoss()
criterion_vgg = nn.MSELoss()

start_epoch = 171

if start_epoch < 300:
    learning_rate = 1e-4
else: 
    learning_rate = 1e-5

optimizer = torch.optim.Adam(imageFilter.parameters(), lr=learning_rate)
total_step = len(train_loader)

start_time = time.time()
for epoch in range(start_epoch, num_epochs):
    loss_tol = 0
    loss_tol_vgg  = 0
    loss_tol_l2   = 0

    if epoch > 0 and epoch % 10 == 0:
      torch.save(imageFilter.state_dict(), f'/content/drive/My Drive/image_filtering/MFF-net-sr.ckpt')
      files.download('/content/drive/My Drive/image_filtering/MFF-net-sr.ckpt')
    
    if epoch == 300:
        learning_rate = 1e-5
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate
        
    if epoch == 600:
        learning_rate = 1e-6
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate
    
    for i, im in enumerate(train_loader):
        inputs = im[0].float().to(device)
        target = im[1].float().to(device)
        
        outputs = imageFilter(inputs)
        
        loss_l2 = criterion_img( outputs, target )
        
        outputs_n = utils.normalize_ImageNet_stats(outputs)
        target_n  = utils.normalize_ImageNet_stats(target)
        
        feature_o = VGG(outputs_n, 3)
        feature_t = VGG(target_n, 3)
        VGG_loss = []
        for l in range(3+1):
            VGG_loss.append( criterion_vgg(feature_o[l], feature_t[l]) )
        
        loss_vgg = sum(VGG_loss)
        loss = loss_l2 + 0.01*loss_vgg
    
        loss_tol += loss.item()
        
        loss_tol_vgg  += loss_vgg
        loss_tol_l2   += loss_l2
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print ( 'Epoch [{}/{}], Training Loss: {:.4f}, vgg Loss: {:.4f}, L2 Loss: {:.4f}' .format(epoch+1, num_epochs, loss_tol, loss_tol_vgg, loss_tol_l2) )

print("--- %0.4f seconds ---" % (time.time() - start_time)) 
torch.save(imageFilter.state_dict(), f'/content/drive/My Drive/image_filtering/MFF-net-sr.ckpt')
files.download('/content/drive/My Drive/image_filtering/MFF-net-sr.ckpt')

  input = Variable(torch.from_numpy(np.expand_dims(np.moveaxis(noise_im, -1, 0), axis=0)), volatile=True).to(device=self.device, dtype=torch.float)


Epoch [172/600], Training Loss: 2.8805, vgg Loss: 251.1579, L2 Loss: 0.3689
Epoch [173/600], Training Loss: 2.7482, vgg Loss: 243.9403, L2 Loss: 0.3087
Epoch [174/600], Training Loss: 2.5966, vgg Loss: 232.4969, L2 Loss: 0.2716
Epoch [175/600], Training Loss: 2.6914, vgg Loss: 242.8894, L2 Loss: 0.2625
Epoch [176/600], Training Loss: 2.6375, vgg Loss: 236.5473, L2 Loss: 0.2720
Epoch [177/600], Training Loss: 2.6431, vgg Loss: 236.7070, L2 Loss: 0.2760
Epoch [178/600], Training Loss: 2.6334, vgg Loss: 233.5445, L2 Loss: 0.2980
Epoch [179/600], Training Loss: 2.6579, vgg Loss: 235.6569, L2 Loss: 0.3013
Epoch [180/600], Training Loss: 2.7179, vgg Loss: 240.1791, L2 Loss: 0.3162


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [181/600], Training Loss: 2.6797, vgg Loss: 239.1833, L2 Loss: 0.2878
Epoch [182/600], Training Loss: 2.6446, vgg Loss: 235.2348, L2 Loss: 0.2922
Epoch [183/600], Training Loss: 2.6232, vgg Loss: 235.1160, L2 Loss: 0.2720
Epoch [184/600], Training Loss: 2.6246, vgg Loss: 233.3669, L2 Loss: 0.2909
Epoch [185/600], Training Loss: 2.7127, vgg Loss: 241.5767, L2 Loss: 0.2969
Epoch [186/600], Training Loss: 2.5903, vgg Loss: 233.4937, L2 Loss: 0.2553
Epoch [187/600], Training Loss: 2.6248, vgg Loss: 235.2117, L2 Loss: 0.2727
Epoch [188/600], Training Loss: 2.7209, vgg Loss: 242.1626, L2 Loss: 0.2993
Epoch [189/600], Training Loss: 2.7264, vgg Loss: 239.8883, L2 Loss: 0.3275
Epoch [190/600], Training Loss: 2.5535, vgg Loss: 230.6124, L2 Loss: 0.2473


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [191/600], Training Loss: 2.6489, vgg Loss: 234.6512, L2 Loss: 0.3024
Epoch [192/600], Training Loss: 2.6007, vgg Loss: 235.1931, L2 Loss: 0.2488
Epoch [193/600], Training Loss: 2.5707, vgg Loss: 231.9987, L2 Loss: 0.2507
Epoch [194/600], Training Loss: 2.7153, vgg Loss: 240.2122, L2 Loss: 0.3131
Epoch [195/600], Training Loss: 2.6754, vgg Loss: 237.2077, L2 Loss: 0.3034
Epoch [196/600], Training Loss: 2.6506, vgg Loss: 235.7891, L2 Loss: 0.2927
Epoch [197/600], Training Loss: 2.6299, vgg Loss: 233.9118, L2 Loss: 0.2908
Epoch [198/600], Training Loss: 2.5965, vgg Loss: 231.6221, L2 Loss: 0.2803
Epoch [199/600], Training Loss: 2.6333, vgg Loss: 234.5751, L2 Loss: 0.2875
Epoch [200/600], Training Loss: 2.6314, vgg Loss: 233.3569, L2 Loss: 0.2978


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [201/600], Training Loss: 2.5525, vgg Loss: 227.6976, L2 Loss: 0.2755
Epoch [202/600], Training Loss: 2.5677, vgg Loss: 227.8026, L2 Loss: 0.2897
Epoch [203/600], Training Loss: 2.5408, vgg Loss: 229.1649, L2 Loss: 0.2491
Epoch [204/600], Training Loss: 2.6703, vgg Loss: 236.1878, L2 Loss: 0.3084
Epoch [205/600], Training Loss: 2.6547, vgg Loss: 233.4007, L2 Loss: 0.3207
Epoch [206/600], Training Loss: 2.5335, vgg Loss: 228.2722, L2 Loss: 0.2508
Epoch [207/600], Training Loss: 2.5209, vgg Loss: 226.4546, L2 Loss: 0.2564
Epoch [208/600], Training Loss: 2.4592, vgg Loss: 222.1073, L2 Loss: 0.2381
Epoch [209/600], Training Loss: 2.5602, vgg Loss: 228.5112, L2 Loss: 0.2751
Epoch [210/600], Training Loss: 2.5884, vgg Loss: 230.8684, L2 Loss: 0.2798


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [211/600], Training Loss: 2.5634, vgg Loss: 230.3627, L2 Loss: 0.2597
Epoch [212/600], Training Loss: 2.6355, vgg Loss: 234.6770, L2 Loss: 0.2888
Epoch [213/600], Training Loss: 2.5721, vgg Loss: 231.0187, L2 Loss: 0.2619
Epoch [214/600], Training Loss: 2.5723, vgg Loss: 232.2348, L2 Loss: 0.2500
Epoch [215/600], Training Loss: 2.5236, vgg Loss: 228.1907, L2 Loss: 0.2417
Epoch [216/600], Training Loss: 2.5139, vgg Loss: 226.1651, L2 Loss: 0.2523
Epoch [217/600], Training Loss: 2.5356, vgg Loss: 228.7294, L2 Loss: 0.2483
Epoch [218/600], Training Loss: 2.6251, vgg Loss: 230.6628, L2 Loss: 0.3185
Epoch [219/600], Training Loss: 2.5007, vgg Loss: 223.5518, L2 Loss: 0.2651
Epoch [220/600], Training Loss: 2.6494, vgg Loss: 233.1225, L2 Loss: 0.3181


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [221/600], Training Loss: 2.5747, vgg Loss: 228.1846, L2 Loss: 0.2928
Epoch [222/600], Training Loss: 2.5641, vgg Loss: 228.6595, L2 Loss: 0.2775


KeyboardInterrupt: ignored

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
imageFil = MFFNet()

model_name = 'MFF-net'
imageFil.load_state_dict( torch.load('/content/drive/My Drive/image_filtering/%s.ckpt'%(model_name)) )
imageFil = imageFil.to(device).float()

In [None]:
data_root = '/content/drive/My Drive/image_filtering/test'
out_root = '/content/drive/My Drive/image_filtering/output'
if not os.path.exists(out_root):
    os.mkdir(out_root)

for seq in range(1,6):
    # input rgb image is obtained from demosaicing the raw (no other manipulation)
    # saved in 16-bit TIFF image
    file = ('rgb_%s.tiff' % (seq) )
    filename = os.path.join( data_root, file )
    inputs = io.imread(filename) / 65535

    file = ('guide_%s.bmp' % (seq) )
    filename = os.path.join( data_root, file )
    guided = io.imread(filename) / 255

    guided = guided[:,:,0]+guided[:,:,1]+guided[:,:,2]
    inputs = (inputs*80)**0.4
    inputs = np.concatenate((inputs, guided[:,:,None]), 2)
    inputs = np.transpose(inputs,(2,0,1))
    inputs = torch.from_numpy(inputs)
    inputs = inputs[None,:,:,:].float()

    with torch.no_grad():
        inputs = inputs.to(device) 
        outputs = imageFil(inputs)
    outputs[outputs>1] = 1
    outputs[outputs<0] = 0    

    # the parameter for color balance and brightness should be tuned for different scenes
    outputs[0,0,:,:] = outputs[0,0,:,:]*1.1*1.5
    outputs[0,1,:,:] = outputs[0,1,:,:]*1*1.5
    outputs[0,2,:,:] = outputs[0,2,:,:]*1.5*1.5

    save_image(outputs[0,:,:,:], '%s/out_%s.png' % (out_root, seq))
    save_image(inputs[0,0:3,:,:], '%s/inp_%s.png' % (out_root, seq))