In [3]:
import torch
from torch import nn

In [2]:
torch.__version__

'1.13.0'

In [31]:
from collections import OrderedDict

def init_weights(module):
    if type(module) == nn.Conv2d:
        nn.init.xavier_normal(module.weight)

def crop(image: torch.Tensor, new_shape: int):
    if image.shape[-1] == new_shape:
        return image

    pad = image.shape[-1] - new_shape

    if pad % 2 == 0:
        pad = pad // 2
        return image[..., pad:-pad, pad:-pad]

    return image[..., :-pad, :-pad]



class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, padding: str = "valid", sample: str = None):
        super().__init__()

        module_list = [
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), padding=padding),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), padding=padding),
            nn.ReLU(),
        ]

        if sample == "down":
            module_list.insert(0, nn.MaxPool2d(kernel_size=(2, 2), stride=2))
        elif sample == "up":
            module_list.append(nn.ConvTranspose2d(in_channels=out_channels, out_channels=(out_channels // 2), kernel_size=(2, 2), stride=2))

        self.block = nn.Sequential(*module_list)

    def forward(self, x):
        return self.block.forward(x)




class ImageEncoder(nn.Module):

    def __init__(self):
        super().__init__()

        self.compression_block = nn.ModuleList([
            ConvBlock(3, 64),
            ConvBlock(64, 128, sample="down"),
            ConvBlock(128, 256, sample="down"),
            ConvBlock(256, 512, sample="down"),
            ConvBlock(512, 1024, sample="down")
        ])

        self.compression_block.apply(init_weights)

    def forward(self, image):
        outputs = []
        for block in self.compression_block:
            x = block.forward(x)
            outputs.append(x.clone())

        return outputs[:-1], outputs[-1]

class ImageDecoder(nn.Module):

    def __init__(self):
        super().__init__()
        self.decoder_sequence = nn.Sequential()

    def forward(self):
        pass


class CompressionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, padding='valid'):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(out_channels, out_channels, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2)
        )



SyntaxError: invalid syntax (2419382481.py, line 49)

In [21]:
model = ImageEncoder()

In [25]:
test = torch.randn(1, 3, 1200, 800)

In [26]:
test.shape

torch.Size([1, 3, 1200, 800])

In [27]:
model.forward(test).shape

torch.Size([1, 36, 198, 131])