In [None]:
from perlin import rand_perlin_2d
from perlin import rand_perlin_2d_octaves

from os import listdir
from os.path import join
from random import uniform
from PIL import Image

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from torch.nn.init import orthogonal_
from torch.nn.init import constant_

from torchvision.transforms import RandomCrop
from torchvision.transforms import ToPILImage
from torchvision.transforms import ToTensor
from torchvision.transforms import PILToTensor

to_pil_image = ToPILImage()
to_tensor = ToTensor()
pil_to_tensor = PILToTensor()

torch.__version__

In [None]:
class DenoisingDataset(Dataset):
    def __init__(self, img_dir, ptch_sz=40, transform=None, target_transform=None):
        self.img_dir = img_dir
        self.img_nms = listdir(img_dir)
        self.cropper = RandomCrop(size=(ptch_sz, ptch_sz))
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_nms)

    def __getitem__(self, idx):
        img_path = join(self.img_dir, self.img_nms[idx])
        image = Image.open(img_path).convert('1')
        image = self.cropper(image)
        image = to_tensor(image)
        noise = get_noise(image)
        observation = get_observation(image, noise)
        if self.transform:
            observation = self.transform(observation)
        if self.target_transform:
            noise = self.target_transform(noise)
        return observation, noise


class DnCNN(nn.Module):
    def __init__(
        self,
        depth=17,
        n_channels=64,
        image_channels=1,
        kernel_size=3,
        padding=1,
    ):
        super(DnCNN, self).__init__()
        layers = []

        layers.append(
            nn.Conv2d(
                image_channels,
                n_channels,
                kernel_size=(kernel_size, kernel_size),
                stride=(1, 1),
                padding=(padding, padding),
                bias=True,
            )
        )
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth - 2):
            layers.append(
                nn.Conv2d(
                    n_channels,
                    n_channels,
                    kernel_size=(kernel_size, kernel_size),
                    stride=(1, 1),
                    padding=(padding, padding),
                    bias=True,
                )
            )
            layers.append(
                nn.BatchNorm2d(
                    n_channels,
                    eps=1e-05,
                    momentum=0.1,
                    affine=True,
                    track_running_stats=True,
                )
            )
            layers.append(nn.ReLU(inplace=True))
        layers.append(
            nn.Conv2d(
                n_channels,
                image_channels,
                kernel_size=(kernel_size, kernel_size),
                stride=(1, 1),
                padding=(padding, padding),
                bias=True,
            )
        )
        self.features = nn.Sequential(*layers)
        self._initialize_weights()

    def forward(self, x):
        return self.features(x)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                orthogonal_(m.weight)
                if m.bias is not None:
                    constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                constant_(m.weight, 1)
                constant_(m.bias, 0)


class PepperWithLogitsLoss(nn.Module):
    def __init__(self, weight=None, reduction='mean', pos_weight=None):
        super(PepperWithLogitsLoss, self).__init__()
        self.reduction = reduction
        self.loss = nn.BCEWithLogitsLoss(
            weight=weight, reduction='none', pos_weight=pos_weight
        )

    def forward(self, pred, y, X):
        X = 1 - X
        output = X * self.loss(pred, y)

        if self.reduction == 'none':
            output = output
        elif self.reduction == 'mean':
            output = output.sum() / X.sum()
        elif self.reduction == 'sum':
            output = output.sum()
        else:
            raise ValueError(f'{self.reduction} is not a valid value for reduction')

        return output


def learn(
    training_data,
    test_data,
    device,
    model,
    loss_fn,
    optimizer,
    batch_size=64,
    shuffle=False,
    epoch=0,
    epochs=5,
    mdl_dir=None,
    verbose=False,
):

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=shuffle)
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=shuffle)

    if verbose:
        for X, y in test_dataloader:
            print(f"Shape of X [N, C, H, W]: {X.shape}")
            print(f"Shape of y: {y.shape} {y.dtype}")
            break
        print(f"Using {device} device")
        print(model)

    for t in range(epoch, epochs):
        if verbose:
            print(f"Epoch {t+1}\n-------------------------------")
        train(train_dataloader, model, loss_fn, optimizer, verbose)
        test(test_dataloader, model, loss_fn, verbose)
        if mdl_dir:
            mdl_path = join(mdl_dir, f"model-{t+1}.pth")
            torch.save(model.state_dict(), mdl_path)
            if verbose:
                print(f"Saved PyTorch Model State to {mdl_path}")
    if verbose:
        print("Done!")


def train(dataloader, model, loss_fn, optimizer, verbose=False):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y, X)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 10 == 0 and verbose:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test(dataloader, model, loss_fn, verbose=True):
    if verbose:
        num_batches = len(dataloader)
        model.eval()
        test_loss = 0
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)
                pred = model(X)
                test_loss += loss_fn(pred, y, X).item()
        test_loss /= num_batches
        print(f"Test Error: \n Avg loss: {test_loss:>8f} \n")


def get_noise(image, level=None):
    _, height, width = image.shape
    shape = (height, width)
    
    if level is None:
        level = uniform(-1, 1)
    
    if res is None:
        hght_res = choice([i for i in range(1, height + 1) if height % i == 0])
        wdth_res = choice([i for i in range(1, width + 1) if width % i == 0])
        res = (hght_res, wdth_res)
    
    black = torch.tensor(0.0, dtype=image.dtype).to(image.device)
    white = torch.tensor(1.0, dtype=image.dtype).to(image.device)

    noise = rand_perlin_2d(shape, res).to(image.dtype).to(image.device)
    noise = torch.where(noise < level, black, white)
    noise = torch.where(image == 0.0, white, noise)

    return noise


def get_observation(image, noise):
    noise = 1 - noise
    observation = image - noise
    return observation

In [None]:
training_data = DenoisingDataset('datasets/training', ptch_sz=70)
test_data = DenoisingDataset('datasets/test', ptch_sz=70)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = DnCNN(depth=30).to(device)
loss_fn = PepperWithLogitsLoss()
optimizer = torch.optim.RMSprop(model.parameters())
learn(
    training_data,
    test_data,
    device,
    model,
    loss_fn,
    optimizer,
    batch_size=128,
    shuffle=True,
    epoch=0,
    epochs=2000,
    mdl_dir='models',
    verbose=True,
)