In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ---- BLOCKS -----
class Module(object):
    def __call__(self, input):
        return self.forward(input)

    def forward(self, input):
        return input

    def backward(self, gradwrtoutput):
        pass

    def get_params(self):
        return []

    def set_params(self, params):
        pass

    def cuda(self):
        pass

    def params(self):
        return []

    def eval(self):
        pass

    def train(self):
        pass

In [3]:
def compare_fcts(gd, scratch):
    x = torch.rand(5, 3, 32, 32, requires_grad=True)
    y_gd = gd(x)
    with torch.no_grad():
        y_scratch = scratch(x)
        y_gd.backward(torch.ones_like(y_gd))
        y_grad_gd = x.grad
        y_grad_scatch = scratch.backward(torch.ones_like(y_scratch))
    
    print("max error forward pass: ", torch.max(torch.abs(y_gd - y_scratch)))
    print("max error backward pass: ", torch.max(torch.abs(y_grad_gd - y_grad_scatch)))

# Activation

# ReLU

In [4]:
class ReLU(Module):
    def __init__(self):
        self.zero_mask = None

    def __call__(self, x):
        return self.forward(x)

    def forward(self, x):
        self.zero_mask = x > 0
        return x * self.zero_mask

    def backward(self, gradwrtoutput):
        result = self.zero_mask * gradwrtoutput
        return result

compare_fcts(nn.ReLU(), ReLU())

max error forward pass:  tensor(0., grad_fn=<MaxBackward1>)
max error backward pass:  tensor(0.)


In [5]:
class LeakyRelu(Module):
    def __init__(self, slope):
        self.slope = slope
        self.zero_mask = None

    def forward(self, x):
        zero_mask = x > 0
        self.zero_mask = zero_mask.float()
        result = zero_mask * (1 - self.slope)
        result += self.slope
        return x * result

    def backward(self, gradwrtoutput):
        result = self.zero_mask * (1 - self.slope)
        result += self.slope 
        return result
    

compare_fcts(nn.LeakyReLU(), LeakyRelu(0.1))

max error forward pass:  tensor(0., grad_fn=<MaxBackward1>)
max error backward pass:  tensor(0.)


In [6]:
class Sigmoid(Module):
    def __init__(self):
        self.__e = torch.e
        self.forward_sigm = None
        self.mode = "train"

    def __sig(self, x):
        return 1 / (1 + self.__e ** -x)

    def forward(self, x):
        sig = self.__sig(x)
        if self.mode == "train":
            self.forward_sigm = sig
        return sig

    def backward(self, gradwrtoutput):
        result = gradwrtoutput * self.forward_sigm * (1 - self.forward_sigm)
        return result
    

compare_fcts(nn.Sigmoid(), Sigmoid())

max error forward pass:  tensor(1.1921e-07, grad_fn=<MaxBackward1>)
max error backward pass:  tensor(5.9605e-08)


In [28]:
class Conv2d(Module):
    def __init__(self, channels_in=1, channels_out=1, kernel_size=(3, 3), stride=1, padding=0, dilation=1,
                 bias=True):
        self.channels_in = channels_in
        self.channels_out = channels_out
        self.is_bias = bias
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
        self.device = "cpu"
        self.mode = "train"
        self.input_unfolded = None
        self.input_shape = None

        self._initialize_parameters()
    
    def _initialize_parameters(self):
        xavier_bound = (2 / (self.channels_out + self.channels_in) ** 0.5)
        self.weight = torch.empty((self.channels_out, self.channels_in, *self.kernel_size)).uniform_(-xavier_bound, xavier_bound).to(self.device)
        self.bias = torch.empty(self.channels_out).uniform_(-xavier_bound, xavier_bound).to(self.device) if self.is_bias else None
        
        self.dB = self.bias.clone() if self.is_bias else None
        self.dW = self.weight.clone()
    
    def forward(self, input):
        input = input.to(self.device)
        N, _, _, _ = input.shape
        self.input_unfolded = torch.nn.functional.unfold(input, kernel_size=self.kernel_size, padding=self.padding,
                                                         stride=self.stride, dilation=self.dilation).to(self.device)
        wxb = self.weight.view(self.channels_out, -1) @ self.input_unfolded + self.bias.view(1, -1, 1)
        result = wxb.view(N, self.channels_out,
                          ((input.shape[2] - self.dilation * self.kernel_size[
                              0] + 2 * self.padding) // self.stride + 1),
                          ((input.shape[3] - self.dilation * self.kernel_size[
                              1] + 2 * self.padding) // self.stride + 1))

        if self.mode == "train":
            self.input_shape = input.shape

        return result

    def backward(self, gradwrtoutput):
        N, C, H, W = self.input_shape

        # compute gradient w.r.t. weights
        grad_reshaped = gradwrtoutput.permute(1, 2, 3, 0).reshape(self.channels_out, -1)
        dW = grad_reshaped @ self.input_unfolded.permute(1, 2, 0).reshape(self.input_unfolded.shape[1], -1).T
        self.dW.zero_()
        self.dW += dW.reshape(self.weight.shape)

        # compute gradient w.r.t. biases
        self.dB.zero_()
        self.dB += gradwrtoutput.sum(axis=(0, 2, 3)).view(-1)

        # compute gradient w.r.t. inputs
        weight_reshaped = self.weight.view(self.channels_out, -1)
        dX_col = weight_reshaped.T @ grad_reshaped
        out_h = (H - self.dilation * (self.kernel_size[0] - 1) + 2 * self.padding - 1) // self.stride + 1
        out_w = (W - self.dilation * (self.kernel_size[1] - 1) + 2 * self.padding - 1) // self.stride + 1
        dX_col_reshaped = dX_col.view(C * self.kernel_size[0] * self.kernel_size[1], out_w * out_h, N)
        dX = torch.nn.functional.fold(dX_col_reshaped.permute(2, 0, 1),
                                      self.input_shape[2:], kernel_size=self.kernel_size, padding=self.padding,
                                      stride=self.stride, dilation=self.dilation)
        
        return dX


    def _backward(self, grad_output):
        N, C, H, W = self.input_shape
        grad_output_unfolded = F.unfold(grad_output, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation)

        # compute gradient w.r.t. weights
        grad_w = grad_output_unfolded.transpose(0, 1).matmul(self.input_unfolded.transpose(1, 2))
        grad_w = grad_w.transpose(0, 1).view_as(self.weight)

        # compute gradient w.r.t. biases
        grad_b = grad_output.sum(dim=(0, 2, 3))
        if self.bias is not None:
            grad_b = grad_b.view_as(self.bias)

        # compute gradient w.r.t. inputs
        grad_output_reshaped = grad_output_unfolded.transpose(0, 1)
        grad_output_reshaped = grad_output_reshaped.view(C * self.kernel_size[0] * self.kernel_size[1], -1)
        grad_input_unfolded = self.weight.view(self.weight.size(0), -1).matmul(grad_output_reshaped)
        grad_input_unfolded = grad_input_unfolded.view(-1, self.input_unfolded.size(1), self.input_unfolded.size(2))
        grad_input = F.fold(grad_input_unfolded, output_size=self.input_shape[2:], kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation)

        return grad_input, grad_w, grad_b

    
    def get_params(self):
        return [self.weight, self.bias, self.is_bias]

    def set_params(self, params):
        self.weight = params[0].to(self.device)
        self.bias = params[1].to(self.device)
        self.is_bias = params[2]

    def cuda(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.weight = self.weight.to(self.device)
        self.bias = self.bias.to(self.device)
        self.dW = self.dW.to(self.device)
        self.dB = self.dB.to(self.device)
        self.prev_update_conv = self.prev_update_conv.to(self.device)
        self.prev_update_bias = self.prev_update_bias.to(self.device)

    def params(self):
        return [[self.weight, self.dW], [self.bias, self.dB]]

conv2_gd = Conv2d(3, 16, 3, 1, 1)
conv2_scratch = Conv2d(3, 16, 3, 1, 1)
conv2_gd.weight = conv2_scratch.weight.clone()
conv2_gd.bias = conv2_scratch.bias.clone()
compare_fcts(conv2_gd, conv2_scratch)

RuntimeError: The size of tensor a (144) must match the size of tensor b (5) at non-singleton dimension 0

In [7]:
%%timeit
a = torch.randn((500,500,3,3))
for i in range(1000):
    a.zero_()

58.7 ms ± 3.42 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
%%timeit
a = torch.randn((500,500,3,3))
for i in range(1000):
    a *= 0

107 ms ± 3.59 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
class MaxPool2d(Module):
    def __init__(self):
        self.index = None
        self.mask = None

    def forward(self, x):
        n, c, W, H = x.shape
        xmax = x.view(n, c, W // 2, 2, H // 2, 2).max(5).values.max(3)
        xmax_max = xmax.values.repeat_interleave(2, axis=2).repeat_interleave(2, axis=3)
        self.mask = xmax_max == x
        return xmax.values

    def backward(self, gradwrtoutput):
        result = self.mask * gradwrtoutput.repeat_interleave(2, axis=2).repeat_interleave(2, axis=3)
        return result    


compare_fcts(nn.MaxPool2d(2), MaxPool2d())

max error forward pass:  tensor(0., grad_fn=<MaxBackward1>)
max error backward pass:  tensor(0.)


In [9]:

class UpSampling2D(Module):
    """ Nearest neighbor up sampling of the input. Repeats the rows and columns of the data by size[0] and size[1] respectively. Parameters: ----------- size: tuple (size_y, size_x) - The number of times each axis will be repeated. """

    def __init__(self, scale_factor=2):
        self.scale_factor = scale_factor
        self.device = "cpu"

    def forward(self, x):
        return x.repeat_interleave(self.scale_factor, dim=2).repeat_interleave(self.scale_factor, dim=3)
        
    def backward(self, dy):
        N, C, H, W = dy.shape
        dx = torch.zeros((N, C, H//self.scale_factor, W//self.scale_factor)).to(self.device)
        for i in range(H//self.scale_factor):
            for j in range(W//self.scale_factor):
                dx[:, :, i, j] = dy[:, :, i*self.scale_factor:(i+1)*self.scale_factor, j*self.scale_factor:(j+1)*self.scale_factor].sum(dim=(-1, -2))
        return dx


    

compare_fcts(nn.Upsample(scale_factor=2, mode='nearest'), UpSampling2D(scale_factor=2))

max error forward pass:  tensor(0., grad_fn=<MaxBackward1>)
max error backward pass:  tensor(0.)


In [10]:
class Sequential(Module):
    def __init__(self, *functions):
        self.functions = functions

    def __call__(self, x):
        return self.forward(x)

    def forward(self, x):
        for func in self.functions:
            x = func.forward(x)
        return x

    def backward(self, grad):
        for func in reversed(self.functions):
            grad = func.backward(grad)
        return grad

    def get_params(self):
        return [func.get_params() for func in self.functions]

    def set_params(self, params):
        for i, param in enumerate(params):
            self.functions[i].set_params(param)

    def cuda(self):
        for func in self.functions:
            func.cuda()

    def parameters(self):
        return [func.params() for func in self.functions]

    def eval(self):
        for func in self.functions:
            func.eval()

    def train(self):
        for func in self.functions:
            func.train()

In [11]:
seq_test_scratch = Sequential(UpSampling2D(scale_factor=2), UpSampling2D(scale_factor=2), ReLU())
seq_test_gd = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'), nn.Upsample(scale_factor=2, mode='nearest'), nn.ReLU())

compare_fcts(seq_test_gd, seq_test_scratch)

max error forward pass:  tensor(0., grad_fn=<MaxBackward1>)
max error backward pass:  tensor(0.)


In [12]:
# skip connection block

class SkipConnection(Module):

    def __init__(self, module):
        self.module = module
        self.size_in = None
        self.size_out = None

    def forward(self, x):
        self.size_in = x.shape
        skip_connection = x
        x = self.module.forward(x)
        x =  torch.cat((x, skip_connection), dim=1)
        self.size_out = x.shape
        return x
    
    def backward(self, dz):
        N, C, H, W = dz.shape
        dy = dz[:, :C - self.size_in[1], :, :]
        dx = dz[:, C - self.size_in[1]:, :, :]
        dy = self.module.backward(dy)
        dx += dy
        return dx

    def get_params(self):
        return self.module.get_params()

    def set_params(self, params):
        self.module.set_params(params)

    def cuda(self):
        self.module.cuda()

    def params(self):
        return self.module.params()

    def eval(self):
        self.module.eval()

    def train(self):
        self.module.train()

class SkipConnection_gd(nn.Module):
    
        def __init__(self, module):
            super(SkipConnection_gd, self).__init__()
            self.module = module
    
        def forward(self, x):
            return torch.cat((self.module.forward(x), x), dim=1)

conv2_gd = Conv2d(3, 16, 3, 1, 1)
conv2_scratch = Conv2d(3, 16, 3, 1, 1)
conv2_gd.weight = conv2_scratch.weight.clone()
conv2_gd.bias = conv2_scratch.bias.clone()

skip_test_scratch = SkipConnection(conv2_scratch)
skip_test_gd = SkipConnection_gd(conv2_gd)

compare_fcts(skip_test_gd, skip_test_scratch)



max error forward pass:  tensor(0., grad_fn=<MaxBackward1>)
max error backward pass:  tensor(0.)


In [13]:
gd = SkipConnection_gd(conv2_gd)
scratch = SkipConnection(conv2_scratch)

x = torch.rand(1, 3, 25, 25, requires_grad=True)
y_gd = gd(x)
with torch.no_grad():
    grad = torch.rand_like(y_gd)
    y_scratch = scratch(x)
    y_gd.backward(grad)
    y_grad_gd = x.grad
    y_grad_scatch = scratch.backward(grad)

print("max error forward pass: ", torch.max(torch.abs(y_gd - y_scratch)))
print("max error backward pass: ", torch.max(torch.abs(y_grad_gd - y_grad_scatch)))

max error forward pass:  tensor(0., grad_fn=<MaxBackward1>)
max error backward pass:  tensor(0.)


In [14]:
class MSELoss:
    def __init__(self):
        self.grad = None

    def __call__(self, model_output, ground_truth):
        return self.forward(model_output, ground_truth)

    def forward(self, model_output, ground_truth):
        self.grad = 2 * (model_output - ground_truth) / model_output.numel()
        return ((ground_truth - model_output) ** 2).mean()

    def backward(self):
        return self.grad

In [15]:
class Optimizer(object):
    def __init__(self, params, lr):
        self.params = params
        self.lr = lr

    def step(self):
        for param in self.params:
            for elem in param:
                elem[0] -= self.lr * elem[1]



In [16]:
# Noise2Noise network

def convblock(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, activation=LeakyRelu(0.1)):
    return Sequential(
        Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
        activation
    )


class Noise2Noise(Module):
    
    def __init__(self) -> None:
        super().__init__()
        self.encoder01 = Sequential(
            convblock(in_channels=3, out_channels=48, kernel_size=3, stride=1, padding=1, bias=False, activation=LeakyRelu(0.1)),
            convblock(in_channels=48, out_channels=48, kernel_size=3, stride=2, padding=1, bias=False, activation=LeakyRelu(0.1)),
            MaxPool2d()
        )

        self.encoder2 = Sequential(
            convblock(in_channels=48, out_channels=48, kernel_size=3, stride=1, padding=1, bias=False, activation=LeakyRelu(0.1)),
            MaxPool2d()
        )

        self.encoder3 = Sequential(
            convblock(in_channels=48, out_channels=48, kernel_size=3, stride=1, padding=1, bias=False, activation=LeakyRelu(0.1)),
            MaxPool2d()
        )

        self.encoder4 = Sequential(
            convblock(in_channels=48, out_channels=48, kernel_size=3, stride=1, padding=1, bias=False, activation=LeakyRelu(0.1)),
            MaxPool2d()
        )

        self.encoder56 = Sequential(
            convblock(in_channels=48, out_channels=48, kernel_size=3, stride=1, padding=1, bias=False, activation=LeakyRelu(0.1)),
            MaxPool2d(),
            convblock(in_channels=48, out_channels=48, kernel_size=3, stride=1, padding=1, bias=False, activation=LeakyRelu(0.1)),
            UpSampling2D(scale_factor=2),

        )

        self.decoder5ab = Sequential(
            convblock(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1, bias=False, activation=LeakyRelu(0.1)),
            UpSampling2D(scale_factor=2),
        )

        self.decoder4ab = Sequential(
            convblock(in_channels=144, out_channels=96, kernel_size=3, stride=1, padding=1, bias=False, activation=LeakyRelu(0.1)),
            convblock(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1, bias=False, activation=LeakyRelu(0.1)),
            UpSampling2D(scale_factor=2),
        )

        self.decoder3ab = Sequential(
            convblock(in_channels=144, out_channels=96, kernel_size=3, stride=1, padding=1, bias=False, activation=LeakyRelu(0.1)),
            convblock(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1, bias=False, activation=LeakyRelu(0.1)),
            UpSampling2D(scale_factor=2),
        )

        self.decoder2ab = Sequential(
            convblock(in_channels=144, out_channels=96, kernel_size=3, stride=1, padding=1, bias=False, activation=LeakyRelu(0.1)),
            convblock(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1, bias=False, activation=LeakyRelu(0.1)),
            UpSampling2D(scale_factor=2),
        )

        self.decoder1abc = Sequential(
            convblock(in_channels=99, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False, activation=LeakyRelu(0.1)),
            convblock(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False, activation=LeakyRelu(0.1)),
            Conv2d(channels_in=32, channels_out=3, kernel_size=3, stride=1, padding=1, bias=True),
        )


        self.skip1 = SkipConnection(self.encoder56)
        self.skip2 = SkipConnection(Sequential(self.encoder4, self.skip1, self.decoder5ab))
        self.skip3 = SkipConnection(Sequential(self.encoder3, self.skip2, self.decoder4ab))
        self.skip4 = SkipConnection(Sequential(self.encoder2, self.skip3, self.decoder3ab))
        self.skip5 = SkipConnection(Sequential(self.encoder01, self.skip4, self.decoder2ab))
        self.net = Sequential(self.skip5, self.decoder1abc)

        self.lr = 6.7

        self.loss = MSELoss()
        self.optimizer = Optimizer(self.net.params(), lr=self.lr)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"


    def load_pretrained_model(self, model_path) -> None:
        ## This loads the parameters saved in bestmodel .pth into the model
        params = torch.load(model_path, map_location=self.device)
        self.net.set_params(params)

    def train(self, train_input, train_target, num_epochs) -> None:
        train_input = (train_input.float() / 255).to(self.device)
        train_target = (train_target.float() / 255).to(self.device)
        if torch.cuda.is_available():
            self.net.cuda()
            self.optimizer = Optimizer(self.net.params(), lr=self.lr)
        self.net.train()
        n_data = len(train_input)
        batch_size = 64
        for epoch in range(num_epochs):
            print(f'EPOCH: {epoch + 1}/{num_epochs}')
            loss = []
            for first in range(0, n_data, batch_size):
                with torch.no_grad():
                    last = first + batch_size
                    x_batch, y_batch = train_input[first:last], train_target[first:last]

                    results = self.net.forward(x_batch)
                    loss_ = self.loss(results, y_batch)
                    loss.append(loss_)
                    self.net.backward(self.loss.backward())
                    self.optimizer.step()

            sum = 0
            for val in loss:
                sum += val
            print(
                f'{(first / batch_size)} / {(n_data // batch_size)} | loss: {(sum / (first / batch_size + 1))}')

    def predict(self, test_input) -> torch.Tensor:
        #: test_input : tensor of size (N1 , C, H, W) with values in range 0 -255 that has tobe denoised by the trained or the loaded network
        #: returns a tensor of the size (N1 , C, H, W) with values in range 0 -255.
        self.net.eval()

        def normalization_cut(imgs):
            imgs_shape = imgs.shape
            imgs = imgs.flatten()
            imgs[imgs < 0] = 0
            imgs[imgs > 1] = 1
            imgs = imgs.view(imgs_shape)
            return imgs

        return 255 * normalization_cut(self.net(test_input.float() / 255).to(self.device)).cpu()

    def save(self, model_path):
        pck_file = self.net.get_params()
        torch.save(pck_file, model_path)
    


In [17]:
model = Noise2Noise()

model.cuda()

optimizer = Optimizer(model.parameters(), lr=0.001)

AttributeError: 'Noise2Noise' object has no attribute 'parameters'

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import torchvision.transforms as transforms

import torch
import numpy as np

from matplotlib import pyplot as plt



class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset1, dataset2):
        self.datasets = torch.cat([dataset1[:, None], dataset2[:, None]], dim=1)
        self.transforms = torch.nn.Sequential(
            # transforms.TrivialAugmentWide()
            transforms.RandomHorizontalFlip(0.5),
            transforms.RandomVerticalFlip(0.5),
            # transforms.RandomCrop(size=(32, 32)),
            # transforms.ColorJitter(brightness=.5, hue=.3)
        )

    def __getitem__(self, i):
        if torch.rand(1) > 0.5:
            return self.transforms(self.datasets[i])
        else:
            return self.transforms(self.datasets[i, [1, 0]])

    def __len__(self):
        return len(self.datasets)

def display_rgb(img):
    plt.imshow(img.permute(1,2,0))
    plt.show()

def psnr_eval(model, noised, ground_truth, must_randomize=True):
    def psnr(denoised, ground_truth):
        mse = torch.mean((denoised.cpu() - ground_truth.cpu()) ** 2)
        return -10 * torch.log10(mse + 10 ** -8)
    clean_imgs = ground_truth.clone()
    noised = noised.float()
    ground_truth = ground_truth.float()

    denoised = model.predict(noised) / 255

    psnr_result = psnr(denoised, (ground_truth / 255)).item()
    print(f'PSNR result: {psnr_result}dB')

    nb_images = 3

    f, axarr = plt.subplots(nb_images, 3)

    if must_randomize:
        nb_index = np.random.choice(len(noised), nb_images)
    else:
        nb_index = np.arange(nb_images)
    axarr[0, 0].set_title("Noisy Images")
    axarr[0, 1].set_title("Denoised")
    axarr[0, 2].set_title("Ground Truth")

    for i, index in enumerate(nb_index):
        axarr[i, 0].imshow(noised[index].permute(1,2,0).int())
        axarr[i,0].get_yaxis().set_visible(False)
        axarr[i,0].get_xaxis().set_visible(False)
        axarr[i, 1].imshow(denoised[index].cpu().detach().permute(1,2,0))
        axarr[i, 1].get_yaxis().set_visible(False)
        axarr[i, 1].get_xaxis().set_visible(False)
        axarr[i, 2].imshow(clean_imgs[index].permute(1,2,0))
        axarr[i, 2].get_yaxis().set_visible(False)
        axarr[i, 2].get_xaxis().set_visible(False)
    plt.show()

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'The model will be loaded on the {"GPU" if device == "cuda" else "cpu"}.')

noisy_imgs_1, noisy_imgs_2 = torch.load('train_data.pkl')
noisy_imgs, clean_imgs = torch.load('val_data.pkl')



In [None]:
psnr_eval(model, noisy_imgs, clean_imgs, must_randomize=False)
