In [None]:
import torch.nn.functional as F

class SimpleEncoder(nn.Module):

    def __init__(self, input_channels=3, latent_dim=256):
        super(SimpleEncoder, self).__init__()

        # First Layer: 3x32x32 -> 512x32x32
        self.conv1 = nn.Conv2d(input_channels, 512, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(512)

        # Down sampling 1: 512x32x32 -> 256x16x16
        self.conv2 = nn.Conv2d(512, 256, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(256)

        # Down sampling 2: 256x16x16 -> 256x8x8
        self.conv3 = nn.Conv2d(256, 256, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(256)

        # Down sampling 3: 256x8x8 -> 64x4x4
        self.conv4 = nn.Conv2d(256, 64, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(64)

        # Flatten: 64x4x4 -> 1024
        self.flatten_size = 64 * 4 * 4

        # Final Layer: 1024 -> 256 (with sigmoid aktivation)
        self.fc = nn.Sequential(
            nn.Linear(self.flatten_size, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),

            nn.Linear(512, latent_dim),
            nn.Sigmoid()
        )
    def forward(self, x):
        # Step-by-step downsampling
        # Input: 3x32x32
        x = self.bn1(F.relu(self.conv1(x)))    # -> 512x32x32
        x = self.bn2(F.relu(self.conv2(x)))    # -> 256x16x16
        x = self.bn3(F.relu(self.conv3(x)))    # -> 256x8x8
        x = self.bn4(F.relu(self.conv4(x)))    # -> 64x4x4

        # Flatten
        x = x.view(x.size(0), -1)  # -> batch_size x 1024

        # Final layer with sigmoid activation
        encoded = self.fc(x) # -> batch_size x 256 (0-1 arası)

        return encoded

class SimpleDecoder(nn.Module):

    def __init__(self, latent_dim=256, output_channels=3):
        super(SimpleDecoder, self).__init__()

        # Convert latent to feature map: 256 → 64x4x4
        self.fc = nn.Sequential(
            nn.Linear(latent_dim,  512),
            nn.ReLU(),

            nn.Linear(512, 64 * 4 * 4)
        )

        # Up sampling 1: 64x4x4 -> 256x8x8
        self.deconv1 = nn.ConvTranspose2d(64, 256, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(256)

        # Up sampling 3: 256x8x8 -> 256x16x16
        self.deconv2 = nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(256)

        # Up sampling 4: 256x16x16 -> 512x32x32
        self.deconv3 = nn.ConvTranspose2d(256, 512, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(512)

        # Final output: 512x32x32 -> 3x32x32 (with sigmoid aktivation)
        self.final_conv = nn.Conv2d(512, output_channels, kernel_size=3, padding=1)

    def forward(self, z):
        # Convert latent to feature map
        x = self.fc(z)                         # 256 -> 1024
        x = x.view(x.size(0), 64, 4, 4)      # -> 64x4x4

        # Step-by-step upsampling
        x = F.relu(self.bn1(self.deconv1(x)))  # -> 256x8x8
        x = F.relu(self.bn2(self.deconv2(x)))  # -> 256x16x16
        x = F.relu(self.bn3(self.deconv3(x)))  # -> 512x32x32

        # Final output with sigmoid activation
        x = torch.sigmoid(self.final_conv(x))  # -> 3x32x32 (In the range [0, 1])

        return x

class To_Uniform(nn.Module):

    def __init__(self, input_channels=3, latent_dim=256, output_channels=3):
        super(To_Uniform, self).__init__()

        self.encoder = SimpleEncoder(input_channels, latent_dim)
        self.decoder = SimpleDecoder(latent_dim, output_channels)
        self.latent_dim = latent_dim

    def forward(self, x):
        # Encode: 3x32x32 -> 256 (with sigmoid aktivation)
        encoded = self.encoder(x)

        # Decode: 256 -> 3x32x32 (with sigmoid aktivation)
        reconstructed = self.decoder(encoded)

        return encoded, reconstructed

    def encode(self, x):
        """Encode only"""
        return self.encoder(x)

    def decode(self, z):
        """Decode only"""
        return self.decoder(z)

    def criterion(self, z_pred, z_true, x_pred, x_true):
        return - 1.0 * F.mse_loss(z_pred, z_true) + 1.0 * F.mse_loss(x_pred, x_true)