<a href="https://colab.research.google.com/github/vfrantc/deweather/blob/main/star_decomposition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install python-box

In [None]:
from google.cloud import drive
drive.mount('drive')

In [None]:
!cp /content/drive/MyDrive/deweather2/split_star.zip .
!unzip split_star.zip

In [None]:
import os
import time
import random

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
from tqdm.notebook import tqdm
from glob import glob
from PIL import Image
from box import Box

In [None]:
opt = Box({'epochs': 100,
           'batch_size': 16,
           'patch_size': 96,
           'lr': 0.001})

In [None]:
class DecomNet(nn.Module):
    def __init__(self, channel=64, kernel_size=3):
        super(DecomNet, self).__init__()
        # Shallow feature extraction
        self.net1_conv0 = nn.Conv2d(6, channel, kernel_size * 3,  padding=4, padding_mode='replicate')
        # Activated layers!
        self.net1_convs = nn.Sequential(nn.Conv2d(channel, channel, kernel_size, padding=1, padding_mode='replicate'),
                                        nn.ReLU(),
                                        nn.Conv2d(channel, channel, kernel_size, padding=1, padding_mode='replicate'),
                                        nn.ReLU(),
                                        nn.Conv2d(channel, channel, kernel_size, padding=1, padding_mode='replicate'),
                                        nn.ReLU(),
                                        nn.Conv2d(channel, channel, kernel_size, padding=1, padding_mode='replicate'),
                                        nn.ReLU(),
                                        nn.Conv2d(channel, channel, kernel_size, padding=1, padding_mode='replicate'),
                                        nn.ReLU())
        # Final recon layer
        self.net1_recon = nn.Conv2d(channel, 6, kernel_size, padding=1, padding_mode='replicate')

    def forward(self, input_im):
        input_img = torch.cat((input_im, input_im), dim=1)
        feats0   = self.net1_conv0(input_img)
        featss   = self.net1_convs(feats0)
        outs     = self.net1_recon(featss)
        R        = torch.sigmoid(outs[:, 0:3, :, :])
        L        = torch.sigmoid(outs[:, 3:6, :, :])
        return R, L

In [None]:
net = DecomNet()
net = net.cuda()
lr = opt.lr * np.ones([opt.epochs])
lr[20:] = lr[0] / 10.0

train_input_data_names = glob('./input/*.png')
train_input_data_names.sort()
train_slow_data_names = glob('./slow/*.png')
train_slow_data_names.sort()
train_fast_data_names = glob('./fast/*.png')
train_fast_data_names.sort()

train_op = optim.Adam(self.net.parameters(), lr=lr[0], betas=(0.9, 0.999))

In [None]:
image_id   = 0
numBatch = len(train_input_data_names) // int(opt.batch_size)
start_time = time.time()
for epoch in range(0, epoch):
    self.lr = lr[epoch]

    # Adjust learning rate
    for param_group in self.train_op.param_groups:
        param_group['lr'] = self.lr

    for batch_id in range(0, numBatch):
        # Generate training data for a batch
        batch_input = np.zeros((opt.batch_size, 3, opt.patch_size, opt.patch_size,), dtype="float32")
        batch_slow = np.zeros((opt.batch_size, 3, opt.patch_size, opt.patch_size,), dtype="float32")
        batch_fast = np.zeros((opt.batch_size, 3, opt.patch_size, opt.patch_size,), dtype="float32")

        for patch_id in range(opt.batch_size):
            # Load images
            train_input_img = Image.open(train_input_data_names[image_id])
            train_input_img = np.array(train_input_img, dtype='float32')/255.0
            train_slow_img= Image.open(train_slow_data_names[image_id])
            train_slow_img= np.array(train_slow_img, dtype='float32')/255.0
            train_fast_img= Image.open(train_fast_data_names[image_id])
            train_fast_img= np.array(train_fast_img, dtype='float32')/255.0

            # Take random crops
            h, w, _        = train_input_img.shape
            x = random.randint(0, h - opt.patch_size)
            y = random.randint(0, w - opt.patch_size)
            train_input_img = train_input_img[x: x + opt.patch_size, y: y + opt.patch_size, :]
            train_slow_img= train_slow_img[x: x + opt.patch_size, y: y + opt.patch_size, :]
            train_fast_img= train_fast_img[x: x + opt.patch_size, y: y + opt.patch_size, :]

            # Data augmentation
            if random.random() < 0.5:
                train_input_img = np.flipud(train_input_img)
                train_slow_img= np.flipud(train_slow_img)
                train_fast_img= np.flipud(train_fast_img)
            if random.random() < 0.5:
                train_input_img = np.fliplr(train_input_img)
                train_slow_img= np.fliplr(train_slow_img)
                train_fast_img= np.fliplr(train_fast_img)
            rot_type = random.randint(1, 4)
            if random.random() < 0.5:
                train_input_img = np.rot90(train_input_img, rot_type)
                train_slow_img= np.rot90(train_slow_img, rot_type)
                train_fast_img= np.rot90(train_fast_img, rot_type)
            
            # Permute the images to tensor format
            train_input_img = np.transpose(train_input_img, (2, 0, 1))
            train_slow_img= np.transpose(train_slow_img, (2, 0, 1))
            train_fast_img= np.transpose(train_fast_img, (2, 0, 1))
            
            # Prepare the batch
            batch_input[patch_id, :, :, :] = train_input_img
            batch_slow[patch_id, :, :, :]= train_slow_img
            batch_fast[patch_id, :, :, :]= train_fast_img

            image_id = (image_id + 1) % len(train_input_data_names)
            if image_id == 0:
                tmp = list(zip(train_input_data_names, train_slow_data_names, train_fast_data_names))
                random.shuffle(list(tmp))
                train_input_data_names, train_slow_data_names, train_fast_data_names = zip(*tmp)

        input = Variable(torch.FloatTensor(torch.from_numpy(batch_input))).cuda()
        target_slow = Variable(torch.FloatTensor(torch.from_numpy(batch_slow))).cuda()
        target_fast = Variable(torch.FloatTensor(torch.from_numpy(batch_fast))).cuda()

        out_fast, out_slow = net(input)
        train_op.zero_grad()
        loss = F.l1_loss(out_fast*out_slow,  input)
             + F.l1_loss(out_slow, target_slow) 
             + F.l1_loss(out_fast, target_fast)
        loss.backward()
        train_op.step()

        print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" % (epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss.item()))
        iter_num += 1

print("Finished training...")

# Text model

In [None]:
    def save(self, iter_num, ckpt_dir):
        save_dir = ckpt_dir + '/' + self.train_phase + '/'
        save_name= save_dir + '/' + str(iter_num) + '.tar'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        if self.train_phase == 'Decom':
            torch.save(self.DecomNet.state_dict(), save_name)
        elif self.train_phase == 'Relight':
            torch.save(self.RelightNet.state_dict(),save_name)

In [None]:
def get_decom(trainable=True):
  net = DecomNet().cuda()
  ckpt_dict  = torch.load('ckpts/Decom/9200.tar') # , map_location=torch.device('cpu')
  net.load_state_dict(ckpt_dict)
  for p in net.parameters():
      p.requires_grad = trainable
  return net

In [None]:
def decom_image(image):
  test_low_img   = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
  test_low_img   = np.transpose(test_low_img, (2, 0, 1))
  input_low_test = np.expand_dims(test_low_img, axis=0)
  input_low_test = Variable(torch.FloatTensor(torch.from_numpy(input_low_test))).cuda()
  R_low, I_low   = net(input_low_test)
  R_low = np.clip(np.transpose(R_low.cpu().detach().numpy().squeeze(), (1, 2, 0)), 0, 1)
  I_low = np.clip(I_low.cpu().detach().numpy().squeeze(), 0, 1)
  return R_low, I_low

In [None]:
net = get_decom()

In [None]:
!cp /content/drive/MyDrive/deweather2/input.zip .
!unzip input.zip 

In [None]:
FNAME = 'input/input/010.png'
dehazed_image = cv2.imread(FNAME)
reflectance, illumination = decom_image(dehazed_image)

fig, axs = plt.subplots(2, figsize=(16, 8))
axs[0].imshow(reflectance)
axs[1].imshow(illumination, cmap='gray')