# Pytorch Tiny Imagenet GCNN Denoising

Resources:
- [Tiny Imagenet](https://tiny-imagenet.herokuapp.com/)
- [GCNN Paper](https://arxiv.org/abs/1905.12281)
- [Building KNN index](https://github.com/rusty1s/pytorch_cluster)

Notes:
- [unfold](https://pytorch.org/docs/stable/tensors.html#torch.Tensor.unfold) / [Fold](https://pytorch.org/docs/stable/nn.html#torch.nn.Fold) / [how-to-use](https://stackoverflow.com/questions/53972159/how-does-pytorchs-fold-and-unfold-work)

## Experiment Results

- [Experiment logs on comet.ml](https://www.comet.ml/topilskiyak/fastrino)

### TLDR

Trained a model with the architecture from the [gcnn paper](https://arxiv.org/abs/1905.12281), with 24 hidden features, 8 nearest neighbours, using a batch size of 48, for 10 epochs (**default**).

Compared the effectiveness of the graph convolution layers to a model which uses a conv5x5 layer instead of a non-local aggregation with knn (**conv5x5**).

The *default* model used way more memory (\~x10), was way slower (\~x15) and had lower quality (\~x3).

We fixed this by approximating the non-local aggregation with a dense layer (**approx_aggr**).  
The resulting model used way less memory (\~x6), was a bit faster (\~x2) and had slightly better quality than *default*. With more tuning it could potentially surpass *conv5x5* in quality.

| memory | batch | mse_train | mse_val | link |  
| ------ | ----- | --------- | ------- | ---- |  
| 9.2GB  | 01.4it/s |  0.029 | 0.034 | [default](https://www.comet.ml/topilskiyak/fastrino/6f32ba37cf94416886e6088dbd3300bc) |  
| 1.5GB  | 03.1it/s | - | - | approx_aggr |
| 1.1GB  | 21.0it/s | 0.015 | 0.012 | [conv5x5](https://www.comet.ml/topilskiyak/fastrino/9ea5cbfc90384e2c8bb750ca5db53004) |

The main hurdle is the large slowdown due to building the knn graph.  

A potential solution is to switch from the discrete knn graph to an attention map on all pixels. That way you avoid the computational cost of building a knn graph while still being able to potentially learn it through the attention mapping. Although this approach strays away from the main focus of this current work on exploring gcnns.


### GCNN Quality: NonLocalAggregationByKNN vs Conv5x5

The **default GCNN** setup is:
- the architechture from the [gcnn paper](https://arxiv.org/abs/1905.12281)
- 24 hidden features
- 8 nearest neighbours
- 48 batch size

Training for 10 epochs on 1/10 of the dataset.  
MSE loss on noise (units are $10^{-3}$):  

| std | train | test | val | link |
| --- | ----- | ---- | --- | ---- |
| 0.1 |    14 |   17 |  17 | [std=0.1](https://www.comet.ml/topilskiyak/fastrino/ccde75780048412fb155070ddec8bd7a) |
| 0.2 |    29 |   35 |  34 | [std=0.2](https://www.comet.ml/topilskiyak/fastrino/6f32ba37cf94416886e6088dbd3300bc) |
| 1.0 |    89 |  108 | 105 | [std=1.0](https://www.comet.ml/topilskiyak/fastrino/ccde75780048412fb155070ddec8bd7a) |

Note that the test loss being huge at the start is normal - at that point the model hasn't seen that many training batches yet.

Comparison to the case where non-local aggregation is swapped for a simple conv5x5 layer (the **conv5x5** setup).  
This is 400 times faster and yields better results, but the best performance given unlimited training time would *probably* be lower.

| std | train | test | val | link |
| --- | ----- | ---- | --- | ---- |
| 0.1 |   4.4 |  5.5 | 5.7 | [std=0.1](https://www.comet.ml/topilskiyak/fastrino/851dd85f73f549829ebce24aeeb95ed6) |
| 0.2 |    15 |   12 |  12 | [std=0.2](https://www.comet.ml/topilskiyak/fastrino/9ea5cbfc90384e2c8bb750ca5db53004) |
| 1.0 |    68 |   87 |  88 | [std=1.0](https://www.comet.ml/topilskiyak/fastrino/f667de2b49fc44ec85e902da163a66da) |

### GCNN NonLocalAggregation: GPU Memory Limitations

GCNN hits GPU memory limitations pretty hard.

The main culprit is probably the intermediate graph convolution aggregation tensor of size BHWKCC:
- B - batch
- HW - height-width
- K - number of nearest neighbours taken
- CC - matrix converting hidden features to new hidden features

In the original paper: H,W = 32, K = 8, C = 66, HWKCC $\sim 4 \cdot 10^7$.  
With a batch size of just 16 this results in BHWKCC of $\sim 0.6 \cdot 10^9$.  
Now consider that each value is float64 (8 bytes), and that these kinds of tensors are used several times in the model ($\sim$ 10 times). This quickly overwhelms the memory limit of the GPUs use (10-12GB).

So we have to balance the size of the hidden features and the batch size.

Another problem is that each batch itself trains very slowly, so going throw the whole training/testing/validation datasets takes very long.


Results on a Tesla K80:  

| hidden | max batch | opt batch | train (1b) | eval (1b) |
| ------ | --------- | --------- | ---------- | --------- |
|   66   |         8 |         - |       2.0s |      1.5s |
|   48   |        16 |         - |       1.8s |      1.2s |
|   24   |        56 |        48 |       2.1s |      1.2s |
|   12   |       160 |         - |       3.9s |      2.2s |

### GCNN NonLocalAggregation: Approximation using a dense layer

To overcome GPU memory limitations and for a potential speed-up, we tried to emulate the graph convolution aggregation (to get rid of the BHWKCC tensor).

So, we tried using a 1-layer dense NN for the following conversion:
$$
(F_c - F_{n_1}, ..., F_c - F_{n_k}) \rightarrow F'_c
$$
where $F_c$ is the current feature vector, $F_{n_i}$ are the feature vectors of all the neighbours and $F'_c$ is the resulting feature vector.

We compared memory consumption and speed for different hidden (h) and batch sizes (b).
    
| memory | batch | epoch | link | GPU |
| ------ | ----- | ----- | ---- | --- |
| 9.2GB | 1.4it/s | - | h=24,b=48 (default) | - |
| 1.5GB | 3.1it/s | 300s | [h=24,b=48](https://www.comet.ml/topilskiyak/fastrino/67475c227e0e4e2c8cb2a41817cce841) | Tesla T4 |  
| 2.6GB | 1.3it/s | 260s | [h=24,b=128](https://www.comet.ml/topilskiyak/fastrino/4477bc26a91a4e19a520dbc019f6ec8a) | Tesla T4 |  
| 4.4GB | 0.7it/s | 220s | [h=24,b=256](https://www.comet.ml/topilskiyak/fastrino/e5f0b174a1e74f9693374a1cbef262d1?experiment-tab=systemMetrics) | Tesla T4 |  
| 1.4GB | 3.2it/s | 290s | [h=24,b=48](https://www.comet.ml/topilskiyak/fastrino/3ec1b6f40e43437c97511c14b5a52444) | Tesla P100 |  
| 2.6GB | 1.0it/s | 900s | [h=66,b=48](https://www.comet.ml/topilskiyak/fastrino/41334b49a5d64222835066de46f38208) | Tesla P100 |  
| 5.2GB | 0.38it/s | 910s | [h=66,b=128](https://www.comet.ml/topilskiyak/fastrino/41334b49a5d64222835066de46f38208) | Tesla P100 |  

Compared to the default GCNN, the dense approximation uses much less memory and is ~2 times faster.  
This allows for larger hidden sizes (for potentially better quality) and larger batch sizes (for potentially faster and more stable training).  

Although not observed here, do note that this method potentially trades peak quality for efficiency. Although peak quality of the default setup might be unfeasible due to memory and time constraints.

Ideas to try out:
- use a 2-layer network
- use both the deltas ($F_c - F_{n_i}$) and the neighbours themselves ($F_{n_i}$)


### GCNN NonLocalAggregation: Speed and Memory impact

Using the default setup (h=24, n=8, b=48) as the baseline, we looked into how much the aggregation step impacts speed and memory consumption.

| memory | batch | name |
| ------ | ----- | ---- |
| 9.2GB  | 01.4it/s | default |  
| 9.2GB  | 01.4it/s | no bias in aggr_weights |  
| 8.6GB  | 01.5it/s | single-layer aggr_weights |  
| 8.7GB  | 01.3it/s | single-layer aggr_weights + use einsum |  
| 1.5GB  | 03.1it/s | approximation using a single-layer NN |  
| 2.0GB | 02.4it/s | conv5x5, but still count aggr_weights (and knn) |  
| 1.1GB  | 21.0it/s | conv5x5 (no aggr_weights/knn) |  
| 0.9GB  | 35.0it/s | no aggregation or conv5x5 |  
| 1.0GB  | 36.0it/s | dcnn |  

GPUs used: Tesla P100-PCIE/K80/T4

The table shows that:
- small changes in how aggregation weights are calculated do not impact speed/memory much  
- applying the aggregation weights has the most memory footprint (7GB=9GB-2GB) 
- approximation of aggregation using a single-layer NN has the best improvement in both speed and memory

Ideas to try out:
- ECCConv for faster tensor multiplication

### GCNN KNN Graph Construction: Speed impact

We'll continue exploring how we can make the model more effecient by looking into how much does it take to make a knn graph.

It turns out that the memory impact of KNN Graph construction isn't as large as that of the NonLocalAggregation, so we'll only explore speed in this section.

The following table shows that constructing the knn graph causes a 7-fold slowdown (compared to conv5x5):  

| speed (it/s) | name |
| ------------ | ---- |
| 01.37        | default gcnn |
| 03.00        | gcnn conv5x5 with knn-graph |
| 21.00        | gcnn conv5x5 w/out knn-graph |  

The other (2-fold) slowdown is caused by applying the weight aggregation.  
This shows that approximation of aggregation using a dense layer is optimal speed-wise.

Varying the number of nearest neighbours doesn't seem to impact speed that much:  

| speed (it/s) | name |
| ------------ | ---- |
| 03.00        | gcnn conv5x5 with knn-graph (nn=8) |
| 03.30        | gcnn conv5x5 with knn-graph (nn=4) |

Note: even though knn_graph uses scipy.stats.cKDTree for CPU knn construction, on GPU it uses a special GPU implementation so that it doesn't have to move tensors from/to GPU.

Ideas to try out:
- write your own knn using [scipy pykdtree](https://github.com/scipy/scipy/blob/v1.4.1/scipy/spatial/kdtree.py#L185-L942)  
- write your own knn using [topk](https://discuss.pytorch.org/t/k-nearest-neighbor-in-pytorch/59695)?
- use [faiss](https://github.com/facebookresearch/faiss) (only supports float32)
- some more pytorch links:
 - [approx knn layer](https://discuss.pytorch.org/t/approximate-nearest-neighbors-layer/31466)
 - [fast pytorch knn](https://discuss.pytorch.org/t/fastest-way-to-find-nearest-neighbor-for-a-set-of-points/5938)
 - [pytorch loss function](https://discuss.pytorch.org/t/build-your-own-loss-function-in-pytorch/235/2)







### General ideas to try out

- use gcnn only in some blocks?
- [checkpointing](https://pytorch.org/docs/stable/checkpoint.html)



## Experiments

### System Setup

In [None]:
# Check that GPU has enough RAM (>10GB)
!nvidia-smi

In [0]:
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "2"

### Comet_ML

NOTE: comet_ml is not installed in google colab by default

In [0]:
!pip -q install comet_ml

In [0]:
comet_ml_settings = dict(
    api_key=None,
    project_name='fastrino',
    workspace=None,
)

assert comet_ml_settings["api_key"] is not None, "set your comet_ml api_key"
assert comet_ml_settings["workspace"] is not None, "set your comet_ml workspace"

In [0]:
from comet_ml import Experiment

### Import

In [0]:
# knn_graph
!pip -q install torch-cluster
# cv2
!pip -q install opencv-python
# debug
!pip -q install psutil

In [0]:
import os
from functools import partial
from pathlib import Path

import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook

import torch
import torch.functional as F
from torch import nn
from torch.utils.data.dataset import Dataset

import torchvision
from torchvision import transforms

from torch_cluster import knn_graph

device = 'cuda'

In [0]:
%%writefile torchmemdebug.py

# https://discuss.pytorch.org/t/how-pytorch-releases-variable-garbage/7277/2

import gc
import psutil
import os
import sys

import torch

def mem_report():
    for obj in gc.get_objects():
        if torch.is_tensor(obj):
            print(type(obj), obj.size())
    
def cpu_stats():
        print(sys.version)
        print(psutil.cpu_percent())
        print(psutil.virtual_memory())  # physical memory usage
        pid = os.getpid()
        py = psutil.Process(pid)
        memoryUse = py.memory_info()[0] / 2. ** 30  # memory use in GB...I think
        print('memory GB:', memoryUse)

Overwriting torchmemdebug.py


In [0]:
import torchmemdebug

In [0]:
%%writefile pytorch_ssim.py

# https://github.com/Po-Hsun-Su/pytorch-ssim

import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

class SSIM(torch.nn.Module):
    def __init__(self, window_size = 11, size_average = True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)
            
            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)
            
            self.window = window
            self.channel = channel


        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)

def ssim(img1, img2, window_size = 11, size_average = True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)
    
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    
    return _ssim(img1, img2, window, window_size, channel, size_average)

Overwriting pytorch_ssim.py


In [0]:
import pytorch_ssim

### Data

Download TinyImagenet

In [0]:
if not Path('tiny-imagenet-200.zip').exists():
    !wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
    !unzip -q tiny-imagenet-200.zip

Full TinyImagenet

In [0]:
TINY_IMAGENET_DIR = Path('tiny-imagenet-200')
TINY_IMAGENET_PARTS = ['train', 'test', 'val']

In [0]:
def imgdir_to_array(imgdir, take_part=0.1):
    images = []
    for path in imgdir.iterdir():
        img = cv2.imread(str(path))
        images.append(img)
    stacked = np.stack(images)
    take_every = max(1, int(1 / take_part))
    taken = stacked[::take_every]
    return taken

In [0]:
def dump_cache(source_dict, cache_suffix):
    for part, array in source_dict.items():
        np.save(part + cache_suffix, array)
    print("Dumped to cache: " + cache_suffix)    


def load_cache(target_dict, cache_suffix):
    for part in TINY_IMAGENET_PARTS:
        target_dict[part] = np.load(part + cache_suffix + '.npy')

In [0]:
class TinyImagenet:
    CACHE_SUFFIX = '_full'

    def __init__(self, use_cache=True):
        self._load_arrays(use_cache)
    
    def _load_arrays(self, use_cache):
        self.arrays = {}
        if use_cache:
            load_cache(self.arrays, self.CACHE_SUFFIX)
            return
        
        self._load_train()
        self._load_test()
        self._load_val()

        dump_cache(self.arrays, self.CACHE_SUFFIX)
    
    def _load_train(self):
        train_dir = TINY_IMAGENET_DIR / 'train'

        image_arrays = []
        for imgdir in train_dir.iterdir():
            image_arrays.append(imgdir_to_array(imgdir / 'images'))
        self.arrays['train'] = np.concatenate(image_arrays, axis=0)

    def _load_test(self):
        self.arrays['test'] = imgdir_to_array(TINY_IMAGENET_DIR / 'test' / 'images')
    
    def _load_val(self):
        self.arrays['val'] = imgdir_to_array(TINY_IMAGENET_DIR / 'val' / 'images')

In [0]:
def test_ti(use_cache):
    ti = TinyImagenet(use_cache)

    for name, arr in ti.arrays.items():
        print(name + '\t', arr.shape)
    
    for img in [arr[100] for arr in ti.arrays.values()]:
        plt.imshow(img)
        plt.show()

In [0]:
#test_ti(False)
#test_ti(True)

In [0]:
def print_statistics():
    ti = TinyImagenet(use_cache=True)
    ars = ti.arrays['train'] / 255
    print('mean=', ars.mean(axis=(0,2,3)))
    ars -= ars.mean(axis=(0,2,3))[np.newaxis, :, np.newaxis, np.newaxis]
    ars **= 2
    print('std=', ars.mean(axis=(0,2,3)))

#print_statistics()

TinyImagenet split into patches

In [0]:
def split_into_patches(array, kernel_size=32, stride=32):
    _, max_x, max_y, _ = array.shape
    patches = []
    for x_end in range(kernel_size, max_x + 1, stride):
        for y_end in range(kernel_size, max_y + 1, stride):
            x_start = x_end - kernel_size
            y_start = y_end - kernel_size
            patch = array[:, x_start:x_end, y_start:y_end, :]
            patches.append(patch)
    return np.concatenate(patches, axis=0)

In [0]:
NORMALIZATION_MEAN = (0.39750364, 0.44806704, 0.48023694)
NORMALIZATION_STD  = (0.28158993, 0.26886327, 0.27643643)

DEFAULT_TRANSFORMS = \
    transforms.Compose([
        transforms.ToTensor(),  # ALSO IMPLICITLY DIVIDES BY 255 AND DOES HWC->CHW
        transforms.Normalize(mean=NORMALIZATION_MEAN, std=NORMALIZATION_STD)
    ])


class ArrayDataset(Dataset):
    def __init__(self, array, transform=None):
        if transform is None:
            transform = DEFAULT_TRANSFORMS
        
        self.array = array
        self.transform = transform
    
    def __len__(self):
        return len(self.array)
    
    def __getitem__(self, index):
        return self.transform(self.array[index])


class TinyImagenetPatches:
    CACHE_SUFFIX = '_patches'

    def __init__(self, tiny_imagenet=None, initial_transform=None):
        self._make_patches(tiny_imagenet)
        self._make_datasets(initial_transform)
    
    def _make_patches(self, tiny_imagenet):
        self.patches = {}
        if tiny_imagenet is None:
            load_cache(self.patches, self.CACHE_SUFFIX)
            return
        
        for name, array in tiny_imagenet.arrays.items():
            self.patches[name] = split_into_patches(array)

        dump_cache(self.patches, self.CACHE_SUFFIX)
    
    def _make_datasets(self, transform):
        self.datasets = {}
        for name, patches in self.patches.items():
            self.datasets[name] = ArrayDataset(patches, transform)

In [0]:
def test_tip(use_cache):
    if use_cache:
        tip = TinyImagenetPatches()
    else:
        ti = TinyImagenet(use_cache=True)
        tip = TinyImagenetPatches(ti)

    for name, arr in tip.patches.items():
        print(name + '\t', arr.shape)
    
    for img in [arr[100] for arr in tip.patches.values()]:
        plt.imshow(img)
        plt.show()

In [0]:
#test_tip(use_cache=False)
#test_tip(use_cache=True)

Denoising Dataset

In [0]:
class DenoisingDataset(Dataset):
    def __init__(self, dataset, noise_std):
        self.dataset = dataset
        self.noise_std = noise_std
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        image = self.dataset[index]
        noise = self.noise_std * torch.randn_like(image)
        noisy_image = image + noise
        return noisy_image, noise


class TinyImagenetPatchesDenoising:
    def __init__(self, noise_std, tiny_imagenet=None):
        if tiny_imagenet is None:
            tiny_imagenet = TinyImagenetPatches()
        self.tiny_imagenet = tiny_imagenet
        self.noise_std = noise_std

        self._make_datasets()

    def _make_datasets(self):
        self.datasets = {}
        for name, dataset in self.tiny_imagenet.datasets.items():
            self.datasets[name] = DenoisingDataset(dataset, self.noise_std)
    
    def get_loaders(self, batch_size):
        train = torch.utils.data.DataLoader(
            self.datasets['train'], batch_size=batch_size, shuffle=True)
        test = torch.utils.data.DataLoader(
            self.datasets['test'], batch_size=batch_size, shuffle=False)
        val = torch.utils.data.DataLoader(
            self.datasets['val'], batch_size=batch_size, shuffle=False)
        return train, test, val

In [0]:
def test_tipd():
    loaders = TinyImagenetPatchesDenoising(0.1).get_loaders(128)

    for name, loader in zip(TINY_IMAGENET_PARTS, loaders):
        batch_image, batch_noise = next(iter(loader))
        print(name)
        print(batch_image.shape, f'mean={batch_image.mean()}', f'std={batch_image.std()}')
        print(batch_noise.shape, f'mean={batch_noise.mean()}', f'std={batch_noise.std()}')

In [0]:
#test_tipd()

Cache

In [0]:
def cache_exists():
    path = Path('.')
    for cache_suffix in [TinyImagenet.CACHE_SUFFIX, TinyImagenetPatches.CACHE_SUFFIX]:
        for part in TINY_IMAGENET_PARTS:
            file_ = path / (part + cache_suffix + '.npy')
            if not file_.exists():
                return False
    return True

if not cache_exists():
    ti = TinyImagenet(use_cache=False)
    TinyImagenetPatches(ti)
    assert cache_exists()
    os._exit(0)  # Restart

### DnCNN Model ([Paper](https://arxiv.org/abs/1608.03981))

In [0]:
class DnCNNBlock(nn.Module):
    def __init__(self, in_, out):
        super(DnCNNBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_, out, kernel_size=3, padding=1),
            nn.BatchNorm2d(out),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.block(x)

class DnCNN(nn.Module):
    def __init__(self, num_blocks=4, input_image_shape=[3, 32, 32], block_num_filters=64):
        super(DnCNN, self).__init__()
        
        self.name = f'dncnn'
        
        num_input_channels, *_ = input_image_shape

        self.input_convrelu = nn.Sequential(
            nn.Conv2d(num_input_channels, block_num_filters, kernel_size=3, padding=1),
            nn.ReLU(),
        )

        blocks = [DnCNNBlock(block_num_filters, block_num_filters) for _ in range(num_blocks)]
        self.blocks = nn.Sequential(*blocks)

        self.output_conv = nn.Conv2d(block_num_filters, num_input_channels, kernel_size=3, padding=1)
    
    def forward(self, x):
        input_ = self.input_convrelu(x)
        blocks = self.blocks(input_)
        output = self.output_conv(blocks)
        return output

In [0]:
def get_dncnn_model():
    model = DnCNN().to(device)
    return model

### GCNN Model ([Paper](https://arxiv.org/abs/1905.12281))

Graph Convolution

In [0]:
class NonLocalGraph(nn.Module):
    def __init__(self, features, near_neigh):
        super(NonLocalGraph, self).__init__()
        self.features = features
        self.k = near_neigh
        self._init_neightbours()
    
    def _init_neightbours(self):
        self.neighbours = 1
        return

        features = self.features.permute((0, 2, 3, 1))
        b, h, w, c = features.shape
        k = self.k

        flat_features = features.reshape((-1, c))

        batch_indices = torch.arange(b, device=device)\
                             .unsqueeze(-1).unsqueeze(-1).expand((b, h, w))\
                             .reshape(-1)
        
        flat_nn = knn_graph(flat_features, batch=batch_indices, k=k, loop=False)
        flat_nn = flat_nn[0, :]  # assume that we always find k neighbours
        
        self.neighbours = flat_nn.reshape((b, h, w, k))
    
    def get_neighbours(self):
        return self.neighbours

In [0]:
class AggregationWeights(nn.Module):
    def __init__(self, input_features, output_features, leakyrelu_alpha):
        super(AggregationWeights, self).__init__()
        self.c_in = input_features
        self.c_out = output_features
        self.output = nn.Sequential(
            #nn.Linear(input_features, input_features),
            #nn.LeakyReLU(leakyrelu_alpha),
            nn.Linear(input_features, output_features * input_features)
        )  # single-layer + no bias = less memory

    def forward(self, input_):
        b, h, w, k, c_in = input_.shape
        c_out = self.c_out

        output = self.output(input_)
        return output.reshape((b, h, w, k, c_out, c_in))

In [0]:
class NonLocalAggregation(nn.Module):
    def __init__(self, input_features, output_features, near_neigh, leakyrelu_alpha):
        super(NonLocalAggregation, self).__init__()
        self.linear = nn.Linear(input_features, output_features)
        # self.aggregation_weights = AggregationWeights(input_features, output_features, leakyrelu_alpha)
        self.aggregation = nn.Linear(near_neigh * input_features, output_features)
        self.activation = nn.LeakyReLU(leakyrelu_alpha)
    
    def forward(self, input_):
        input_, non_local_graph = input_
        input_ = input_.permute((0, 2, 3, 1))

        b, h, w, c = input_.shape
        k = non_local_graph.k

        indices = non_local_graph.get_neighbours()
        indices = indices.reshape(-1).unsqueeze(-1).expand((-1, c))
        gathered = torch.gather(input_.reshape((-1, c)), 0, indices)
        neighbours = gathered.reshape((b, h, w, k, c))

        # aggregation approximation by dense layer
        aggregation = self.aggregation(neighbours.view((b, h, w, k * c)))

        # default aggregation
        # delta = neighbours - input_.unsqueeze(-2)
        # weights = self.aggregation_weights(delta)
        # aggregation = torch.einsum(
        #     "bhwkqc,bhwkc->bhwq", weights, neighbours
        # ) / k

        linear = self.linear(input_)
        output = self.activation(aggregation + linear)
        return output.permute((0, 3, 1, 2))

In [0]:
class GraphConvolution(nn.Module):
    def __init__(self, input_features, output_features, near_neigh, leakyrelu_alpha):
        super(GraphConvolution, self).__init__()

        self.near_neigh = near_neigh

        self.conv1x1 = nn.Conv2d(input_features, output_features, kernel_size=1, padding=0)
        self.conv3x3 = nn.Conv2d(input_features, output_features, kernel_size=3, padding=1)
        self.conv5x5 = nn.Conv2d(input_features, output_features, kernel_size=5, padding=2)
        #self.non_local_aggregation = NonLocalAggregation(input_features, output_features, near_neigh, leakyrelu_alpha)

    def forward(self, input_):
        if type(input_) is list:
            input_, non_local_graph = input_
        else:
            non_local_graph = NonLocalGraph(input_, self.near_neigh)

        scales = [
            self.conv1x1(input_),
            self.conv3x3(input_),
            self.conv5x5(input_), #self.non_local_aggregation([input_, non_local_graph])
        ]
        output = torch.mean(torch.stack(scales), dim=0)
        return output

In [0]:
class GraphConvolutionBlock(nn.Module):
    def __init__(self, hidden_features, near_neigh, leakyrelu_alpha):
        super(GraphConvolutionBlock, self).__init__()

        self.output = nn.Sequential(
            GraphConvolution(hidden_features, hidden_features, near_neigh, leakyrelu_alpha),
            nn.BatchNorm2d(hidden_features),
            nn.LeakyReLU(leakyrelu_alpha)
        )
    
    def forward(self, input_):
        return self.output(input_)

Residual Block

In [0]:
class GCNNResidualBlock(nn.Module):
    def __init__(self, hidden_features, near_neigh, leakyrelu_alpha):
        super(GCNNResidualBlock, self).__init__()

        self.near_neigh = near_neigh
        num_graphconv_blocks = 3

        self.blocks = []
        for block_num in range(num_graphconv_blocks):
            block = GraphConvolutionBlock(
                hidden_features, near_neigh, leakyrelu_alpha)
            self.blocks.append(block)
            self.add_module(f'block_{block_num}', block)
    
    def forward(self, input_):
        non_local_graph = NonLocalGraph(input_, self.near_neigh)

        output = input_
        for block in self.blocks:
            output = block([output, non_local_graph])

        return input_ + output

Preprocessing

In [0]:
class GCNNPreprocessingBlockSingleScale(nn.Module):
    def __init__(self, kernel_size, input_features, hidden_features, near_neigh, leakyrelu_alpha):
        super(GCNNPreprocessingBlockSingleScale, self).__init__()

        self.output = nn.Sequential(
            nn.Conv2d(input_features, hidden_features, kernel_size=kernel_size, padding=kernel_size // 2),
            nn.LeakyReLU(leakyrelu_alpha),
            GraphConvolutionBlock(hidden_features, near_neigh, leakyrelu_alpha)
        )
    
    def forward(self, input_):
        return self.output(input_)


class GCNNPreprocessingBlock(nn.Module):
    def __init__(self, input_features, hidden_features, near_neigh, leakyrelu_alpha, kernel_sizes):
        super(GCNNPreprocessingBlock, self).__init__()
        hidden_features_scale = hidden_features // len(kernel_sizes)

        self.scales = []
        for kernel_size in kernel_sizes:
            scale = GCNNPreprocessingBlockSingleScale(
                kernel_size, input_features, hidden_features_scale, near_neigh, leakyrelu_alpha)
            self.scales.append(scale)
            self.add_module(f'conv{kernel_size}x{kernel_size}', scale)
    
    def forward(self, input_):
        scales = [scale(input_) for scale in self.scales]
        output = torch.cat(scales, dim=1)
        return output

GCNN

In [0]:
class GCNN(nn.Module):
    def __init__(self, input_features=3, hidden_features=66, near_neigh=8, leakyrelu_alpha=1e-2):
        super(GCNN, self).__init__()
        
        self.name = f'gcnn{hidden_features}'
        
        num_residual_blocks = 2
        kernel_sizes = [3, 5, 7]
        num_kernels = len(kernel_sizes)
        hidden_features = ((hidden_features + num_kernels - 1) // num_kernels) * num_kernels

        self.preprocessing = GCNNPreprocessingBlock(
            input_features, hidden_features, near_neigh, leakyrelu_alpha, kernel_sizes)

        residual_blocks = []
        for _ in range(num_residual_blocks):
            residual_blocks.append(GCNNResidualBlock(hidden_features, near_neigh, leakyrelu_alpha))
        self.residual_blocks = nn.Sequential(*residual_blocks)

        self.output_graph_conv = GraphConvolution(hidden_features, input_features, 
                                                  near_neigh, leakyrelu_alpha)
    
    def forward(self, input_):
        preprocess = self.preprocessing(input_)
        residual_blocks = self.residual_blocks(preprocess)
        output = self.output_graph_conv(residual_blocks)
        return input_ + output

In [0]:
def get_gcnn_model(hidden_features=66):
    model = GCNN(hidden_features=hidden_features).to(device)
    return model

In [0]:
def test_gcnn():
    model = get_gcnn_model()
    criterion = nn.MSELoss()

    model.eval()

    loader, *_ = TinyImagenetPatchesDenoising(noise_std=0.1).get_loaders(batch_size=8)    
    image, target = next(iter(loader))
    pred = model(image.to(device))

    print('loss=', criterion(pred, target.to(device)).item())

In [0]:
#test_gcnn()

### Load/Save Model

In [0]:
MODEL_DIR = Path('model')

def save_model(model, name):
    MODEL_DIR.mkdir(exist_ok=True)
    torch.save(model.state_dict(), MODEL_DIR / name)

def load_model(model, name):
    model.load_state_dict(torch.load(MODEL_DIR / name))

### Loss and Optimizer

In [0]:
def get_criterion():
    criterion = nn.MSELoss()
    return criterion

def get_optimizer(model):
    optimizer = torch.optim.Adam(model.parameters())
    return optimizer

### Training

In [0]:
tiny_imagenet = TinyImagenetPatches()

In [0]:
def train_step(train_loader, model, criterion, optimizer, experiment, epoch):
    with experiment.train():
        model.train()

        tqdm_notebook_train = tqdm_notebook(
            train_loader, desc='train loop', leave=False)

        for image, target in tqdm_notebook_train:
            pred = model(image.to(device))
            loss = criterion(pred, target.to(device))

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            experiment.log_metric("loss", loss.item())

            
def test_step(test_loader, model, criterion, experiment, epoch):
    def _step():
        model.eval()

        mse_losses = []
        tqdm_notebook_test = tqdm_notebook(
            test_loader, desc='test loop', leave=False)
        for image, target in tqdm_notebook_test:
            pred = model(image.to(device))
            mse_losses.append(criterion(pred, target.to(device)).item())

        experiment.log_metric("loss", np.mean(mse_losses), epoch=epoch)

    with torch.no_grad():
        with experiment.test():
            _step()

            
def validate_step(val_loader, model, criterion, experiment):
    def _step():
        model.eval()

        mse_losses = []
        tqdm_notebook_val = tqdm_notebook(
            val_loader, desc='val loop', leave=False)
        for image, target in tqdm_notebook_val:
            pred = model(image.to(device))
            mse_losses.append(criterion(pred, target.to(device)).item())

        experiment.log_metric("loss", np.mean(mse_losses))
        
    with torch.no_grad():
        with experiment.validate():
            _step()


def train(get_model, noise_std=0.1, batch_size=48, num_epochs=10, model_name_prefix=''):
    train_loader, test_loader, val_loader = \
        TinyImagenetPatchesDenoising(noise_std, tiny_imagenet).get_loaders(batch_size)

    model = get_model()
    criterion = get_criterion()
    optimizer = get_optimizer(model)

    experiment = Experiment(**comet_ml_settings)

    for epoch in tqdm_notebook(range(num_epochs), desc='Epoch loop'):
        train_step(train_loader, model, criterion, optimizer, experiment, epoch)
        test_step(test_loader, model, criterion, experiment, epoch)
        save_model(model, f'{model.name}_{int(noise_std * 100)}')
        if epoch % 10 == 9:
            validate_step(val_loader, model, criterion, experiment)

    experiment.end()

In [0]:
train(partial(get_gcnn_model, hidden_features=24),
      noise_std=0.2, batch_size=48, num_epochs=10)

### Full inference

In [0]:
class InferenceByPatch:
    def __init__(self, model, patch_size, batch_size):
        self.model = model
        self.patch_size = patch_size
        self.batch_size = batch_size
    
    def feed_patches(self, patches):
        loader = torch.utils.data.DataLoader(
            patches, batch_size=self.batch_size, shuffle=False)
        
        self.model.eval()
        with torch.no_grad():
            results = []
            for batch in loader:
                result =  self.model(batch)
                results.append(result)
        return torch.cat(results, dim=0)
    
    def feed_inner(self, input_):
        ph, pw = self.patch_size
        
        insplit = input_.unfold(2, ph, ph).unfold(3, pw, pw)
        b, c, nh, nw, ph, pw = insplit.shape
        
        patches = insplit.permute(0, 2, 3, 1, 4, 5).reshape(-1, c, ph, pw)
        results = self.feed_patches(patches)
        outsplit = results.reshape(b, nh, nw, c, ph, pw).permute(0, 3, 1, 2, 4, 5)
        
        output = outsplit.permute(0, 1, 2, 4, 3, 5)\
                         .contiguous().view(b, c, nh * ph, nw * pw)
        return output
    
    def __call__(self, input_):
        b, c, h, w = input_.shape
        ph, pw = self.patch_size
        
        output = torch.zeros_like(input_)
        overlap_counts = torch.zeros_like(input_)
        
        top_left = self.feed_inner(input_)
        b, c, tlh, tlw = top_left.shape
        output[:, :, :tlh, :tlw] += top_left
        overlap_counts[:, :, :tlh, :tlw] += 1
        
        bottom = self.feed_inner(input_[:, :, -ph:, :])
        b, c, ph, bw = bottom.shape
        output[:, :, -ph:, :bw] += bottom
        overlap_counts[:, :, -ph:, :bw] += 1
        
        right = self.feed_inner(input_[:, :, :, -pw:])
        b, c, rh, pw = right.shape
        output[:, :, :rh, -pw:] += right
        overlap_counts[:, :, :rh, -pw:] += 1
        
        bottom_right = self.feed_inner(input_[:, :, -ph:, -pw:])
        b, c, ph, pw = bottom_right.shape
        output[:, :, -ph:, -pw:] += bottom_right
        overlap_counts[:, :, -ph:, -pw:] += 1
        
        output /= overlap_counts
        return output

In [0]:
class Id(nn.Module):
    def __init__(self, shape):
        super(Id, self).__init__()
        self.shape = shape
        
    def forward(self, input_):
        for i, (d_t, d_in) in enumerate(zip(self.shape, input_.shape)):
            if i == 0:
                assert d_in <= d_t  # batch
            else:
                assert d_in == d_t # c h w
        return input_

def test_inference(shape_input=(11, 5, 113, 313), shape_model=(3, 5, 7, 9)):
    b, c, h, w = shape_model
    
    infer = InferenceByPatch(Id(shape_model), (h, w), b)
    
    input_ = torch.randn(shape_input)
    output = infer(input_)
    assert ((input_ - output) == 0).all()

In [0]:
#test_inference()

In [0]:
NORMALIZATION_MEAN_TENSOR = torch.Tensor(NORMALIZATION_MEAN)[None,:,None,None].to(device)
NORMALIZATION_STD_TENSOR  = torch.Tensor(NORMALIZATION_STD)[None,:,None,None].to(device)

def denormalize(images):
    return images.mul_(NORMALIZATION_STD_TENSOR).add_(NORMALIZATION_MEAN_TENSOR)

mse_loss = nn.MSELoss()

def psnr(pred, target, max_=1.0):
    mse = mse_loss(pred, target).item()
    return 20 * np.log10(max_) - 10 * np.log10(mse)

In [0]:
def inference(get_model, load_model_name, 
              dataset, dataset_name,
              return_images=False,
              patch_size=(32, 32), model_batch_size=48,
              noise_std=0.1, batch_size=128):
    model = get_model()
    load_model(model, load_model_name)
    infer = InferenceByPatch(model, patch_size, model_batch_size)
    
    denoising_dataset = DenoisingDataset(dataset, noise_std=noise_std)    
    loader = torch.utils.data.DataLoader(
        denoising_dataset, batch_size=batch_size, shuffle=False)
    
    criterion = get_criterion()
    
    losses = []
    psnrs = []
    ssims = []
    images_pred = []
    images_target = []

    progress_bar = tqdm_notebook(loader, desc='Inference', leave=True)
    for img, noise_target in progress_bar:
        img = img.to(device)
        noise_target = noise_target.to(device)

        noise_pred = infer(img)
        losses.append(criterion(noise_pred, noise_target).item())

        img_pred = denormalize(img - noise_pred)
        img_target = denormalize(img - noise_target)
        psnrs.append(psnr(img_pred, img_target).item())
        ssims.append(pytorch_ssim.ssim(img_pred, img_target).item())

        if return_images:
            images_pred.append(img_pred.to('cpu').numpy())
            images_target.append(img_target.to('cpu').numpy())

    print(f'Results for {dataset_name} for {load_model_name}:')
    print(f'loss = {np.mean(losses):.5f}')
    print(f'psnr = {np.mean(psnrs):.5f}')
    print(f'ssim = {np.mean(ssims):.5f}')

    if return_images:
        return np.concat(images_pred), np.concat(images_target)

In [0]:
#inference(partial(get_gcnn_model, hidden_features=24),
#          load_model_name='gcnn24_20',
#          dataset=ArrayDataset(np.load('test' + TinyImagenet.CACHE_SUFFIX + '.npy')), 
#          dataset_name='test',
#          noise_std=0.2)

In [0]:
#inference(partial(get_gcnn_model, hidden_features=24),
#          load_model_name='gcnn24_100',
#          dataset=ArrayDataset(np.load('test' + TinyImagenet.CACHE_SUFFIX + '.npy')), 
#          dataset_name='test',
#          noise_std=1.0)