In [1]:
## mount google drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import sys
import os
sys.path.append(os.path.abspath('/content/drive/My Drive/image_filtering/'))

Mounted at /content/drive


In [2]:
from __future__ import print_function, division
import torchvision
import torch
from skimage import io, transform
import numpy as np
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import cv2
from math import log10, pi
import time
from google.colab import files

import utils
from datasets import RedFlashDataset
from vgg import Vgg16

from torchvision.utils import save_image

In [3]:
class MFFNet(torch.nn.Module):
    def __init__(self):
        super(MFFNet, self).__init__()
        
        self.conv1 = ConvLayer(4, 32, kernel_size=9, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
        self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
        self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
        # Residual layers
        self.res1 = ResidualBlock(128)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)
        self.res4 = ResidualBlock(128)
        self.res5 = ResidualBlock(128)
        self.res6 = ResidualBlock(128)
        self.res7 = ResidualBlock(128)
        self.res8 = ResidualBlock(128)
        self.res9 = ResidualBlock(128)
        self.res10 = ResidualBlock(128)
        self.res11 = ResidualBlock(128)
        self.res12 = ResidualBlock(128)
        self.res13 = ResidualBlock(128)
        self.res14 = ResidualBlock(128)
        self.res15 = ResidualBlock(128)
        self.res16 = ResidualBlock(128)
        
        self.deconv1 = UpsampleConvLayer(128*2, 64, kernel_size=3, stride=1, upsample=2)
        self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
        self.deconv2 = UpsampleConvLayer(64*2, 32, kernel_size=3, stride=1, upsample=2)
        self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
        self.deconv3 = ConvLayer(32*2, 3, kernel_size=9, stride=1)

        self.relu = torch.nn.ReLU()
    
    def forward(self, X):
        o1 = self.relu(self.conv1(X))
        o2 = self.relu(self.conv2(o1))
        o3 = self.relu(self.conv3(o2))

        y = self.res1(o3)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.res6(y)
        y = self.res7(y)
        y = self.res8(y)
        y = self.res9(y)
        y = self.res10(y)
        y = self.res11(y)
        y = self.res12(y)
        y = self.res13(y)
        y = self.res14(y)
        y = self.res15(y)
        y = self.res16(y)
        
        in1 = torch.cat( (y, o3), 1 )
        y = self.relu(self.deconv1(in1))
        in2 = torch.cat( (y, o2), 1 )
        y = self.relu(self.deconv2(in2))
        in3 = torch.cat( (y, o1), 1 )
        y = self.deconv3(in3)
        
        return y

class ConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
    
    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out

class ResidualBlock(torch.nn.Module):
    
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.relu = torch.nn.ReLU()
    
    def forward(self, x):
        residual = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        out = out + residual
        return out


class UpsampleConvLayer(torch.nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
    
    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out

In [4]:
redflash_dataset = RedFlashDataset('/content/drive/My Drive/image_filtering/Deep Learning Final Project/images', True)
train_size = int(0.7 * len(redflash_dataset.fileID))
val_size = len(redflash_dataset.fileID) - train_size
print(train_size, val_size)
train_dataset, val_dataset = torch.utils.data.random_split(redflash_dataset, [train_size, val_size])
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=16,
                                           shuffle=True,
                                           num_workers=0)

val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=16,
                                           shuffle=True,
                                           num_workers=0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
imageFilter = MFFNet()
model_name = 'MFF-net'
# imageFilter.load_state_dict(torch.load(f'/content/drive/My Drive/image_filtering/MFF-net-nd.ckpt'))
imageFilter = imageFilter.to(device).float()

# Initializing VGG16 model for perceptual loss
VGG = Vgg16(requires_grad=False)
VGG = VGG.to(device)

1014 435


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

In [None]:
num_epochs = 600
learning_rate = 1e-4

criterion_img = nn.MSELoss()
criterion_vgg = nn.MSELoss()

optimizer = torch.optim.Adam(imageFilter.parameters(), lr=learning_rate)
total_step = len(train_loader)

start_epoch = 0

start_time = time.time()
for epoch in range(start_epoch, num_epochs):
    loss_tol = 0
    loss_tol_vgg  = 0
    loss_tol_l2   = 0

    if epoch > 0 and epoch % 20 == 0:
      torch.save(imageFilter.state_dict(), f'/content/drive/My Drive/image_filtering/MFF-net-nd.ckpt')
      files.download('/content/drive/My Drive/image_filtering/MFF-net-nd.ckpt')
    
    if epoch == 300:
        learning_rate = 1e-5
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate
        
    if epoch == 600:
        learning_rate = 1e-6
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate
    
    for i, im in enumerate(train_loader):
        inputs = im[0].float().to(device)
        target = im[1].float().to(device)
        
        outputs = imageFilter(inputs)
        
        loss_l2 = criterion_img( outputs, target )
        
        outputs_n = utils.normalize_ImageNet_stats(outputs)
        target_n  = utils.normalize_ImageNet_stats(target)
        
        feature_o = VGG(outputs_n, 3)
        feature_t = VGG(target_n, 3)
        VGG_loss = []
        for l in range(3+1):
            VGG_loss.append( criterion_vgg(feature_o[l], feature_t[l]) )
        
        loss_vgg = sum(VGG_loss)
        loss = loss_l2 + 0.01*loss_vgg
    
        loss_tol += loss.item()
        
        loss_tol_vgg  += loss_vgg
        loss_tol_l2   += loss_l2
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print ( 'Epoch [{}/{}], Training Loss: {:.4f}, vgg Loss: {:.4f}, L2 Loss: {:.4f}' .format(epoch+1, num_epochs, loss_tol, loss_tol_vgg, loss_tol_l2) )

print("--- %0.4f seconds ---" % (time.time() - start_time)) 
torch.save(imageFilter.state_dict(), f'/content/drive/My Drive/image_filtering/MFF-net-nd.ckpt')
files.download('/content/drive/My Drive/image_filtering/MFF-net-nd.ckpt')

Epoch [1/600], Training Loss: 11.5563, vgg Loss: 933.6595, L2 Loss: 2.2197
Epoch [2/600], Training Loss: 7.2331, vgg Loss: 685.7779, L2 Loss: 0.3753
Epoch [3/600], Training Loss: 5.7115, vgg Loss: 534.6409, L2 Loss: 0.3651
Epoch [4/600], Training Loss: 5.0596, vgg Loss: 476.6694, L2 Loss: 0.2929
Epoch [5/600], Training Loss: 4.6662, vgg Loss: 440.3136, L2 Loss: 0.2631
Epoch [6/600], Training Loss: 4.4699, vgg Loss: 420.1012, L2 Loss: 0.2688
Epoch [7/600], Training Loss: 4.3311, vgg Loss: 408.6003, L2 Loss: 0.2451
Epoch [8/600], Training Loss: 4.0230, vgg Loss: 380.2422, L2 Loss: 0.2206
Epoch [9/600], Training Loss: 4.0003, vgg Loss: 378.7560, L2 Loss: 0.2127
Epoch [10/600], Training Loss: 3.8848, vgg Loss: 367.2649, L2 Loss: 0.2121
Epoch [11/600], Training Loss: 3.7086, vgg Loss: 351.5185, L2 Loss: 0.1934
Epoch [12/600], Training Loss: 3.6544, vgg Loss: 344.5465, L2 Loss: 0.2089
Epoch [13/600], Training Loss: 3.5948, vgg Loss: 341.0427, L2 Loss: 0.1844
Epoch [14/600], Training Loss: 3.

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [21/600], Training Loss: 3.0972, vgg Loss: 296.2436, L2 Loss: 0.1347
Epoch [22/600], Training Loss: 3.0356, vgg Loss: 288.8590, L2 Loss: 0.1470
Epoch [23/600], Training Loss: 2.9808, vgg Loss: 285.3190, L2 Loss: 0.1276
Epoch [24/600], Training Loss: 2.9318, vgg Loss: 280.5728, L2 Loss: 0.1261
Epoch [25/600], Training Loss: 2.9498, vgg Loss: 282.0716, L2 Loss: 0.1291
Epoch [26/600], Training Loss: 2.9815, vgg Loss: 285.4479, L2 Loss: 0.1271
Epoch [27/600], Training Loss: 2.9040, vgg Loss: 279.5026, L2 Loss: 0.1090
Epoch [28/600], Training Loss: 2.8773, vgg Loss: 275.7560, L2 Loss: 0.1197
Epoch [29/600], Training Loss: 2.8390, vgg Loss: 271.7835, L2 Loss: 0.1211
Epoch [30/600], Training Loss: 2.8230, vgg Loss: 271.6785, L2 Loss: 0.1062
Epoch [31/600], Training Loss: 2.7263, vgg Loss: 262.0810, L2 Loss: 0.1055
Epoch [32/600], Training Loss: 2.7307, vgg Loss: 262.2357, L2 Loss: 0.1084
Epoch [33/600], Training Loss: 2.7164, vgg Loss: 260.7654, L2 Loss: 0.1087
Epoch [34/600], Training 

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [41/600], Training Loss: 2.6713, vgg Loss: 257.5684, L2 Loss: 0.0956
Epoch [42/600], Training Loss: 2.6029, vgg Loss: 250.8204, L2 Loss: 0.0947
Epoch [43/600], Training Loss: 2.5732, vgg Loss: 248.6102, L2 Loss: 0.0871
Epoch [44/600], Training Loss: 2.5268, vgg Loss: 243.9189, L2 Loss: 0.0876
Epoch [45/600], Training Loss: 2.5234, vgg Loss: 243.6022, L2 Loss: 0.0874
Epoch [46/600], Training Loss: 2.5446, vgg Loss: 245.9496, L2 Loss: 0.0851
Epoch [47/600], Training Loss: 2.5819, vgg Loss: 249.6548, L2 Loss: 0.0854
Epoch [48/600], Training Loss: 2.5650, vgg Loss: 247.6674, L2 Loss: 0.0883
Epoch [49/600], Training Loss: 2.4630, vgg Loss: 238.2375, L2 Loss: 0.0806
Epoch [50/600], Training Loss: 2.4876, vgg Loss: 238.8417, L2 Loss: 0.0991
Epoch [51/600], Training Loss: 2.4843, vgg Loss: 240.4035, L2 Loss: 0.0803
Epoch [52/600], Training Loss: 2.5281, vgg Loss: 244.4685, L2 Loss: 0.0834
Epoch [53/600], Training Loss: 2.5375, vgg Loss: 244.2849, L2 Loss: 0.0947
Epoch [54/600], Training 

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [61/600], Training Loss: 2.3208, vgg Loss: 224.6367, L2 Loss: 0.0744
Epoch [62/600], Training Loss: 2.3952, vgg Loss: 231.7842, L2 Loss: 0.0773
Epoch [63/600], Training Loss: 2.3827, vgg Loss: 230.0897, L2 Loss: 0.0818
Epoch [64/600], Training Loss: 2.3859, vgg Loss: 230.0778, L2 Loss: 0.0851
Epoch [65/600], Training Loss: 2.3926, vgg Loss: 231.8032, L2 Loss: 0.0746
Epoch [66/600], Training Loss: 2.3290, vgg Loss: 226.5315, L2 Loss: 0.0637
Epoch [67/600], Training Loss: 2.3301, vgg Loss: 225.4751, L2 Loss: 0.0754
Epoch [68/600], Training Loss: 2.3278, vgg Loss: 225.1710, L2 Loss: 0.0760
Epoch [69/600], Training Loss: 2.3574, vgg Loss: 228.5060, L2 Loss: 0.0723
Epoch [70/600], Training Loss: 2.3295, vgg Loss: 225.7002, L2 Loss: 0.0725
Epoch [71/600], Training Loss: 2.3319, vgg Loss: 226.2174, L2 Loss: 0.0697
Epoch [72/600], Training Loss: 2.2895, vgg Loss: 222.1776, L2 Loss: 0.0677
Epoch [73/600], Training Loss: 2.2910, vgg Loss: 222.2459, L2 Loss: 0.0685
Epoch [74/600], Training 

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [81/600], Training Loss: 2.2198, vgg Loss: 215.6718, L2 Loss: 0.0631
Epoch [82/600], Training Loss: 2.2771, vgg Loss: 219.7778, L2 Loss: 0.0793
Epoch [83/600], Training Loss: 2.1869, vgg Loss: 212.1122, L2 Loss: 0.0658
Epoch [84/600], Training Loss: 2.2036, vgg Loss: 213.5495, L2 Loss: 0.0681
Epoch [85/600], Training Loss: 2.2121, vgg Loss: 214.1334, L2 Loss: 0.0708
Epoch [86/600], Training Loss: 2.2445, vgg Loss: 217.8952, L2 Loss: 0.0655
Epoch [87/600], Training Loss: 2.2177, vgg Loss: 215.3435, L2 Loss: 0.0642
Epoch [88/600], Training Loss: 2.2127, vgg Loss: 214.4559, L2 Loss: 0.0682
Epoch [89/600], Training Loss: 2.2132, vgg Loss: 214.8472, L2 Loss: 0.0648
Epoch [90/600], Training Loss: 2.1228, vgg Loss: 206.3576, L2 Loss: 0.0593
Epoch [91/600], Training Loss: 2.1356, vgg Loss: 207.1765, L2 Loss: 0.0639
Epoch [92/600], Training Loss: 2.1582, vgg Loss: 209.3665, L2 Loss: 0.0645
Epoch [93/600], Training Loss: 2.1036, vgg Loss: 204.2016, L2 Loss: 0.0616
Epoch [94/600], Training 

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [101/600], Training Loss: 2.1253, vgg Loss: 206.8781, L2 Loss: 0.0565
Epoch [102/600], Training Loss: 2.1421, vgg Loss: 207.9642, L2 Loss: 0.0625
Epoch [103/600], Training Loss: 2.1112, vgg Loss: 205.3547, L2 Loss: 0.0577
Epoch [104/600], Training Loss: 2.1320, vgg Loss: 207.2634, L2 Loss: 0.0594
Epoch [105/600], Training Loss: 2.1576, vgg Loss: 209.7170, L2 Loss: 0.0604
Epoch [106/600], Training Loss: 2.1305, vgg Loss: 206.9461, L2 Loss: 0.0611
Epoch [107/600], Training Loss: 2.1038, vgg Loss: 204.4311, L2 Loss: 0.0595
Epoch [108/600], Training Loss: 2.0913, vgg Loss: 203.1441, L2 Loss: 0.0599
Epoch [109/600], Training Loss: 2.1052, vgg Loss: 204.1129, L2 Loss: 0.0641
Epoch [110/600], Training Loss: 2.0697, vgg Loss: 201.2137, L2 Loss: 0.0576
Epoch [111/600], Training Loss: 2.1246, vgg Loss: 206.8163, L2 Loss: 0.0564
Epoch [112/600], Training Loss: 2.0952, vgg Loss: 203.3494, L2 Loss: 0.0617
Epoch [113/600], Training Loss: 2.0719, vgg Loss: 201.9991, L2 Loss: 0.0519
Epoch [114/6

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [121/600], Training Loss: 2.0569, vgg Loss: 199.7825, L2 Loss: 0.0591
Epoch [122/600], Training Loss: 2.0376, vgg Loss: 198.2166, L2 Loss: 0.0554
Epoch [123/600], Training Loss: 2.0410, vgg Loss: 198.4445, L2 Loss: 0.0565
Epoch [124/600], Training Loss: 1.9852, vgg Loss: 193.1863, L2 Loss: 0.0533
Epoch [125/600], Training Loss: 2.0673, vgg Loss: 201.2149, L2 Loss: 0.0552
Epoch [126/600], Training Loss: 2.0059, vgg Loss: 195.6562, L2 Loss: 0.0494
Epoch [127/600], Training Loss: 2.1094, vgg Loss: 204.2231, L2 Loss: 0.0671
Epoch [128/600], Training Loss: 2.0465, vgg Loss: 198.8950, L2 Loss: 0.0575
Epoch [129/600], Training Loss: 2.0389, vgg Loss: 198.8577, L2 Loss: 0.0504
Epoch [130/600], Training Loss: 1.9978, vgg Loss: 194.9067, L2 Loss: 0.0488
Epoch [131/600], Training Loss: 2.0372, vgg Loss: 198.2379, L2 Loss: 0.0549
Epoch [132/600], Training Loss: 2.0701, vgg Loss: 200.9349, L2 Loss: 0.0607
Epoch [133/600], Training Loss: 2.0278, vgg Loss: 197.0298, L2 Loss: 0.0575
Epoch [134/6

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [141/600], Training Loss: 1.9491, vgg Loss: 189.7445, L2 Loss: 0.0517
Epoch [142/600], Training Loss: 2.0116, vgg Loss: 195.1289, L2 Loss: 0.0603
Epoch [143/600], Training Loss: 2.0155, vgg Loss: 196.2849, L2 Loss: 0.0526
Epoch [144/600], Training Loss: 1.9381, vgg Loss: 188.9018, L2 Loss: 0.0491
Epoch [145/600], Training Loss: 1.9465, vgg Loss: 189.9302, L2 Loss: 0.0472
Epoch [146/600], Training Loss: 1.9627, vgg Loss: 190.8764, L2 Loss: 0.0539
Epoch [147/600], Training Loss: 1.9469, vgg Loss: 189.8692, L2 Loss: 0.0482
Epoch [148/600], Training Loss: 1.9682, vgg Loss: 191.6403, L2 Loss: 0.0517
Epoch [149/600], Training Loss: 1.9920, vgg Loss: 193.9537, L2 Loss: 0.0525
Epoch [150/600], Training Loss: 1.9411, vgg Loss: 189.1895, L2 Loss: 0.0492
Epoch [151/600], Training Loss: 1.9479, vgg Loss: 189.6544, L2 Loss: 0.0514
Epoch [152/600], Training Loss: 2.0192, vgg Loss: 195.6980, L2 Loss: 0.0623
Epoch [153/600], Training Loss: 1.9857, vgg Loss: 193.4897, L2 Loss: 0.0508
Epoch [154/6

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [161/600], Training Loss: 1.9907, vgg Loss: 193.2356, L2 Loss: 0.0583
Epoch [162/600], Training Loss: 1.9564, vgg Loss: 190.3137, L2 Loss: 0.0532
Epoch [163/600], Training Loss: 1.9922, vgg Loss: 193.4031, L2 Loss: 0.0582
Epoch [164/600], Training Loss: 2.0196, vgg Loss: 196.4139, L2 Loss: 0.0554
Epoch [165/600], Training Loss: 1.9032, vgg Loss: 185.3521, L2 Loss: 0.0497
Epoch [166/600], Training Loss: 1.9294, vgg Loss: 187.5685, L2 Loss: 0.0537
Epoch [167/600], Training Loss: 1.8777, vgg Loss: 183.2142, L2 Loss: 0.0455
Epoch [168/600], Training Loss: 1.9469, vgg Loss: 189.6611, L2 Loss: 0.0503
Epoch [169/600], Training Loss: 1.8732, vgg Loss: 182.6618, L2 Loss: 0.0465
Epoch [170/600], Training Loss: 1.8697, vgg Loss: 182.5990, L2 Loss: 0.0437
Epoch [171/600], Training Loss: 1.9025, vgg Loss: 185.6637, L2 Loss: 0.0459
Epoch [172/600], Training Loss: 1.8911, vgg Loss: 184.0429, L2 Loss: 0.0507
Epoch [173/600], Training Loss: 1.9114, vgg Loss: 186.0200, L2 Loss: 0.0512
Epoch [174/6

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [181/600], Training Loss: 1.8988, vgg Loss: 184.9813, L2 Loss: 0.0490
Epoch [182/600], Training Loss: 1.9324, vgg Loss: 188.2636, L2 Loss: 0.0498
Epoch [183/600], Training Loss: 1.9948, vgg Loss: 193.8800, L2 Loss: 0.0560
Epoch [184/600], Training Loss: 1.8879, vgg Loss: 183.8911, L2 Loss: 0.0490
Epoch [185/600], Training Loss: 1.9161, vgg Loss: 186.9598, L2 Loss: 0.0465
Epoch [186/600], Training Loss: 1.8695, vgg Loss: 182.5023, L2 Loss: 0.0445
Epoch [187/600], Training Loss: 1.9412, vgg Loss: 189.0012, L2 Loss: 0.0512
Epoch [188/600], Training Loss: 1.8701, vgg Loss: 182.2346, L2 Loss: 0.0477
Epoch [189/600], Training Loss: 1.8765, vgg Loss: 182.9290, L2 Loss: 0.0472
Epoch [190/600], Training Loss: 1.9139, vgg Loss: 186.7927, L2 Loss: 0.0459
Epoch [191/600], Training Loss: 1.8720, vgg Loss: 182.6627, L2 Loss: 0.0454
Epoch [192/600], Training Loss: 1.9014, vgg Loss: 185.6053, L2 Loss: 0.0454
Epoch [193/600], Training Loss: 1.8637, vgg Loss: 181.4019, L2 Loss: 0.0497
Epoch [194/6

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [201/600], Training Loss: 1.8415, vgg Loss: 179.4422, L2 Loss: 0.0471
Epoch [202/600], Training Loss: 1.8642, vgg Loss: 181.1709, L2 Loss: 0.0525
Epoch [203/600], Training Loss: 1.8611, vgg Loss: 181.5173, L2 Loss: 0.0459
Epoch [204/600], Training Loss: 1.8894, vgg Loss: 184.4031, L2 Loss: 0.0454
Epoch [205/600], Training Loss: 1.8331, vgg Loss: 178.9824, L2 Loss: 0.0433
Epoch [206/600], Training Loss: 1.9243, vgg Loss: 186.6570, L2 Loss: 0.0577
Epoch [207/600], Training Loss: 1.8670, vgg Loss: 182.1244, L2 Loss: 0.0457
Epoch [208/600], Training Loss: 1.8398, vgg Loss: 179.8654, L2 Loss: 0.0411
Epoch [209/600], Training Loss: 1.8281, vgg Loss: 178.1551, L2 Loss: 0.0465
Epoch [210/600], Training Loss: 1.8904, vgg Loss: 184.0487, L2 Loss: 0.0499
Epoch [211/600], Training Loss: 1.8578, vgg Loss: 180.6813, L2 Loss: 0.0510
Epoch [212/600], Training Loss: 1.7920, vgg Loss: 174.8145, L2 Loss: 0.0438
Epoch [213/600], Training Loss: 1.8279, vgg Loss: 178.4357, L2 Loss: 0.0435
Epoch [214/6

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [221/600], Training Loss: 1.8174, vgg Loss: 177.4225, L2 Loss: 0.0432
Epoch [222/600], Training Loss: 1.8128, vgg Loss: 176.8945, L2 Loss: 0.0438
Epoch [223/600], Training Loss: 1.7970, vgg Loss: 175.4481, L2 Loss: 0.0425
Epoch [224/600], Training Loss: 1.8295, vgg Loss: 178.6370, L2 Loss: 0.0431
Epoch [225/600], Training Loss: 1.8026, vgg Loss: 175.8191, L2 Loss: 0.0444
Epoch [226/600], Training Loss: 1.8353, vgg Loss: 179.2039, L2 Loss: 0.0433
Epoch [227/600], Training Loss: 1.8447, vgg Loss: 180.4374, L2 Loss: 0.0404
Epoch [228/600], Training Loss: 1.8954, vgg Loss: 184.2567, L2 Loss: 0.0528
Epoch [229/600], Training Loss: 1.8441, vgg Loss: 179.6826, L2 Loss: 0.0473
Epoch [230/600], Training Loss: 1.8527, vgg Loss: 180.9912, L2 Loss: 0.0428
Epoch [231/600], Training Loss: 1.8463, vgg Loss: 180.0589, L2 Loss: 0.0457
Epoch [232/600], Training Loss: 1.8435, vgg Loss: 178.9785, L2 Loss: 0.0538
Epoch [233/600], Training Loss: 1.8431, vgg Loss: 179.8038, L2 Loss: 0.0451
Epoch [234/6

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [241/600], Training Loss: 1.8293, vgg Loss: 178.5887, L2 Loss: 0.0434
Epoch [242/600], Training Loss: 1.7998, vgg Loss: 175.7234, L2 Loss: 0.0425
Epoch [243/600], Training Loss: 1.8180, vgg Loss: 177.7014, L2 Loss: 0.0410
Epoch [244/600], Training Loss: 1.8443, vgg Loss: 179.9693, L2 Loss: 0.0446
Epoch [245/600], Training Loss: 1.8166, vgg Loss: 177.3201, L2 Loss: 0.0434
Epoch [246/600], Training Loss: 1.7776, vgg Loss: 173.6227, L2 Loss: 0.0414
Epoch [247/600], Training Loss: 1.7597, vgg Loss: 171.8193, L2 Loss: 0.0415
Epoch [248/600], Training Loss: 1.8464, vgg Loss: 179.5715, L2 Loss: 0.0507
Epoch [249/600], Training Loss: 1.8024, vgg Loss: 175.9720, L2 Loss: 0.0426
Epoch [250/600], Training Loss: 1.8031, vgg Loss: 175.9715, L2 Loss: 0.0434
Epoch [251/600], Training Loss: 1.7452, vgg Loss: 170.7064, L2 Loss: 0.0381
Epoch [252/600], Training Loss: 1.8183, vgg Loss: 177.2831, L2 Loss: 0.0455
Epoch [253/600], Training Loss: 1.8481, vgg Loss: 179.6443, L2 Loss: 0.0516
Epoch [254/6

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [261/600], Training Loss: 1.7822, vgg Loss: 174.2456, L2 Loss: 0.0398
Epoch [262/600], Training Loss: 1.8388, vgg Loss: 178.8741, L2 Loss: 0.0500
Epoch [263/600], Training Loss: 1.7734, vgg Loss: 173.4106, L2 Loss: 0.0393
Epoch [264/600], Training Loss: 1.8112, vgg Loss: 176.7741, L2 Loss: 0.0434
Epoch [265/600], Training Loss: 1.8047, vgg Loss: 176.3627, L2 Loss: 0.0411
Epoch [266/600], Training Loss: 1.7884, vgg Loss: 174.5018, L2 Loss: 0.0434
Epoch [267/600], Training Loss: 1.7999, vgg Loss: 175.3584, L2 Loss: 0.0463
Epoch [268/600], Training Loss: 1.7911, vgg Loss: 174.4986, L2 Loss: 0.0461
Epoch [269/600], Training Loss: 1.8076, vgg Loss: 176.2291, L2 Loss: 0.0453
Epoch [270/600], Training Loss: 1.7992, vgg Loss: 175.7985, L2 Loss: 0.0412
Epoch [271/600], Training Loss: 1.7834, vgg Loss: 173.9717, L2 Loss: 0.0437
Epoch [272/600], Training Loss: 1.7940, vgg Loss: 174.7382, L2 Loss: 0.0466
Epoch [273/600], Training Loss: 1.7681, vgg Loss: 172.4726, L2 Loss: 0.0434
Epoch [274/6

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [281/600], Training Loss: 1.7871, vgg Loss: 174.1915, L2 Loss: 0.0451
Epoch [282/600], Training Loss: 1.7549, vgg Loss: 171.6912, L2 Loss: 0.0380
Epoch [283/600], Training Loss: 1.7459, vgg Loss: 170.5379, L2 Loss: 0.0405
Epoch [284/600], Training Loss: 1.7755, vgg Loss: 173.1501, L2 Loss: 0.0440
Epoch [285/600], Training Loss: 1.7804, vgg Loss: 173.7881, L2 Loss: 0.0425
Epoch [286/600], Training Loss: 1.7092, vgg Loss: 166.9778, L2 Loss: 0.0395
Epoch [287/600], Training Loss: 1.7567, vgg Loss: 171.7063, L2 Loss: 0.0396
Epoch [288/600], Training Loss: 1.7262, vgg Loss: 167.7046, L2 Loss: 0.0492
Epoch [289/600], Training Loss: 1.7374, vgg Loss: 169.7118, L2 Loss: 0.0402
Epoch [290/600], Training Loss: 1.7117, vgg Loss: 167.3385, L2 Loss: 0.0383
Epoch [291/600], Training Loss: 1.7626, vgg Loss: 171.7106, L2 Loss: 0.0455
Epoch [292/600], Training Loss: 1.7437, vgg Loss: 169.9451, L2 Loss: 0.0442
Epoch [293/600], Training Loss: 1.7714, vgg Loss: 173.1273, L2 Loss: 0.0401
Epoch [294/6

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [301/600], Training Loss: 1.6757, vgg Loss: 164.2875, L2 Loss: 0.0328
Epoch [302/600], Training Loss: 1.7124, vgg Loss: 167.9174, L2 Loss: 0.0332
Epoch [303/600], Training Loss: 1.6511, vgg Loss: 161.9526, L2 Loss: 0.0316
Epoch [304/600], Training Loss: 1.6901, vgg Loss: 165.7748, L2 Loss: 0.0323
Epoch [305/600], Training Loss: 1.7230, vgg Loss: 169.0007, L2 Loss: 0.0330
Epoch [306/600], Training Loss: 1.6778, vgg Loss: 164.5928, L2 Loss: 0.0319
Epoch [307/600], Training Loss: 1.6623, vgg Loss: 163.0844, L2 Loss: 0.0314
Epoch [308/600], Training Loss: 1.6516, vgg Loss: 162.0123, L2 Loss: 0.0315
Epoch [309/600], Training Loss: 1.6800, vgg Loss: 164.7685, L2 Loss: 0.0323
Epoch [310/600], Training Loss: 1.7116, vgg Loss: 167.8959, L2 Loss: 0.0326
Epoch [311/600], Training Loss: 1.7014, vgg Loss: 166.8943, L2 Loss: 0.0325
Epoch [312/600], Training Loss: 1.6620, vgg Loss: 163.0536, L2 Loss: 0.0314
Epoch [313/600], Training Loss: 1.6435, vgg Loss: 161.2491, L2 Loss: 0.0310
Epoch [314/6

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [321/600], Training Loss: 1.7031, vgg Loss: 167.0530, L2 Loss: 0.0326
Epoch [322/600], Training Loss: 1.6598, vgg Loss: 162.8181, L2 Loss: 0.0316
Epoch [323/600], Training Loss: 1.6761, vgg Loss: 164.4040, L2 Loss: 0.0320
Epoch [324/600], Training Loss: 1.6936, vgg Loss: 166.0426, L2 Loss: 0.0331
Epoch [325/600], Training Loss: 1.7024, vgg Loss: 166.9852, L2 Loss: 0.0326
Epoch [326/600], Training Loss: 1.6614, vgg Loss: 162.9764, L2 Loss: 0.0316
Epoch [327/600], Training Loss: 1.6695, vgg Loss: 163.7980, L2 Loss: 0.0315
Epoch [328/600], Training Loss: 1.6304, vgg Loss: 159.8929, L2 Loss: 0.0315
Epoch [329/600], Training Loss: 1.6698, vgg Loss: 163.7894, L2 Loss: 0.0319
Epoch [330/600], Training Loss: 1.6729, vgg Loss: 164.0600, L2 Loss: 0.0323
Epoch [331/600], Training Loss: 1.6860, vgg Loss: 165.3926, L2 Loss: 0.0321
Epoch [332/600], Training Loss: 1.6776, vgg Loss: 164.5397, L2 Loss: 0.0322
Epoch [333/600], Training Loss: 1.6458, vgg Loss: 161.4090, L2 Loss: 0.0317
Epoch [334/6

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [341/600], Training Loss: 1.6952, vgg Loss: 166.2509, L2 Loss: 0.0327
Epoch [342/600], Training Loss: 1.6583, vgg Loss: 162.6704, L2 Loss: 0.0316
Epoch [343/600], Training Loss: 1.6464, vgg Loss: 161.4667, L2 Loss: 0.0317
Epoch [344/600], Training Loss: 1.6570, vgg Loss: 162.5226, L2 Loss: 0.0318
Epoch [345/600], Training Loss: 1.6713, vgg Loss: 163.9286, L2 Loss: 0.0320
Epoch [346/600], Training Loss: 1.6825, vgg Loss: 164.9790, L2 Loss: 0.0327
Epoch [347/600], Training Loss: 1.6811, vgg Loss: 164.8654, L2 Loss: 0.0324
Epoch [348/600], Training Loss: 1.6535, vgg Loss: 162.1907, L2 Loss: 0.0315
Epoch [349/600], Training Loss: 1.6712, vgg Loss: 163.8576, L2 Loss: 0.0326
Epoch [350/600], Training Loss: 1.6420, vgg Loss: 161.0491, L2 Loss: 0.0315
Epoch [351/600], Training Loss: 1.6625, vgg Loss: 163.0517, L2 Loss: 0.0319
Epoch [352/600], Training Loss: 1.6393, vgg Loss: 160.7988, L2 Loss: 0.0313
Epoch [353/600], Training Loss: 1.6459, vgg Loss: 161.4538, L2 Loss: 0.0314
Epoch [354/6

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [361/600], Training Loss: 1.6844, vgg Loss: 165.2324, L2 Loss: 0.0321
Epoch [362/600], Training Loss: 1.6450, vgg Loss: 161.3398, L2 Loss: 0.0316
Epoch [363/600], Training Loss: 1.6690, vgg Loss: 163.7146, L2 Loss: 0.0319
Epoch [364/600], Training Loss: 1.6750, vgg Loss: 164.3456, L2 Loss: 0.0316
Epoch [365/600], Training Loss: 1.6575, vgg Loss: 162.5703, L2 Loss: 0.0318
Epoch [366/600], Training Loss: 1.6415, vgg Loss: 161.0323, L2 Loss: 0.0311
Epoch [367/600], Training Loss: 1.6447, vgg Loss: 161.3646, L2 Loss: 0.0310
Epoch [368/600], Training Loss: 1.6817, vgg Loss: 164.9697, L2 Loss: 0.0320
Epoch [369/600], Training Loss: 1.6800, vgg Loss: 164.8154, L2 Loss: 0.0319
Epoch [370/600], Training Loss: 1.6406, vgg Loss: 160.9430, L2 Loss: 0.0312
Epoch [371/600], Training Loss: 1.6840, vgg Loss: 165.1725, L2 Loss: 0.0323
Epoch [372/600], Training Loss: 1.7006, vgg Loss: 166.8328, L2 Loss: 0.0323
Epoch [373/600], Training Loss: 1.6688, vgg Loss: 163.6595, L2 Loss: 0.0322
Epoch [374/6

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Epoch [381/600], Training Loss: 1.6163, vgg Loss: 158.5844, L2 Loss: 0.0304
Epoch [382/600], Training Loss: 1.6548, vgg Loss: 162.3361, L2 Loss: 0.0314
Epoch [383/600], Training Loss: 1.6767, vgg Loss: 164.4312, L2 Loss: 0.0324
Epoch [384/600], Training Loss: 1.6550, vgg Loss: 162.3834, L2 Loss: 0.0312
Epoch [385/600], Training Loss: 1.6533, vgg Loss: 162.1615, L2 Loss: 0.0316
Epoch [386/600], Training Loss: 1.6404, vgg Loss: 160.9347, L2 Loss: 0.0310
Epoch [387/600], Training Loss: 1.6754, vgg Loss: 164.3246, L2 Loss: 0.0322
Epoch [388/600], Training Loss: 1.6596, vgg Loss: 162.7931, L2 Loss: 0.0317
Epoch [389/600], Training Loss: 1.6364, vgg Loss: 160.5434, L2 Loss: 0.0310
Epoch [390/600], Training Loss: 1.6701, vgg Loss: 163.8172, L2 Loss: 0.0320


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
imageFil = MFFNet()

model_name = 'MFF-net'
imageFil.load_state_dict( torch.load('/content/drive/My Drive/image_filtering/%s.ckpt'%(model_name)) )
imageFil = imageFil.to(device).float()

In [None]:
data_root = '/content/drive/My Drive/image_filtering/test'
out_root = '/content/drive/My Drive/image_filtering/output'
if not os.path.exists(out_root):
    os.mkdir(out_root)

for seq in range(1,6):
    # input rgb image is obtained from demosaicing the raw (no other manipulation)
    # saved in 16-bit TIFF image
    file = ('rgb_%s.tiff' % (seq) )
    filename = os.path.join( data_root, file )
    inputs = io.imread(filename) / 65535

    file = ('guide_%s.bmp' % (seq) )
    filename = os.path.join( data_root, file )
    guided = io.imread(filename) / 255

    guided = guided[:,:,0]+guided[:,:,1]+guided[:,:,2]
    inputs = (inputs*80)**0.4
    inputs = np.concatenate((inputs, guided[:,:,None]), 2)
    inputs = np.transpose(inputs,(2,0,1))
    inputs = torch.from_numpy(inputs)
    inputs = inputs[None,:,:,:].float()

    with torch.no_grad():
        inputs = inputs.to(device) 
        outputs = imageFil(inputs)
    outputs[outputs>1] = 1
    outputs[outputs<0] = 0    

    # the parameter for color balance and brightness should be tuned for different scenes
    outputs[0,0,:,:] = outputs[0,0,:,:]*1.1*1.5
    outputs[0,1,:,:] = outputs[0,1,:,:]*1*1.5
    outputs[0,2,:,:] = outputs[0,2,:,:]*1.5*1.5

    save_image(outputs[0,:,:,:], '%s/out_%s.png' % (out_root, seq))
    save_image(inputs[0,0:3,:,:], '%s/inp_%s.png' % (out_root, seq))