In [2]:
import os, sys
import time, math
import argparse, random
from math import exp
import numpy as np

import torch
from torch import nn, optim
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data import DataLoader
from torch.backends import cudnn
from torch.autograd import Variable

import torchvision
import torchvision.transforms as tfs
from torchvision.transforms import ToPILImage
from torchvision.transforms import functional as FF
import torchvision.utils as vutils
from torchvision.utils import make_grid
from torchvision.models import vgg16

from PIL import Image
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

In [5]:
from noise import pnoise3

In [3]:
from utils import *
from ffa_net import *

In [2]:
# number of training steps
steps = 20000
# Device name
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# resume Training
resume = False
# number of evaluation steps
eval_step = 5000
# learning rate
learning_rate = 0.0001
# pre-trained model directory
pretrained_model_dir = '../input/ffa-net-for-single-image-dehazing-pytorch/trained_models/'
# directory to save models to
model_dir = './trained_models/'
# train data
trainset = 'its_train'
# test data
testset = 'its_test'
# model to be used
network = 'ffa'
# residual_groups
gps = 3
# residual_blocks
blocks = 12
# batch size
bs = 1
# crop image
crop = True
# Takes effect when crop = True
crop_size = 240
# No lr cos schedule
no_lr_sche = True
# perceptual loss
perloss = True
    
crop_size='whole_img'
if crop:
    crop_size = crop_size


In [None]:
# path to your 'data' folder
its_train_path = '../input/indoor-training-set-its-residestandard'
its_test_path = '../input/synthetic-objective-testing-set-sots-reside/indoor'

ITS_train_loader = DataLoader(dataset=RESIDE_Dataset(its_train_path, train=True, size=crop_size), batch_size=bs, shuffle=True)
ITS_test_loader = DataLoader(dataset=RESIDE_Dataset(its_test_path, train=False, size='whole img'), batch_size=1, shuffle=False)

### Define Train / Test Functions

In [8]:
print('log_dir :', log_dir)
print('model_name:', model_name)

models_ = {'ffa': FFA(gps = gps, blocks = blocks)}
loaders_ = {'its_train': ITS_train_loader, 'its_test': ITS_test_loader}
# loaders_ = {'its_train': ITS_train_loader, 'its_test': ITS_test_loader, 'ots_train': OTS_train_loader, 'ots_test': OTS_test_loader}
start_time = time.time()
T = steps

def train(net, loader_train, loader_test, optim, criterion):
    losses = []
    start_step = 0
    max_ssim = max_psnr = 0
    ssims, psnrs = [], []
    if resume and os.path.exists(pretrained_model_dir):
        print(f'resume from {pretrained_model_dir}')
        ckp = torch.load(pretrained_model_dir)
        losses = ckp['losses']
        net.load_state_dict(ckp['model'])
        start_step = ckp['step']
        max_ssim = ckp['max_ssim']
        max_psnr = ckp['max_psnr']
        psnrs = ckp['psnrs']
        ssims = ckp['ssims']
        print(f'Resuming training from step: {start_step} ***')
    else :
        print('Training from scratch *** ')
    for step in range(start_step+1, steps+1):
        net.train()
        lr = learning_rate
        if not no_lr_sche:
            lr = lr_schedule_cosdecay(step,T)
            for param_group in optim.param_groups:
                param_group["lr"] = lr
        x, y = next(iter(loader_train))
        x = x.to(device); y = y.to(device)
        out = net(x)
        loss = criterion[0](out,y)
        if perloss:
            loss2 = criterion[1](out,y)
            loss = loss + 0.04*loss2

        loss.backward()

        optim.step()
        optim.zero_grad()
        losses.append(loss.item())
        print(f'\rtrain loss: {loss.item():.5f} | step: {step}/{steps} | lr: {lr :.7f} | time_used: {(time.time()-start_time)/60 :.1f}',end='',flush=True)

        if step % eval_step ==0 :
            with torch.no_grad():
                ssim_eval, psnr_eval = test(net, loader_test, max_psnr, max_ssim, step)
            print(f'\nstep: {step} | ssim: {ssim_eval:.4f} | psnr: {psnr_eval:.4f}')

            ssims.append(ssim_eval)
            psnrs.append(psnr_eval)
            if ssim_eval > max_ssim and psnr_eval > max_psnr :
                max_ssim = max(max_ssim,ssim_eval)
                max_psnr = max(max_psnr,psnr_eval)
                torch.save({
                            'step': step,
                            'max_psnr': max_psnr,
                            'max_ssim': max_ssim,
                            'ssims': ssims,
                            'psnrs': psnrs,
                            'losses': losses,
                            'model': net.state_dict()
                }, model_dir)
                print(f'\n model saved at step : {step} | max_psnr: {max_psnr:.4f} | max_ssim: {max_ssim:.4f}')

    np.save(f'./numpy_files/{model_name}_{steps}_losses.npy',losses)
    np.save(f'./numpy_files/{model_name}_{steps}_ssims.npy',ssims)
    np.save(f'./numpy_files/{model_name}_{steps}_psnrs.npy',psnrs)

def test(net, loader_test, max_psnr, max_ssim, step):
    net.eval()
    torch.cuda.empty_cache()
    ssims, psnrs = [], []
    for i, (inputs, targets) in enumerate(loader_test):
        inputs = inputs.to(device); targets = targets.to(device)
        pred = net(inputs)
        # # print(pred)
        # tfs.ToPILImage()(torch.squeeze(targets.cpu())).save('111.png')
        # vutils.save_image(targets.cpu(),'target.png')
        # vutils.save_image(pred.cpu(),'pred.png')
        ssim1 = ssim(pred, targets).item()
        psnr1 = psnr(pred, targets)
        ssims.append(ssim1)
        psnrs.append(psnr1)

    return np.mean(ssims) ,np.mean(psnrs)


log_dir : logs/its_train_ffa_3_12
model_name: its_train_ffa_3_12


### Train FFA-Net

In [9]:
%%time

loader_train = loaders_[trainset]
loader_test = loaders_[testset]
net = models_[network]
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True
criterion = []
criterion.append(nn.L1Loss().to(device))
if perloss:
    vgg_model = vgg16(pretrained=True).features[:16]
    vgg_model = vgg_model.to(device)
    for param in vgg_model.parameters():
        param.requires_grad = False
    criterion.append(PerLoss(vgg_model).to(device))
optimizer = optim.Adam(params = filter(lambda x: x.requires_grad, net.parameters()), lr=learning_rate, betas=(0.9,0.999), eps=1e-08)
optimizer.zero_grad()
train(net, loader_train, loader_test, optimizer, criterion)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))


Training from scratch *** 
train loss: 0.10126 | step: 5000/20000 | lr: 0.0001000 | time_used: 103.0
step: 5000 | ssim: 0.8149 | psnr: 19.5125

 model saved at step : 5000 | max_psnr: 19.5125 | max_ssim: 0.8149
train loss: 0.05416 | step: 10000/20000 | lr: 0.0001000 | time_used: 208.3
step: 10000 | ssim: 0.8526 | psnr: 20.2081

 model saved at step : 10000 | max_psnr: 20.2081 | max_ssim: 0.8526
train loss: 0.04772 | step: 15000/20000 | lr: 0.0001000 | time_used: 313.8
step: 15000 | ssim: 0.8572 | psnr: 20.8563

 model saved at step : 15000 | max_psnr: 20.8563 | max_ssim: 0.8572
train loss: 0.06030 | step: 20000/20000 | lr: 0.0001000 | time_used: 419.2
step: 20000 | ssim: 0.8892 | psnr: 23.1375

 model saved at step : 20000 | max_psnr: 23.1375 | max_ssim: 0.8892
CPU times: user 4h 25min 5s, sys: 2h 27min 3s, total: 6h 52min 9s
Wall time: 7h 2min 4s


### Test FFA-Net

In [10]:
# its or ots
task = 'its'
# test imgs folder
test_imgs = '../input/synthetic-objective-testing-set-sots-reside/indoor/hazy/'

dataset = task
img_dir = test_imgs

output_dir = f'pred_FFA_{dataset}/'
print("pred_dir:",output_dir)

if not os.path.exists(output_dir):
    os.mkdir(output_dir)

ckp = torch.load(model_dir, map_location=device)
net = FFA(gps=gps, blocks=blocks)
net = nn.DataParallel(net)
net.load_state_dict(ckp['model'])
net.eval()

for im in os.listdir(img_dir):
    haze = Image.open(img_dir+im)
    haze1 = tfs.Compose([
        tfs.ToTensor(),
        tfs.Normalize(mean=[0.64, 0.6, 0.58],std=[0.14,0.15, 0.152])
    ])(haze)[None,::]
    haze_no = tfs.ToTensor()(haze)[None,::]
    with torch.no_grad():
        pred = net(haze1)
    ts = torch.squeeze(pred.clamp(0,1).cpu())
    # tensorShow([haze_no, pred.clamp(0,1).cpu()],['haze', 'pred'])
    
    haze_no = make_grid(haze_no, nrow=1, normalize=True)
    ts = make_grid(ts, nrow=1, normalize=True)
    image_grid = torch.cat((haze_no, ts), -1)
    vutils.save_image(image_grid, output_dir+im.split('.')[0]+'_FFA.png')

pred_dir: pred_FFA_its/
