In [1]:
from torch import nn # type: ignore
import torch # type: ignore


In [2]:
class Encoder(nn.Module):
    def __init__(self, layers):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.encoder(x)

In [3]:
class Decoder(nn.Module):
    def __init__(self, layers):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.decoder(x)

In [19]:
class SimpleAutoencoderCNN3D(nn.Module):
    def __init__(self):
        super(SimpleAutoencoderCNN3D, self).__init__()

        # Encoder layers
        encoder_layers = []
        in_channels = 1
        out_channels_list = [16, 32, 64, 128, 256, 1]

        for out_channels in out_channels_list:
            encoder_layers.append(nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1))
            encoder_layers.append(nn.BatchNorm3d(out_channels))
            if out_channels != 1:  # No ReLU for the last layer in the encoder
                encoder_layers.append(nn.ReLU())
            in_channels = out_channels

        # Decoder layers
        decoder_layers = []
        in_channels = 1
        out_channels_list = [256, 128, 64, 32, 16, 1]

        for out_channels in out_channels_list:
            decoder_layers.append(nn.ConvTranspose3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1))
            decoder_layers.append(nn.BatchNorm3d(out_channels))
            if out_channels != 1:  # No ReLU for the last layer in the decoder
                decoder_layers.append(nn.ReLU())
            in_channels = out_channels

        # Initialize Encoder and Decoder
        self.encoder = Encoder(encoder_layers)
        self.decoder = Decoder(decoder_layers)

    def forward(self, x):
        # Encoder
        x = self.encoder(x)

        # Decoder
        x = self.decoder(x)

        return x

In [20]:
x = torch.rand(8, 1, 6, 100, 100)

In [21]:
model = SimpleAutoencoderCNN3D()

In [22]:
out = model(x)

In [23]:
out.shape

torch.Size([8, 1, 6, 100, 100])

In [56]:
x = torch.rand(8, 1, 6, 100, 100)

# conv = nn.Conv3d(1, 1, kernel_siz=1)
# t = conv(x)
 # Final convolution to adjust depth to 1 and squeeze layer
final_layers = nn.Conv3d(1, 1, kernel_size=(6, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
t = final_layers(x)  
print(t.shape)

torch.Size([8, 1, 1, 100, 100])
