# Training a Denoising Autoencoder

Training a UNet to denoise MNIST images.
The net is trained on multiple noise levels.

We use Technique 3 in [Improved Techniques for Training Score-Based Generative Models](https://arxiv.org/abs/2006.09011).\
I.e. the UNet is trained to predict the unscaled noise, which the paper says removes the need to make the net noise-conditional.

## Setup ##

In [None]:
# imports
import torch
from torch import nn, optim
import torch.nn.functional as F

# import torchvision
from torchvision import datasets, transforms
from torchvision.utils import make_grid

from torch.utils.data import DataLoader

import numpy as np
import random

import matplotlib.pyplot as plt
plt.rcParams['axes.grid'] = False
plt.rcParams['image.cmap'] = 'gray'

from IPython import display

from google.colab import drive

batch_size = 64

In [None]:
training_data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=batch_size)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 158535856.31it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 92540636.99it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 52393513.56it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 18513633.40it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw






In [None]:
test_data = datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=transforms.ToTensor()   # floating point, normalized to range [0, 1]
)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

Alternate data loader with training data augmentation.

In [None]:
# training_data = datasets.MNIST(
#     root='data',
#     train=True,
#     download=True,
#     transform=transforms.Compose([
#                       transforms.RandomRotation(15),
#                       transforms.Pad(padding=1),
#                       transforms.RandomCrop([28, 28]),
#                       transforms.ToTensor()
#     ])
# )

## Variant of InstanceNorm that lets through some DC signal.
Modified from https://github.com/ermongroup/ncsn

In [None]:
class InstanceNorm2dPlus(nn.Module):
    def __init__(self, num_features, bias=True):
        super().__init__()
        self.num_features = num_features
        self.bias = bias
        self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
        self.alpha = nn.Parameter(torch.zeros(num_features))
        self.gamma = nn.Parameter(torch.zeros(num_features))
        self.alpha.data.normal_(1, 0.02)
        self.gamma.data.normal_(1, 0.02)
        if bias:
            self.beta = nn.Parameter(torch.zeros(num_features))

    def forward(self, x):
        means = torch.mean(x, dim=(2, 3))
        m = torch.mean(means, dim=-1, keepdim=True)
        v = torch.var(means, dim=-1, keepdim=True)
        means = (means - m) / (torch.sqrt(v + 1e-5))
        h = self.instance_norm(x)

        if self.bias:
            h = h + means[..., None, None] * self.alpha[..., None, None]
            out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1)
        else:
            h = h + means[..., None, None] * self.alpha[..., None, None]
            out = self.gamma.view(-1, self.num_features, 1, 1) * h
        return out

## Create UNet ##

In [None]:
## Your UNet code here
# helper operations
def conv3x3(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels,
        kernel_size=3, stride=1, padding=1, bias=True)

def maxpool2x2():
    return nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

class UpConv2x2(nn.Module):
    def __init__(self, channels):
        super(UpConv2x2, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2)
        self.conv = nn.Conv2d(channels, channels // 2,
            kernel_size=2, stride=1, padding=0, bias=True)

    def forward(self, x):
        x = self.upsample(x)
        x = F.pad(x, (0,1,0,1))
        x = self.conv(x)
        return x

def concat(xh, xv):
    return torch.cat([xh, xv], dim=1)


# unet blocks
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        """
        Args:
            in_channels: number of channels in input (1st) feature map
            out_channels: number of channels in output feature maps
        """
        super(ConvBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels)
        self.relu1 = nn.ReLU()
        self.conv2 = conv3x3(out_channels, out_channels)
        self.insta1 = InstanceNorm2dPlus(out_channels)
        self.insta2 = InstanceNorm2dPlus(out_channels)
        self.relu2 = nn.ReLU()


    def forward(self, x):
        x = self.conv1(x)
        x = self.insta1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.insta2(x)
        x = self.relu2(x)
        return x

class DownConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        """
        Args:
            in_channels: number of channels in input (1st) feature map
            out_channels: number of channels in output feature maps
        """
        super(DownConvBlock, self).__init__()

        self.downsample = maxpool2x2()
        self.conv = ConvBlock(in_channels, out_channels)

    def forward(self, x):
        x = self.downsample(x)
        x = self.conv(x)

        return x

class UpConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        """
        Args:
            in_channels: number of channels in input (1st) feature map
            out_channels: number of channels in output feature maps
        """
        super(UpConvBlock, self).__init__()

        self.up = UpConv2x2(in_channels)
        self.conv = ConvBlock(in_channels, out_channels)

    def forward(self, xh, xv):
        """
        Args:
            xv: torch Variable, activations from same resolution feature maps (gray arrow in diagram)
            xh: torch Variable, activations from lower resolution feature maps (green arrow in diagram)
        """
        x = self.up(xh)
        x = concat(x, xv)
        x = self.conv(x)

        return x

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        fs = [16,32,64,128,256]
        self.conv_in = ConvBlock(1, fs[0])
        self.dconv1 = DownConvBlock(fs[0], fs[1])
        self.dconv2 = DownConvBlock(fs[1], fs[2])
        self.dconv3 = DownConvBlock(fs[2], fs[3])
        self.dconv4 = DownConvBlock(fs[3], fs[4])
        self.uconv1 = UpConvBlock(fs[4], fs[3])
        self.uconv2 = UpConvBlock(fs[3], fs[2])
        self.uconv3 = UpConvBlock(fs[2], fs[1])
        self.uconv4 = UpConvBlock(fs[1], fs[0])
        self.conv_out = conv3x3(fs[0], 1)

    def forward(self, x):
        x1 = self.conv_in(x)
        x2 = self.dconv1(x1)
        x3 = self.dconv2(x2)
        x4 = self.dconv3(x3)
        x5 = self.dconv4(x4)
        x = self.uconv1(x5, x4)
        x = self.uconv2(x, x3)
        x = self.uconv3(x, x2)
        x = self.uconv4(x, x1)
        x = self.conv_out(x)

        return x

## Train UNet ##

In [None]:
eta = 0.001
nepoch = 50

# create net
net = UNet()
if torch.cuda.is_available():
  net.cuda()
#net.cpu()
lossfn = nn.MSELoss()

# create optimizer
optimizer = optim.Adam(net.parameters(), lr=eta)

#### Geometric sequence for scale of noise

In [None]:
gamma = (0.01)**0.1
sigma = [gamma**i for i in range(11)]

#### Training loop
Monitor validation loss (on test set) at each epoch

In [None]:
for epoch in range(nepoch):
    for (X, y) in iter(train_dataloader):
      X = F.pad(X, (2,2,2,2))  # pad MNIST images from 28x28 to 32x32
      X -= 0.5
      noise = torch.randn(X.shape)
      scale = torch.tensor(random.choices(sigma, k=X.size(0)))
      scale = scale.unsqueeze(1).unsqueeze(2).unsqueeze(3)
      noisy = X + scale*noise

      #X = X.cpu()
      #noise = noise.cpu()
      #noisy = noisy.cpu()

      if torch.cuda.is_available():
        X = X.cuda()
        noise = noise.cuda()
        noisy = noisy.cuda()

    # forward, backward, update
      optimizer.zero_grad()
      pred = net(noisy)
      loss = lossfn(pred, - noise)
      loss.backward()
      optimizer.step()

    loss_val = 0
    with torch.no_grad():
      for (X, y) in iter(test_dataloader):
        X = F.pad(X, (2,2,2,2))
        X -= 0.5
        noise = torch.randn(X.shape)
        scale = torch.tensor(random.choices(sigma, k=X.size(0)))
        scale = scale.unsqueeze(1).unsqueeze(2).unsqueeze(3)

        noisy = X + scale*noise
        if torch.cuda.is_available():
          X = X.cuda()
          noise = noise.cuda()
          noisy = noisy.cuda()

        #X = X.cpu()
        #noise = noise.cpu()
        #noisy = noisy.cpu()

        pred = net(noisy)
        loss_val += lossfn(pred, - noise)

      print(loss_val)

tensor(15.9291, device='cuda:0')
tensor(14.1403, device='cuda:0')
tensor(12.7089, device='cuda:0')
tensor(12.0037, device='cuda:0')
tensor(11.8619, device='cuda:0')
tensor(11.5213, device='cuda:0')
tensor(11.6754, device='cuda:0')
tensor(11.7264, device='cuda:0')
tensor(10.9006, device='cuda:0')
tensor(11.1719, device='cuda:0')
tensor(10.7108, device='cuda:0')
tensor(10.9455, device='cuda:0')
tensor(10.7407, device='cuda:0')
tensor(10.8185, device='cuda:0')
tensor(10.8619, device='cuda:0')
tensor(10.4810, device='cuda:0')
tensor(10.5030, device='cuda:0')
tensor(10.4263, device='cuda:0')
tensor(10.5022, device='cuda:0')
tensor(10.3624, device='cuda:0')
tensor(10.3055, device='cuda:0')
tensor(10.4142, device='cuda:0')
tensor(10.1961, device='cuda:0')
tensor(10.4912, device='cuda:0')
tensor(10.3294, device='cuda:0')
tensor(10.1823, device='cuda:0')
tensor(10.3810, device='cuda:0')
tensor(10.1242, device='cuda:0')
tensor(10.2391, device='cuda:0')
tensor(10.2827, device='cuda:0')
tensor(10.

## Save model for later experiments

In [None]:
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
torch.save(net.state_dict(), 'gdrive/MyDrive/testnoaugment')

In [None]:
drive.flush_and_unmount()