In [None]:
from google.colab import drive

drive.mount('/content/drive')

In [None]:
!unzip /content/drive/MyDrive/data.zip

In [3]:
train_raincityscapes_path = "/content"

test_raincityscapes_path = "/content"

In [4]:
import datetime
import os
import torch
import os.path
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.backends import cudnn
from torch.utils.data import DataLoader
from torchvision import transforms
import random
from PIL import Image
import torch.nn.functional as F
import torch.utils.data as data
import numpy as np

In [5]:
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, mask, depth):
        assert img.size == mask.size
        assert img.size == depth.size
        for t in self.transforms:
            img, mask, depth = t(img, mask, depth)
        return img, mask, depth

class Resize(object):
    def __init__(self, size):
        self.size = tuple(reversed(size))  # size: (h, w)

    def __call__(self, img, mask, depth):
        assert img.size == mask.size
        assert img.size == depth.size
        return img.resize(self.size, Image.BILINEAR), mask.resize(self.size, Image.BILINEAR), depth.resize(self.size, Image.BILINEAR)


class RandomCrop(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, img, mask, depth):
        assert img.size == mask.size
        assert img.size == depth.size
        w, h = img.size

        x1 = random.randint(0, w - self.size)
        y1 = random.randint(0, h - self.size)
        return img.crop((x1, y1, x1 + self.size, y1 + self.size)), mask.crop((x1, y1, x1 + self.size, y1 + self.size)), depth.crop((x1, y1, x1 + self.size, y1 + self.size))


class RandomHorizontallyFlip(object):
    def __call__(self, img, mask, depth):
        if random.random() < 0.5:
            return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT), depth.transpose(Image.FLIP_LEFT_RIGHT)
        return img, mask, depth

In [6]:
class AvgMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def check_mkdir(dir_name):
    if not os.path.exists(dir_name):
        os.mkdir(dir_name)


class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))


def mse_loss(input, target):
    return torch.sum((input - target)**2) / input.data.nelement()


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [7]:
def make_dataset(root, is_train):
    if is_train:

        input = open(os.path.join(root, 'data/train_images.txt'))
        ground_t = open(os.path.join(root, 'data/train_gt.txt'))
        depth_t = open(os.path.join(root, 'data/train_depth.txt'))
        image = [(os.path.join(root, img_name.strip('\n'))) for img_name in
                 input]
        gt = [(os.path.join(root,  img_name.strip('\n'))) for img_name in
                 ground_t]
        depth = [(os.path.join(root, img_name.strip('\n'))) for img_name in
              depth_t]

        input.close()
        ground_t.close()
        depth_t.close()

        return [[image[i], gt[i], depth[i]]for i in range(len(image))]

    else:

        input = open(os.path.join(root, 'data/test_images.txt'))
        ground_t = open(os.path.join(root, 'data/test_gt.txt'))
        depth_t = open(os.path.join(root, 'data/test_depth.txt'))

        image = [(os.path.join(root, img_name.strip('\n'))) for img_name in
                 input]
        gt = [(os.path.join(root, img_name.strip('\n'))) for img_name in
              ground_t]
        depth = [(os.path.join(root, img_name.strip('\n'))) for img_name in
                 depth_t]

        input.close()
        ground_t.close()
        depth_t.close()
        
  
        return [[image[i], gt[i], depth[i]]for i in range(len(image))]



class ImageFolder(data.Dataset):
    def __init__(self, root, triple_transform=None, transform=None, target_transform=None, is_train=True):
        self.root = root
        self.imgs = make_dataset(root, is_train)
        self.triple_transform = triple_transform
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        img_path, gt_path, depth_path = self.imgs[index]
        
        img = Image.open(img_path)
        target = Image.open(gt_path)
        depth = Image.open(depth_path)

        if len(img.getbands()) == 4:
          temp = np.asarray(img)
          temp = temp[:,:,:3]
          img = Image.fromarray(temp)
        if len(target.getbands()) == 4:
          temp = np.asarray(target)
          temp = temp[:,:,:3]
          target = Image.fromarray(temp)
          
        if self.triple_transform is not None:
            img, target, depth = self.triple_transform(img, target, depth)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
            depth = self.target_transform(depth)
          
        return img, target, depth

    def __len__(self):
        return len(self.imgs)

In [8]:
cudnn.benchmark = True

ckpt_path = '/content/drive/MyDrive/ckpt'
exp_name = 'DGNLNet'

args = {
    'iter_num': 20,
    'train_batch_size': 2,
    'last_iter': 0,
    'lr': 5e-4,
    'lr_decay': 0.9,
    'weight_decay': 0,
    'momentum': 0.9,
    'resume_snapshot': '',
    'val_freq': 50000000,
    'img_size_h': 512,
	'img_size_w': 1024,
	'crop_size': 512,
    'snapshot_epochs': 1
}

In [9]:
transform = transforms.Compose([
    transforms.ToTensor()
    #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
to_pil = transforms.ToPILImage()
triple_transform = Compose([
    Resize((args['img_size_h'], args['img_size_w'])),
    #triple_transforms.RandomCrop(args['crop_size']),
    RandomHorizontallyFlip()
])

In [10]:
train_set = ImageFolder(train_raincityscapes_path, transform=transform, target_transform=transform, is_train = True,triple_transform=triple_transform)
train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=4, shuffle=True)
test1_set = ImageFolder(test_raincityscapes_path, transform=transform, target_transform=transform, is_train=False)
test1_loader = DataLoader(test1_set, batch_size=2)

In [11]:
len(train_set)

12595

In [12]:
criterion = nn.L1Loss()
criterion_depth = nn.L1Loss()
log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt')

In [13]:
class DGNL(nn.Module):
    def __init__(self, in_channels):
        super(DGNL, self).__init__()

        self.eps = 1e-6
        self.sigma_pow2 = 100

        self.theta = nn.Conv2d(in_channels, int(in_channels / 2), kernel_size=1)
        self.phi = nn.Conv2d(in_channels, int(in_channels / 2), kernel_size=1)
        self.g = nn.Conv2d(in_channels, int(in_channels / 2), kernel_size=1)

        self.down = nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=4, groups=in_channels, bias=False)
        self.down.weight.data.fill_(1. / 16)

        self.z = nn.Conv2d(int(in_channels / 2), in_channels, kernel_size=1)



    def forward(self, x, depth_map):
        n, c, h, w = x.size()
        x_down = self.down(x)

		# [n, (h / 8) * (w / 8), c / 2]
        g = F.max_pool2d(self.g(x_down), kernel_size=2, stride=2).view(n, int(c / 2), -1).transpose(1, 2)

        ### appearance relation map
        # [n, (h / 4) * (w / 4), c / 2]
        theta = self.theta(x_down).view(n, int(c / 2), -1).transpose(1, 2)
        # [n, c / 2, (h / 8) * (w / 8)]
        phi = F.max_pool2d(self.phi(x_down), kernel_size=2, stride=2).view(n, int(c / 2), -1)

		# [n, (h / 4) * (w / 4), (h / 8) * (w / 8)]
        Ra = F.softmax(torch.bmm(theta, phi), 2)


        ### depth relation map
        depth1 = F.interpolate(depth_map, size=[int(h / 4), int(w / 4)], mode='bilinear', align_corners = True).view(n, 1, int(h / 4)*int(w / 4)).transpose(1,2)
        depth2 = F.interpolate(depth_map, size=[int(h / 8), int(w / 8)], mode='bilinear', align_corners = True).view(n, 1, int(h / 8)*int(w / 8))

        # n, (h / 4) * (w / 4), (h / 8) * (w / 8)
        depth1_expand = depth1.expand(n, int(h / 4) * int(w / 4), int(h / 8) * int(w / 8))
        depth2_expand = depth2.expand(n, int(h / 4) * int(w / 4), int(h / 8) * int(w / 8))

        Rd = torch.min(depth1_expand / (depth2_expand + self.eps), depth2_expand / (depth1_expand + self.eps))

        Rd = F.softmax(Rd, 2)

        S = F.softmax(Ra * Rd, 2)


        # [n, c / 2, h / 4, w / 4]
        y = torch.bmm(S, g).transpose(1, 2).contiguous().view(n, int(c / 2), int(h / 4), int(w / 4))

        return x + F.upsample(self.z(y), size=x.size()[2:], mode='bilinear', align_corners = True)



class NLB(nn.Module):
    def __init__(self, in_channels):
        super(NLB, self).__init__()
        self.theta = nn.Conv2d(in_channels, int(in_channels / 2), kernel_size=1)
        self.phi = nn.Conv2d(in_channels, int(in_channels / 2), kernel_size=1)
        self.g = nn.Conv2d(in_channels, int(in_channels / 2), kernel_size=1)

        self.down = nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=4, groups=in_channels, bias=False)
        self.down.weight.data.fill_(1. / 16)

        self.z = nn.Conv2d(int(in_channels / 2), in_channels, kernel_size=1)

    def forward(self, x):
        n, c, h, w = x.size()
        x_down = self.down(x)

        # [n, (h / 4) * (w / 4), c / 2]
        theta = self.theta(x_down).view(n, int(c / 2), -1).transpose(1, 2)
        # [n, c / 2, (h / 8) * (w / 8)]
        phi = F.max_pool2d(self.phi(x_down), kernel_size=2, stride=2).view(n, int(c / 2), -1)
        # [n, (h / 8) * (w / 8), c / 2]
        g = F.max_pool2d(self.g(x_down), kernel_size=2, stride=2).view(n, int(c / 2), -1).transpose(1, 2)
        # [n, (h / 4) * (w / 4), (h / 8) * (w / 8)]
        f = F.softmax(torch.bmm(theta, phi), 2)
        # [n, c / 2, h / 4, w / 4]
        y = torch.bmm(f, g).transpose(1, 2).contiguous().view(n, int(c / 2), int(h / 4), int(w / 4))

        return x + F.upsample(self.z(y), size=x.size()[2:], mode='bilinear', align_corners=True)


class DepthWiseDilatedResidualBlock(nn.Module):
    def __init__(self, reduced_channels, channels, dilation):
        super(DepthWiseDilatedResidualBlock, self).__init__()
        self.conv0 = nn.Sequential(

		    # pw
		    nn.Conv2d(channels, channels * 2, 1, 1, 0, 1, bias=False),
			nn.ReLU6(inplace=True),
		    # dw
		    nn.Conv2d(channels*2, channels*2, kernel_size=3, padding=dilation, dilation=dilation, groups=channels, bias=False),
		    nn.ReLU6(inplace=True),
		    # pw-linear
		    nn.Conv2d(channels*2, channels, 1, 1, 0, 1, 1, bias=False)
        )

        self.conv1 = nn.Sequential(
	
			nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation, groups=channels,
					  bias=False),
			nn.ReLU6(inplace=True),
			# pw-linear
			nn.Conv2d(channels, channels, 1, 1, 0, 1, 1, bias=False)
		)


    def forward(self, x):
        res = self.conv1(self.conv0(x))
        return res + x


class DilatedResidualBlock(nn.Module):
    def __init__(self, channels, dilation):
        super(DilatedResidualBlock, self).__init__()
        self.conv0 = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation), nn.ReLU()
        )
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation)

    def forward(self, x):
        conv0 = self.conv0(x)
        conv1 = self.conv1(conv0)
        return x + conv1


class SpatialRNN(nn.Module):
	"""
	SpatialRNN model for one direction only
	"""
	def __init__(self, alpha = 1.0, channel_num = 1, direction = "right"):
		super(SpatialRNN, self).__init__()
		self.alpha = nn.Parameter(torch.Tensor([alpha] * channel_num))
		self.direction = direction

	def __getitem__(self, item):
		return self.alpha[item]

	def __len__(self):
		return len(self.alpha)


	def forward(self, x):
		"""
		:param x: (N,C,H,W)
		:return:
		"""
		height = x.size(2)
		weight = x.size(3)
		x_out = []

		# from left to right
		if self.direction == "right":
			x_out = [x[:, :, :, 0].clamp(min=0)]

			for i in range(1, weight):
				temp = (self.alpha.unsqueeze(1) * x_out[i - 1] + x[:, :, :, i]).clamp(min=0)
				x_out.append(temp)  # a list of tensor

			return torch.stack(x_out, 3)  # merge into one tensor

		# from right to left
		elif self.direction == "left":
			x_out = [x[:, :, :, -1].clamp(min=0)]

			for i in range(1, weight):
				temp = (self.alpha.unsqueeze(1) * x_out[i - 1] + x[:, :, :, -i - 1]).clamp(min=0)
				x_out.append(temp)

			x_out.reverse()
			return torch.stack(x_out, 3)

		# from up to down
		elif self.direction == "down":
			x_out = [x[:, :, 0, :].clamp(min=0)]

			for i in range(1, height):
				temp = (self.alpha.unsqueeze(1) * x_out[i - 1] + x[:, :, i, :]).clamp(min=0)
				x_out.append(temp)

			return torch.stack(x_out, 2)

		# from down to up
		elif self.direction == "up":
			x_out = [x[:, :, -1, :].clamp(min=0)]

			for i in range(1, height):
				temp = (self.alpha.unsqueeze(1) * x_out[i - 1] + x[:, :, -i - 1, :]).clamp(min=0)
				x_out.append(temp)

			x_out.reverse()
			return torch.stack(x_out, 2)

		else:
			print("Invalid direction in SpatialRNN!")
			return KeyError



class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]


class NLB(nn.Module):
    def __init__(self, in_channels):
        super(NLB, self).__init__()
        self.theta = nn.Conv2d(in_channels, int(in_channels / 2), kernel_size=1)
        self.phi = nn.Conv2d(in_channels, int(in_channels / 2), kernel_size=1)
        self.g = nn.Conv2d(in_channels, int(in_channels / 2), kernel_size=1)

        self.down = nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=4, groups=in_channels, bias=False)
        self.down.weight.data.fill_(1. / 16)

        self.z = nn.Conv2d(int(in_channels / 2), in_channels, kernel_size=1)

    def forward(self, x):
        n, c, h, w = x.size()
        x_down = self.down(x)

        # [n, (h / 4) * (w / 4), c / 2]
        theta = self.theta(x_down).view(n, int(c / 2), -1).transpose(1, 2)
        # [n, c / 2, (h / 8) * (w / 8)]
        phi = F.max_pool2d(self.phi(x_down), kernel_size=2, stride=2).view(n, int(c / 2), -1)
        # [n, (h / 8) * (w / 8), c / 2]
        g = F.max_pool2d(self.g(x_down), kernel_size=2, stride=2).view(n, int(c / 2), -1).transpose(1, 2)
        # [n, (h / 4) * (w / 4), (h / 8) * (w / 8)]
        f = F.softmax(torch.bmm(theta, phi), 2)
        # [n, c / 2, h / 4, w / 4]
        y = torch.bmm(f, g).transpose(1, 2).contiguous().view(n, int(c / 2), int(h / 4), int(w / 4))

        return x + F.upsample(self.z(y), size=x.size()[2:], mode='bilinear', align_corners=True)

In [14]:
class DGNLNet_fast(nn.Module):
    def __init__(self, num_features=64):
        super(DGNLNet_fast, self).__init__()
        self.mean = torch.zeros(1, 3, 1, 1)
        self.std = torch.zeros(1, 3, 1, 1)
        self.mean[0, 0, 0, 0] = 0.485
        self.mean[0, 1, 0, 0] = 0.456
        self.mean[0, 2, 0, 0] = 0.406
        self.std[0, 0, 0, 0] = 0.229
        self.std[0, 1, 0, 0] = 0.224
        self.std[0, 2, 0, 0] = 0.225

        self.mean = nn.Parameter(self.mean)
        self.std = nn.Parameter(self.std)
        self.mean.requires_grad = False
        self.std.requires_grad = False

        ############################################ Depth prediction network
        self.conv1 = nn.Sequential(
			nn.Conv2d(3, 32, 8, stride=4, padding=2),
			nn.GroupNorm(num_groups=32, num_channels=32),
			nn.SELU(inplace=True)
		)

        self.conv2 = nn.Sequential(
			nn.Conv2d(32, 64, 4, stride=2, padding=1),
			nn.GroupNorm(num_groups=32, num_channels=64),
			nn.SELU(inplace=True)
		)

        self.conv3 = nn.Sequential(
			nn.Conv2d(64, 128, 4, stride=2, padding=1),
			nn.GroupNorm(num_groups=32, num_channels=128),
			nn.SELU(inplace=True)
		)

        self.conv5 = nn.Sequential(
			nn.Conv2d(128, 128, 3, padding=2, dilation=2),
			nn.GroupNorm(num_groups=32, num_channels=128),
			nn.SELU(inplace=True)
		)

        self.conv8 = nn.Sequential(
			nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
			nn.GroupNorm(num_groups=32, num_channels=64),
			nn.SELU(inplace=True)
		)

        self.conv9 = nn.Sequential(
			nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
			nn.GroupNorm(num_groups=32, num_channels=32),
			nn.SELU(inplace=True)
		)


        self.depth_pred = nn.Sequential(
			nn.Conv2d(32, 32, kernel_size=3, padding=1),
			nn.GroupNorm(num_groups=32, num_channels=32),
			nn.SELU(inplace=True),
			nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
			nn.Sigmoid()
		)

	############################################ Rain removal network


        self.head = nn.Sequential(
			# pw
			nn.Conv2d(3, 32, 1, 1, 0, 1, bias=False),
			nn.ReLU6(inplace=True),
			# dw
			nn.Conv2d(32, 32, kernel_size=8, stride=4, padding=2, groups=32, bias=False),
			nn.ReLU6(inplace=True),
			# pw-linear
			nn.Conv2d(32, num_features, 1, 1, 0, 1, 1, bias=False),
		)

        self.body = nn.Sequential(
			DepthWiseDilatedResidualBlock(num_features, num_features, 1),
			# DepthWiseDilatedResidualBlock(num_features, num_features, 1),
			DepthWiseDilatedResidualBlock(num_features, num_features, 2),
			DepthWiseDilatedResidualBlock(num_features, num_features, 2),
			DepthWiseDilatedResidualBlock(num_features, num_features, 4),
			DepthWiseDilatedResidualBlock(num_features, num_features, 8),
			DepthWiseDilatedResidualBlock(num_features, num_features, 4),
			DepthWiseDilatedResidualBlock(num_features, num_features, 2),
			DepthWiseDilatedResidualBlock(num_features, num_features, 2),
			# DepthWiseDilatedResidualBlock(num_features, num_features, 1),
			DepthWiseDilatedResidualBlock(num_features, num_features, 1)
		)


        self.dgnlb = DGNL(num_features)


        self.tail = nn.Sequential(
			# dw
			nn.ConvTranspose2d(num_features, 32, kernel_size=8, stride=4, padding=2, groups=32, bias=False),
			nn.ReLU6(inplace=True),
			# pw-linear
			nn.Conv2d(32, 3, 1, 1, 0, 1, bias=False),
		)

        for m in self.modules():
            if isinstance(m, nn.ReLU):
                m.inplace = True

    def forward(self, x):
        x = (x - self.mean) / self.std

        ################################## depth prediction
        d_f1 = self.conv1(x)
        d_f2 = self.conv2(d_f1)
        d_f3 = self.conv3(d_f2)
        d_f5 = self.conv5(d_f3)
        d_f8 = self.conv8(d_f5)
        d_f9 = self.conv9(d_f8 + d_f2)
        depth_pred = self.depth_pred(d_f9 + d_f1)

        ################################## rain removal

        f = self.head(x)
        f = self.body(f)
        f = self.dgnlb(f, depth_pred.detach())
        r = self.tail(f)
        x = x + r

        x = (x * self.std + self.mean).clamp(min=0, max=1)

        if self.training:
            return x, F.upsample(depth_pred, size=x.size()[2:], mode='bilinear', align_corners=True)

        return x



class DGNLNet(nn.Module):
    def __init__(self, num_features=64):
        super(DGNLNet, self).__init__()
        self.mean = torch.zeros(1, 3, 1, 1)
        self.std = torch.zeros(1, 3, 1, 1)
        self.mean[0, 0, 0, 0] = 0.485
        self.mean[0, 1, 0, 0] = 0.456
        self.mean[0, 2, 0, 0] = 0.406
        self.std[0, 0, 0, 0] = 0.229
        self.std[0, 1, 0, 0] = 0.224
        self.std[0, 2, 0, 0] = 0.225

        self.mean = nn.Parameter(self.mean)
        self.std = nn.Parameter(self.std)
        self.mean.requires_grad = False
        self.std.requires_grad = False

        ############################################ Depth prediction network
        self.conv1 = nn.Sequential(
			nn.Conv2d(3, 32, 4, stride=2, padding=1),
			nn.GroupNorm(num_groups=32, num_channels=32),
			nn.SELU(inplace=True)
		)

        self.conv2 = nn.Sequential(
			nn.Conv2d(32, 64, 4, stride=2, padding=1),
			nn.GroupNorm(num_groups=32, num_channels=64),
			nn.SELU(inplace=True)
		)

        self.conv3 = nn.Sequential(
			nn.Conv2d(64, 128, 4, stride=2, padding=1),
			nn.GroupNorm(num_groups=32, num_channels=128),
			nn.SELU(inplace=True)
		)

        self.conv4 = nn.Sequential(
			nn.Conv2d(128, 256, 4, stride=2, padding=1),
			nn.GroupNorm(num_groups=32, num_channels=256),
			nn.SELU(inplace=True)
		)

        self.conv5 = nn.Sequential(
			nn.Conv2d(256, 256, 3, padding=2, dilation=2),
			nn.GroupNorm(num_groups=32, num_channels=256),
			nn.SELU(inplace=True)
		)

        self.conv6 = nn.Sequential(
			nn.Conv2d(256, 256, 3, padding=4, dilation=4),
			nn.GroupNorm(num_groups=32, num_channels=256),
			nn.SELU(inplace=True)
		)

        self.conv7 = nn.Sequential(
			nn.Conv2d(256, 256, 3, padding=2, dilation=2),
			nn.GroupNorm(num_groups=32, num_channels=256),
			nn.SELU(inplace=True)
		)

        self.conv8 = nn.Sequential(
			nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
			nn.GroupNorm(num_groups=32, num_channels=128),
			nn.SELU(inplace=True)
		)

        self.conv9 = nn.Sequential(
			nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
			nn.GroupNorm(num_groups=32, num_channels=64),
			nn.SELU(inplace=True)
		)

        self.conv10 = nn.Sequential(
			nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
			nn.GroupNorm(num_groups=32, num_channels=32),
			nn.SELU(inplace=True)
		)

        self.depth_pred = nn.Sequential(
			nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1),
			nn.GroupNorm(num_groups=32, num_channels=32),
			nn.SELU(inplace=True),
			nn.Conv2d(32, 32, kernel_size=3, padding=1),
			nn.SELU(inplace=True),
			nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
			nn.Sigmoid()
		)

	############################################ Rain removal network

        self.head = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, num_features, kernel_size=1, stride=1, padding=0), nn.ReLU()
        )
        self.body = nn.Sequential(
            DilatedResidualBlock(num_features, 1),
            DilatedResidualBlock(num_features, 1),
            DilatedResidualBlock(num_features, 2),
            DilatedResidualBlock(num_features, 2),
            DilatedResidualBlock(num_features, 4),
            DilatedResidualBlock(num_features, 8),
            DilatedResidualBlock(num_features, 4),
            DilatedResidualBlock(num_features, 2),
            DilatedResidualBlock(num_features, 2),
            DilatedResidualBlock(num_features, 1),
            DilatedResidualBlock(num_features, 1)
        )

        self.dgnlb = DGNL(num_features)

        self.tail = nn.Sequential(
            # nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(num_features, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, 3, kernel_size=3, padding=1)
        )

        for m in self.modules():
            if isinstance(m, nn.ReLU):
                m.inplace = True

    def forward(self, x):
        x = (x - self.mean) / self.std

        ################################## depth prediction
        d_f1 = self.conv1(x)
        d_f2 = self.conv2(d_f1)
        d_f3 = self.conv3(d_f2)
        d_f4 = self.conv4(d_f3)
        d_f5 = self.conv5(d_f4)
        d_f6 = self.conv6(d_f5)
        d_f7 = self.conv7(d_f6)
        d_f8 = self.conv8(d_f7)
        d_f9 = self.conv9(d_f8 + d_f3)
        d_f10 = self.conv10(d_f9 + d_f2)
        depth_pred = self.depth_pred(d_f10 + d_f1)

        ################################## rain removal

        f = self.head(x)
        f = self.body(f)
        f = self.dgnlb(f, depth_pred.detach())
        r = self.tail(f)
        x = x + r

        x = (x * self.std + self.mean).clamp(min=0, max=1)

        if self.training:
            return x, depth_pred

        return x


class basic_NL(nn.Module):
    def __init__(self, num_features=64):
        super(basic_NL, self).__init__()
        self.mean = torch.zeros(1, 3, 1, 1)
        self.std = torch.zeros(1, 3, 1, 1)
        self.mean[0, 0, 0, 0] = 0.485
        self.mean[0, 1, 0, 0] = 0.456
        self.mean[0, 2, 0, 0] = 0.406
        self.std[0, 0, 0, 0] = 0.229
        self.std[0, 1, 0, 0] = 0.224
        self.std[0, 2, 0, 0] = 0.225

        self.mean = nn.Parameter(self.mean)
        self.std = nn.Parameter(self.std)
        self.mean.requires_grad = False
        self.std.requires_grad = False

        self.head = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, num_features, kernel_size=1, stride=1, padding=0), nn.ReLU()
        )
        self.body = nn.Sequential(
            DilatedResidualBlock(num_features, 1),
            DilatedResidualBlock(num_features, 1),
            DilatedResidualBlock(num_features, 2),
            DilatedResidualBlock(num_features, 2),
            DilatedResidualBlock(num_features, 4),
            DilatedResidualBlock(num_features, 8),
            DilatedResidualBlock(num_features, 4),
            DilatedResidualBlock(num_features, 2),
            DilatedResidualBlock(num_features, 2),
            DilatedResidualBlock(num_features, 1),
            DilatedResidualBlock(num_features, 1)
        )

        self.nlb = NLB(num_features)

        self.tail = nn.Sequential(
            # nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(num_features, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, 3, kernel_size=3, padding=1)
        )

        for m in self.modules():
            if isinstance(m, nn.ReLU):
                m.inplace = True

    def forward(self, x):
        x = (x - self.mean) / self.std

        f = self.head(x)
        f = self.body(f)
        f = self.nlb(f)
        r = self.tail(f)
        x = x + r

        x = (x * self.std + self.mean).clamp(min=0, max=1)
        return x



class basic(nn.Module):
    def __init__(self, num_features=64):
        super(basic, self).__init__()
        self.mean = torch.zeros(1, 3, 1, 1)
        self.std = torch.zeros(1, 3, 1, 1)
        self.mean[0, 0, 0, 0] = 0.485
        self.mean[0, 1, 0, 0] = 0.456
        self.mean[0, 2, 0, 0] = 0.406
        self.std[0, 0, 0, 0] = 0.229
        self.std[0, 1, 0, 0] = 0.224
        self.std[0, 2, 0, 0] = 0.225

        self.mean = nn.Parameter(self.mean)
        self.std = nn.Parameter(self.std)
        self.mean.requires_grad = False
        self.std.requires_grad = False

        self.head = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride = 2 ,padding=1), nn.ReLU(),
			nn.Conv2d(32, num_features, kernel_size=1, stride=1, padding=0), nn.ReLU()
        )
        self.body = nn.Sequential(
            DilatedResidualBlock(num_features, 1),
            DilatedResidualBlock(num_features, 1),
            DilatedResidualBlock(num_features, 2),
            DilatedResidualBlock(num_features, 2),
            DilatedResidualBlock(num_features, 4),
            DilatedResidualBlock(num_features, 8),
            DilatedResidualBlock(num_features, 4),
            DilatedResidualBlock(num_features, 2),
            DilatedResidualBlock(num_features, 2),
            DilatedResidualBlock(num_features, 1),
            DilatedResidualBlock(num_features, 1)
        )

        self.tail = nn.Sequential(
            #nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(num_features, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, 3, kernel_size=3, padding=1)
        )

        for m in self.modules():
            if isinstance(m, nn.ReLU):
                m.inplace = True

    def forward(self, x):
        x = (x - self.mean) / self.std

        f = self.head(x)
        f = self.body(f)
        r = self.tail(f)
        x = x + r

        x = (x * self.std + self.mean).clamp(min=0, max=1)
        return x



class depth_predciton(nn.Module):
    def __init__(self):
        super(depth_predciton, self).__init__()
        self.mean = torch.zeros(1, 3, 1, 1)
        self.std = torch.zeros(1, 3, 1, 1)
        self.mean[0, 0, 0, 0] = 0.485
        self.mean[0, 1, 0, 0] = 0.456
        self.mean[0, 2, 0, 0] = 0.406
        self.std[0, 0, 0, 0] = 0.229
        self.std[0, 1, 0, 0] = 0.224
        self.std[0, 2, 0, 0] = 0.225

        self.mean = nn.Parameter(self.mean)
        self.std = nn.Parameter(self.std)
        self.mean.requires_grad = False
        self.std.requires_grad = False


        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=32),
            nn.SELU(inplace=True)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=64),
            nn.SELU(inplace=True)
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=128),
            nn.SELU(inplace=True)
        )

        self.conv4 = nn.Sequential(
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=256),
            nn.SELU(inplace=True)
        )

        self.conv5 = nn.Sequential(
            nn.Conv2d(256, 256, 3, padding=2, dilation=2),
            nn.GroupNorm(num_groups=32, num_channels=256),
            nn.SELU(inplace=True)
        )

        self.conv6 = nn.Sequential(
            nn.Conv2d(256, 256, 3, padding=4, dilation=4),
            nn.GroupNorm(num_groups=32, num_channels=256),
            nn.SELU(inplace=True)
        )

        self.conv7 = nn.Sequential(
            nn.Conv2d(256, 256, 3, padding=2, dilation=2),
            nn.GroupNorm(num_groups=32, num_channels=256),
            nn.SELU(inplace=True)
        )

        self.conv8 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=128),
            nn.SELU(inplace=True)
        )

        self.conv9 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=64),
            nn.SELU(inplace=True)
        )

        self.conv10 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=32),
            nn.SELU(inplace=True)
        )

        self.depth_pred = nn.Sequential(
            nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=32),
            nn.SELU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.SELU(inplace=True),
            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
            nn.Sigmoid()
        )


    def forward(self, x):

        x = (x - self.mean) / self.std

        d_f1 = self.conv1(x)
        d_f2 = self.conv2(d_f1)
        d_f3 = self.conv3(d_f2)
        d_f4 = self.conv4(d_f3)
        d_f5 = self.conv5(d_f4)
        d_f6 = self.conv6(d_f5)
        d_f7 = self.conv7(d_f6)
        d_f8 = self.conv8(d_f7)
        d_f9 = self.conv9(d_f8+d_f3)
        d_f10 = self.conv10(d_f9+d_f2)
        depth_pred = self.depth_pred(d_f10+d_f1)


        return depth_pred

In [15]:
def main():
    net = DGNLNet().cuda().train()
    torch.cuda.empty_cache()

    optimizer = optim.Adam([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias' and param.requires_grad],
         'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias' and param.requires_grad],
         'lr': args['lr'], 'weight_decay': args['weight_decay']}
    ])

    if len(args['resume_snapshot']) > 0:
        print('training resumes from \'%s\'' % args['resume_snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['resume_snapshot'] + '.pth')))
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['resume_snapshot'] + '_optim.pth')))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(log_path, 'w').write(str(args) + '\n\n')
    train(net, optimizer)

In [16]:
def train(net, optimizer):
    curr_iter = args['last_iter']
    while True:
        train_loss_record = AvgMeter()
        train_net_loss_record = AvgMeter()
        train_depth_loss_record = AvgMeter()

        for i, data in enumerate(train_loader):
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num']
                                                                ) ** args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num']
                                                            ) ** args['lr_decay']

            inputs, gts, dps = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            gts = Variable(gts).cuda()
            dps = Variable(dps).cuda()

            optimizer.zero_grad()

            result, depth_pred = net(inputs)

            loss_net = criterion(result, gts)
            loss_depth = criterion_depth(depth_pred, dps)

            loss = loss_net + loss_depth

            loss.backward()

            optimizer.step()

            train_loss_record.update(loss.data, batch_size)
            train_net_loss_record.update(loss_net.data, batch_size)
            train_depth_loss_record.update(loss_depth.data, batch_size)

            curr_iter += 1

            log = '[iter %d], [train loss %.5f], [lr %.13f], [loss_net %.5f], [loss_depth %.5f]' % \
                  (curr_iter, train_loss_record.avg, optimizer.param_groups[1]['lr'],
                   train_net_loss_record.avg, train_depth_loss_record.avg)
            print(log)
            open(log_path, 'a').write(log + '\n')

            if (curr_iter + 1) % args['val_freq'] == 0:
                validate(net, curr_iter, optimizer)

            if (curr_iter + 1) % args['snapshot_epochs'] == 0:
                torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, ('%d_test.pth' % (curr_iter + 1) )))
                torch.save(optimizer.state_dict(), os.path.join(ckpt_path, exp_name, ('%d_optim_test.pth' % (curr_iter + 1) )))

            if curr_iter > args['iter_num']:
                return

In [17]:
def validate(net, curr_iter, optimizer):
    print('validating...')
    net.eval()

    loss_record1, loss_record2 = AvgMeter(), AvgMeter()
    iter_num1 = len(test1_loader)

    with torch.no_grad():
        for i, data in enumerate(test1_loader):
            inputs, gts, dps = data
            inputs = Variable(inputs).cuda()
            gts = Variable(gts).cuda()
            dps = Variable(dps).cuda()

            res = net(inputs)

            loss = criterion(res, gts)
            loss_record1.update(loss.data, inputs.size(0))

            print('processed test1 %d / %d' % (i + 1, iter_num1))


    snapshot_name = 'iter_%d_loss1_%.5f_loss2_%.5f_lr_%.6f' % (curr_iter + 1, loss_record1.avg, loss_record2.avg,
                                                               optimizer.param_groups[1]['lr'])
    print('[validate]: [iter %d], [loss1 %.5f], [loss2 %.5f]' % (curr_iter + 1, loss_record1.avg, loss_record2.avg))
    torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, snapshot_name + '.pth'))
    torch.save(optimizer.state_dict(), os.path.join(ckpt_path, exp_name, snapshot_name + '_optim.pth'))

    net.train()

In [None]:
main()

In [None]:
import os
import time
import sys
import torch
from PIL import Image
from torch.autograd import Variable
from torchvision import transforms
from pathlib import Path
# from nets import DGNLNet_fast
# from misc import check_mkdir
import matplotlib.pyplot as plt

ckpt = "/content/drive/MyDrive/ckpt/DGNLNet/22_test.pth"

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

torch.manual_seed(2019)
torch.cuda.set_device(0)

transform = transforms.Compose([
    transforms.Resize([512,1024]),
    transforms.ToTensor() ])

to_pil = transforms.ToPILImage()



if __name__ == '__main__':
    root = "/content"
    input = open(os.path.join(root, 'data/test_images.txt'))
    i = 0
    image = [(os.path.join(root, 'data/images', img_name.strip('\n'))) for img_name in
                 input]
    input.close()

    # image = image[1000:]

    for img in image:
    # img = "/content/drive/MyDrive/test images/test_carla_H.png"

      net = DGNLNet().cuda()

      net.load_state_dict(torch.load(ckpt,map_location=lambda storage,loc: storage.cuda(0)))

      net.eval()
      
      name = img.split("/")[-1]
      img = Image.open(Path(img))
      if len(img.getbands()) == 4:
        temp = np.asarray(img)
        temp = temp[:,:,:3]
        img = Image.fromarray(temp)
      # if isinstance(img,Image.Image):
      #     img = img.convert("RGB")
      # else:
      #     img = Image.open(Path(img))
      #     img = img.convert("RGB")
      # # plt.figure(figsize=(12,6))
      # plt.axis("off")
      # plt.imshow(img)
      with torch.no_grad():

          w, h = img.size
          img_var = Variable(transform(img).unsqueeze(0)).cuda()
          
          res = net(img_var)

          torch.cuda.synchronize()

          result = transforms.Resize((h, w))(to_pil(res.data.squeeze(0).cpu()))
          # fig = plt.figure(figsize=(16, 8))
          # fig.add_subplot(2,1,1)
          # plt.figure(figsize=(12,6))
          # plt.axis("off")
          # plt.imshow(img)
          # # fig.add_subplot(2,1,2)
          # plt.figure(figsize=(12,6))
          # plt.axis("off")
          # plt.imshow(result)
          i += 1
          print(i)
          result.save("/content/drive/MyDrive/generated_test/"+ name) 

In [20]:
import os
import time
import sys
import torch
from PIL import Image
from torch.autograd import Variable
from torchvision import transforms
from pathlib import Path
# from nets import DGNLNet_fast
# from misc import check_mkdir
import matplotlib.pyplot as plt

ckpt = "/content/drive/MyDrive/ckpt/DGNLNet/22_test.pth"

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

torch.manual_seed(2019)
torch.cuda.set_device(0)

transform = transforms.Compose([
    transforms.Resize([512,1024]),
    transforms.ToTensor() ])

to_pil = transforms.ToPILImage()



if __name__ == '__main__':
    root = "/content"
    
    
    img = "/content/drive/MyDrive/generated_test/1647760644.0353005_rain_M.png"

    # net = DGNLNet().cuda()

    # net.load_state_dict(torch.load(ckpt,map_location=lambda storage,loc: storage.cuda(0)))

    # net.eval()
    
    # name = img.split("/")[-1]

    if len(img.getbands()) == 4:
      temp = np.asarray(img)
      temp = temp[:,:,:3]
      img = Image.fromarray(temp)
    if isinstance(img,Image.Image):
        img = img.convert("RGB")
    else:
        img = Image.open(Path(img))
        img = img.convert("RGB")
    # # # plt.figure(figsize=(12,6))
    # # plt.axis("off")
    # # plt.imshow(img)
    # with torch.no_grad():

    #     w, h = img.size
    #     img_var = Variable(transform(img).unsqueeze(0)).cuda()
        
    #     res = net(img_var)

    #     torch.cuda.synchronize()

    #     result = transforms.Resize((h, w))(to_pil(res.data.squeeze(0).cpu()))
    #     fig = plt.figure(figsize=(16, 8))
    #     fig.add_subplot(2,1,1)
    #     plt.figure(figsize=(12,6))
    #     plt.axis("off")
    #     plt.imshow(img)
    #     # fig.add_subplot(2,1,2)
    #     plt.figure(figsize=(12,6))
    #     plt.axis("off")
    #     plt.imshow(result)

In [23]:
import cv2
import math
import numpy as np

def ssim(img1,img2):

    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255)**2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2

    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                            (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()

def psnr(img1,img2):

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    mse = np.mean((img1 - img2)**2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10(255.0 / math.sqrt(mse))

In [None]:
img1 = cv2.imread("/content/data/images/val/carla/1647760470.3255076_rain_H.png")
img2 = cv2.imread("/content/drive/MyDrive/generated_test/1647760470.3255076_rain_H.png")

psnr(img1,img2),ssim(img1,img2)

In [80]:
import os
import time
import sys
import torch
from PIL import Image
from torch.autograd import Variable
from torchvision import transforms
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
# from nets import DGNLNet

class RainRemoval:
    def __init__(self,model):
        self.model = model

    def infer(self,img,flag = True):

        os.environ['CUDA_VISIBLE_DEVICES'] = '0'

        torch.manual_seed(2019)
        torch.cuda.set_device(0)

        transform = transforms.Compose([
            transforms.Resize([512,1024]),
            transforms.ToTensor() ])

        to_pil = transforms.ToPILImage()    


        if type(img) == str:
          img = np.asarray(Image.open(img))
          
        net = DGNLNet().cuda()

        net.load_state_dict(torch.load(self.model,map_location=lambda storage,loc: storage.cuda(0)))

        net.eval()

        if img.shape[-1] == 4:
            img = img[:,:,:-1]

        self.img_infer = Image.fromarray(img)

        with torch.no_grad():

            w, h = self.img_infer.size
            img_var = Variable(transform(self.img_infer).unsqueeze(0)).cuda()
            
            res = net(img_var)

            torch.cuda.synchronize()

            self.result = transforms.Resize((h, w))(to_pil(res.data.squeeze(0).cpu()))

            self.result_np = np.array(self.result)
        
        if flag:
            return self.result_np
        else:
            return self.result
    
    def displayRes(self):

        fig = plt.figure(figsize=(14, 7))
        
        fig.add_subplot(1,2,1)
        plt.axis("off")
        plt.imshow(self.img_infer)
        
        fig.add_subplot(1,2,2)
        plt.axis("off")
        plt.imshow(self.result)