In [None]:
class To_Uniform(nn.Module):
    def __init__(self,
                 input_dim=2,
                 latent_dim=1,
                 output_dim=2,
                 encoder_hidden=64,
                 decoder_hidden=64,
                 encoder_act=nn.ELU,
                 decoder_act=nn.ELU,
                 final_encoder_act=nn.Sigmoid,
                 final_decoder_act=nn.Sigmoid):
        super().__init__()

        # Encoder: 2 → 1 dimension
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, encoder_hidden), encoder_act(), nn.BatchNorm1d(encoder_hidden),
            nn.Linear(encoder_hidden, encoder_hidden), encoder_act(), nn.BatchNorm1d(encoder_hidden),
            nn.Linear(encoder_hidden, encoder_hidden), encoder_act(), nn.BatchNorm1d(encoder_hidden),
            nn.Linear(encoder_hidden, latent_dim), final_encoder_act()
        )

        # Decoder: 1 → 2 dimension
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, decoder_hidden), decoder_act(),
            nn.Linear(decoder_hidden, decoder_hidden), decoder_act(),
            nn.Linear(decoder_hidden, decoder_hidden), decoder_act(),
            nn.Linear(decoder_hidden, output_dim), final_decoder_act()
        )

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z) + torch.randn(x.shape) * 0.01
        return z, x_hat

    def decode(self, z):
        x_decoded = self.decoder(z)
        x_hat = x_decoded + torch.randn(x_decoded.shape) * 0.01
        return x_hat

    def criterion(self, z_pred, z_true, x_pred, x_true):
        return - 1.0 * nn.MSELoss()(z_pred, z_true) + 1.0 * nn.MSELoss()(x_pred, x_true)