# Masked Evidence Lower Bound

In [60]:
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader

import numpy as np
import pandas as pd

import io
import os
import tqdm

In [61]:
def load_data(filepath: str):
    buffer = None
    filepath = os.path.expanduser(filepath)
    with open(filepath, "r") as f:
        buffer = io.StringIO(f.read().replace("TRUE", "1").replace("FALSE", "0"))
    return pd.read_csv(buffer).to_numpy()

In [62]:
mar60_mask = load_data("../Data Simulation/missingness mechanism level data/MCAR60_mask.csv")
mar60_full = load_data("../Data Simulation/missingness mechanism level data/MCAR60_X_full.csv")

In [74]:
np.random.seed(42)
test_train_mask = np.array(np.random.binomial(1, 1/3, (3000,)), dtype=np.bool)

In [75]:
mar60_full_test = mar60_full[test_train_mask, :]
mar60_full_train = mar60_full[test_train_mask ^ True, :]
mar60_mask_test = mar60_mask[test_train_mask, :]
mar60_mask_train = mar60_mask[test_train_mask ^ True, :]

In [76]:
# Model Hyperparameters

cuda = False
DEVICE = torch.device("cuda" if cuda else "cpu")

batch_size = 100

x_dim = 20
hidden_dim = 400
latent_dim = 200

lr = 1e-3

epochs = 30

In [77]:
kwargs = {'num_workers': 1, 'pin_memory': True}
test_loader = DataLoader((mar60_full_test, mar60_mask_test), batch_size=batch_size, shuffle=False, **kwargs)
train_loader = DataLoader((mar60_full_train, mar60_mask_train), batch_size=batch_size, shuffle=True, **kwargs)

In [78]:
class Encoder(nn.Module):
    """Gaussian MLP Encoder."""
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.FC_input = nn.Linear(input_dim, hidden_dim)
        self.FC_input2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_mean = nn.Linear(hidden_dim, latent_dim)
        self.FC_var = nn.Linear(hidden_dim, latent_dim)

        self.LeakyReLU = nn.LeakyReLU(0.2)

        self.training = True

    def forward(self, x):
        h_ = self.LeakyReLU(self.FC_input(x))
        h_ = self.LeakyReLU(self.FC_input2(h_))
        mean = self.FC_mean(h_)
        log_var = self.FC_var(h_)

        return mean, log_var

In [79]:
class Decoder(nn.Module):
    """Gaussian MLP Decoder."""
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()

        self.FC_hidden = nn.Linear(latent_dim, hidden_dim)
        self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_output = nn.Linear(hidden_dim, output_dim)

        self.LeakyReLU = nn.LeakyReLU(0.2)

    def forward(self, x):
        h = self.LeakyReLU(self.FC_hidden(x))
        h = self.LeakyReLU(self.FC_hidden2(h))
        x_hat = torch.sigmoid(self.FC_output(h))
        return x_hat

In [80]:
class Model(nn.Module):
    def __init__(self, encoder, decoder):
        super(Model, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparametrization(self, mean, var):
        # Sample from standard normal
        epsilon = torch.randn_like(var).to(DEVICE)
        # Reparametrization trick
        z = mean + var * epsilon
        return z
    
    def forward(self, x):
        mean, log_var = self.encoder(x)
        # Takes exponential function (log_var -> var)
        z = self.reparametrization(mean, torch.exp(.5 * log_var))
        x_hat = self.decoder(z)
        return x_hat, mean, log_var

In [81]:
encoder = Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
decoder = Decoder(latent_dim=latent_dim, hidden_dim=hidden_dim, output_dim=x_dim)

model = Model(encoder=encoder, decoder=decoder).to(DEVICE)

### 3. Define Loss Function

In [82]:
from torch.optim import Adam

BCE_loss = nn.BCELoss()


def loss_function(x, x_hat, mean, log_var, mask):
    reproduction_loss = nn.functional.binary_cross_entropy(
    x_hat,
    x.float(),
    reduction='none'
)
    masked_repr_loss = (x_hat * mask).sum()
    KLD = -.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())

    return masked_repr_loss + KLD


optimizer = Adam(model.parameters(), lr=lr)

### 4. Train the Model

In [83]:
model.train()

for epoch in range(epochs):
    overall_loss = 0
    for batch_idx, (x, masks), in enumerate(train_loader):
        original_x = x.clone()  # Save original x for loss computation
        
        x = x.to(DEVICE).float()
        # Rescaling the data from {0,1} to {-1,1}
        x = x * 2 - 1
        masks = masks.to(DEVICE).float()

        x = x * masks
        x = x.to(DEVICE).float()

        optimizer.zero_grad()

        x_hat, mean, log_var = model(x)
        loss = loss_function(original_x, x_hat, mean, log_var, masks)

        overall_loss += loss.item()

        loss.backward()
        optimizer.step()

    print(f"\tEpoch {epoch + 1} complete.\tAverage Loss: {overall_loss / (batch_idx + 1)}")
print("Finish.")



	Epoch 1 complete.	Average Loss: 14039.408203125
	Epoch 2 complete.	Average Loss: 12025.8388671875
	Epoch 3 complete.	Average Loss: 10837.6630859375
	Epoch 4 complete.	Average Loss: 9767.6904296875
	Epoch 5 complete.	Average Loss: 8607.30078125
	Epoch 6 complete.	Average Loss: 7400.46533203125
	Epoch 7 complete.	Average Loss: 6090.17822265625
	Epoch 8 complete.	Average Loss: 4996.54638671875
	Epoch 9 complete.	Average Loss: 3587.431396484375
	Epoch 10 complete.	Average Loss: 2555.0224609375
	Epoch 11 complete.	Average Loss: 1869.779052734375
	Epoch 12 complete.	Average Loss: 1253.478271484375
	Epoch 13 complete.	Average Loss: 790.1278076171875
	Epoch 14 complete.	Average Loss: 509.42083740234375
	Epoch 15 complete.	Average Loss: 341.7733154296875
	Epoch 16 complete.	Average Loss: 226.39454650878906
	Epoch 17 complete.	Average Loss: 156.79315185546875
	Epoch 18 complete.	Average Loss: 120.42412567138672
	Epoch 19 complete.	Average Loss: 99.82807922363281
	Epoch 20 complete.	Average Loss

In [84]:
model.eval()

total_test_loss = 0.0
total_samples = 0

with torch.no_grad():
    for x, masks in tqdm.tqdm(test_loader):
        original_x = x.clone()  # Save original x for loss computation
        
        x = x.to(DEVICE).float()
        # Rescaling the data from {0,1} to {-1,1}
        x = x * 2 - 1
        masks = masks.to(DEVICE).float()

        x = x * masks
        x = x.to(DEVICE).float()

        x_hat, mean, log_var = model(x)
 
        # Calculate loss for the batch
        batch_loss = loss_function(original_x, x_hat, mean, log_var, masks)

        total_test_loss += batch_loss.item() * x.size(0)
        total_samples += x.size(0)

# Calculate final average error metric
average_loss = total_test_loss / total_samples
print(f"Test Loss: {average_loss:.4f}")

100%|██████████| 1/1 [00:00<00:00,  9.34it/s]

Test Loss: 12.9288





In [47]:
mar60_full_test

array([[0, 1, 1, ..., 1, 0, 0],
       [1, 1, 0, ..., 1, 1, 1],
       [0, 1, 1, ..., 1, 0, 1],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 1, 0, 0],
       [0, 1, 0, ..., 1, 0, 1]], shape=(992, 20))

In [86]:
(mar60_full_test * 2 - 1) * mar60_mask_test

array([[-1,  1,  0, ...,  0, -1, -1],
       [ 0,  1, -1, ...,  1,  1,  1],
       [-1,  0,  0, ..., -1,  0, -1],
       ...,
       [ 0, -1,  0, ..., -1, -1, -1],
       [-1,  1,  0, ...,  1, -1,  0],
       [-1,  1,  1, ...,  1,  0,  1]], shape=(1003, 20))

In [None]:
with torch.no_grad():
    