In [None]:
from perlin import rand_perlin_2d
from perlin import rand_perlin_2d_octaves

from os import listdir
from os.path import join
from random import uniform
from PIL import Image

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from torch.nn.init import orthogonal_
from torch.nn.init import constant_

from torchvision.transforms import RandomCrop
from torchvision.transforms import ToPILImage
from torchvision.transforms import ToTensor
from torchvision.transforms import PILToTensor

to_pil_image = ToPILImage()
to_tensor = ToTensor()
pil_to_tensor = PILToTensor()

torch.__version__

In [None]:
class DnCNN(nn.Module):
    def __init__(
        self,
        depth=17,
        n_channels=64,
        image_channels=1,
        kernel_size=3,
        padding=1,
    ):
        super(DnCNN, self).__init__()
        layers = []

        layers.append(
            nn.Conv2d(
                image_channels,
                n_channels,
                kernel_size=(kernel_size, kernel_size),
                stride=(1, 1),
                padding=(padding, padding),
                bias=True,
            )
        )
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth - 2):
            layers.append(
                nn.Conv2d(
                    n_channels,
                    n_channels,
                    kernel_size=(kernel_size, kernel_size),
                    stride=(1, 1),
                    padding=(padding, padding),
                    bias=True,
                )
            )
            layers.append(
                nn.BatchNorm2d(
                    n_channels,
                    eps=1e-05,
                    momentum=0.1,
                    affine=True,
                    track_running_stats=True,
                )
            )
            layers.append(nn.ReLU(inplace=True))
        layers.append(
            nn.Conv2d(
                n_channels,
                image_channels,
                kernel_size=(kernel_size, kernel_size),
                stride=(1, 1),
                padding=(padding, padding),
                bias=True,
            )
        )
        self.features = nn.Sequential(*layers)
        self._initialize_weights()

    def forward(self, x):
        return self.features(x)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                orthogonal_(m.weight)
                if m.bias is not None:
                    constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                constant_(m.weight, 1)
                constant_(m.bias, 0)


def get_noise(image, level=None):
    _, height, width = image.shape
    shape = (height, width)
    
    if level is None:
        level = uniform(-1, 1)
    
    if res is None:
        hght_res = choice([i for i in range(1, height + 1) if height % i == 0])
        wdth_res = choice([i for i in range(1, width + 1) if width % i == 0])
        res = (hght_res, wdth_res)
    
    black = torch.tensor(0.0, dtype=image.dtype).to(image.device)
    white = torch.tensor(1.0, dtype=image.dtype).to(image.device)

    noise = rand_perlin_2d(shape, res).to(image.dtype).to(image.device)
    noise = torch.where(noise < level, black, white)
    noise = torch.where(image == 0.0, white, noise)

    return noise


def get_observation(image, noise):
    noise = 1 - noise
    observation = image - noise
    return observation

In [None]:
# Get DnCNN.
mdl_path = 'models/model-513.pth'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = DnCNN(depth=30).eval().to(device)
model.load_state_dict(torch.load(mdl_path))

In [None]:
# Get a clean image, noisy observation, and pepper noise.
img_dir = 'datasets/test'
img_nm = 'TIP BRISQUE-01.png'
nis_lvl = -0.5
hght_res = 440
wdth_res = 425

img_path = join(img_dir, img_nm)
image = to_tensor(Image.open(img_path).convert('1'))
noise = get_noise(image, level=nis_lvl, res=(hght_res, wdth_res))
observation = get_observation(image, noise)

to_pil_image(observation)

In [None]:
# Get the denoised image and residual image.
with torch.no_grad():
    pred = model(observation.unsqueeze(0))
    pred = nn.Sigmoid()(pred)
    pred = pred.round().squeeze(0)

white = torch.tensor(1.0, dtype=image.dtype).to(image.device)
residual = observation + (1 - pred)
denoised = torch.where(observation == 0.0, residual, white)

to_pil_image(denoised)