<a href="https://colab.research.google.com/github/zahrabayramli/Self-guided-network-for-fast-image-denoising/blob/master/SGN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
from google.colab import drive
drive.mount('/gdrive')

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [None]:
from tqdm import tqdm
from pathlib import Path
import numpy as np
from numpy import random
import time
import math
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.io
import torchvision.transforms.functional as F2

In [None]:
class SubNetwork(nn.Module):
    def __init__(self, g, m, c_in, c_k, c_out, sub_network_type):
        super(SubNetwork, self).__init__()

        self.act = nn.ReLU()
        self.conv1 = nn.Conv2d(c_in, c_k, kernel_size=3, stride=1, padding=1, bias=False)
        
        if sub_network_type != 'TOP':
          self.conv2 = nn.Conv2d(c_k // 2 * 3, c_k, kernel_size=3, stride=1, padding=1, bias=False)

        if sub_network_type != 'BOTTOM':
          self.res_block = nn.ModuleList()
          for i in range(g-1):
            self.res_block.append(nn.Conv2d(c_k, c_k, kernel_size=3, stride=1, padding=1, bias=False))
            self.res_block.append(self.act)
          self.res_block.append(nn.Conv2d(c_k, c_k, kernel_size=3, stride=1, padding=1, bias=False))
          
          self.conv_last = nn.ModuleList([nn.Conv2d(c_k, c_k, kernel_size=3, stride=1, padding=1, bias=False)])
          self.conv_last.append(self.act)
          
        else:
          self.res_block = None
          self.conv_last = nn.ModuleList()
          for i in range(m):
            self.conv_last.append(nn.Conv2d(c_k, c_k, kernel_size=3, stride=1, padding=1, bias=False))
            self.conv_last.append(self.act)

          self.conv_last.append(nn.Conv2d(c_k, c_out, kernel_size=3, stride=1, padding=1, bias=False))

    def forward(self, x, upper_features=None):
      x = self.act(self.conv1(x))

      if upper_features != None:
        x = torch.concat((x, F.pixel_shuffle(upper_features, 2)), 1)
        x = self.act(self.conv2(x))

      if self.res_block != None:
        y = x
        for l in self.res_block:
          y = l(y)
        x = self.act(torch.add(x, y))

      for l in self.conv_last:
        x = l(x)

      return x

class SGN(nn.Module):
    def __init__(self, args):
        super(SGN, self).__init__()
        
        c_in = args.in_channels
        c_0 = args.start_channels
        c_out = args.out_channels
        m = args.m_block
        g = args.g
        self.K = args.K

        self.bottom = SubNetwork(g, m, c_in, c_0, c_out, 'BOTTOM')

        c_k = c_0
        self.middle = nn.ModuleList()
        for i in range(self.K):
          c_k *= 2
          c_in *= 4
          self.middle.append(SubNetwork(g, m, c_in, c_k, c_out, 'MIDDLE'))

        c_k *= 2
        c_in *= 4
        self.top = SubNetwork(g, m, c_in, c_k, c_out, 'TOP')

    def forward(self, x):
      l = [x]
      for i in range(self.K + 1):
        l.append(F.pixel_unshuffle(l[-1], 2))
      
      upper_features = self.top(l[-1])
      i = -2
      for middle in reversed(self.middle):
        upper_features = middle(l[i], upper_features)
        i -= 1
      
      x = self.bottom(x, upper_features)

      return x

In [None]:
# Configurations & Hyper-parameters

from easydict import EasyDict as edict

# set manual seeds 
torch.manual_seed(470)
torch.cuda.manual_seed(470)

args = edict()

# basic options
args.root = '/gdrive/My Drive/CS492I/project' 
args.name = 'main'                   # experiment name.
args.ckpt_dir = 'ckpts'              # checkpoint directory name.
args.ckpt_iter = 1000                # how frequently checkpoints are saved.
args.ckpt_reload = 'best'            # which checkpoint to re-load.
args.gpu = True                      # whether or not to use gpu. 

# network options
args.in_channels = 3
args.out_channels = 3
args.start_channels = 32
args.m_block = 2
args.g = 3
args.K = 2

# data options
args.train_data_root = '/DIV2K_train_HR'
args.test_data_root = '/DIV2K_valid_HR'
args.crop_size = 256
args.mu = 0.0
args.sigma = 30.0

# training options
args.epoch = 100                     # training epoch.
args.batch_size = 8                  # number of mini-batch size.
args.lr = 0.0001                     # learning rate.
args.betas = (0.9, 0.999)
args.weight_decay = 0

# tensorboard options
args.tensorboard = True              # whether or not to use tensorboard logging.
args.log_dir = 'logs'                # to which tensorboard logs will be saved.
args.log_iter = 10                   # how frequently logs are saved.

In [None]:
# Basic settings
device = 'cuda' if torch.cuda.is_available() and args.gpu else 'cpu'

result_dir = Path(args.root) / 'results'
result_dir.mkdir(parents=True, exist_ok=True)

global_step = 0
best_psnr = 0.

In [None]:
# Define train/test data loaders  
# Use data augmentation in training set to mitigate overfitting.

class DatasetClass(Dataset):
    def __init__(self, is_train=True):
        self.is_train = is_train
        if is_train:
            path = args.root + args.train_data_root
            self.cropper = transforms.RandomCrop(args.crop_size)
            self.flipper = transforms.RandomHorizontalFlip(0.5)
        else:
            path = args.root + args.test_data_root
            self.resizer = transforms.Resize((args.crop_size, args.crop_size))

        self.imglist = []
        for root, dirs, files in os.walk(path):
            for filespath in files:
                self.imglist.append(os.path.join(root, filespath))        

    def __getitem__(self, index):
        img = torchvision.io.read_image(self.imglist[index])

        if self.is_train:
            img = self.cropper(img)
            img = self.flipper(img)
            img = F2.rotate(img, 90.0 * random.randint(0, 4))
        else:
            img = self.resizer(img)

        noise = torch.normal(torch.full(img.shape, args.mu), torch.full(img.shape, args.sigma))
        noisy_img = img + noise
        noisy_img = torch.where(noisy_img < 0.0, 0.0, noisy_img)
        noisy_img = torch.where(noisy_img > 256.0, 256.0, noisy_img)
        
        img = (img - 128) / 128
        noisy_img = (noisy_img - 128) / 128

        return noisy_img, img

    def __len__(self):
        return len(self.imglist)

train_dataset = DatasetClass(is_train=True)
test_dataset = DatasetClass(is_train=False)

train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False)

In [None]:
# Setup tensorboard.
if args.tensorboard:
    from torch.utils.tensorboard import SummaryWriter 
    %load_ext tensorboard
    %tensorboard --logdir "/gdrive/My Drive/{str(result_dir).replace('/gdrive/My Drive/', '')}"
else:
    writer = None

In [None]:
def train_net(net, optimizer, scheduler, writer):
    global_step = 0
    best_psnr = 0

    for epoch in range(args.epoch):
        # Here starts the train loop.
        net.train()
        for batch_idx, (x, y) in enumerate(train_dataloader):

            global_step += 1

            #  Send `x` and `y` to either cpu or gpu using `device` variable. 
            x = x.to(device=device)
            y = y.to(device=device)

            # Feed `x` into the network, get an output, and keep it in a variable called `logit`. 
            y_pred = net(x)

            # Compute loss using `y_pred` and `y`, and keep it in a variable called `loss`.
            loss = nn.MSELoss()(y_pred, y)

            # flush out the previously computed gradient.
            optimizer.zero_grad()

            # backward the computed loss. 
            loss.backward()

            # update the network weights. 
            optimizer.step()

            if global_step % args.log_iter == 0 and writer is not None:
                # Log loss and accuracy values using `writer`. Use `global_step` as a timestamp for the log. 
                writer.add_scalar('train_loss', loss, global_step)
                #writer.add_scalar('train_psnr', psnr, global_step)

            if global_step % args.ckpt_iter == 0: 
                # Save network weights in the directory specified by `ckpt_dir` directory. 
                torch.save(net.state_dict(), f'{ckpt_dir}/{global_step}.pt')

        # Here starts the test loop.
        net.eval() 
        with torch.no_grad():
            test_psnr = 0.
            test_loss = 0.
            test_num_data = 0.
            for batch_idx, (x, y) in enumerate(test_dataloader):
                # Send `x` and `y` to either cpu or gpu using `device` variable..
                x = x.to(device=device)
                y = y.to(device=device)

                y_pred = net(x)
                loss = nn.MSELoss()(y_pred, y)

                test_loss += loss.item() * x.shape[0]

                for a, b in zip(y_pred, y):                    
                    aa = torch.mul(torch.add(a, 1), 128)
                    bb = torch.mul(torch.add(b, 1), 128)
                    test_psnr += 20 * math.log10(255) - 10 * math.log10(F.mse_loss(aa, bb).item())

                test_num_data += x.shape[0]

            test_loss /= test_num_data
            test_psnr /= test_num_data

            if writer is not None: 
                # Log loss and accuracy values using `writer`. Use `global_step` as a timestamp for the log. 
                writer.add_scalar('test_loss', test_loss, global_step)
                writer.add_scalar('test_psnr', test_psnr, global_step)

                # Just for checking progress
                print(f'Test result of epoch {epoch}/{args.epoch} || loss : {test_loss:.3f} psnr : {test_psnr:.3f} ')

                writer.flush()

            # Whenever `test_accuracy` is greater than `best_accuracy`, save network weights with the filename 'best.pt' in the directory specified by `ckpt_dir`.
            if test_psnr > best_psnr:
                best_psnr = test_psnr
                torch.save(net.state_dict(), f'{ckpt_dir}/best.pt')
    
        scheduler.step()
    return best_psnr

In [None]:
# Function for weight initialization.
def weight_init(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        torch.nn.init.constant_(m.weight, 1)
        torch.nn.init.constant_(m.bias, 0)

In [None]:
# Create directory name.
num_trial=0
parent_dir = result_dir / f'trial_{num_trial}'
while parent_dir.is_dir():
    num_trial = int(parent_dir.name.replace('trial_',''))
    parent_dir = result_dir / f'trial_{num_trial+1}'
print(f'Logs and ckpts will be saved in : {parent_dir}')

# Define network
net = SGN(args).to(device)
net.apply(weight_init)

Logs and ckpts will be saved in : /gdrive/My Drive/CS492I/project/results/trial_10


SGN(
  (bottom): SubNetwork(
    (act): ReLU()
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (conv2): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (conv_last): ModuleList(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): ReLU()
      (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): ReLU()
      (4): Conv2d(32, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
  )
  (middle): ModuleList(
    (0): SubNetwork(
      (act): ReLU()
      (conv1): Conv2d(12, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (res_block): ModuleList(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): ReLU()
        (2): Conv2d(64, 64, kernel_siz

In [None]:
final_psnr = 0

# Start training
try:
    optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5000], gamma=0.1)
    
    # Create directories for logs and ckechpoints.
    ckpt_dir = parent_dir / args.ckpt_dir
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    log_dir = parent_dir / args.log_dir
    log_dir.mkdir(parents=True, exist_ok=True)

    # Create tensorboard writer,
    if args.tensorboard: 
        writer = SummaryWriter(log_dir)

    # Call the train & test function.
    t1 = time.time()
    final_psnr = train_net(net, optimizer, scheduler, writer)
    t = time.time() - t1
    print(f'Best psnr: {final_psnr:.3f} took {t:.3f} secs')
except Exception as e:
    print(e)

# Print final best accuracies of the model.
print(f'Best psnr = {final_psnr:.2f}')

In [None]:
# Aggregating experimental results

ckpt_dir = parent_dir / args.ckpt_dir

# load weights from best checkpoints.
ckpt_path = f'{ckpt_dir}/best.pt'
try:
    net.load_state_dict(torch.load(ckpt_path))
except Exception as e:
    print(e)

# Measure test performance.
net.eval() 
with torch.no_grad():
    test_psnr = 0.
    test_num_data = 0.
    for batch_idx, (x, y) in enumerate(test_dataloader):
        # Send `x` and `y` to either cpu or gpu using `device` variable..
        x = x.to(device=device)
        y = y.to(device=device)

        y_pred = net(x)
        loss = nn.MSELoss()(y_pred, y)

        for a, b in zip(y_pred, y):                    
            aa = torch.mul(torch.add(a, 1), 128)
            bb = torch.mul(torch.add(b, 1), 128)
            test_psnr += 20 * math.log10(255) - 10 * math.log10(F.mse_loss(aa, bb).item())

        test_num_data += x.shape[0]

test_psnr /= test_num_data

In [None]:
# Printing final results.
print(test_psnr)