<a href="https://colab.research.google.com/github/vfrantc/weather_experiments/blob/main/run_gcanet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://github.com/cddlyf/GCANet

This is the implementation of our WACV 2019 paper "Gated Context Aggregation Network for Image Dehazing and Deraining" by Dongdong Chen, Mingming He, Qingnan Fan, et al.

In [1]:
import os
import argparse
import numpy as np
from PIL import Image
from glob import glob

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [2]:
from matplotlib import pyplot as plt
%matplotlib inline

In [3]:
from google.colab import drive

In [4]:
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
!git clone https://github.com/cddlyf/GCANet.git

Cloning into 'GCANet'...
remote: Enumerating objects: 28, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 28 (delta 0), reused 1 (delta 0), pack-reused 25[K
Unpacking objects: 100% (28/28), done.


In [6]:
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)

    return images


def edge_compute(x):
    x_diffx = torch.abs(x[:,:,1:] - x[:,:,:-1])
    x_diffy = torch.abs(x[:,1:,:] - x[:,:-1,:])

    y = x.new(x.size())
    y.fill_(0)
    y[:,:,1:] += x_diffx
    y[:,:,:-1] += x_diffx
    y[:,1:,:] += x_diffy
    y[:,:-1,:] += x_diffy
    y = torch.sum(y,0,keepdim=True)/3
    y /= 4
    return y

In [7]:
class ShareSepConv(nn.Module):
    def __init__(self, kernel_size):
        super(ShareSepConv, self).__init__()
        assert kernel_size % 2 == 1, 'kernel size should be odd'
        self.padding = (kernel_size - 1)//2
        weight_tensor = torch.zeros(1, 1, kernel_size, kernel_size)
        weight_tensor[0, 0, (kernel_size-1)//2, (kernel_size-1)//2] = 1
        self.weight = nn.Parameter(weight_tensor)
        self.kernel_size = kernel_size

    def forward(self, x):
        inc = x.size(1)
        expand_weight = self.weight.expand(inc, 1, self.kernel_size, self.kernel_size).contiguous()
        return F.conv2d(x, expand_weight,
                        None, 1, self.padding, 1, inc)

In [8]:
class SmoothDilatedResidualBlock(nn.Module):
    def __init__(self, channel_num, dilation=1, group=1):
        super(SmoothDilatedResidualBlock, self).__init__()
        self.pre_conv1 = ShareSepConv(dilation*2-1)
        self.conv1 = nn.Conv2d(channel_num, channel_num, 3, 1, padding=dilation, dilation=dilation, groups=group, bias=False)
        self.norm1 = nn.InstanceNorm2d(channel_num, affine=True)
        self.pre_conv2 = ShareSepConv(dilation*2-1)
        self.conv2 = nn.Conv2d(channel_num, channel_num, 3, 1, padding=dilation, dilation=dilation, groups=group, bias=False)
        self.norm2 = nn.InstanceNorm2d(channel_num, affine=True)

    def forward(self, x):
        y = F.relu(self.norm1(self.conv1(self.pre_conv1(x))))
        y = self.norm2(self.conv2(self.pre_conv2(y)))
        return F.relu(x+y)

In [9]:
class ResidualBlock(nn.Module):
    def __init__(self, channel_num, dilation=1, group=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channel_num, channel_num, 3, 1, padding=dilation, dilation=dilation, groups=group, bias=False)
        self.norm1 = nn.InstanceNorm2d(channel_num, affine=True)
        self.conv2 = nn.Conv2d(channel_num, channel_num, 3, 1, padding=dilation, dilation=dilation, groups=group, bias=False)
        self.norm2 = nn.InstanceNorm2d(channel_num, affine=True)

    def forward(self, x):
        y = F.relu(self.norm1(self.conv1(x)))
        y = self.norm2(self.conv2(y))
        return F.relu(x+y)

In [10]:
class GCANet(nn.Module):
    def __init__(self, in_c=4, out_c=3, only_residual=True):
        super(GCANet, self).__init__()
        self.conv1 = nn.Conv2d(in_c, 64, 3, 1, 1, bias=False)
        self.norm1 = nn.InstanceNorm2d(64, affine=True)
        self.conv2 = nn.Conv2d(64, 64, 3, 1, 1, bias=False)
        self.norm2 = nn.InstanceNorm2d(64, affine=True)
        self.conv3 = nn.Conv2d(64, 64, 3, 2, 1, bias=False)
        self.norm3 = nn.InstanceNorm2d(64, affine=True)

        self.res1 = SmoothDilatedResidualBlock(64, dilation=2)
        self.res2 = SmoothDilatedResidualBlock(64, dilation=2)
        self.res3 = SmoothDilatedResidualBlock(64, dilation=2)
        self.res4 = SmoothDilatedResidualBlock(64, dilation=4)
        self.res5 = SmoothDilatedResidualBlock(64, dilation=4)
        self.res6 = SmoothDilatedResidualBlock(64, dilation=4)
        self.res7 = ResidualBlock(64, dilation=1)

        self.gate = nn.Conv2d(64 * 3, 3, 3, 1, 1, bias=True)

        self.deconv3 = nn.ConvTranspose2d(64, 64, 4, 2, 1)
        self.norm4 = nn.InstanceNorm2d(64, affine=True)
        self.deconv2 = nn.Conv2d(64, 64, 3, 1, 1)
        self.norm5 = nn.InstanceNorm2d(64, affine=True)
        self.deconv1 = nn.Conv2d(64, out_c, 1)
        self.only_residual = only_residual

    def forward(self, x):
        y = F.relu(self.norm1(self.conv1(x)))
        y = F.relu(self.norm2(self.conv2(y)))
        y1 = F.relu(self.norm3(self.conv3(y)))

        y = self.res1(y1)
        y = self.res2(y)
        y = self.res3(y)
        y2 = self.res4(y)
        y = self.res5(y2)
        y = self.res6(y)
        y3 = self.res7(y)

        gates = self.gate(torch.cat((y1, y2, y3), dim=1))
        gated_y = y1 * gates[:, [0], :, :] + y2 * gates[:, [1], :, :] + y3 * gates[:, [2], :, :]
        y = F.relu(self.norm4(self.deconv3(gated_y)))
        y = F.relu(self.norm5(self.deconv2(y)))
        if self.only_residual:
            y = self.deconv1(y)
        else:
            y = F.relu(self.deconv1(y))

        return y

# Run

In [12]:
!ls GCANet/models

wacv_gcanet_dehaze.pth	wacv_gcanet_derain.pth


In [33]:
indir = 'examples'
outdir = 'output'
#os.mkdir(outdir)
model = 'GCANet/models/wacv_gcanet_dehaze.pth'
only_residual = True
net = GCANet(in_c=4, out_c=3, only_residual=True)
torch.cuda.set_device(0)
net.cuda()
net.load_state_dict(torch.load(model, map_location='cpu'))
net.eval()

GCANet(
  (conv1): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (norm3): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (res1): SmoothDilatedResidualBlock(
    (pre_conv1): ShareSepConv()
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
    (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (pre_conv2): ShareSepConv()
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
    (norm2): InstanceNorm2

In [16]:
def restore_image(img_path):
  img = Image.open(img_path).convert('RGB')
  im_w, im_h = img.size
  if im_w % 4 != 0 or im_h % 4 != 0:
      img = img.resize((int(im_w // 4 * 4), int(im_h // 4 * 4))) 
  img = np.array(img).astype('float')
  img_data = torch.from_numpy(img.transpose((2, 0, 1))).float()
  edge_data = edge_compute(img_data)
  in_data = torch.cat((img_data, edge_data), dim=0).unsqueeze(0) - 128 
  in_data = in_data.cuda()
  with torch.no_grad():
      pred = net(Variable(in_data))

  out_img_data = (pred.data[0].cpu().float() + img_data).round().clamp(0, 255)
  out_img = out_img_data.numpy().astype(np.uint8).transpose(1, 2, 0)
  return out_img

In [17]:
!cp /content/drive/MyDrive/weather_experiments/weather_test.zip .
!unzip weather_test.zip 

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: test/outdoor/input/0194_1_0.2.jpg  
  inflating: test/outdoor/input/1986_0.9_0.08.jpg  
  inflating: test/outdoor/input/0251_0.8_0.16.jpg  
  inflating: test/outdoor/input/1953_0.8_0.2.jpg  
  inflating: test/outdoor/input/0127_0.8_0.2.jpg  
  inflating: test/outdoor/input/0003_0.8_0.2.jpg  
  inflating: test/outdoor/input/1861_0.85_0.2.jpg  
  inflating: test/outdoor/input/0317_0.85_0.08.jpg  
  inflating: test/outdoor/input/0269_0.95_0.2.jpg  
  inflating: test/outdoor/input/1055_0.95_0.2.jpg  
  inflating: test/outdoor/input/0267_0.9_0.16.jpg  
  inflating: test/outdoor/input/1848_0.8_0.12.jpg  
  inflating: test/outdoor/input/0082_0.85_0.16.jpg  
  inflating: test/outdoor/input/0138_0.9_0.08.jpg  
  inflating: test/outdoor/input/1724_0.95_0.2.jpg  
  inflating: test/outdoor/input/1889_0.85_0.08.jpg  
  inflating: test/outdoor/input/0302_0.9_0.16.jpg  
  inflating: test/outdoor/input/1868_1_0.08.jpg  
  in

In [18]:
from tqdm.notebook import tqdm

In [19]:
def process_one(input_dir, output_dir):
  print(f'{input_dir} ----> {output_dir}')
  if not os.path.exists(output_dir):
    os.makedirs(output_dir)
  for fname in tqdm(list(glob(os.path.join(input_dir, '*.jpg'))) + list(glob(os.path.join(input_dir, '*.png')))):
    image = restore_image(fname)
    reconstructed = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    cv2.imwrite(os.path.join(output_dir, os.path.basename(fname)), image)

In [20]:
!mkdir gca_dehaze

In [22]:
!ls test

fog	     rain	     raindroptesta.txt	sand	    snowtest100k_L.txt
fog.txt      rain1400	     rain.txt		sand.txt    snow.txt
outdoor      rain1400.txt    real		snow	    test1
outdoor.txt  rain_drop_test  real.txt		Snow100K-L  test1.txt


In [34]:
!rm -rf gca_dehaze

In [35]:
process_one('test/fog/input/', 'gca_dehaze/fog')
process_one('test/outdoor/input/', 'gca_dehaze/outdoor')
process_one('test/rain/input/', 'gca_dehaze/rain')
process_one('test/real/input/', 'gca_dehaze/real')
process_one('test/sand/input/', 'gca_dehaze/sand')
process_one('test/snow/input/', 'gca_dehaze/snow')
process_one('test/test1/input/', 'gca_dehaze/test1')

test/fog/input/ ----> gca_dehaze/fog


  0%|          | 0/300 [00:00<?, ?it/s]

test/outdoor/input/ ----> gca_dehaze/outdoor


  0%|          | 0/500 [00:00<?, ?it/s]

test/rain/input/ ----> gca_dehaze/rain


  0%|          | 0/200 [00:00<?, ?it/s]

test/real/input/ ----> gca_dehaze/real


  0%|          | 0/497 [00:00<?, ?it/s]

test/sand/input/ ----> gca_dehaze/sand


  0%|          | 0/323 [00:00<?, ?it/s]

test/snow/input/ ----> gca_dehaze/snow


  0%|          | 0/204 [00:00<?, ?it/s]

test/test1/input/ ----> gca_dehaze/test1


  0%|          | 0/750 [00:00<?, ?it/s]

In [36]:
!zip -r gca_dehaze.zip gca_dehaze/
!cp gca_dehaze.zip /content/drive/MyDrive/weather_experiments/

updating: gca_dehaze/ (stored 0%)
updating: gca_dehaze/real/ (stored 0%)
updating: gca_dehaze/real/195.png (deflated 1%)
updating: gca_dehaze/real/663.png (deflated 3%)
updating: gca_dehaze/real/169.png (deflated 1%)
updating: gca_dehaze/real/367.png (deflated 0%)
updating: gca_dehaze/real/260.png (deflated 2%)
updating: gca_dehaze/real/559.png (deflated 1%)
updating: gca_dehaze/real/271.png (deflated 0%)
updating: gca_dehaze/real/003.png (deflated 1%)
updating: gca_dehaze/real/008.png (deflated 1%)
updating: gca_dehaze/real/113.png (deflated 2%)
updating: gca_dehaze/real/135.png (deflated 1%)
updating: gca_dehaze/real/631.png (deflated 1%)
updating: gca_dehaze/real/290.png (deflated 1%)
updating: gca_dehaze/real/257.png (deflated 1%)
updating: gca_dehaze/real/469.png (deflated 1%)
updating: gca_dehaze/real/350.png (deflated 1%)
updating: gca_dehaze/real/495.png (deflated 2%)
updating: gca_dehaze/real/214.png (deflated 1%)
updating: gca_dehaze/real/591.png (deflated 1%)
updating: gca_d

In [28]:
indir = 'examples'
outdir = 'output'
#os.mkdir(outdir)
model = 'GCANet/models/wacv_gcanet_derain.pth'
only_residual = True
net = GCANet(in_c=4, out_c=3, only_residual=True)
torch.cuda.set_device(0)
net.cuda()
net.load_state_dict(torch.load(model, map_location='cpu'))
net.eval()

GCANet(
  (conv1): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (norm3): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (res1): SmoothDilatedResidualBlock(
    (pre_conv1): ShareSepConv()
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
    (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (pre_conv2): ShareSepConv()
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
    (norm2): InstanceNorm2

In [29]:
!mkdir gca_derain

In [30]:
!ls test

fog	     rain	     raindroptesta.txt	sand	    snowtest100k_L.txt
fog.txt      rain1400	     rain.txt		sand.txt    snow.txt
outdoor      rain1400.txt    real		snow	    test1
outdoor.txt  rain_drop_test  real.txt		Snow100K-L  test1.txt


In [31]:
process_one('test/rain1400/input/', 'gca_derain/rain1400')
process_one('test/rain/input/', 'gca_derain/rain')
process_one('test/real/input/', 'gca_derain/real')
process_one('test/test1/input/', 'gca_derain/test1')

test/rain1400/input/ ----> gca_derain/rain1400


  0%|          | 0/1400 [00:00<?, ?it/s]

test/rain/input/ ----> gca_derain/rain


  0%|          | 0/200 [00:00<?, ?it/s]

test/real/input/ ----> gca_derain/real


  0%|          | 0/497 [00:00<?, ?it/s]

test/test1/input/ ----> gca_derain/test1


  0%|          | 0/750 [00:00<?, ?it/s]

In [32]:
!zip -r gca_rain.zip gca_derain/
!cp gca_rain.zip /content/drive/MyDrive/weather_experiments/

  adding: gca_derain/ (stored 0%)
  adding: gca_derain/real/ (stored 0%)
  adding: gca_derain/real/195.png (deflated 0%)
  adding: gca_derain/real/663.png (deflated 1%)
  adding: gca_derain/real/169.png (deflated 1%)
  adding: gca_derain/real/367.png (deflated 0%)
  adding: gca_derain/real/260.png (deflated 1%)
  adding: gca_derain/real/559.png (deflated 1%)
  adding: gca_derain/real/271.png (deflated 0%)
  adding: gca_derain/real/003.png (deflated 1%)
  adding: gca_derain/real/008.png (deflated 0%)
  adding: gca_derain/real/113.png (deflated 2%)
  adding: gca_derain/real/135.png (deflated 1%)
  adding: gca_derain/real/631.png (deflated 0%)
  adding: gca_derain/real/290.png (deflated 0%)
  adding: gca_derain/real/257.png (deflated 1%)
  adding: gca_derain/real/469.png (deflated 1%)
  adding: gca_derain/real/350.png (deflated 1%)
  adding: gca_derain/real/495.png (deflated 1%)
  adding: gca_derain/real/214.png (deflated 0%)
  adding: gca_derain/real/591.png (deflated 0%)
  adding: gca_d