In [1]:
import os, sys
import time, math
import argparse, random
from math import exp
import numpy as np

import torch
from torch import nn, optim
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data import DataLoader
from torch.backends import cudnn
from torch.autograd import Variable

import torchvision
import torchvision.transforms as tfs
from torchvision.transforms import ToPILImage
from torchvision.transforms import functional as FF
import torchvision.utils as vutils
from torchvision.utils import make_grid
from torchvision.models import vgg16

from PIL import Image
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

In [2]:
# number of training steps
steps = 20000
# Device name
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# resume Training
resume = False
# number of evaluation steps
eval_step = 500
# learning rate
learning_rate = 0.0001
# pre-trained model directory
pretrained_model_dir = './trained_models/'
# directory to save models to
model_dir = './trained_models/'
# train data
trainset = 'its_train'
# test data
testset = 'its_test'
# model to be used
network = 'ffa'
# residual_groups
gps = 3
# residual_blocks
blocks = 12
# batch size
bs = 1
# crop image
crop = True
# Takes effect when crop = True
crop_size = 240
# No lr cos schedule
no_lr_sche = True
# perceptual loss
perloss = True

model_name = trainset + '_' + network.split('.')[0] + '_' + str(gps) + '_' + str(blocks)
pretrained_model_dir = pretrained_model_dir + model_name + '.pk'
model_dir = model_dir + model_name + '.pk'
log_dir = 'logs/' + model_name

if not os.path.exists('trained_models'):
    os.mkdir('trained_models')
if not os.path.exists('numpy_files'):
    os.mkdir('numpy_files')
if not os.path.exists('logs'):
    os.mkdir('logs')
if not os.path.exists('samples'):
    os.mkdir('samples')
if not os.path.exists(f"samples/{model_name}"):
    os.mkdir(f'samples/{model_name}')
if not os.path.exists(log_dir):
    os.mkdir(log_dir)
    
crop_size='whole_img'
if crop:
    crop_size = crop_size


In [3]:
!pip install einops
from functools import partial
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce

# helpers

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# helper classes

class ChanLayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

class Downsample(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)

    def forward(self, x):
        return self.conv(x)

class PEG(nn.Module):
    def __init__(self, dim, kernel_size = 3):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)

    def forward(self, x):
        return self.proj(x) + x

# feedforward

class FeedForward(nn.Module):
    def __init__(self, dim, expansion_factor = 4, dropout = 0.):
        super().__init__()
        inner_dim = dim * expansion_factor
        self.net = nn.Sequential(
            ChanLayerNorm(dim),
            nn.Conv2d(dim, inner_dim, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# attention

class ScalableSelfAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_key = 32,
        dim_value = 32,
        dropout = 0.,
        reduction_factor = 1
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_key ** -0.5
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.norm = ChanLayerNorm(dim)
        self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
        self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
        self.to_v = nn.Conv2d(dim, dim_value * heads, reduction_factor, stride = reduction_factor, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(dim_value * heads, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        height, width, heads = *x.shape[-2:], self.heads

        x = self.norm(x)

        q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

        # split out heads

        q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))

        # similarity

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # attention

        attn = self.attend(dots)
        attn = self.dropout(attn)

        # aggregate values

        out = torch.matmul(attn, v)

        # merge back heads

        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = height, y = width)
        return self.to_out(out)

class InteractiveWindowedSelfAttention(nn.Module):
    def __init__(
        self,
        dim,
        window_size,
        heads = 8,
        dim_key = 32,
        dim_value = 32,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_key ** -0.5
        self.window_size = window_size
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.norm = ChanLayerNorm(dim)
        self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)

        self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
        self.to_k = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
        self.to_v = nn.Conv2d(dim, dim_value * heads, 1, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(dim_value * heads, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size

        x = self.norm(x)

        wsz_h, wsz_w = default(wsz, height), default(wsz, width)
        assert (height % wsz_h) == 0 and (width % wsz_w) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})'

        q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

        # get output of LIM

        local_out = self.local_interactive_module(v)

        # divide into window (and split out heads) for efficient self attention

        q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz_h, w2 = wsz_w), (q, k, v))

        # similarity

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # attention

        attn = self.attend(dots)
        attn = self.dropout(attn)

        # aggregate values

        out = torch.matmul(attn, v)

        # reshape the windows back to full feature map (and merge heads)

        out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz_h, y = width // wsz_w, w1 = wsz_h, w2 = wsz_w)

        # add LIM output 

        out = out + local_out

        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        heads = 8,
        ff_expansion_factor = 4,
        dropout = 0.,
        ssa_dim_key = 32,
        ssa_dim_value = 32,
        ssa_reduction_factor = 1,
        iwsa_dim_key = 32,
        iwsa_dim_value = 32,
        iwsa_window_size = None,
        norm_output = True
    ):
        super().__init__()
        self.layers = nn.ModuleList([])
        for ind in range(depth):
            is_first = ind == 0

            self.layers.append(nn.ModuleList([
                ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout),
                FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout),
                PEG(dim) if is_first else None,
                FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout),
                InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout)
            ]))

        self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()

    def forward(self, x):
        for ssa, ff1, peg, iwsa, ff2 in self.layers:
            x = ssa(x) + x
            x = ff1(x) + x

            if exists(peg):
                x = peg(x)

            x = iwsa(x) + x
            x = ff2(x) + x

        return self.norm(x)

class ScalableViT(nn.Module):
    def __init__(
        self,
        *,
        num_classes,
        dim,
        depth,
        heads,
        reduction_factor,
        window_size = None,
        iwsa_dim_key = 32,
        iwsa_dim_value = 32,
        ssa_dim_key = 32,
        ssa_dim_value = 32,
        ff_expansion_factor = 4,
        channels = 64,
        dropout = 0.
    ):
        super().__init__()
        self.to_patches = nn.Conv2d(channels, dim, 7, stride = 4, padding = 3)

        assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'

        num_stages = len(depth)
        dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))

        hyperparams_per_stage = [
            heads,
            ssa_dim_key,
            ssa_dim_value,
            reduction_factor,
            iwsa_dim_key,
            iwsa_dim_value,
            window_size,
        ]

        hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
        assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))

        self.layers = nn.ModuleList([])

        for ind, (layer_dim, layer_depth, layer_heads, layer_ssa_dim_key, layer_ssa_dim_value, layer_ssa_reduction_factor, layer_iwsa_dim_key, layer_iwsa_dim_value, layer_window_size) in enumerate(zip(dims, depth, *hyperparams_per_stage)):
            is_last = ind == (num_stages - 1)

            self.layers.append(nn.ModuleList([
                Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_expansion_factor = ff_expansion_factor, dropout = dropout, ssa_dim_key = layer_ssa_dim_key, ssa_dim_value = layer_ssa_dim_value, ssa_reduction_factor = layer_ssa_reduction_factor, iwsa_dim_key = layer_iwsa_dim_key, iwsa_dim_value = layer_iwsa_dim_value, iwsa_window_size = layer_window_size, norm_output = not is_last),
                Downsample(layer_dim, layer_dim * 2) if not is_last else None
            ]))

        self.mlp_head = nn.Sequential(
            Reduce('b d h w -> b d', 'mean'),
            nn.LayerNorm(dims[-1]),
            nn.Linear(dims[-1], num_classes)
        )

    def forward(self, img):
        x = self.to_patches(img)
        #print(x.shape,self.layers)
        x_fuses = []

        for transformer, downsample in self.layers:
            x = transformer(x)
            x_fuses.append(x)
            #print('bb ',x.shape)

            if exists(downsample):
                x = downsample(x)
                #print('xx ',x.shape)
        #print(x.shape)
        
        return x_fuses#self.mlp_head(x)

Collecting einops
  Obtaining dependency information for einops from https://files.pythonhosted.org/packages/29/0b/2d1c0ebfd092e25935b86509a9a817159212d82aa43d7fb07eca4eeff2c2/einops-0.7.0-py3-none-any.whl.metadata
  Downloading einops-0.7.0-py3-none-any.whl.metadata (13 kB)
Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [None]:
import torch
from torch import nn
import math
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torch.nn import init
import functools
from torch.optim import lr_scheduler
# from scalable_vit import ScalableViT
###############################################################################
# Helper Functions
###############################################################################
def get_norm_layer(norm_type='instance'):

    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
    elif norm_type == 'none':
        norm_layer = None
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer


def get_scheduler(optimizer, opt):

    if opt.lr_policy == 'linear':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
            return lr_l
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif opt.lr_policy == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
    elif opt.lr_policy == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
    elif opt.lr_policy == 'cosine':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
    else:
        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
    return scheduler


def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias)

def conv_layers(inp, oup, dilation):
    #if dilation:
    d_rate = dilation
    #else:
    #    d_rate = 1
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernel_size=3, padding=d_rate, dilation=d_rate),
        nn.ReLU(inplace=True)
    )


def feature_transform(inp, oup):
    conv2d = nn.Conv2d(inp, oup, kernel_size=1)  # no padding
    relu = nn.ReLU(inplace=True)
    layers = []
    layers += [conv2d, relu]
    return nn.Sequential(*layers)


def pool_layers(ceil_mode=True):
    return nn.MaxPool2d(kernel_size=5, stride=2)

class block_V(nn.Module):
    """
    Compact Dilation Convolution based Module
    """
    def __init__(self, in_channels):
        super(block_V, self).__init__()
        

        self.conv0_0 = conv_layers(in_channels, in_channels,1)


        self.pool0 = pool_layers()
        self.conv1_0 = conv_layers(in_channels, in_channels,2)
        self.conv1_1 = conv_layers(in_channels, in_channels,2)

        self.pool1 = pool_layers()
        self.conv2_0 = conv_layers(in_channels, in_channels,4)
        self.conv2_1 = conv_layers(in_channels, in_channels,4)
        
        self.pool2 = pool_layers()
        self.conv3_0 = conv_layers(in_channels, in_channels,8)
        self.conv3_1 = conv_layers(in_channels, in_channels,8)

        
      
        
    def forward(self, x):
        H, W = x.size()[2:]

        x = self.conv0_0(x)

        x = self.pool0(x)
        x = self.conv1_0(x)      
        x1 = self.conv1_1(x)
       
        x = self.pool1(x1)      
        x = self.conv2_0(x)
        x2 = self.conv2_1(x)
             
        x = self.pool2(x2)
        x = self.conv3_0(x)       
        x3 = self.conv3_1(x)

        return [x1,x2,x3]

class block_O(nn.Module):
    """
    Compact Dilation Convolution based Module
    """
    def __init__(self, in_channels):
        super(block_O, self).__init__()

        self.conv1_0 = conv_layers(in_channels, in_channels,2)
        self.conv1_1 = conv_layers(in_channels, in_channels,2)
   
        self.conv2_0 = conv_layers(in_channels, in_channels,2)
        self.conv2_1 = conv_layers(in_channels, in_channels,2)
              
        self.conv3_0 = conv_layers(in_channels, in_channels,2)
        self.conv3_1 = conv_layers(in_channels, in_channels,2)
        
    def forward(self, x):

        x = self.conv1_0(x)      
        x1 = self.conv1_1(x)
             
        x = self.conv2_0(x1)
        x2 = self.conv2_1(x)
             
        x = self.conv3_0(x2)       
        x3 = self.conv3_1(x)
        
        return [x1,x2,x3]
        
class blockbn(nn.Module):
    """
    Compact Dilation Convolution based Module
    """
    def __init__(self, inn):
        super(blockbn, self).__init__()
        #self.norm_layer = get_norm_layer(norm_type='batch')
        norm_layer = get_norm_layer(norm_type='batch')
        

        self.conv0_0= ScalableViT(
            num_classes = 1000,
            dim = inn,                               # starting model dimension. at every stage, dimension is doubled
            heads = (2, 4, 8, 16),                  # number of attention heads at each stage
            depth = (2, 2, 20, 2),                  # number of transformer blocks at each stage
            ssa_dim_key = (4, 4, 4, 4),         # the dimension of the attention keys (and queries) for SSA. in the paper, they represented this as a scale factor on the base dimension per key (ssa_dim_key / dim_key)
            reduction_factor = (8, 4, 2, 1),        # downsampling of the key / values in SSA. in the paper, this was represented as (reduction_factor ** -2)
            window_size = (16, 16, None, None),     # window size of the IWSA at each stage. None means no windowing needed
            dropout = 0.1,                          # attention and feedforward dropout
            )



        #self.pool1 = pool_layers()
        self.c0 = nn.Conv2d(32, inn, 3, padding=1)
        self.c1 = nn.Conv2d(64, inn, 3, padding=1)
        self.c2 = nn.Conv2d(128, inn, 3, padding=1)
        self.c3 = nn.Conv2d(256, inn, 3, padding=1)
        self.c4 = nn.Conv2d(512, inn, 3, padding=1)

        #self.conv2_0 = ScalableSelfAttention(inn, heads = 2, dim_key = 40, dim_value = 32, reduction_factor = 8, dropout = 0.1)
  
       
        self.classifier = nn.Conv2d(inn*3, 3, kernel_size=1)
    def _conv_block(self, in_nc, out_nc, norm_layer, num_block=1, kernel_size=3,stride=1, padding=2,bias=False):
        conv = []
        for i in range(num_block):
            cur_in_nc = in_nc if i == 0 else out_nc
            conv += [nn.Conv2d(cur_in_nc, out_nc, kernel_size=kernel_size, stride=stride, 
                               padding=padding, dilation=padding, bias=False),
                     norm_layer(out_nc),
                     nn.ReLU(True)]
        return conv
        
    def forward(self, x):
        if x.size()[1] == 1:
            x = x.repeat(1,3,1,1)
        #print(x.shape)
        H, W = x.size()[2:]
        if x.size()[1] == 128:
            
            x = self.cc2(x)
            
        
        [x1,x2,x3,x4] = self.conv0_0(x)
       
        x1 = F.interpolate(x1, (H, W), mode="bilinear", align_corners=False)
        

        
        x2 = F.interpolate(x2, (H, W), mode="bilinear", align_corners=False)


        x3 = F.interpolate(x3, (H, W), mode="bilinear", align_corners=False)
        x4 = F.interpolate(x4, (H, W), mode="bilinear", align_corners=False)
        #x3=x1+x2
        x1=self.c1(x1)
        x2=self.c2(x2)
        x3=self.c3(x3)
        x4=self.c4(x4)
        #print(x1.shape,x2.shape,x3.shape,x4.shape,)

        return [x1,x2,x3,x4]


def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")


class FFA(nn.Module):
    def __init__(self,gps = gps, blocks = blocks):
        super(FFA, self).__init__()

        self.bn=blockbn(64)
        self.bn1=blockbn(128)
        self.bn2=blockbn(128)
        self.conv1_1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv1_2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv1_3 = nn.Conv2d(64, 128, 3, padding=1)
        
        self.conv2_1 = nn.Conv2d(64,128, 3, padding=1)
        self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)

        self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv3_2 = nn.Conv2d(256, 128, 3, padding=1)
        
        #self.conv4_1 = nn.Conv2d(128, 128, 3, padding=1)
        #self.conv4_2 = nn.Conv2d(128, 128, 3, padding=1)

        self.convo1 = nn.Conv2d(128, 64, 3, padding=1)
        self.convo2 = nn.Conv2d(64, 32, 3, padding=1)
        self.convo3 = nn.Conv2d(32, 3, 3, padding=1)
        self.tt = nn.Conv2d(128, 32, 3, padding=1)
       
        self.p = nn.MaxPool2d(3, stride=2)
        
        self.relu = nn.ReLU()
        
        self.score_final = nn.Conv2d(12, 3, 1)
        
        nn.init.constant_(self.score_final.weight, 0.25)
        nn.init.constant_(self.score_final.bias, 0)

        print('initialization done')
        
    def _conv_block(self, in_nc, out_nc, norm_layer, num_block=1, kernel_size=1,stride=1,bias=False):
        conv = []
        for i in range(num_block):
            cur_in_nc = in_nc if i == 0 else out_nc
            conv += [nn.Conv2d(cur_in_nc, out_nc, kernel_size=kernel_size, stride=stride, bias=False),
                     norm_layer(out_nc),
                     nn.ReLU(True)]
        return conv
    def get_weights(self):
        conv_weights = []
        bn_weights = []
        relu_weights = []
        for pname, p in self.named_parameters():
            if 'bn' in pname:
                bn_weights.append(p)
            elif 'relu' in pname:
                relu_weights.append(p)
            else:
                conv_weights.append(p)

        return conv_weights, bn_weights, relu_weights

    def forward(self, x):
        H1, W1 = x.size()[2:]
        conv1 = self.conv1_1(x)
        conv2 = self.conv1_2(conv1)
        
        [z1,z2,z3,z4] = self.bn(conv2)
        y1= self.conv2_1(z1+conv2)
        y2= self.conv2_1(z1+z2)
        y3= self.conv2_1(z2+z3)
        y4= self.conv2_1(z1+z2+z3+z4)
        
        t1=self.conv2_2(y1+y2)
        t2=self.conv2_2(y2+y3)
        t3=self.conv2_2(y3+y4)

        #print(y1.shape,y2.shape,y3.shape,y4.shape,t1.shape,t2.shape,t3.shape,(t1+t2+t3).shape)
        #orig = self.conv1_3(conv2)
        fuse0=self.conv3_1(t1+t2+t3)
        fuse1=self.conv3_2(fuse0)
        fuse2 = self.convo1(fuse1)#fuse1)
        fuse2 = self.convo2(fuse2)#fuse1)
        #print(fuse2.shape)
        fuse3 = self.convo3(fuse2)
        fuse = F.interpolate(fuse3, (H1, W1), mode="bilinear", align_corners=False)

        return fuse#results

# model = FFA()
# x = torch.randn(1, 3, 256, 256)
# print(model(x).shape)


In [5]:
# --- Perceptual loss network  --- #
class PerLoss(torch.nn.Module):
    def __init__(self, vgg_model):
        super(PerLoss, self).__init__()
        self.vgg_layers = vgg_model
        self.layer_name_mapping = {
            '3': "relu1_2",
            '8': "relu2_2",
            '15': "relu3_3"
        }

    def output_features(self, x):
        output = {}
        for name, module in self.vgg_layers._modules.items():
            x = module(x)
            if name in self.layer_name_mapping:
                output[self.layer_name_mapping[name]] = x
        return list(output.values())

    def forward(self, dehaze, gt):
        loss = []
        dehaze_features = self.output_features(dehaze)
        gt_features = self.output_features(gt)
        for dehaze_feature, gt_feature in zip(dehaze_features, gt_features):
            loss.append(F.mse_loss(dehaze_feature, gt_feature))

        return sum(loss)/len(loss)

In [6]:
def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2
    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2
    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

def ssim(img1, img2, window_size=11, size_average=True):
    img1=torch.clamp(img1,min=0,max=1)
    img2=torch.clamp(img2,min=0,max=1)
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    return _ssim(img1, img2, window, window_size, channel, size_average)

def psnr(pred, gt):
    pred=pred.clamp(0,1).cpu().numpy()
    gt=gt.clamp(0,1).cpu().numpy()
    imdff = pred - gt
    rmse = math.sqrt(np.mean(imdff ** 2))
    if rmse == 0:
        return 100
    return 20 * math.log10( 1.0 / rmse)

In [7]:
class RESIDE_Dataset(data.Dataset):
    def __init__(self, path, train, size=crop_size, format='.png'):
        super(RESIDE_Dataset, self).__init__()
        self.size = size
        self.train = train
        self.format = format
        self.haze_imgs_dir = os.listdir(os.path.join(path,'hazy'))
        self.haze_imgs = [os.path.join(path, 'hazy', img) for img in self.haze_imgs_dir]
        self.clear_dir = os.path.join(path,'clear')
        
    def __getitem__(self, index):
        haze = Image.open(self.haze_imgs[index])
        haze=haze.resize((256, 256), Image.BICUBIC)
        if isinstance(self.size, int):
            while haze.size[0] < self.size or haze.size[1] < self.size :
                index = random.randint(0, 2000)
                haze = Image.open(self.haze_imgs[index])
        img = self.haze_imgs[index]
        id = img.split('/')[-1]#.split('_')[0]
        clear_name = id #+ self.format
        clear_name=clear_name.split('\\')[-1]
    
        
        clear = Image.open(os.path.join(self.clear_dir, clear_name))
        clear =clear.resize((256, 256), Image.BICUBIC)
        #print(self.haze_imgs[index],os.path.join(self.clear_dir, clear_name))
        clear = tfs.CenterCrop(haze.size[::-1])(clear)
        if not isinstance(self.size, str):
            i, j, h, w = tfs.RandomCrop.get_params(haze, output_size=(self.size, self.size))
            haze = FF.crop(haze, i, j, h, w)
            clear = FF.crop(clear, i, j, h, w)
        haze, clear = self.augData(haze.convert("RGB"), clear.convert("RGB") )
        return haze, clear
    
    def augData(self, data, target):
        if self.train:
            rand_hor = random.randint(0,1)
            rand_rot = random.randint(0,3)
            data = tfs.RandomHorizontalFlip(rand_hor)(data)
            target = tfs.RandomHorizontalFlip(rand_hor)(target)
            if rand_rot:
                data = FF.rotate(data, 90*rand_rot)
                target = FF.rotate(target, 90*rand_rot)
        data = tfs.ToTensor()(data)
        data = tfs.Normalize(mean=[0.64,0.6,0.58], std=[0.14,0.15,0.152])(data)
        target = tfs.ToTensor()(target)
        return data, target

    def __len__(self):
        return len(self.haze_imgs)


# path to your 'data' folder
its_train_path = '/kaggle/input/o-haze'
its_test_path = '/kaggle/input/o-haze'

ITS_train_loader = DataLoader(dataset=RESIDE_Dataset(its_train_path, train=True, size=crop_size), batch_size=bs, shuffle=True)
ITS_test_loader = DataLoader(dataset=RESIDE_Dataset(its_test_path, train=False, size='whole img'), batch_size=1, shuffle=False)

In [8]:
print('log_dir :', log_dir)
print('model_name:', model_name)

models_ = {'ffa': FFA(gps = gps, blocks = blocks)}
loaders_ = {'its_train': ITS_train_loader, 'its_test': ITS_test_loader}
# loaders_ = {'its_train': ITS_train_loader, 'its_test': ITS_test_loader, 'ots_train': OTS_train_loader, 'ots_test': OTS_test_loader}
start_time = time.time()
T = steps

def train(net, loader_train, loader_test, optim, criterion):
    losses = []
    start_step = 0
    max_ssim = max_psnr = 0
    ssims, psnrs = [], []
    if resume and os.path.exists(pretrained_model_dir):
        print(f'resume from {pretrained_model_dir}')
        ckp = torch.load(pretrained_model_dir)
        losses = ckp['losses']
        net.load_state_dict(ckp['model'])
        start_step = ckp['step']
        max_ssim = ckp['max_ssim']
        max_psnr = ckp['max_psnr']
        psnrs = ckp['psnrs']
        ssims = ckp['ssims']
        print(f'Resuming training from step: {start_step} ***')
    else :
        print('Training from scratch *** ')
    for step in range(start_step+1, steps+1):
        net.train()
        lr = learning_rate
        if not no_lr_sche:
            lr = lr_schedule_cosdecay(step,T)
            for param_group in optim.param_groups:
                param_group["lr"] = lr
        x, y = next(iter(loader_train))
        x = x.to(device); y = y.to(device)
        out = net(x)
        loss = criterion[0](out,y)
        if perloss:
            loss2 = criterion[1](out,y)
            loss = loss + 0.04*loss2

        loss.backward()

        optim.step()
        optim.zero_grad()
        losses.append(loss.item())
        print(f'\rtrain loss: {loss.item():.5f} | step: {step}/{steps} | lr: {lr :.7f} | time_used: {(time.time()-start_time)/60 :.1f}',end='',flush=True)

        if step % eval_step ==0 :
            with torch.no_grad():
                ssim_eval, psnr_eval = test(net, loader_test, max_psnr, max_ssim, step)
            print(f'\nstep: {step} | ssim: {ssim_eval:.4f} | psnr: {psnr_eval:.4f}')

            ssims.append(ssim_eval)
            psnrs.append(psnr_eval)
            if ssim_eval > max_ssim and psnr_eval > max_psnr :
                max_ssim = max(max_ssim,ssim_eval)
                max_psnr = max(max_psnr,psnr_eval)
                torch.save({
                            'step': step,
                            'max_psnr': max_psnr,
                            'max_ssim': max_ssim,
                            'ssims': ssims,
                            'psnrs': psnrs,
                            'losses': losses,
                            'model': net.state_dict()
                }, model_dir)
                print(f'\n model saved at step : {step} | max_psnr: {max_psnr:.4f} | max_ssim: {max_ssim:.4f}')

    np.save(f'./numpy_files/{model_name}_{steps}_losses.npy',losses)
    np.save(f'./numpy_files/{model_name}_{steps}_ssims.npy',ssims)
    np.save(f'./numpy_files/{model_name}_{steps}_psnrs.npy',psnrs)

def test(net, loader_test, max_psnr, max_ssim, step):
    net.eval()
    torch.cuda.empty_cache()
    ssims, psnrs = [], []
    for i, (inputs, targets) in enumerate(loader_test):
        inputs = inputs.to(device); targets = targets.to(device)
        pred = net(inputs)
        # # print(pred)
        # tfs.ToPILImage()(torch.squeeze(targets.cpu())).save('111.png')
        # vutils.save_image(targets.cpu(),'target.png')
        # vutils.save_image(pred.cpu(),'pred.png')
        ssim1 = ssim(pred, targets).item()
        psnr1 = psnr(pred, targets)
        ssims.append(ssim1)
        psnrs.append(psnr1)
        #if (psnr1>max_psnr or ssim1 > max_ssim) and s :
#             ts=vutils.make_grid([torch.squeeze(inputs.cpu()),torch.squeeze(targets.cpu()),torch.squeeze(pred.clamp(0,1).cpu())])
#             vutils.save_image(ts,f'samples/{model_name}/{step}_{psnr1:.4}_{ssim1:.4}.png')
#             s=False
    return np.mean(ssims) ,np.mean(psnrs)


log_dir : logs/its_train_ffa_3_12
model_name: its_train_ffa_3_12
initialization done


In [9]:
%%time

# ckp = torch.load(model_dir, map_location=device)
# net = FFA(gps=gps, blocks=blocks)
# net = nn.DataParallel(net)
# net.load_state_dict(ckp['model'])
loader_train = loaders_[trainset]
loader_test = loaders_[testset]
net = models_[network]
net = net.to(device)

if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True
criterion = []
criterion.append(nn.L1Loss().to(device))
if perloss:
    vgg_model = vgg16(pretrained=True).features[:16]
    vgg_model = vgg_model.to(device)
    for param in vgg_model.parameters():
        param.requires_grad = False
    criterion.append(PerLoss(vgg_model).to(device))
optimizer = optim.Adam(params = filter(lambda x: x.requires_grad, net.parameters()), lr=learning_rate, betas=(0.9,0.999), eps=1e-08)
optimizer.zero_grad()
train(net, loader_train, loader_test, optimizer, criterion)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:01<00:00, 339MB/s]


Training from scratch *** 
train loss: 0.08983 | step: 500/20000 | lr: 0.0001000 | time_used: 6.7
step: 500 | ssim: 0.7238 | psnr: 18.8582

 model saved at step : 500 | max_psnr: 18.8582 | max_ssim: 0.7238
train loss: 0.08118 | step: 1000/20000 | lr: 0.0001000 | time_used: 13.8
step: 1000 | ssim: 0.7828 | psnr: 19.1610

 model saved at step : 1000 | max_psnr: 19.1610 | max_ssim: 0.7828
train loss: 0.06218 | step: 1500/20000 | lr: 0.0001000 | time_used: 20.9
step: 1500 | ssim: 0.8104 | psnr: 20.9885

 model saved at step : 1500 | max_psnr: 20.9885 | max_ssim: 0.8104
train loss: 0.09905 | step: 2000/20000 | lr: 0.0001000 | time_used: 27.8
step: 2000 | ssim: 0.8215 | psnr: 21.5018

 model saved at step : 2000 | max_psnr: 21.5018 | max_ssim: 0.8215
train loss: 0.08462 | step: 2500/20000 | lr: 0.0001000 | time_used: 34.8
step: 2500 | ssim: 0.8252 | psnr: 21.5874

 model saved at step : 2500 | max_psnr: 21.5874 | max_ssim: 0.8252
train loss: 0.06083 | step: 3000/20000 | lr: 0.0001000 | time_

In [10]:
# its or ots
task = 'its'
# test imgs folder
test_imgs = '/kaggle/input/o-haze/hazyd/'

dataset = task
img_dir = test_imgs

output_dir = f'pred_FFA_{dataset}/'
print("pred_dir:",output_dir)

if not os.path.exists(output_dir):
    os.mkdir(output_dir)

ckp = torch.load(model_dir, map_location=device)
net = FFA(gps=gps, blocks=blocks)
net = nn.DataParallel(net)
net.load_state_dict(ckp['model'])
net.eval()

for im in os.listdir(img_dir):
    haze = Image.open(img_dir+im)
    haze=haze.resize((256, 256), Image.BICUBIC)
    haze1 = tfs.Compose([
        tfs.ToTensor(),
        tfs.Normalize(mean=[0.64, 0.6, 0.58],std=[0.14,0.15, 0.152])
    ])(haze)[None,::]
    haze_no = tfs.ToTensor()(haze)[None,::]
    with torch.no_grad():
        pred = net(haze1)
    ts = torch.squeeze(pred.clamp(0,1).cpu())
    # tensorShow([haze_no, pred.clamp(0,1).cpu()],['haze', 'pred'])
    
    haze_no = make_grid(haze_no, nrow=1, normalize=True)
    ts = make_grid(ts, nrow=1, normalize=True)
    image_grid = torch.cat((haze_no, ts), -1)
    vutils.save_image(image_grid, output_dir+im.split('.')[0]+'_FFA.png')

pred_dir: pred_FFA_its/
initialization done


FileNotFoundError: [Errno 2] No such file or directory: '/kaggle/input/o-haze/hazyd/'