In [10]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.functional import jacobian

In [2]:


class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()

        # Encoder layers
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)     # Outputs the mean
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim) # Outputs the log variance

        # Decoder layers
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        """Encodes input by mapping it into the latent space."""
        h1 = F.relu(self.fc1(x))
        mu = self.fc_mu(h1)
        logvar = self.fc_logvar(h1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """Applies the reparameterization trick to sample from N(mu, var)."""
        std = torch.exp(0.5 * logvar)   # Calculate the standard deviation
        eps = torch.randn_like(std)     # Sample from standard normal
        return mu + eps * std           # Sample from N(mu, var)

    def decode(self, z):
        """Decodes the latent representation z to reconstruct the input."""
        h3 = F.relu(self.fc3(z))
        # x_recon = torch.sigmoid(self.fc4(h3))
        x_recon = self.fc4(h3)

        return x_recon

    def forward(self, x):
        """Defines the computation performed at every call."""
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

def loss_function(recon_x, x, mu, logvar):
    """Computes the VAE loss function."""
    # Reconstruction loss (binary cross entropy)
    # BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')

    MSE = F.mse_loss(recon_x, x, reduction='sum')

    # KL divergence between the approximate posterior and the prior
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) * 0.1

    # KLD = 0

    return MSE + KLD


In [3]:
seed = 1
n = 5000
d = 10

np.random.seed(seed)
A1_vector = np.random.normal(0, 1, d)
A2_vector = np.random.normal(0, 1, d)
A1 = np.diag(A1_vector)
A2 = np.diag(A2_vector)

B1 = np.random.normal(0, 1, d)
B2 = np.random.normal(0, 1, d)

In [4]:
np.random.seed(seed)
X = np.random.normal(0, 1, (n, d))

Y1 = X * A1_vector + B1
Y2 = X * A2_vector + B2

Y = np.random.normal(0, 1, (n, d))


In [5]:
Y1 = torch.from_numpy(Y1).float()
Y2 = torch.from_numpy(Y2).float()

In [8]:
# Assuming X is a torch.Tensor of shape (n, d)
from torch.utils.data import DataLoader, TensorDataset

# Hyperparameters
input_dim = Y1.size(1)   # Dimension of the input data
hidden_dim = 100        # Size of the hidden layer
latent_dim = 10         # Dimension of the latent space
batch_size = 100        # Batch size for training
learning_rate = 1e-3    # Learning rate
epochs = 200             # Number of training epochs

# Create DataLoader
dataset = TensorDataset(Y1)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize the VAE model
model = VAE(input_dim, hidden_dim, latent_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(epochs):
    model.train()
    train_loss = 0
    for data_batch in dataloader:
        data = data_batch[0]
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss = loss / batch_size
        loss.backward()
        train_loss += loss.item() * batch_size
        optimizer.step()
    average_loss = train_loss / len(dataset)
    print(f'Epoch {epoch + 1}, Average Loss: {average_loss:.4f}')


Epoch 1, Average Loss: 14.4744
Epoch 2, Average Loss: 6.3806
Epoch 3, Average Loss: 4.0787
Epoch 4, Average Loss: 3.2962
Epoch 5, Average Loss: 3.0169
Epoch 6, Average Loss: 2.8513
Epoch 7, Average Loss: 2.6544
Epoch 8, Average Loss: 2.4790
Epoch 9, Average Loss: 2.3632
Epoch 10, Average Loss: 2.2644
Epoch 11, Average Loss: 2.2217
Epoch 12, Average Loss: 2.1924
Epoch 13, Average Loss: 2.1677
Epoch 14, Average Loss: 2.1385
Epoch 15, Average Loss: 2.1153
Epoch 16, Average Loss: 2.0865
Epoch 17, Average Loss: 2.0704
Epoch 18, Average Loss: 2.0459
Epoch 19, Average Loss: 2.0320
Epoch 20, Average Loss: 2.0149
Epoch 21, Average Loss: 1.9999
Epoch 22, Average Loss: 1.9994
Epoch 23, Average Loss: 1.9834
Epoch 24, Average Loss: 1.9813
Epoch 25, Average Loss: 1.9714
Epoch 26, Average Loss: 1.9643
Epoch 27, Average Loss: 1.9580
Epoch 28, Average Loss: 1.9475
Epoch 29, Average Loss: 1.9442
Epoch 30, Average Loss: 1.9362
Epoch 31, Average Loss: 1.9352
Epoch 32, Average Loss: 1.9297
Epoch 33, Averag

In [9]:
F.mse_loss(recon_batch, data, reduction='mean')



tensor(0.0555, grad_fn=<MseLossBackward0>)

In [14]:
torch.randn(1, input_dim, requires_grad=True)

tensor([[ 0.1430,  1.2515,  0.2960, -0.2840, -1.0920,  2.6523, -0.7730,  0.4186,
          0.9832, -0.6907]], requires_grad=True)

In [11]:
def encoder_function(x):
    mu, _ = model.encode(x)
    return mu

In [15]:
x = Y1[0:1,]

torch.Size([1, 10])

In [51]:
Y = np.random.normal(0, 1, (n, d))