<a href="https://colab.research.google.com/github/ravit-cohen-segev/ravit-cohen-segev/blob/main/DnCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import time
import tensorflow as tf
import torch
from torch import nn
import torch.utils.data as td
import torch.nn.functional as F

from torchvision import datasets, transforms
import torchvision.transforms.functional as TF
from abc import ABC, abstractmethod
from torch.utils.tensorboard import SummaryWriter

#additional packages needed for image pre-processing
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import gc

In [2]:
class NeuralNetwork(nn.Module, ABC):
    """An abstract class representing a neural network.
    All other neural network should subclass it. All subclasses should override
    ``forward``, that makes a prediction for its input argument, and
    ``criterion``, that evaluates the fit between a prediction and a desired
    output. This class inherits from ``nn.Module`` and overloads the method
    ``named_parameters`` such that only parameters that require gradient
    computation are returned. Unlike ``nn.Module``, it also provides a property
    ``device`` that returns the current device in which the network is stored
    (assuming all network parameters are stored on the same device).
    """

    def __init__(self):
        super(NeuralNetwork, self).__init__()

    @property
    def device(self):
        # This is important that this is a property and not an attribute as the
        # device may change anytime if the user do ``net.to(newdevice)``.
        return next(self.parameters()).device

    def named_parameters(self, recurse=True):
        nps = nn.Module.named_parameters(self)
        for name, param in nps:
            if not param.requires_grad:
                continue
            yield name, param

    @abstractmethod
    def forward(self, x):
        pass

    @abstractmethod
    def criterion(self, y, d):
        pass

In [3]:
class ResidualDenoiser(NeuralNetwork):
  def __init__(self, in_channels=3 , num_features=64, expansions=18):
        super().__init__()
        # build the network consisting of:
        # zero padding to input
        # first layer - conv2d with filters (3,3,1)
        # create block template for residual blocks: conv2d -> BatchNormaliztion ->relu
        # last layer - conv2d layer
        all_layers = []

        first_layer = [nn.Conv2d(in_channels=in_channels, out_channels=num_features, kernel_size=(3,3), padding=1),
                                         nn.ReLU()]

        block = [nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=(3,3), padding=1),
                                   nn.BatchNorm2d(num_features),
                                   nn.ReLU()]
        bottleneck = []
        for i in range(expansions):
            bottleneck += block

        last_layer = [nn.Conv2d(in_channels=num_features, out_channels=in_channels, kernel_size=(3,3), padding=1)]

        #add all layers in order to all_list and then convert to Sequential
        all_layers += first_layer
        all_layers += bottleneck
        all_layers += last_layer

        self.model = nn.Sequential(*all_layers)

  def forward(self, x):
     x = self.model(x)
     return x

  def criterion(self, y, d):
     loss = torch.nn.MSELoss()
     return loss(y,d)

In [27]:
class Experiment(object):
    """
    A class meant to run a neural network learning experiment.
    After being instantiated, the experiment can be run using the method
    ``run``. At each epoch, a checkpoint file will be created in the directory
    ``output_dir``. Two files will be present: ``checkpoint.pth.tar`` a binary
    file containing the state of the experiment, and ``config.txt`` an ASCII
    file describing the setting of the experiment. If ``output_dir`` does not
    exist, it will be created. Otherwise, the last checkpoint will be loaded,
    except if the setting does not match (in that case an exception will be
    raised). The loaded experiment will be continued from where it stopped when
    calling the method ``run``. The experiment can be evaluated using the method
    ``evaluate``.

    Attributes/Properties:
        epoch (integer): the number of performed epochs.
        history (list): a list of statistics for each epoch.

    Arguments:
        net (NeuralNetork): a neural network.
        train_set (Dataset): a training data set.
        val_set (Dataset): a validation data set.
        output_dir (string, optional): path where to load/save checkpoints. If
            None, ``output_dir`` is set to "experiment_TIMESTAMP" where
            TIMESTAMP is the current time stamp as returned by ``time.time()``.
            (default: None)
        batch_size (integer, optional): the size of the mini batches.
            (default: 16)
        perform_validation_during_training (boolean, optional): if False,
            statistics at each epoch are computed on the training set only.
            If True, statistics at each epoch are computed on both the training
            set and the validation set. (default: False)
    """

    def __init__(self, net, train_set, val_set, optimizer,
                 output_dir=None, batch_size=1, perform_validation_during_training=True):

        self.train_loader = train_set
        self.val_loader = val_set

        # Initialize epochs
        self.epoch = 0

        # Define checkpoint paths
        if output_dir is None:
            output_dir = 'experiment_{}'.format(time.time())
        os.makedirs(output_dir, exist_ok=True)
        checkpoint_path = os.path.join(output_dir, "checkpoint.pth.tar")
        config_path = os.path.join(output_dir, "config.txt")

        # Transfer all local arguments/variables into attributes
        locs = {k: v for k, v in locals().items() if k is not 'self'}
        self.__dict__.update(locs)

        # Load checkpoint and check compatibility
        if os.path.isfile(config_path):
            with open(config_path, 'r') as f:
                if f.read()[:-1] != repr(self):
                    raise ValueError(
                        "Cannot create this experiment: "
                        "I found a checkpoint conflicting with the current setting.")
            self.load()
        else:
            self.save()

    def setting(self):
        """Returns the setting of the experiment."""
        return {'Net': self.net.double(),
                'TrainSet': self.train_set,
                'ValSet': self.val_set,
                'Optimizer': self.optimizer,
                'BatchSize': self.batch_size,
                'PerformValidationDuringTraining': self.perform_validation_during_training}

    def __repr__(self):
        """Pretty printer showing the setting of the experiment. This is what
        is displayed when doing ``print(experiment)``. This is also what is
        saved in the ``config.txt`` file.
        """
        string = ''
        for key, val in self.setting().items():
            string += '{}({})\n'.format(key, val)
        return string

    def state_dict(self):
        """Returns the current state of the experiment."""
        return {'Net': self.net.state_dict(),
                'Optimizer': self.optimizer.state_dict(),
                'Epoch': self.epoch}

    def load_state_dict(self, checkpoint):
        """Loads the experiment from the input checkpoint."""
        self.net.load_state_dict(checkpoint['Net'])
        self.optimizer.load_state_dict(checkpoint['Optimizer'])
        self.epoch = checkpoint['Epoch']

        # The following loops are used to fix a bug that was
        # discussed here: https://github.com/pytorch/pytorch/issues/2830
        # (it is supposed to be fixed in recent PyTorch version)
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(self.net.device)

    def save(self):
        """Saves the experiment on disk, i.e, create/update the last checkpoint."""
        torch.save(self.state_dict(), self.checkpoint_path)
        with open(self.config_path, 'w') as f:
            print(self, file=f)

    def load(self):
        """Loads the experiment from the last checkpoint saved on disk."""
        checkpoint = torch.load(self.checkpoint_path,
                                map_location=self.net.device)
        self.load_state_dict(checkpoint)
        del checkpoint

    def run(self, num_epochs):
        """Runs the experiment, i.e., trains the network using backpropagation
        based on the optimizer and the training set. Adds loss and epoch to Tensorboard dashboards"

        Arguments:
            num_epoch (integer): the number of epoch to perform.
        """
        self.net.train()
        start_epoch = self.epoch
        print("Start/Continue training from epoch {}".format(start_epoch))

        #For Tensorboard
        writer = SummaryWriter('logs/Tensorboard_exp')
        for epoch in range(start_epoch, num_epochs):
            s = time.time()
            for x, d in self.train_loader:
                x,d =x.type(torch.DoubleTensor), d.type(torch.DoubleTensor)
                x, d = x.to(self.net.device), d.to(self.net.device)
                self.optimizer.zero_grad()
                y = self.net.forward(x)
                loss = self.net.criterion(y, d)
                writer.add_scalar("Loss/train", loss, epoch)
                loss.backward()
                self.optimizer.step()
            #perform validation
            self.evaluate(writer)
            print("Epoch {} (Time: {:.2f}s)".format(
                self.epoch, time.time() - s))
            self.epoch +=1
            self.save()
        writer.flush()
        print("Finish training for {} epochs".format(num_epochs))

    def evaluate(self, writer):
        """Evaluates the experiment, i.e., forward propagates the validation set
        through the network and add loss and epoch to Tensoboard"
        """
        self.net.eval()
      
        with torch.no_grad():
            for x, d in self.val_loader:
                x,d =x.type(torch.DoubleTensor), d.type(torch.DoubleTensor)
                x, d = x.to(self.net.device), d.to(self.net.device)
                y = self.net.forward(x)
                loss = self.net.criterion(y, d)
                writer.add_scalar("Loss/val", loss, self.epoch)
        self.net.train()
        return 

In [5]:
def GaussianNoise(arr, mean=0., std=0.1):
    'adds gaussian noise to array'
    return np.random.normal(mean, std, arr.shape)

In [6]:
class Custom_Dataset():  
    def __init__(self, path):
       self.path = path   
       self.train_noisy, self.train_residual = self.__getitem__('train')
       self.val_noisy, self.val_residual = self.__getitem__('val')
       self.max_height = max([img.shape[1] for img in [*self.train_noisy, *self.val_noisy]])
       self.max_width = max([img.shape[2] for img in [*self.train_noisy, *self.val_noisy]])
       
       #pad the images
       self.train_batch = list(zip(self.__padding__(self.train_noisy), self.__padding__(self.train_residual)))
       self.val_batch = list(zip(self.__padding__(self.val_noisy), self.__padding__(self.val_residual)))
           
    def __getitem__(self, folder_name):
        full_path = os.path.join(self.path, folder_name)
        file_names = os.listdir(full_path)
        noisy_images = [] #input
        residual_noise = [] #output
        
        for file in file_names:
            file_path = os.path.join(full_path, file)
            image = Image.open(file_path)
            
            #convert to numpy
            image = np.asarray(image)
            
            #images.append(TF.to_tensor(image))
            gaus_noise = GaussianNoise(image)   
            
            residual_noise.append(transforms.ToTensor()(gaus_noise))
            noisy_image = image + gaus_noise
            noisy_images.append(transforms.ToTensor()(noisy_image))
        return noisy_images, residual_noise

    def __padding__(self, batch):               
        image_batch = [
        # The needed padding is the difference between the
        # max width/height and the image's actual width/height.
        F.pad(img, [0, self.max_width - img.shape[2], 0, self.max_height - img.shape[1]])
        for img in batch]
        return image_batch


In [7]:
def remove_noise(noisy_image, noise):
  '''input: two tensors. noisy image tensor and noise tensor
  plots denoised image'''
  #remove noise and reduce output tensor to 3 dimensions
  image_tensor = torch.sub(noisy_image, noise).squeeze()
  
  #plot image
  plt.imshow(transforms.ToPILImage()(image_tensor))

#Load data

In [8]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [9]:
#unzip data folder
!unzip /content/smaller_dataset.zip

Archive:  /content/smaller_dataset.zip
   creating: smaller_dataset/
   creating: smaller_dataset/train/
  inflating: smaller_dataset/train/100075.jpg  
  inflating: smaller_dataset/train/100080.jpg  
  inflating: smaller_dataset/train/100098.jpg  
  inflating: smaller_dataset/train/103041.jpg  
  inflating: smaller_dataset/train/104022.jpg  
  inflating: smaller_dataset/train/105019.jpg  
  inflating: smaller_dataset/train/105053.jpg  
  inflating: smaller_dataset/train/106020.jpg  
  inflating: smaller_dataset/train/106025.jpg  
  inflating: smaller_dataset/train/108041.jpg  
  inflating: smaller_dataset/train/108073.jpg  
  inflating: smaller_dataset/train/109034.jpg  
  inflating: smaller_dataset/train/112082.jpg  
  inflating: smaller_dataset/train/113009.jpg  
  inflating: smaller_dataset/train/113016.jpg  
  inflating: smaller_dataset/train/113044.jpg  
  inflating: smaller_dataset/train/117054.jpg  
  inflating: smaller_dataset/train/118020.jpg  
  inflating: smaller_dataset/tr

In [10]:
path = '/content/smaller_dataset'


In [29]:
net = ResidualDenoiser()
params = net.parameters()
net.to('cuda')

ResidualDenoiser(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU()
    (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU()
    (14): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [20]:
custom_data = Custom_Dataset(path)
train_data, val_data = custom_data.train_batch, custom_data.val_batch 

In [25]:
train_set = td.DataLoader(train_data, batch_size=1, shuffle=True)
val_set = td.DataLoader(val_data, batch_size=1, shuffle=True)

In [30]:
exp = Experiment(net, train_set=train_set, val_set=val_set, optimizer= torch.optim.SGD(params=params, lr=0.003))

In [31]:
gc.collect()
del train_data
del val_data
del train_set
del val_set

In [None]:
exp.run(num_epochs=25)

Start/Continue training from epoch 0
Epoch 0 (Time: 486.08s)
Epoch 1 (Time: 484.70s)
Epoch 2 (Time: 486.42s)
Epoch 3 (Time: 486.54s)
Epoch 4 (Time: 486.52s)
Epoch 5 (Time: 486.43s)


In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [None]:
%tensorboard --logdir '/content/logs'

In [None]:
custom_data = Custom_Dataset(path)
train_data, val_data = custom_data.train_batch, custom_data.val_batch 

val_set = td.DataLoader(val_data, batch_size=1, shuffle=True)

In [None]:
it=iter(val_set)
ex_im = next(it)
squeezed_im = np.squeeze(ex_im[0])
residual_im = np.squeeze(ex_im[1])

In [None]:
#show performance on one of the images in the validation set 
plt.imshow(transforms.ToPILImage()(squeezed_im))

In [None]:
plt.imshow(transforms.ToPILImage()(residual_im))

In [None]:
ex_image = ex_im[0].to('cuda')
res_image = exp.net.model(ex_image)

In [None]:
#remove noise for image and display results
remove_noise(ex_image, res_image)
