In [200]:
from typing import Tuple


import torch
import torch.nn as nn
from torchvision.models import (alexnet,
                                AlexNet_Weights,
                                resnet18,
                                ResNet18_Weights,
                                ResNet)


model = alexnet(weights=AlexNet_Weights.DEFAULT)
encoder  = model.features



In [243]:
class NeuralImageCompressor(nn.Module):
    def __init__(self,
                 encoder: nn.Module,
                 decoder: nn.Module,
                 normalising_activation: nn.Module = nn.Sigmoid(),
                 B: int = 1):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.normalising_activation = normalising_activation
        self.B = B
            
    def _get_quantization_error(self, shape: Tuple[int, ...]):
        mean = torch.full(shape, -0.5)
        std = torch.full(shape, 0.5)
        return 0.5**self.B * torch.normal(mean = mean, std = std)
    
    def forward(self, x):
        out = self.encoder(x)
        out = self.normalising_activation(out)
        out += self._get_quantization_error(out.shape)
        out = self.decoder(out)
        return out

In [175]:
class AutoencoderAlexNet(NeuralImageCompressor):
    def __init__(self, B: int = 1, normalising_activation: nn.Module = nn.Sigmoid()):
        encoder = alexnet(weights=AlexNet_Weights.DEFAULT).features
        decoder = self._get_decoder(encoder)
        super().__init__(encoder, decoder, normalising_activation, B)
    
    @staticmethod
    def _get_decoder(encoder: nn.Sequential):
        decoder_modules = []

        for module in reversed(encoder):
            if isinstance(module, nn.Conv2d):
                trans_conv = nn.ConvTranspose2d(
                    in_channels=module.out_channels,
                    out_channels=module.in_channels,
                    kernel_size=module.kernel_size,
                    stride=module.stride,
                    padding=module.padding
                )
                decoder_modules.append(trans_conv)
            elif isinstance(module, nn.ReLU):
                decoder_modules.append(nn.ReLU(inplace=True))
            elif isinstance(module, nn.MaxPool2d):
                # We can use MaxUnpool if we're going to save MaxPool indicies
                # unpool = nn.MaxUnpool2d(
                #     kernel_size=module.kernel_size,
                #     stride=module.stride,
                #     padding=module.padding
                # )
                upsample = nn.Upsample(
                    scale_factor=2,
                    mode='nearest'
                )
                decoder_modules.append(upsample)
            else:
                raise ValueError(f"unexpected module {module}")
            
        decoder = nn.Sequential(*decoder_modules)
        return decoder

In [287]:
alexnet_autoencoder = AutoencoderAlexNet()

test_image = torch.rand(3, 256, 256)

alexnet_autoencoder(test_image.unsqueeze(0)).shape

torch.Size([1, 3, 227, 227])

In [288]:
class AutoEncoderResNet(NeuralImageCompressor):
    def __init__(self, resnet: ResNet, decoder: nn.Module,
                 normalising_activation: nn.Module = nn.Sigmoid(), B: int = 1):
        encoder = self._get_resnet_encoder(resnet)
        super().__init__(encoder, decoder, normalising_activation, B)
    
    @staticmethod
    def _get_resnet_encoder(resnet: ResNet):
        return nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4)

In [307]:
def decoder_8x_upsample_constructor(in_chan_num):
    out_chan_nums = [512, 256, 128, 64, 3]

    decoder_modules = []

    for out_chan_num in out_chan_nums:
        decoder_modules.append(
            nn.Sequential(
                nn.Upsample(
                    scale_factor=2,
                    mode='nearest'
                ),
                nn.Conv2d(in_channels=in_chan_num, out_channels=out_chan_num, kernel_size=3, stride=1, padding=1),
                nn.ReLU()
            )
        )

        in_chan_num = out_chan_num
    
    return nn.Sequential(*decoder_modules)

In [308]:
def resnet_encoder_constructor(resnet):
    return nn.Sequential(
        resnet.conv1,
        resnet.bn1,
        resnet.relu,
        resnet.maxpool,
        resnet.layer1,
        resnet.layer2,
        resnet.layer3,
        resnet.layer4)

In [309]:
resnet_encoder = resnet_encoder_constructor(resnet18(weights=ResNet18_Weights.DEFAULT))

decoder_in_channels = 512

decoder_8x_upsample = decoder_8x_upsample_constructor(decoder_in_channels)

In [310]:
resnet_autoencoder = NeuralImageCompressor(resnet_encoder, decoder_8x_upsample, nn.Sigmoid(), 1)

In [314]:
test_image = torch.rand(3, 64, 64)
resnet_autoencoder(test_image.unsqueeze(0)).shape

torch.Size([1, 3, 64, 64])

In [None]:
test_image = torch.rand(3, 64, 64)
resnet_autoencoder(test_image.unsqueeze(0)).shape

In [227]:
# def get_output_channels_of_resnet_encoder(resnet_encoder):
#     num_out_channels = None
#     for module in reversed(resnet_encoder[-1][-1]):
#         if isinstance(module, nn.Conv2d):
#             last_conv = module
#             num_out_channels = module.out_channels
#     return num_out_channels

In [233]:
# list(resnet_encoder[-1][-1].children())

[Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(inplace=True),
 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)]

In [201]:
# Upsampling with transposed convolution

import torch
import torch.nn as nn

input_size = (3, 16, 16)  # Input tensor size (channels, height, width)
output_channels = 3  # Number of output channels
kernel_size = 4  # Kernel size in the transpose convolution
stride = 2  # Stride in the transpose convolution
padding = 1  # Padding in the transpose convolution

# Create a random input tensor
input_tensor = torch.randn(1, *input_size)

# Define the transpose convolution layer
trans_conv = nn.ConvTranspose2d(
    in_channels=input_size[0],
    out_channels=output_channels,
    kernel_size=kernel_size,
    stride=stride,
    padding=padding
)

# Apply the transpose convolution to the input tensor
output_tensor = trans_conv(input_tensor)

# Check the size of the output tensor
print("Input size:", input_tensor.size())
print("Output size:", output_tensor.size())

Input size: torch.Size([1, 3, 16, 16])
Output size: torch.Size([1, 3, 32, 32])
