In [51]:
from typing import Tuple


import torch
import torch.nn as nn
from torchvision.models import alexnet, AlexNet_Weights


model = alexnet(weights=AlexNet_Weights.DEFAULT)
encoder  = model.features



In [61]:
class UNet(nn.Module):
    def __init__(self, B: int = 1):
        super(UNet, self).__init__()
        self.encoder = alexnet(weights=AlexNet_Weights.DEFAULT).features
        self.decoder = self._get_decoder(self.encoder)
        self.B = B
    
    @staticmethod
    def _get_decoder(encoder: nn.Sequential):
        decoder_modules = []

        for module in reversed(encoder):
            if isinstance(module, nn.Conv2d):
                trans_conv = nn.ConvTranspose2d(
                    in_channels=module.out_channels,
                    out_channels=module.in_channels,
                    kernel_size=module.kernel_size,
                    stride=module.stride,
                    padding=module.padding
                )
                decoder_modules.append(trans_conv)
            elif isinstance(module, nn.ReLU):
                decoder_modules.append(nn.ReLU(inplace=True))
            elif isinstance(module, nn.MaxPool2d):
                trans_conv = nn.MaxUnpool2d(
                    kernel_size=module.kernel_size,
                    stride=module.stride,
                    padding=module.padding
                )
                decoder_modules.append(trans_conv)
            else:
                raise ValueError(f"unexpected module {module}")
            
            decoder = nn.Sequential(*decoder_modules)
            return decoder
        
    def _get_quantization_error(self, shape: Tuple[int, ...]):
        mean = torch.full(shape, -0.5)
        std = torch.full(shape, 0.5)
        return 0.5**self.B * torch.normal(mean = mean, std = std)
    
    def forward(self, x):
        out = self.encoder(x)
        out += self._get_quantization_error(out.shape)
        out = self.decoder(out)
        return out

In [62]:
unet = UNet()

In [66]:
test_image = torch.rand(3, 256, 256)

In [67]:
unet(test_image.unsqueeze(0))

TypeError: MaxUnpool2d.forward() missing 1 required positional argument: 'indices'