In [1]:
import torch
import torch.nn as nn

In [183]:
encoder = nn.Sequential(nn.Linear(1000,50), nn.ReLU(), nn.Linear(50,10))
decoder = nn.Sequential(nn.Linear(10,50), nn.ReLU(), nn.Linear(50,1000))
missing_model = nn.Sequential(nn.Linear(1000, 1000), nn.Sigmoid())

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, Bernoulli, Independent

import numpy as np


class notMIWAE(nn.Module):

    def __init__(self, encoder, decoder, missing_model, encoder_input_dim, encoder_output_dim, decoder_output_dim, latent_dim):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.missing_model = missing_model

        self.encoder_input_dim = encoder_input_dim
        self.encoder_output_dim = encoder_output_dim
        self.decoder_output_dim = decoder_output_dim
        self.latent_dim = latent_dim

        self.q_mu = nn.Linear(encoder_output_dim, latent_dim)
        self.q_logvar = nn.Linear(encoder_output_dim, latent_dim)
        self.p_mu = nn.Linear(decoder_output_dim, encoder_input_dim)
        self.p_logvar = nn.Linear(decoder_output_dim, encoder_input_dim)

        # Prior distribution
        self.p_z = Independent(Normal(torch.zeros(1), torch.ones(1)),1)


    def forward(self, x : torch.Tensor, s : torch.Tensor, K : int = 1) -> torch.Tensor:
        """
        Computes the not-MIWAE loss
        
        Inputs:
        ------
        - x: Tensor of shape (batch_size, input_dim)
        - s: Tensor of shape (batch_size, input_dim) with 1s where x is observed and 0s where x is missing
        - K: Number of samples to draw from the posterior (default = 1)

        Outputs:
        -------
        - loss: Tensor of shape (1) with the not-MIWAE loss
        """
        # s[i,j] = 1 iff x[i,j] is observed. If x[i,j] is missing, pad with 0.
        x_observed = s * x                                                # Size (bs,input_dim)

        # Encoder: q(z|x_observed)
        h = self.encoder(x_observed)
        mu_z = self.q_mu(h)
        logvar_z = self.q_logvar(h)

        # Sampling z from q(z|x_observed) K times
        q_z_given_x = Independent(Normal(loc = mu_z, scale = torch.exp(0.5 * logvar_z)), 1)
        z = q_z_given_x.rsample([K])                                      # Size (K,bs,latent_dim)
        
        # Decoder: p(x|z_k) for all k = 1, ..., K
        h = self.decoder(z)
        mu_x = self.p_mu(h)
        logvar_x = self.p_logvar(h)
        p_x_given_z = Independent(Normal(mu_x, torch.exp(0.5 * logvar_x)),1)

        # Sample missing data
        x_missing = p_x_given_z.rsample() * (1 - s)
        x_imputed = x_observed + x_missing

        # Missing model: p(s|x_observed, x_missing)
        p_s_given_x = Independent(Bernoulli(self.missing_model(x_imputed)),1)

        # Log probabilities:
        x_observed_k = torch.Tensor.repeat(x_observed, [K,1,1])            # Size (K,bs,input_dim)
        s_k = torch.Tensor.repeat(s, [K,1,1])                              
        
        log_p_x_given_z = p_x_given_z.log_prob(x_observed_k)               # Size (K, bs)
        log_q_z_given_x = q_z_given_x.log_prob(z)
        log_p_s_given_x = p_s_given_x.log_prob(s_k)
        log_p_z = self.p_z.log_prob(z)

        loss = -torch.mean(torch.logsumexp(log_p_x_given_z + log_p_s_given_x - log_q_z_given_x + log_p_z - np.log(K), dim = 0))
     
        return loss

    def impute(self, x : torch.Tensor, s : torch.Tensor, K : int = 1) -> torch.Tensor:
            """ 
            Imputes missing values in x using the trained model
            
            Inputs:
            ------
            - x: Tensor of shape (batch_size, input_dim)
            - s: Tensor of shape (batch_size, input_dim) with 1s where x is observed and 0s where x is missing
            - K: Number of samples to draw from the posterior (default = 1)

            Outputs:
            -------
            - x_imputed: Tensor of shape (batch_size, input_dim) with the imputed values and 0s where x is observed
            """
            # s[i,j] = 1 iff x[i,j] is observed. If x[i,j] is missing, pad with 0.
            x_observed = s * x                                                # Size (bs,input_dim)

            # Encoder: q(z|x_observed)
            h = self.encoder(x_observed)
            mu_z = self.q_mu(h)
            logvar_z = self.q_logvar(h)

            # Sampling z from q(z|x_observed) K times
            q_z_given_x = Independent(Normal(loc = mu_z, scale = torch.exp(0.5 * logvar_z)), 1)
            z = q_z_given_x.rsample([K])                                      # Size (K,bs,latent_dim)
            
            # Decoder: p(x|z_k) for all k = 1, ..., K
            h = self.decoder(z)
            mu_x = self.p_mu(h)
            logvar_x = self.p_logvar(h)
            p_x_given_z = Independent(Normal(mu_x, torch.exp(0.5 * logvar_x)),1)

            # Sample missing data
            x_missing = p_x_given_z.rsample() * (1 - s)
            x_imputed = x_observed + x_missing

            # Missing model: p(s|x_observed, x_missing)
            p_s_given_x = Independent(Bernoulli(self.missing_model(x_imputed)),1)

            # Importance weights:
            x_observed_k = torch.Tensor.repeat(x_observed, [K,1,1])            # Size (K,bs,input_dim)
            s_k = torch.Tensor.repeat(s, [K,1,1])                              
            
            log_p_x_given_z = p_x_given_z.log_prob(x_observed_k)               # Size (K, bs)
            log_q_z_given_x = q_z_given_x.log_prob(z)
            log_p_s_given_x = p_s_given_x.log_prob(s_k)
            log_p_z = self.p_z.log_prob(z)

            imp_weights = F.softmax(log_p_x_given_z + log_p_s_given_x - log_q_z_given_x + log_p_z - np.log(K), dim = 0)

            return torch.einsum('ki,kij->ij', imp_weights, x_missing)

        






In [318]:
model = notMIWAE(encoder, decoder, missing_model, 1000, 10, 1000, 10)

In [319]:
x = torch.randn(64, 1000)
s = torch.randint(0, 2, (64, 1000)).float()

In [320]:
model.impute(x, s, K = 2)

tensor([[ 1.6836,  0.0000, -0.0872,  ..., -0.3183, -1.2457,  0.0909],
        [ 0.0000,  0.4767,  0.0000,  ..., -0.5510, -0.7413,  0.5663],
        [ 0.0000,  0.0000, -0.8638,  ..., -1.2282,  0.0000,  0.8220],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ..., -0.2372,  0.7566, -0.5400],
        [ 0.0000, -0.9929,  1.1133,  ...,  0.0000,  0.6700,  0.0000],
        [-0.5146,  0.0000, -0.3997,  ...,  0.0000,  0.1830,  0.0000]],
       grad_fn=<ViewBackward0>)