In [1]:
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "2"

# Pytorch Tiny Imagenet GCNN Denoising

Resources:
- [Tiny Imagenet](https://tiny-imagenet.herokuapp.com/)
- [GCNN Paper]()
- [Building KNN index](https://github.com/rusty1s/pytorch_cluster)

Notes:
- [unfold](https://pytorch.org/docs/stable/tensors.html#torch.Tensor.unfold) / [Fold](https://pytorch.org/docs/stable/nn.html#torch.nn.Fold) / [how-to-use](https://stackoverflow.com/questions/53972159/how-does-pytorchs-fold-and-unfold-work)

### GCNN Limitations

GCNN hits GPU memory limitations pretty hard.

Main culprit is probably the intermediate graph convolution aggregation tensor of size BHWKCC:
- B - batch
- HW - height-width
- K - number of nearest neighbours taken
- CC - matrix converting hidden features to new hidden features

In the original paper: H,W = 32, K = 8, C = 66, HWKCC $\sim 4 \cdot 10^7$.  
Note that these types of tensors are used 10 times in the model, so the overall impact is $\sim 4 \cdot 10^8$. This means that even 16 is already a large batch size - these tensors will have $\sim 6.4 \cdot 10^9$ parameters, so approx 6GB. In comparison, the memory limit of the GPUs used is 10-12GB.

So we have to balance the size of hidden features and the batch size.

Another problem is that each batch itself trains very slowly, so going throw the whole training/testing/validation datasets takes very long.


Results on a Tesla K80:  

| hidden | max batch | opt batch | train (1b) | eval (1b) |
| ------ | --------- | --------- | ---------- | --------- |
|   66   |         8 |         - |       2.0s |      1.5s |
|   48   |        16 |         - |       1.8s |      1.2s |
|   24   |        56 |        48 |       2.1s |      1.2s |
|   12   |       160 |         - |       3.9s |      2.2s |

### Comet_ML

NOTE: comet_ml is not installed in google colab by default

In [5]:
!pip install comet_ml



In [6]:
comet_ml_settings = dict(
    api_key=None,
    project_name='fastrino',
    workspace=None,
)

assert comet_ml_settings["api_key"] is not None, "set your comet_ml api_key"
assert comet_ml_settings["workspace"] is not None, "set your comet_ml workspace"

In [7]:
from comet_ml import Experiment

### Import

In [8]:
# knn_graph
!pip install --user torch-cluster

# cv2
!pip install --user opencv-python

# debug
!pip install --user psutil



In [9]:
import os
from functools import partial
from pathlib import Path

import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm as tqdm_notebook

import torch
import torch.functional as F
from torch import nn
from torch.utils.data.dataset import Dataset

import torchvision
from torchvision import transforms

from torch_cluster import knn_graph

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = 'cuda'

In [10]:
%%writefile torchmemdebug.py

# https://discuss.pytorch.org/t/how-pytorch-releases-variable-garbage/7277/2

import gc
import psutil
import os
import sys

import torch

def mem_report():
    for obj in gc.get_objects():
        if torch.is_tensor(obj):
            print(type(obj), obj.size())
    
def cpu_stats():
        print(sys.version)
        print(psutil.cpu_percent())
        print(psutil.virtual_memory())  # physical memory usage
        pid = os.getpid()
        py = psutil.Process(pid)
        memoryUse = py.memory_info()[0] / 2. ** 30  # memory use in GB...I think
        print('memory GB:', memoryUse)

Overwriting torchmemdebug.py


In [11]:
import torchmemdebug

### Data

Download TinyImagenet

In [12]:
if not Path('tiny-imagenet-200.zip').exists():
    !wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
    !unzip -q tiny-imagenet-200.zip

Full TinyImagenet

In [13]:
TINY_IMAGENET_DIR = Path('tiny-imagenet-200')
TINY_IMAGENET_PARTS = ['train', 'test', 'val']

In [14]:
def imgdir_to_array(imgdir, take_part=0.1):
    images = []
    for path in imgdir.iterdir():
        img = cv2.imread(str(path))
        images.append(img)
    stacked = np.stack(images)
    take_every = max(1, int(1 / take_part))
    taken = stacked[::take_every]
    return taken

In [15]:
def dump_cache(source_dict, cache_suffix):
    for part, array in source_dict.items():
        np.save(part + cache_suffix, array)
    print("Dumped to cache: " + cache_suffix)    


def load_cache(target_dict, cache_suffix):
    for part in TINY_IMAGENET_PARTS:
        target_dict[part] = np.load(part + cache_suffix + '.npy')

In [16]:
class TinyImagenet:
    CACHE_SUFFIX = '_full'

    def __init__(self, use_cache=True):
        self._load_arrays(use_cache)
    
    def _load_arrays(self, use_cache):
        self.arrays = {}
        if use_cache:
            load_cache(self.arrays, self.CACHE_SUFFIX)
            return
        
        self._load_train()
        self._load_test()
        self._load_val()

        dump_cache(self.arrays, self.CACHE_SUFFIX)
    
    def _load_train(self):
        train_dir = TINY_IMAGENET_DIR / 'train'

        image_arrays = []
        for imgdir in train_dir.iterdir():
            image_arrays.append(imgdir_to_array(imgdir / 'images'))
        self.arrays['train'] = np.concatenate(image_arrays, axis=0)

    def _load_test(self):
        self.arrays['test'] = imgdir_to_array(TINY_IMAGENET_DIR / 'test' / 'images')
    
    def _load_val(self):
        self.arrays['val'] = imgdir_to_array(TINY_IMAGENET_DIR / 'val' / 'images')

In [17]:
def test_ti(use_cache):
    ti = TinyImagenet(use_cache)

    for name, arr in ti.arrays.items():
        print(name + '\t', arr.shape)
    
    for img in [arr[100] for arr in ti.arrays.values()]:
        plt.imshow(img)
        plt.show()

In [18]:
#test_ti(False)
#test_ti(True)

In [19]:
def print_statistics():
    ti = TinyImagenet(use_cache=True)
    ars = ti.arrays['train'] / 255
    print('mean=', ars.mean(axis=(0,2,3)))
    ars -= ars.mean(axis=(0,2,3))[np.newaxis, :, np.newaxis, np.newaxis]
    ars **= 2
    print('std=', ars.mean(axis=(0,2,3)))

#print_statistics()

TinyImagenet split into patches

In [20]:
def split_into_patches(array, kernel_size=32, stride=32):
    _, max_x, max_y, _ = array.shape
    patches = []
    for x_end in range(kernel_size, max_x + 1, stride):
        for y_end in range(kernel_size, max_y + 1, stride):
            x_start = x_end - kernel_size
            y_start = y_end - kernel_size
            patch = array[:, x_start:x_end, y_start:y_end, :]
            patches.append(patch)
    return np.concatenate(patches, axis=0)

In [21]:
class ArrayDataset(Dataset):
    def __init__(self, array, transform):
        self.array = array
        self.transform = transform
    
    def __len__(self):
        return len(self.array)
    
    def __getitem__(self, index):
        return self.transform(self.array[index])
        

class TinyImagenetPatches:
    CACHE_SUFFIX = '_patches'

    def __init__(self, tiny_imagenet=None, initial_transform=None):
        self._make_patches(tiny_imagenet)
        self._make_datasets(initial_transform)
    
    def _make_patches(self, tiny_imagenet):
        self.patches = {}
        if tiny_imagenet is None:
            load_cache(self.patches, self.CACHE_SUFFIX)
            return
        
        for name, array in tiny_imagenet.arrays.items():
            self.patches[name] = split_into_patches(array)

        dump_cache(self.patches, self.CACHE_SUFFIX)
    
    def _make_datasets(self, transform):
        if transform is None:
            transform = transforms.Compose([
                transforms.ToTensor(),  # ALSO IMPLICITLY DIVIDES BY 255 AND DOES HWC->CHW
                transforms.Normalize(mean=(0.39750364, 0.44806704, 0.48023694), 
                                      std=(0.28158993, 0.26886327, 0.27643643))
            ])
        
        self.datasets = {}
        for name, patches in self.patches.items():
            self.datasets[name] = ArrayDataset(patches, transform)

In [22]:
def test_tip(use_cache):
    if use_cache:
        tip = TinyImagenetPatches()
    else:
        ti = TinyImagenet(use_cache=True)
        tip = TinyImagenetPatches(ti)

    for name, arr in tip.patches.items():
        print(name + '\t', arr.shape)
    
    for img in [arr[100] for arr in tip.patches.values()]:
        plt.imshow(img)
        plt.show()

In [23]:
#test_tip(use_cache=False)
#test_tip(use_cache=True)

Denoising Dataset

In [24]:
class DenoisingDataset(Dataset):
    def __init__(self, dataset, noise_std):
        self.dataset = dataset
        self.noise_std = noise_std
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        image = self.dataset[index]
        noise = self.noise_std * torch.randn_like(image)
        noisy_image = image + noise
        return noisy_image, noise


class TinyImagenetPatchesDenoising:
    def __init__(self, noise_std, tiny_imagenet=None):
        if tiny_imagenet is None:
            tiny_imagenet = TinyImagenetPatches()
        self.tiny_imagenet = tiny_imagenet
        self.noise_std = noise_std

        self._make_datasets()

    def _make_datasets(self):
        self.datasets = {}
        for name, dataset in self.tiny_imagenet.datasets.items():
            self.datasets[name] = DenoisingDataset(dataset, self.noise_std)
    
    def get_loaders(self, batch_size):
        train = torch.utils.data.DataLoader(
            self.datasets['train'], batch_size=batch_size, shuffle=True)
        test = torch.utils.data.DataLoader(
            self.datasets['test'], batch_size=batch_size, shuffle=False)
        val = torch.utils.data.DataLoader(
            self.datasets['val'], batch_size=batch_size, shuffle=False)
        return train, test, val

In [25]:
def test_tipd():
    loaders = TinyImagenetPatchesDenoising(0.1).get_loaders(128)

    for name, loader in zip(TINY_IMAGENET_PARTS, loaders):
        batch_image, batch_noise = next(iter(loader))
        print(name)
        print(batch_image.shape, f'mean={batch_image.mean()}', f'std={batch_image.std()}')
        print(batch_noise.shape, f'mean={batch_noise.mean()}', f'std={batch_noise.std()}')

In [26]:
#test_tipd()

Cache

In [27]:
def cache_exists():
    path = Path('.')
    for cache_suffix in [TinyImagenet.CACHE_SUFFIX, TinyImagenetPatches.CACHE_SUFFIX]:
        for part in TINY_IMAGENET_PARTS:
            file_ = path / (part + cache_suffix + '.npy')
            if not file_.exists():
                return False
    return True

if not cache_exists():
    ti = TinyImagenet(use_cache=False)
    TinyImagenetPatches(ti)
    assert cache_exists()
    os._exit(0)  # Restart

### DnCNN Model

[Paper](https://arxiv.org/abs/1608.03981)

In [28]:
class DnCNNBlock(nn.Module):
    def __init__(self, in_, out):
        super(DnCNNBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_, out, kernel_size=3, padding=1),
            nn.BatchNorm2d(out),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.block(x)

class DnCNN(nn.Module):
    def __init__(self, num_blocks=4, input_image_shape=[3, 32, 32], block_num_filters=64):
        super(DnCNN, self).__init__()
        
        self.name = f'dncnn'
        
        num_input_channels, *_ = input_image_shape

        self.input_convrelu = nn.Sequential(
            nn.Conv2d(num_input_channels, block_num_filters, kernel_size=3, padding=1),
            nn.ReLU(),
        )

        blocks = [DnCNNBlock(block_num_filters, block_num_filters) for _ in range(num_blocks)]
        self.blocks = nn.Sequential(*blocks)

        self.output_conv = nn.Conv2d(block_num_filters, num_input_channels, kernel_size=3, padding=1)
    
    def forward(self, x):
        input_ = self.input_convrelu(x)
        blocks = self.blocks(input_)
        output = self.output_conv(blocks)
        return output

In [29]:
def get_dncnn_model():
    model = DnCNN().to(device)
    return model

### GCNN Model

[Paper](https://arxiv.org/abs/1905.12281)

Graph Convolution

In [30]:
class NonLocalGraph(nn.Module):
    def __init__(self, features, near_neigh):
        super(NonLocalGraph, self).__init__()
        self.features = features
        self.k = near_neigh
        self._init_neightbours()
    
    def _init_neightbours(self):
        features = self.features.permute((0, 2, 3, 1))
        b, h, w, c = features.shape
        k = self.k

        flat_features = features.reshape((-1, c))

        batch_indices = torch.arange(b, device=device)\
                             .unsqueeze(-1).unsqueeze(-1).expand((b, h, w))\
                             .reshape(-1)
        
        flat_nn = knn_graph(flat_features, batch=batch_indices, k=k, loop=False)
        flat_nn = flat_nn[0, :]  # assume that we always find k neighbours
        
        self.neighbours = flat_nn.reshape((b, h, w, k))
    
    def get_neighbours(self):
        return self.neighbours

In [31]:
class AggregationWeights(nn.Module):
    def __init__(self, input_features, output_features, leakyrelu_alpha):
        super(AggregationWeights, self).__init__()
        self.c_in = input_features
        self.c_out = output_features
        self.output = nn.Sequential(
            nn.Linear(input_features, input_features),
            nn.LeakyReLU(leakyrelu_alpha),
            nn.Linear(input_features, output_features * input_features, bias=False)
        )

    def forward(self, input_):
        b, h, w, k, c_in = input_.shape
        c_out = self.c_out

        output = self.output(input_)
        return output.reshape((b, h, w, k, c_out, c_in))

In [32]:
class NonLocalAggregation(nn.Module):
    def __init__(self, input_features, output_features, leakyrelu_alpha):
        super(NonLocalAggregation, self).__init__()
        self.linear = nn.Linear(input_features, output_features)
        self.aggregation_weights = AggregationWeights(input_features, output_features, leakyrelu_alpha)
        self.activation = nn.LeakyReLU(leakyrelu_alpha)
    
    def forward(self, input_):
        input_, non_local_graph = input_
        input_ = input_.permute((0, 2, 3, 1))

        b, h, w, c = input_.shape
        k = non_local_graph.k

        indices = non_local_graph.get_neighbours()
        indices = indices.reshape(-1).unsqueeze(-1).expand((-1, c))
        gathered = torch.gather(input_.reshape((-1, c)), 0, indices)
        neighbours = gathered.reshape((b, h, w, k, c))

        delta = neighbours - input_.unsqueeze(-2)
        weights = self.aggregation_weights(delta)

        weighted_neighbours = torch.matmul(
            weights, neighbours.unsqueeze(-1)).squeeze(-1)
        
        aggregation = weighted_neighbours.mean(-2)
        linear = self.linear(input_)

        output = self.activation(aggregation + linear)
        return output.permute((0, 3, 1, 2))

In [33]:
class GraphConvolution(nn.Module):
    def __init__(self, input_features, output_features, near_neigh, leakyrelu_alpha):
        super(GraphConvolution, self).__init__()

        self.near_neigh = near_neigh

        self.conv1x1 = nn.Conv2d(input_features, output_features, kernel_size=1, padding=0)
        self.conv3x3 = nn.Conv2d(input_features, output_features, kernel_size=3, padding=1)
        self.non_local_aggregation = NonLocalAggregation(input_features, output_features, leakyrelu_alpha)

    def forward(self, input_):
        if type(input_) is list:
            input_, non_local_graph = input_
        else:
            non_local_graph = NonLocalGraph(input_, self.near_neigh)

        scales = [
            self.conv1x1(input_),
            self.conv3x3(input_),
            self.non_local_aggregation([input_, non_local_graph])
        ]
        output = torch.mean(torch.stack(scales), dim=0)
        return output

In [34]:
class GraphConvolutionBlock(nn.Module):
    def __init__(self, hidden_features, near_neigh, leakyrelu_alpha):
        super(GraphConvolutionBlock, self).__init__()

        self.output = nn.Sequential(
            GraphConvolution(hidden_features, hidden_features, near_neigh, leakyrelu_alpha),
            nn.BatchNorm2d(hidden_features),
            nn.LeakyReLU(leakyrelu_alpha)
        )
    
    def forward(self, input_):
        return self.output(input_)

Residual Block

In [35]:
class GCNNResidualBlock(nn.Module):
    def __init__(self, hidden_features, near_neigh, leakyrelu_alpha):
        super(GCNNResidualBlock, self).__init__()

        self.near_neigh = near_neigh
        num_graphconv_blocks = 3

        self.blocks = []
        for block_num in range(num_graphconv_blocks):
            block = GraphConvolutionBlock(
                hidden_features, near_neigh, leakyrelu_alpha)
            self.blocks.append(block)
            self.add_module(f'block_{block_num}', block)
    
    def forward(self, input_):
        non_local_graph = NonLocalGraph(input_, self.near_neigh)

        output = input_
        for block in self.blocks:
            output = block([output, non_local_graph])

        return input_ + output

Preprocessing

In [36]:
class GCNNPreprocessingBlockSingleScale(nn.Module):
    def __init__(self, kernel_size, input_features, hidden_features, near_neigh, leakyrelu_alpha):
        super(GCNNPreprocessingBlockSingleScale, self).__init__()

        self.output = nn.Sequential(
            nn.Conv2d(input_features, hidden_features, kernel_size=kernel_size, padding=kernel_size // 2),
            nn.LeakyReLU(leakyrelu_alpha),
            GraphConvolutionBlock(hidden_features, near_neigh, leakyrelu_alpha)
        )
    
    def forward(self, input_):
        return self.output(input_)


class GCNNPreprocessingBlock(nn.Module):
    def __init__(self, input_features, hidden_features, near_neigh, leakyrelu_alpha, kernel_sizes):
        super(GCNNPreprocessingBlock, self).__init__()
        hidden_features_scale = hidden_features // len(kernel_sizes)

        self.scales = []
        for kernel_size in kernel_sizes:
            scale = GCNNPreprocessingBlockSingleScale(
                kernel_size, input_features, hidden_features_scale, near_neigh, leakyrelu_alpha)
            self.scales.append(scale)
            self.add_module(f'conv{kernel_size}x{kernel_size}', scale)
    
    def forward(self, input_):
        scales = [scale(input_) for scale in self.scales]
        output = torch.cat(scales, dim=1)
        return output

GCNN

In [37]:
class GCNN(nn.Module):
    def __init__(self, input_features=3, hidden_features=66, near_neigh=8, leakyrelu_alpha=1e-2):
        super(GCNN, self).__init__()
        
        self.name = f'gcnn{hidden_features}'
        
        num_residual_blocks = 2
        kernel_sizes = [3, 5, 7]
        num_kernels = len(kernel_sizes)
        hidden_features = ((hidden_features + num_kernels - 1) // num_kernels) * num_kernels

        self.preprocessing = GCNNPreprocessingBlock(
            input_features, hidden_features, near_neigh, leakyrelu_alpha, kernel_sizes)

        residual_blocks = []
        for _ in range(num_residual_blocks):
            residual_blocks.append(GCNNResidualBlock(hidden_features, near_neigh, leakyrelu_alpha))
        self.residual_blocks = nn.Sequential(*residual_blocks)

        self.output_graph_conv = GraphConvolution(hidden_features, input_features, 
                                                  near_neigh, leakyrelu_alpha)
    
    def forward(self, input_):
        preprocess = self.preprocessing(input_)
        residual_blocks = self.residual_blocks(preprocess)
        output = self.output_graph_conv(residual_blocks)
        return input_ + output

In [38]:
def get_gcnn_model(hidden_features=66):
    model = GCNN(hidden_features=hidden_features).to(device)
    return model

In [39]:
def test_gcnn():
    model = get_gcnn_model()
    criterion = nn.MSELoss()

    model.eval()

    loader, *_ = TinyImagenetPatchesDenoising(noise_std=0.1).get_loaders(batch_size=8)    
    image, target = next(iter(loader))
    pred = model(image.to(device))

    print('loss=', criterion(pred, target.to(device)).item())

In [40]:
#test_gcnn()

### Load/Save Model

In [41]:
MODEL_DIR = Path('model')

def save_model(model, name):
    MODEL_DIR.mkdir(exist_ok=True)
    torch.save(model.state_dict(), MODEL_DIR / name)

def load_model(model, name):
    model.load_state_dict(torch.load(MODEL_DIR / name))

### Loss and Optimizer

In [42]:
def get_criterion():
    criterion = nn.MSELoss()
    return criterion

def get_optimizer(model):
    optimizer = torch.optim.Adam(model.parameters())
    return optimizer

### Training

In [43]:
tiny_imagenet = TinyImagenetPatches()

In [44]:
def train_step(train_loader, model, criterion, optimizer, experiment, epoch):
    with experiment.train():
        model.train()

        tqdm_notebook_train = tqdm_notebook(
            train_loader, desc='train loop', leave=False)

        for image, target in tqdm_notebook_train:
            pred = model(image.to(device))
            loss = criterion(pred, target.to(device))

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            experiment.log_metric("loss", loss.item())

            
def test_step(test_loader, model, criterion, experiment, epoch):
    def _step():
        model.eval()

        mse_losses = []
        tqdm_notebook_test = tqdm_notebook(
            test_loader, desc='test loop', leave=False)
        for image, target in tqdm_notebook_test:
            pred = model(image.to(device))
            mse_losses.append(criterion(pred, target.to(device)).item())

        experiment.log_metric("loss", np.mean(mse_losses), epoch=epoch)

    with torch.no_grad():
        with experiment.test():
            _step()

            
def validate_step(val_loader, model, criterion, experiment):
    def _step():
        model.eval()

        mse_losses = []
        tqdm_notebook_val = tqdm_notebook(
            val_loader, desc='val loop', leave=False)
        for image, target in tqdm_notebook_val:
            pred = model(image.to(device))
            mse_losses.append(criterion(pred, target.to(device)).item())

        experiment.log_metric("loss", np.mean(mse_losses))
        
    with torch.no_grad():
        with experiment.validate():
            _step()


def train(get_model, noise_std=0.1, batch_size=48, num_epochs=10, model_name_prefix=''):
    train_loader, test_loader, val_loader = \
        TinyImagenetPatchesDenoising(noise_std, tiny_imagenet).get_loaders(batch_size)

    model = get_model()
    criterion = get_criterion()
    optimizer = get_optimizer(model)

    experiment = Experiment(**comet_ml_settings)

    for epoch in tqdm_notebook(range(num_epochs), desc='Epoch loop'):
        train_step(train_loader, model, criterion, optimizer, experiment, epoch)
        test_step(test_loader, model, criterion, experiment, epoch)
        save_model(model, f'{model.name}_{int(noise_std * 100)}')
        if epoch % 10 == 0:
            validate_step(val_loader, model, criterion, experiment)

    experiment.end()

In [46]:
train(partial(get_gcnn_model, hidden_features=24),
      noise_std=0.1, batch_size=48, num_epochs=3)