# Basic Convolutional Auto-encoder


In [9]:
import os
import cv2
import numpy as np
import imutils
from sklearn.metrics import classification_report

import collections
from collections import namedtuple, OrderedDict

import typing
from typing import List, Tuple, Union, Callable

import matplotlib
import matplotlib.pyplot as plt

import torch
from torch.optim import Adam
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchsummary import summary

from torch import nn
from torch.nn import (Sequential,
                      Module,
                      Conv2d,
                      ConvTranspose2d,
                      Upsample,
                      Linear,
                      MaxPool2d,
                      ReLU,
                      Sigmoid,
                      Flatten,
                      Unflatten,
                      LogSoftmax)

import torchvision.datasets
from torchvision.datasets import KMNIST
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torch.utils.data import Subset
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

Device: cpu


## Model Definition

In [2]:
class AutoEncoderBase(Sequential):
    def __init__(self, *args, **kwargs):
        super(AutoEncoderBase, self).__init__(*args, **kwargs)
    def get_k_out_shape(self,
                        in_shape: Tuple[int,int,int], 
                        module: Union[Conv2d, MaxPool2d] = None, 
                        kernel: Union[Tuple[int,int], int] = None,
                        stride: Union[Tuple[int,int], int] = None,
                        padding: Union[Tuple[int,int], int] = None,
                        out_channels: int = None,
                        ) -> Tuple[int,int,int]:
        # validate arguments.
        if isinstance(module, (Conv2d, MaxPool2d)):
            kernel = module.kernel_size
            stride = module.stride
            padding = module.padding
            if isinstance(module, Conv2d): out_channels = module.out_channels
        elif not all([isinstance(kernel, (tuple, int)), 
                    isinstance(stride, (tuple, int)), 
                    isinstance(padding, (tuple, int))]):
            raise ValueError("Arguments kernel, stride and padding must be ints or tuples of ints.")
        # split into x,y axises.
        if isinstance(kernel, int): kernel_y, kernel_x = kernel, kernel
        else: kernel_y, kernel_x = kernel
        if isinstance(stride, int): stride_y, stride_x = stride, stride
        else: stride_y, stride_x = stride
        if isinstance(padding, int): padding_y, padding_x = padding, padding
        else: padding_y, padding_x = padding
        # calculate dimensions.
        h = int( np.ceil( (in_shape[1] + (padding_y*2) - (kernel_y-1) ) / stride_y ) )
        w = int( np.ceil( (in_shape[2] + (padding_x*2) - (kernel_x-1) ) / stride_x ) )
        if out_channels is None: out_channels = in_shape[0]
        return out_channels, h, w
    def init_x_shapes(self, 
                      image_shape: Tuple[int,int,int], 
                      init_channels: int, 
                      layer_depth: int, 
                      kernel_size: Union[Tuple[int,int], int],
                      stride: Union[Tuple[int,int], int],
                      padding: Union[Tuple[int,int], int],
                      fully_connected: bool,
                      latent_dim: int):
        # determine all shapes throughout feeding forward.
        x_shapes: List[Tuple[int,int,int]] = [image_shape, [init_channels,-1.-1]]
        for _ in range(layer_depth):
            x_shapes[-1] = self.get_k_out_shape(in_shape=x_shapes[-2], 
                                                 kernel=kernel_size, 
                                                 stride=stride, 
                                                 padding=padding, 
                                                 out_channels=x_shapes[-1][0])
            x_shapes.append([x_shapes[-1][0], x_shapes[-1][1]//2, x_shapes[-1][2]//2])
            x_shapes.append([x_shapes[-1][0]*2,-1,-1])
        x_shapes.pop()
        if fully_connected: x_shapes += [(np.prod(x_shapes[-1]),), (latent_dim,)]
        self._x_shapes = x_shapes
    @property
    def x_shapes(self) -> list: return self._x_shapes

In [3]:
class ConvLayer(Sequential):
    def __init__(self, 
                 conv_in_channels: int, 
                 conv_out_channels: int, 
                 conv_kernel_size: Tuple[int,int] = (5,5),
                 conv_stride: int = 1,
                 conv_padding: int = 0,    
                 pool_kernel_size: Tuple[int,int] = (2,2),
                 pool_stride: int = (2,2),
                 pool_padding: int = 0,
                 activation_func: Module = ReLU,
                 ) -> None:
        super(ConvLayer, self).__init__(
            Conv2d(in_channels=conv_in_channels, 
                   out_channels=conv_out_channels, 
                   kernel_size=conv_kernel_size, 
                   stride=conv_stride, 
                   padding=conv_padding), 
            activation_func(), 
            MaxPool2d(kernel_size=pool_kernel_size, 
                      stride=pool_stride, 
                      padding=pool_padding)
            )

In [4]:
class Encoder(AutoEncoderBase):
    def __init__(self, 
                image_shape: Tuple[int,int,int], 
                layer_depth: int,
                latent_dim: int, 
                init_channels: int = 32,
                conv_kernel_size: Tuple[int,int] = (5,5),
                conv_stride: int = 1,
                conv_padding: int = 0,    
                pool_kernel_size: Tuple[int, int] = (2, 2),
                pool_stride: int = (2, 2),
                pool_padding: int = 0,
                activation_func: Module = ReLU,
                fully_connected: bool = True,
                 ) -> None:
        # determine all shapes throughout feeding forward.
        self.init_x_shapes(image_shape=image_shape,
                           init_channels=init_channels,
                           layer_depth=layer_depth,
                           kernel_size=conv_kernel_size,
                           stride=conv_stride,
                           padding=conv_padding,
                           fully_connected=fully_connected,
                           latent_dim=latent_dim)
        # construct each layer.
        encoder_layers: List[ConvLayer] = []
        for i in range(layer_depth):
            idx = i*2
            encoder_layers.append(ConvLayer(
                conv_in_channels=self.x_shapes[idx][0], 
                conv_out_channels=self.x_shapes[idx+1][0],
                conv_kernel_size=conv_kernel_size,
                conv_stride=conv_stride,
                conv_padding=conv_padding,
                pool_kernel_size= pool_kernel_size,
                pool_stride=pool_stride,
                pool_padding=pool_padding,
                activation_func=activation_func))
        # Fully connceted layer for flattening to latent.
        if fully_connected: encoder_layers += [ Flatten(), Linear(in_features=self.x_shapes[-2][0], out_features=self.x_shapes[-1][0]) ]
        # Sequential.
        super(Encoder, self).__init__(*encoder_layers)

In [5]:
class TConvLayer(Sequential):
    def __init__(self, 
                 upsample_size: Tuple[int,int],
                 tconv_in_channels: int, 
                 tconv_out_channels: int, 
                 tconv_kernel_size: Tuple[int,int] = (5,5),
                 tconv_stride: int = 1,
                 tconv_padding: int = 0,
                 activation_func: Module = ReLU,
                 ) -> None:
        super(TConvLayer, self).__init__(
            Upsample(size=upsample_size),
            ConvTranspose2d(in_channels=tconv_in_channels, 
                            out_channels=tconv_out_channels, 
                            kernel_size=tconv_kernel_size, 
                            stride=tconv_stride, 
                            padding=tconv_padding),
            activation_func(),
            )        

In [6]:
class Decoder(AutoEncoderBase):
    def __init__(self, 
                 image_shape: Tuple[int,int,int], 
                 layer_depth: int,
                 latent_dim: int, 
                 kernel_size: Tuple[int,int] = (5,5),
                 stride: int = 1,
                 padding: int = 0,    
                 init_channels: int = 32,
                 activation_func: Callable = ReLU,
                 out_act_func: Callable = Sigmoid,
                 fully_connected: bool = True,
                 ) -> None:
        # determine all shapes throughout feeding forward.
        self.init_x_shapes(image_shape=image_shape,
                           init_channels=init_channels,
                           layer_depth=layer_depth,
                           kernel_size=kernel_size,
                           stride=stride,
                           padding=padding,
                           fully_connected=fully_connected,
                           latent_dim=latent_dim)
        self._x_shapes.reverse()
        # fully connected layer for unflattening from latent.
        decoder_layers = [ Linear(in_features=self.x_shapes[0][0], out_features=self.x_shapes[1][0]), Unflatten(1, self.x_shapes[2]) ]
        # construct each layer
        for i in range(layer_depth):
            idx = (i*2)+3
            decoder_layers.append(TConvLayer(
                upsample_size=self.x_shapes[idx][1:],
                tconv_in_channels=self.x_shapes[idx][0],
                tconv_out_channels=self.x_shapes[idx+1][0],
                tconv_kernel_size=kernel_size,
                tconv_stride=stride,
                tconv_padding=padding,
                activation_func=out_act_func if i+1 == layer_depth else activation_func
            ))
        # Sequential.
        super(Decoder, self).__init__(*decoder_layers)
        

In [10]:
class AutoEncoder(Sequential):
    def __init__(self,
                image_shape: Tuple[int, int, int],
                layer_depth: int,
                latent_dim: int,
                init_channels: int = 32,
                conv_kernel: Tuple[int, int] = (5, 5),
                conv_stride: int = 1,
                conv_padding: int = 0,
                pool_kernel: Tuple[int, int] = (2, 2),
                pool_stride: int = (2, 2),
                pool_padding: int = 0,
                conv_act_func: Module = ReLU,
                tconv_kernel: Tuple[int, int] = (5, 5),
                tconv_stride: int = 1,
                tconv_padding: int = 0,
                tconv_act_func: Callable = ReLU,
                out_act_func: Callable = Sigmoid,
                fully_connected: bool = True
    ):
        super(AutoEncoder, self).__init__(
            Encoder(image_shape=image_shape, 
                    layer_depth=layer_depth, latent_dim=latent_dim,
                    init_channels=init_channels,
                    conv_kernel_size=conv_kernel,
                    conv_stride=conv_stride,
                    conv_padding=conv_padding,
                    pool_kernel_size=pool_kernel,
                    pool_stride=pool_stride,
                    pool_padding=pool_padding,
                    activation_func=conv_act_func,
                    fully_connected=fully_connected), 
            Decoder(image_shape=image_shape, 
                    layer_depth=layer_depth, 
                    latent_dim=latent_dim,
                    kernel_size=tconv_kernel,
                    stride=tconv_stride,
                    padding=tconv_padding,
                    init_channels=init_channels,
                    activation_func=tconv_act_func,
                    out_act_func=out_act_func,
                    fully_connected=fully_connected)
            )

In [11]:
image_shape = (1,50,50)
layer_depth = 3
latent_dim = 128

In [12]:
ae_model = AutoEncoder(image_shape=image_shape,
                       layer_depth=layer_depth,
                       latent_dim=latent_dim)
ae_stat = summary(ae_model, image_shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 46, 46]             832
              ReLU-2           [-1, 32, 46, 46]               0
         MaxPool2d-3           [-1, 32, 23, 23]               0
            Conv2d-4           [-1, 64, 19, 19]          51,264
              ReLU-5           [-1, 64, 19, 19]               0
         MaxPool2d-6             [-1, 64, 9, 9]               0
            Conv2d-7            [-1, 128, 5, 5]         204,928
              ReLU-8            [-1, 128, 5, 5]               0
         MaxPool2d-9            [-1, 128, 2, 2]               0
          Flatten-10                  [-1, 512]               0
           Linear-11                  [-1, 128]          65,664
           Linear-12                  [-1, 512]          66,048
        Unflatten-13            [-1, 128, 2, 2]               0
         Upsample-14            [-1, 12