In [None]:
import torch.nn.functional as F

class SimpleEncoder(nn.Module):

    def __init__(self, input_channels=3, latent_dim=256):
        super(SimpleEncoder, self).__init__()

        # First layer: 3x64x64 → 32x64x64
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)

        # Down sampling 1: 32x64x64 -> 64x32x32
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        # Down sampling 2: 64x32x32 -> 128x16x16
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)

        # Down sampling 3: 128x16x16 -> 256x8x8
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(256)

        # Down sampling 4: 256x8x8 -> 512x4x4
        self.conv5 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
        self.bn5 = nn.BatchNorm2d(512)


        # Final: 512x4x4 -> 8192
        self.flatten_size = 512 * 4 * 4

        # Son katman: 8192 -> 1024 (with sigmoid aktivation)

        self.fc = nn.Sequential(
            nn.Linear(self.flatten_size, 2048),
            nn.SiLU(),
            nn.BatchNorm1d(2048),

            nn.Linear(2048, latent_dim),
            nn.Sigmoid()
        )
    def forward(self, x):
        # Step-by-step downsampling
        # Input: 3x64x64
        x = self.bn1(F.silu(self.conv1(x)))    # -> 32x64x64
        x = self.bn2(F.silu(self.conv2(x)))    # -> 64x32x32
        x = self.bn3(F.silu(self.conv3(x)))    # -> 128x16x16
        x = self.bn4(F.silu(self.conv4(x)))    # -> 256x8x8
        x = self.bn5(F.silu(self.conv5(x)))    # -> 512x4x4

        # Flatten
        x = x.view(x.size(0), -1)  # -> batch_size x 8192

        encoded = self.fc(x) # -> batch_size x 1024

        return encoded

class SimpleDecoder(nn.Module):

    def __init__(self, latent_dim=256, output_channels=3):
        super(SimpleDecoder, self).__init__()

        # Convert latent to feature map: 1024 → 512x4x4
        self.fc = nn.Sequential(
            nn.Linear(latent_dim,  2048),
            nn.SiLU(),

            nn.Linear(2048, 512 * 4 * 4)
        )

        # Up sampling 1: 512x4x4 -> 256x8x8
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(256)

        # Up sampling 2: 256x8x8 -> 128x16x16
        self.deconv2 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)

        # Up sampling 3: 128x16x16 -> 64x32x32
        self.deconv3 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(64)

        # Up sampling 4: 64x32x32 -> 32x64x64
        self.deconv4 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(32)

        # Final output: 32x64x64 -> 3x64x64 (with sigmoid aktivation)
        self.final_conv = nn.Conv2d(32, output_channels, kernel_size=3, padding=1)

    def forward(self, z):
        # Convert latent to feature map
        x = self.fc(z)                         # 1024 -> 8192
        x = x.view(x.size(0), 512 , 4 , 4)      # -> 512x4x4

        # Step-by-step upsampling
        x = F.silu(self.bn1(self.deconv1(x)))  # -> 256x8x8
        x = F.silu(self.bn2(self.deconv2(x)))  # -> 128x16x16
        x = F.silu(self.bn3(self.deconv3(x)))  # -> 64x32x32
        x = F.silu(self.bn4(self.deconv4(x)))  # -> 32x64x64

        # Final output with sigmoid activation
        x = torch.sigmoid(self.final_conv(x))  # -> 3x64x64 (In the range [0, 1])

        return x

class To_Uniform(nn.Module):

    def __init__(self, input_channels=3, latent_dim=256, output_channels=3):
        super(To_Uniform, self).__init__()

        self.encoder = SimpleEncoder(input_channels, latent_dim)
        self.decoder = SimpleDecoder(latent_dim, output_channels)
        self.latent_dim = latent_dim

    def forward(self, x):
        # Encode: 3x64x64 -> 1024 (Scaled to [0, 1] using sigmoid)
        encoded = self.encoder(x)

        # Decode: 1024 -> 3x64x64 (Scaled to [0, 1] using sigmoid)
        reconstructed = self.decoder(encoded)

        return encoded, reconstructed

    def encode(self, x):
        """Encode only"""
        return self.encoder(x)

    def decode(self, z):
        """Decode only"""
        return self.decoder(z)

    def criterion(self, z_pred, z_true, x_pred, x_true):
        return -1.0 * F.mse_loss(z_pred, z_true) + 1.0 * F.mse_loss(x_pred, x_true)