In [None]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from pythae.trainers import BaseTrainerConfig
from pythae.pipelines.training import TrainingPipeline
from torchmetrics.image.fid import FrechetInceptionDistance
from tqdm import tqdm
from torch.utils.data import DataLoader
from sklearn.manifold import TSNE
import os
from pythae.models import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def get_figure_data(task,plot_figures=True):
    assert task in ['mnist','cifar']
    if task=='cifar':
        transform = transforms.Compose(
            [transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        cifar_trainset = datasets.CIFAR10(root='../../data', download=True, transform=transform)
        train_dataset = np.transpose(cifar_trainset.data[:-10000],(0,3,1,2))/255
        eval_dataset = np.transpose(cifar_trainset.data[-10000:],(0,3,1,2))/255
        all_dataset = np.transpose(cifar_trainset.data,(0,3,1,2))/255
        print('train_dataset shape:',train_dataset.shape)
        if plot_figures:
            fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(5, 5))
            for i in range(5):
                for j in range(5):
                    axes[i][j].imshow(np.transpose(train_dataset[i*5+j],(1,2,0)))
                    axes[i][j].axis('off')
        return train_dataset,eval_dataset,all_dataset
    elif task=='mnist':
        mnist_trainset = datasets.MNIST(root='../../data', download=True, transform=None)
        train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.
        eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255.
        all_dataset=mnist_trainset.data.reshape(-1, 1, 28, 28) / 255.
        print('train_dataset shape:',train_dataset.shape)
        if plot_figures:
            fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(5, 5))
            for i in range(5):
                for j in range(5):
                    axes[i][j].imshow(train_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')
                    axes[i][j].axis('off')
        return train_dataset,eval_dataset,all_dataset

In [None]:
import logging
from typing import Optional, Union

from typing_extensions import Literal

from pythae.data.datasets import BaseDataset

logger = logging.getLogger(__name__)

# make it print to the console.
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)


class DataProcessor:
    """
    This is a basic class which preprocesses the data.
    Basically, it takes messy data, detects potential nan, bad types end convert the
    data to a type handled by the VAE models (*i.e.* `torch.Tensor`). Moreover, if the
    data does not have the same shape, a reshaping is applied and data is resized to the
    **minimal shape**.
    """

    def __init__(self):
        pass

    def process_data(
        self, data: Union[np.ndarray, torch.Tensor], batch_size: int = 100
    ) -> torch.Tensor:
        """This function detects potential check the data type, detects nan in input data and
        preprocessed the data so it can be handled by the models.

        Args:
            data (Union[np.ndarray, torch.Tensor]): The data that need to be
                checked. Expected:

                    - | np.ndarray of shape `num_data x n_channels x [optional depth] x
                      | [optional height] x width x ...`
                    - | torch.Tensor of shape `num_data x n_channels x [optional depth] x
                      | [optional height] x width x ...`

            batch_size (int): The batch size used for data preprocessing

        Returns:
            clean_data (torch.tensor): The data that has been cleaned
        """

        if isinstance(data, np.ndarray) or torch.is_tensor(data):
            data = self._process_data_array(data, batch_size=batch_size)

        else:
            raise TypeError(
                "Wrong data type provided. Expected one of "
                "[np.ndarray, torch.Tensor]"
            )

        return data

    @staticmethod
    def to_dataset(data: torch.Tensor, labels: Optional[torch.Tensor] = None):
        """This method converts a set of ``torch.Tensor`` to a
        :class:`~pythae.data.datasets.BaseDataset`

        Args:
            data (torch.Tensor): The set of data as a big torch.Tensor
            labels (torch.Tensor): The targets labels as a big torch.Tensor

        Returns:
            (BaseDataset): The resulting dataset
        """

        if labels is None:
            labels = torch.ones(data.shape[0])

        labels = DataProcessor.to_tensor(labels)
        dataset = BaseDataset(data, labels)

        return dataset

    def _process_data_array(self, data: np.ndarray, batch_size: int = 100):

        num_samples = data.shape[0]
        samples_shape = data.shape

        num_complete_batch = num_samples // batch_size
        num_in_last = num_samples % batch_size

        full_data = []

        for i in range(num_complete_batch):

            # Detect potential nan
            if DataProcessor.has_nan(data[i * batch_size : (i + 1) * batch_size]):
                raise ValueError("Nan detected in input data!")

            processed_data = DataProcessor.to_tensor(
                data[i * batch_size : (i + 1) * batch_size]
            )
            full_data.append(processed_data)

        if num_in_last > 0:
            # Detect potential nan
            if DataProcessor.has_nan(data[-num_in_last:]):
                raise ValueError("Nan detected in input data!")

            processed_data = DataProcessor.to_tensor(data[-num_in_last:])
            full_data.append(processed_data)

        processed_data = torch.cat(full_data)

        assert processed_data.shape == samples_shape, (data.shape, num_samples)

        return processed_data

    @staticmethod
    def to_tensor(data: np.ndarray) -> torch.Tensor:
        """Converts numpy arrays to `torch.Tensor` format

        Args:
            data (np.ndarray): The data to be converted

        Return:
            (torch.Tensor): The transformed data"""

        # check input type
        if not torch.is_tensor(data):
            if not isinstance(data, np.ndarray):
                raise TypeError(
                    " Data must be either of type "
                    f"< 'torch.Tensor' > or < 'np.ndarray' > ({type(data)} provided). "
                    f" Check data"
                )

            else:
                try:
                    data = torch.tensor(data).type(torch.float)

                except (TypeError, RuntimeError) as e:
                    raise TypeError(
                        str(e.args) + ". Potential issues:\n"
                        "- input data has not the same shape in array\n"
                        "- input data with unhandable type"
                    ) from e

        return data

    @staticmethod
    def has_nan(data: torch.Tensor) -> bool:
        """Detects potential nan in input data

        Args:
            data (torch.Tensor): The data to be checked

        Return:
            (bool): True if data contains :obj:`nan`
        """

        if (data != data).sum() > 0:
            return True
        else:
            return False

In [None]:
def get_figure_model(model_name,task):
    assert model_name in ['vae','rae_gp']
    assert task in ['mnist','cifar']
    if task=='mnist':
        input_dim=(1,28,28)
        latent_dim=16
        if model_name=='vae':
            from pythae.models import VAE, VAEConfig
            from pythae.models.nn.benchmarks.mnist import Encoder_ResNet_VAE_MNIST, Decoder_ResNet_AE_MNIST
            model_config = VAEConfig(
                input_dim=input_dim,
                latent_dim=latent_dim
            )
            model = VAE(
                model_config=model_config,
                encoder=Encoder_ResNet_VAE_MNIST(model_config), 
                decoder=Decoder_ResNet_AE_MNIST(model_config) 
            )
        elif model_name=='rae_gp':
            from pythae.models import RAE_GP, RAE_GP_Config
            from pythae.models.nn.benchmarks.mnist import Encoder_ResNet_AE_MNIST, Decoder_ResNet_AE_MNIST
            model_config = RAE_GP_Config(
                    input_dim=input_dim,
                    latent_dim=latent_dim, # mnist 16
                    embedding_weight=1e-2,
                    reg_weight=1e-4
                )
            model = RAE_GP(
                model_config=model_config,
                encoder=Encoder_ResNet_AE_MNIST(model_config), 
                decoder=Decoder_ResNet_AE_MNIST(model_config) 
            )
    elif task=='cifar':
        input_dim=(3,32,32)
        latent_dim=128
        if model_name=='vae':
            from pythae.models import VAE, VAEConfig
            from pythae.models.nn.benchmarks.cifar import Encoder_ResNet_VAE_CIFAR, Decoder_ResNet_AE_CIFAR
            model_config = VAEConfig(
                input_dim=input_dim,
                latent_dim=latent_dim
            )
            model = VAE(
                model_config=model_config,
                encoder=Encoder_ResNet_VAE_CIFAR(model_config), 
                decoder=Decoder_ResNet_AE_CIFAR(model_config) 
            )
        elif model_name=='rae_gp':
            from pythae.models import RAE_GP, RAE_GP_Config
            from pythae.models.nn.benchmarks.cifar import Encoder_ResNet_AE_CIFAR, Decoder_ResNet_AE_CIFAR
            model_config = RAE_GP_Config(
                    input_dim=input_dim,
                    latent_dim=latent_dim, # mnist 16
                    embedding_weight=1e-2,
                    reg_weight=1e-4
                )
            model = RAE_GP(
                model_config=model_config,
                encoder=Encoder_ResNet_AE_CIFAR(model_config), 
                decoder=Decoder_ResNet_AE_CIFAR(model_config) 
            )
    return model

training VAE

In [None]:
from pythae.trainers import BaseTrainerConfig
from pythae.pipelines.training import TrainingPipeline
model_name='rae_gp'
task='mnist'
output_dir='my_model/{}/{}/'.format(task,model_name)
config= BaseTrainerConfig(
    output_dir=output_dir,
    learning_rate=1e-4,
    batch_size=4096,
    num_epochs=100, # Change this to train the model a bit more
)

train_dataset,eval_dataset,all_dataset=get_figure_data(task)
model=get_figure_model(model_name,task)

pipeline = TrainingPipeline(
    training_config=config,
    model=model
)
pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset
)

In [None]:
fid=FrechetInceptionDistance(feature=2048).to(device)
fid.eval()

def evaluate(pred, target):
    pred=pred.to(device)
    target=target.to(device)
    print('pred shape: ',pred.shape)
    print('target shape: ',target.shape)
    with  torch.no_grad():
        metric = {}
        batch_size = 512
        imgs_dist1 = (pred.mul(255).add(0.5).clamp(0, 255)).type(torch.uint8)
        imgs_dist2 = (target.mul(255).add(0.5).clamp(0, 255)).type(torch.uint8)
        if imgs_dist1.shape[1] == 1: 
            imgs_dist1 = imgs_dist1.repeat(1, 3, 1, 1)
        if imgs_dist2.shape[1] == 1:
            imgs_dist2 = imgs_dist2.repeat(1, 3, 1, 1)
        # fid.reset()
        for idx in tqdm(range(0, len(pred), batch_size)):
            fid.update(imgs_dist1[idx:idx+batch_size], real=False)
            fid.update(imgs_dist2[idx:idx+batch_size], real=True)
        metric["FID"] = fid.compute()
        print('FID : {}'.format(metric["FID"].cpu().item()))
        
def plot_figures(images,output_dir=None,file_name=None):
    fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(2, 2))
    for i in range(5):
        for j in range(5):
            if images.shape[1]==1:
                axes[i][j].imshow(images[i*5+j].cpu().squeeze(0),cmap='gray')
            else:
                axes[i][j].imshow(np.transpose(images[i*5+j].cpu(),(1,2,0)))
            axes[i][j].axis('off')
    plt.tight_layout(pad=0.)
    if output_dir is not None:
        fig.savefig(os.path.join(output_dir,file_name))
        
def get_sampler(sampler,model,train_dataset=None):
    assert sampler in ['normal','gmm']
    if sampler=='normal':
        from pythae.samplers import NormalSampler
        normal_samper = NormalSampler(model=model)
        return normal_samper
    elif sampler=='gmm':
        assert train_dataset is not None
        from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig,NormalSampler
        gmm_sampler_config = GaussianMixtureSamplerConfig(n_components=30)
        gmm_sampler = GaussianMixtureSampler(sampler_config=gmm_sampler_config,model=model)
        print('fitting gmm sampler, this process could take several mins')
        gmm_sampler.fit(train_dataset)
        return gmm_sampler

NCP_VAE

In [None]:
import torch
import torch.nn as nn

class SE_Block(nn.Module):
    def __init__(self, c, r=16):
        super().__init__()
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Linear(c, c // r, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(c // r, c, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        bs, c, _, _ = x.shape
        y = self.squeeze(x).view(bs, c)
        y = self.excitation(y).view(bs, c, 1, 1)
        return x * y.expand_as(x)
    
class ResidualBlockA(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlockA, self).__init__()
        assert out_channels==in_channels 
        self.conv1 = nn.Sequential(
                        nn.BatchNorm2d(in_channels),
                        nn.SiLU(),
                        nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1))
        self.conv2 = nn.Sequential(
                        nn.BatchNorm2d(out_channels),
                        nn.SiLU(),
                        nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
                        SE_Block(out_channels))   
        self.out_channels = out_channels
        self.sigmoid=nn.Sigmoid()
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        out += residual
        return out

class ResidualBlockB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlockB, self).__init__()
        assert out_channels==2*in_channels
        self.conv1 = nn.Sequential(
                        nn.BatchNorm2d(in_channels),
                        nn.SiLU(),
                        nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1))
        self.conv2 = nn.Sequential(
                        nn.BatchNorm2d(out_channels),
                        nn.SiLU(),
                        nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
                        SE_Block(out_channels))
        self.factorized_reduction=nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=2,padding=0)
        self.out_channels = out_channels
        self.swish=SE_Block(in_channels)
    def forward(self, x):
        # print('residualB')
        # print('x shape',x.shape)
        residual = self.swish(x)
        residual=self.factorized_reduction(residual)
        out = self.conv1(x)
        out = self.conv2(out)
        out += residual
        # print('out shape',out.shape)
        return out

class Binary_Classifier(nn.Module):
    def __init__(self, blockA, blockB, inchannel):
        super(Binary_Classifier, self).__init__()
        self.inplanes = inchannel
        # self.linear1=nn.Linear(16,self.inplanes*16*16)
        # self.conv1 = nn.Sequential(
        #                 nn.ConvTranspose2d(self.inplanes, self.inplanes, 3, 2, padding=1),
        #                 nn.ReLU())
        self.conv1 = nn.Sequential(
                        nn.Conv2d(self.inplanes, self.inplanes, kernel_size=3, stride=1, padding=1),
                        nn.ReLU())
        self.layer0 = self._make_layer(blockA, self.inplanes, self.inplanes, 3)
        self.layer1 = self._make_layer(blockB, self.inplanes, self.inplanes * 2, 1)
        self.layer2 = self._make_layer(blockA, self.inplanes * 2, self.inplanes * 2, 3)
        self.layer3 = self._make_layer(blockB, self.inplanes * 2, self.inplanes * 4,1)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(self.inplanes * 4, 1)
        self.sigmoid=nn.Sigmoid()
        
    def _make_layer(self, block, in_channels,out_channels, blocks):
        layers = []
        for i in range(blocks):
            layers.append(block(in_channels, out_channels))
        return nn.Sequential(*layers)
    
    
    def forward(self, x):
        # x=self.linear1(x)
        # x=x.reshape(x.shape[0],self.inplanes,16,16)

        x = self.conv1(x)
        # print(x.shape)
        x = self.layer0(x)
        # print(x.shape)
        x = self.layer1(x)
        # print('layer1 out',x.shape)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.sigmoid(x)
        return x
        
class NCP_VAE(nn.Module):
    def __init__(self,sampler,bin_classifier):
        super(NCP_VAE, self).__init__()
        self.sampler=sampler
        self.bin_classifier=bin_classifier
        self.fid=FrechetInceptionDistance(feature=2048).to(device)
        self.fid.eval()
        self.bin_classifier.eval()
        self.device="cuda" if torch.cuda.is_available() else "cpu"

    @torch.no_grad()
    def get_reweight(self,z):
        x=self.sampler.model.decoder(z)["reconstruction"].detach()
        D=self.bin_classifier(x)
        r=D/(1-D)
        return r

    @torch.no_grad()
    def sample_from_p(self,num_samples):
        """sample from gaussian mixture"""
        try:
            z=torch.tensor(self.sampler.gmm.sample(num_samples)[0]).type(torch.float).to(self.device)
        except Exception:
            z=torch.randn(num_samples, self.sampler.model.latent_dim).to(self.device)
        return z

    @torch.no_grad()
    def sample_from_q(self,input):
        z=self.sampler.model(input).z.detach().to(device)
        return z

    @torch.no_grad()
    def sample_from_ncp(self,num_samples):
        z=self.sample_from_p(num_samples)
        r=self.get_reweight(z).reshape(-1,)
        r/=r.sum()
        idx=torch.multinomial(r,num_samples)
        return z[idx]

    @torch.no_grad()
    def sample(self,num_samples):
        z=self.sample_from_ncp(num_samples)
        return self.sampler.model.decoder(z)["reconstruction"].detach()

    @torch.no_grad()
    def evaluate(self,real_set,num_samples):
        target=real_set[:num_samples]
        pred=self.sample(num_samples)
        metric = {}
        batch_size = 16
        imgs_dist1 = (pred.mul(255).add(0.5).clamp(0, 255)).type(torch.uint8).to(self.device)
        imgs_dist2 = (target.mul(255).add(0.5).clamp(0, 255)).type(torch.uint8).to(self.device)
        if imgs_dist1.shape[1] == 1: 
            imgs_dist1 = imgs_dist1.repeat(1, 3, 1, 1)
        if imgs_dist2.shape[1] == 1:
            imgs_dist2 = imgs_dist2.repeat(1, 3, 1, 1)
        # fid.reset()
        for idx in tqdm(range(0, len(pred), batch_size)):
            fid.update(imgs_dist1[idx:idx+batch_size], real=False)
            fid.update(imgs_dist2[idx:idx+batch_size], real=True)
        # print("fake_samples: ", fid.fake_features_num_samples)
        # print("real_samples: ", fid.real_features_num_samples)
        metric["FID"] = fid.compute()
        return metric

In [None]:
model_name='rae_gp'
task='mnist'
output_dir='my_model/{}/{}/'.format(task,model_name)
last_training = sorted(os.listdir(output_dir))[-1]
output_dir=os.path.join(output_dir, last_training, 'final_model')
trained_model = AutoModel.load_from_folder(output_dir)
train_dataset,eval_dataset,all_dataset=get_figure_data(task)
sampler_gmm=get_sampler('gmm',trained_model,train_dataset)

train_dataset,eval_dataset,all_dataset=get_figure_data(task)
data_processor = DataProcessor()
train_data_processed = data_processor.process_data(train_dataset).to(device)
train_dataset_processed = data_processor.to_dataset(train_data_processed)
train_loader = DataLoader(dataset=train_dataset_processed, batch_size=1024, shuffle=False)

In [None]:
def train_eval_ncp(sampler,data_loader,epochs=10,bs=1024,lr=0.001):
    binary_classifier=Binary_Classifier(ResidualBlockA,ResidualBlockB,3 if task=='cifar' else 1).to(device)
    optimizer=torch.optim.Adam(binary_classifier.parameters(),lr=lr)
    criterion=nn.BCELoss()
    for epoch in range(epochs):
        acc=[]
        losses=[]
        for batch in data_loader:
            num_samples=len(batch['data'])
            q=sampler.model(batch).recon_x
            # p_=torch.tensor(gmm_sampler.gmm.sample(num_samples)[0]).type(torch.float).to(device)
            p=sampler.sample(num_samples)
            input=torch.cat((q,p),dim=0)
            labels=torch.cat((torch.ones(num_samples),torch.zeros(num_samples))).to(device)
            out=binary_classifier(input).squeeze()
            loss=criterion(out,labels) 
            acc.append(np.mean(np.array(out.cpu()>0.5)==np.array(labels.cpu())))
            optimizer.zero_grad()
            loss.backward()
            losses.append(loss.cpu().item())
            optimizer.step()
        print('Eopoch :{} | Acc : {:.3f} | Loss : {:.3f} '.format(epoch,np.mean(acc),np.mean(losses)))

    ncp_vae=NCP_VAE(sampler,binary_classifier)
    fid_scores=[]
    # print('acc: ',np.mean(acc))
    print("Evaluation ncp_vae")
    for i in range(1):
        fid_score=ncp_vae.evaluate(torch.tensor(all_dataset),int(0.1 * len(all_dataset)))['FID'].cpu().item()
        fid_scores.append(fid_score)
    print('FID : {}, avg: {:.3F}'.format(fid_scores,np.mean(fid_scores)))
    return ncp_vae

In [None]:
ncp_vae_sampler_gmm=train_eval_ncp(sampler_gmm,train_loader,50,lr=0.01)

plot_figures(ncp_vae_sampler_gmm.sample(25),output_dir,'ncp_rae_gmm_output')

test

## POWER Dataset Test

In [1]:
import copy
import pandas as pd
import torch
from my_utils import plot_hist_marginals
from pythae.data.datasets import DatasetOutput, BaseDataset

class TableDataset(torch.utils.data.Dataset):
    """Wraps a Table and yields each row to use in pythae."""
    
    def __init__(self, data, perc_miss = 0.5):
        '''
        Args:
            data: dataframe.
        '''
        super(TableDataset, self).__init__()
        self.tuples_np = np.stack([Discretize(c) for c in data.Columns()], axis=1)
        self.tuples = torch.as_tensor(self.tuples_np.astype(np.float32, copy=False))
        # self.data = self.tuples
        self.onehot_data_np = One_hot(self.tuples_np)
        self.onehot_data = torch.as_tensor(self.onehot_data_np, dtype=torch.float32)
        # self.data_one_hot = self.__encode_onehot(self.tuples_np, [c.DistributionSize() for c in data.columns])
        self.masks_np = Mask_row(self.onehot_data_np, perc_miss)
        self.masks = torch.as_tensor(self.masks_np)
        
    def size(self):
        return len(self.tuples)

    def __len__(self):
        return len(self.tuples)

    def __getitem__(self, idx):
        X = self.tuples[idx]
        X_one_hot = self.onehot_data[idx]
        masks = self.masks[idx]
        return DatasetOutput(
            data = X,
            data_one_hot = X_one_hot,
            masks = masks,
        )
        
    def __encode_onehot(self, data, input_bins):
        '''
        Args:
            data: ndarray.
        '''
        # print(f'data : {data.keys}')
        bs = data.shape[0]
        y_onehots = []
        data = data.astype(np.int32)
        # print(f'input_bins : {input_bins}')
        for i, coli_dom_size in enumerate(input_bins):
            if coli_dom_size <= 2:
                y_onehots.append(data[:, i].reshape(-1, 1).astype(np.float32))
            else:
                y_onehot = np.zeros((bs, coli_dom_size), dtype=np.int32)
                y_onehot[data[:, i].reshape(-1, 1)] = 1
                y_onehots.append(y_onehot)
        # [bs, sum(dist size)]
        return torch.as_tensor(np.concatenate(y_onehots, 1))    
    
    def show_histograms(self, split):

        data_split = getattr(self, split, None)
        if data_split is None:
            raise ValueError('Invalid data split')

        plot_hist_marginals(data_split.x)
        plt.show()
        
    def minmax_normalized(self):
        def normalize(x):
            return (x - x.min()) / (x.max() - x.min())
        self.data = self.data.apply(normalize)
    
    def spilt_train_valid(self, valid_rate):
        N_valid = int(valid_rate * self.data.shape[0])
        data_valid = self.data[-N_valid:]
        data = self.data[0:-N_valid]
        
        return data, data_valid    
    
def Discretize(col, data=None):
    """Transforms data values into integers using a Column's vocab.

    Args:
        col: the Column.
        data: list-like data to be discretized.  If None, defaults to col.data.

    Returns:
        col_data: discretized version; an np.ndarray of type np.int32.
    """
    # pd.Categorical() does not allow categories be passed in an array
    # containing np.nan.  It makes it a special case to return code -1
    # for NaN values.

    if data is None:
        data = col.data
    
    # pd.isnull returns true for both np.nan and np.datetime64('NaT').
    isnan = pd.isnull(col.all_distinct_values)
    if isnan.any():
        # We always add nan or nat to the beginning.
        assert isnan.sum() == 1, isnan
        assert isnan[0], isnan

        dvs = col.all_distinct_values[1:]
        bin_ids = pd.Categorical(data, categories=dvs).codes
        assert len(bin_ids) == len(data)

        # Since nan/nat bin_id is supposed to be 0 but pandas returns -1, just
        # add 1 to everybody
        bin_ids = bin_ids + 1
    else:
        # This column has no nan or nat values.
        dvs = col.all_distinct_values
        bin_ids = pd.Categorical(data, categories=dvs).codes
        # print(f'dvs : {len(dvs)} bin_ids : {len(np.unique(bin_ids))}')
        assert len(bin_ids) == len(data), (len(bin_ids), len(data))

    assert (bin_ids >= 0).all(), (col, data, bin_ids)
    return bin_ids

def One_hot(tuples_np):
    onehot_datas = []
    for i in range(tuples_np.shape[1]):
        onehot_data = pd.get_dummies(tuples_np[:, i]).values
        onehot_datas.append(onehot_data)
    
    return np.concatenate(onehot_datas, 1)

def Mask_row(row, perc_miss):
    def generate_mask(row):
        n = len(row)
        mask = np.zeros(n, dtype=bool)
        indices = np.random.choice(n, size=int(n * perc_miss), replace=False)
        mask[indices] = True
        mask[row] = False # the actual values should be observed
        return mask
    
    vectorized_generate_mask = np.vectorize(generate_mask, signature='(n)->(n)')
    return vectorized_generate_mask(row)

In [None]:
from data_tabular import CsvTable, power
from train_helpers_tabular import set_up_hyperparams

H, logprint = set_up_hyperparams()

In [None]:
original_data = power(H)
# print(f'{original_data.data[0:1]}')

In [None]:
original_dataset = TableDataset(
    data=original_data.data
)
original_dataset.minmax_normalized()
train_data, valid_data = original_dataset.spilt_train_valid(0.1)

train_dataset = TableDataset(
    data=train_data
)

eval_dataset = TableDataset(
    data=valid_data
)

In [None]:
train_dataset[0]['data']

In [None]:
from pythae.models.nn import BaseEncoder, BaseDecoder
from pythae.models.base.base_utils import ModelOutput
import torch.nn as nn

class ResBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super(ResBlock, self).__init__()
        assert out_channels==in_channels 
        self.conv1 = nn.Sequential(
                        nn.BatchNorm1d(in_channels),
                        nn.ReLU(),
                        nn.Conv1d(in_channels, middle_channels, kernel_size = 3, stride = 1, padding = 1))
        self.conv2 = nn.Sequential(
                        nn.BatchNorm1d(middle_channels),
                        nn.ReLU(),
                        nn.Conv1d(middle_channels, out_channels, kernel_size = 3, stride = 1, padding = 1))   
        
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        out += residual
        return out

class Encoder_Conv_VAE_Power(BaseEncoder):
    def __init__(self, args):
        BaseEncoder.__init__(self)

        self.input_dim = (1, 7)
        self.latent_dim = args.latent_dim
        self.n_channels = 1

        self.input_layers = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(self.n_channels, 16, 3, 1, 1),
        )
        self.residual_layers = nn.Sequential(
            *[ResBlock(16, 4, 16) for _ in range(2)]
        )

        self.embedding = nn.Linear(16 * 7, args.latent_dim)
        self.log_var = nn.Linear(16 * 7, args.latent_dim)

    def forward(self, x: torch.Tensor):
        h1 = self.input_layers(x)
        h2 = self.residual_layers(h1)
        h3 = h2.reshape(x.shape[0], -1)
        output = ModelOutput(
            embedding=self.embedding(h3),
            log_covariance=self.log_var(h3)
        )
        return output
    
class Decoder_Conv_VAE_Power(BaseDecoder):
    def __init__(self, args):
        BaseDecoder.__init__(self)
        self.input_dim = (1, 7)
        self.latent_dim = args.latent_dim
        self.n_channels = 1

        self.fc = nn.Linear(args.latent_dim, 16 * 7)
        self.residual_layers = nn.Sequential(
            *[ResBlock(16, 4, 16) for _ in range(2)]
        )
        self.output_layers = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(16, self.n_channels, 3, 1, 1),
            nn.Sigmoid(),
        )

    def forward(self, z: torch.Tensor):
        h1 = self.fc(z).reshape(z.shape[0], 16, 7)
        h2 = self.residual_layers(h1)
        h3 = self.output_layers(h2)
        output = ModelOutput(reconstruction=h3)

        return output

In [None]:
from pythae.models import VAEConfig

model_config = VAEConfig(
    input_dim=(1, 7),
    latent_dim=16
    )

encoder = Encoder_Conv_VAE_Power(model_config)
decoder= Decoder_Conv_VAE_Power(model_config)

from typing import Optional
from pythae.models import VAE
import torch.nn.functional as F

class VAE_CE(VAE):
    def __init__(self,
        model_config: VAEConfig,
        encoder: Optional[BaseEncoder] = None,
        decoder: Optional[BaseDecoder] = None,):
        
        super().__init__(model_config, encoder, decoder)

    def get_nll_batch(self, data, n_samples=1):
        """
        Function computed the estimate negative log-likelihood of the model. It uses importance
        sampling method with the approximate posterior distribution. This may take a while.

        Args:
            data (torch.Tensor): The input data from which the log-likelihood should be estimated.
                Data must be of shape [Batch x n_channels x ...]
            n_samples (int): The number of importance samples to use for estimation
        """
        log_p_x = []

        for i in range(n_samples):

            encoder_output = self.encoder(data)
            mu, log_var = encoder_output.embedding, encoder_output.log_covariance

            std = torch.exp(0.5 * log_var)
            z, _ = self._sample_gauss(mu, std)

            log_q_z_given_x = -0.5 * (
                log_var + (z - mu) ** 2 / torch.exp(log_var)
            ).sum(dim=-1, keepdim=True) # shape [B x 1]
            log_p_z = -0.5 * (z ** 2).sum(dim=-1, keepdim=True) # shape [B x 1]
            
            # print(f'log_q_z_given_x: {log_q_z_given_x.shape}')
            # print(f'log_p_z: {log_p_z.shape}')

            recon_x = self.decoder(z)["reconstruction"]

            if self.model_config.reconstruction_loss == "mse":

                log_p_x_given_z = -0.5 * torch.log(F.mse_loss(
                    recon_x.reshape(data.shape[0], -1),
                    data.reshape(data.shape[0], -1),
                    reduction="none",
                )).sum(dim=-1) - torch.tensor(
                    [np.prod(self.input_dim) / 2 * np.log(np.pi * np.e * 2)]
                ).to(
                    data.device
                )  # decoding distribution is assumed unit variance  N(mu, I)

            elif self.model_config.reconstruction_loss == "bce":

                log_p_x_given_z = -F.binary_cross_entropy(
                    recon_x.reshape(data.shape[0], -1),
                    data.reshape(data.shape[0], -1),
                    reduction="none",
                ).sum(dim=-1)

            log_p_x.append(
                    log_p_x_given_z + log_p_z - log_q_z_given_x
                )  # log(2*pi) simplifies

        log_p_x = torch.cat(log_p_x, dim=1) # [B x n_samples]
        # print(f'log_p_x: {log_p_x.shape}')
        
        return torch.logsumexp(log_p_x, dim=1) - np.log(n_samples)
        
model = VAE_CE(
    model_config=model_config,
    encoder=encoder,
    decoder=decoder
).cuda()

In [None]:
from pythae.trainers import BaseTrainerConfig
from pythae.pipelines import TrainingPipeline

training_config = BaseTrainerConfig(
    output_dir='./saved_models/power_test/test_-1',
    learning_rate=1e-3,
    per_device_train_batch_size=512,
    per_device_eval_batch_size=512,
    steps_saving=None,
    num_epochs=10)

pipeline = TrainingPipeline(
    model=model,
    training_config=training_config)

pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset,
)

In [None]:
# from pythae.models import AutoModel
# import os
# last_training = sorted(os.listdir('./saved_models/power_test/test_-1'))[-1]
# model_rec = AutoModel.load_from_folder(os.path.join('./saved_models/power_test/test_-1', last_training, 'final_model'))
# model.__class__ = VAE_CE

In [None]:
import math
import numpy as np
from my_utils import Card, ErrorMetric, GenerateQuery, estimate_probabilities, make_points

def eval_power(model):
    table = original_data
    rng = np.random.RandomState(1234)
    count = 0
    n_rows = table.data.shape[0] - int(table.data.shape[0] * 0.1) 
    # n_rows = table.data.shape[0] 
    qerrors = []

    for i in range(3000):
            
        cols, ops, vals = GenerateQuery(table.columns, rng, table.data)
        true_card = Card(table.data[:int(table.data.shape[0] * 0.9)], cols, ops, vals)
        # true_card = Card(table.data, cols, ops, vals)
        predicates = []
        for c, o, v in zip(cols, ops, vals):
            predicates.append((c, o, v))
                
        left_bounds = {}
        right_bounds = {}
        
        for idx, attr in enumerate(table.columns):
            col_name = attr.name   
            left_bounds[col_name] = table.mins[idx]
            right_bounds[col_name] = table.maxs[idx] 
                
        table_stats = (table.columns, table.name_to_index, right_bounds, left_bounds)
        
        # print(predicates)
        integration_domain = make_points(table_stats, predicates, table.bias, None, 'minmax')
        
        # print(integration_domain)
        def pdf(x):
            nll = model.get_nll_batch(x)
            
            # print(f'nll: {nll.shape}')
            return nll
        
        prob = estimate_probabilities(pdf, integration_domain, len(table.columns)).item()
        # print(f'prob: {prob}')
        
        if  math.isnan(prob):
            est_card = 1
            count += 1
        elif  math.isinf(prob):   
            est_card = n_rows if prob > 0 else 1
            count += 1
        else:
            est_card = max(prob * n_rows, 1)
            
            if est_card > n_rows:
                count += 1
                est_card = n_rows
                # print(f'prob {prob} true_card: {true_card}')
            
        qerror = ErrorMetric(est_card, true_card)
        qerrors.append(qerror)
        
        if i % 100 == 0:
            print(f'{i} queries done')
    
    return count, qerrors

In [None]:
count, qerrors = eval_power(model)

In [None]:
print(f'estimation failed times: {count}')
print('test results')
print(f"Median: {np.median(qerrors)}")
print(f"90th percentile: {np.percentile(qerrors, 90)}")
print(f"95th percentile: {np.percentile(qerrors, 95)}")
print(f"99th percentile: {np.percentile(qerrors, 99)}")
print(f"Max: {np.max(qerrors)}")
print(f"Mean: {np.mean(qerrors)}")

256
estimation failed times: 2275
test results
Median: 61.312145072691465
90th percentile: 13137.085710979565
95th percentile: 54245.64705882353
99th percentile: 461327.5539670102
Max: 1844352.0
Mean: 19325.10150461067

16
2层
estimation failed times: 1686
test results
Median: 16.94870361887311
90th percentile: 3575.8930232558123
95th percentile: 14070.845913264851
99th percentile: 154568.74266167908
Max: 1844352.0
Mean: 9211.902219659025
8层
estimation failed times: 2269
test results
Median: 55.58625678119349
90th percentile: 13733.809901449096
95th percentile: 48126.172612388735
99th percentile: 461088.0
Max: 1844352.0
Mean: 18189.851955382997

In [None]:
import torch
import torch.nn as nn

class SE_Block(nn.Module):
    def __init__(self, c, r=16):
        super().__init__()
        self.squeeze = nn.AdaptiveAvgPool1d(1)
        self.excitation = nn.Sequential(
            nn.Linear(c, c // r, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(c // r, c, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        bs, c, _ = x.shape
        y = self.squeeze(x).view(bs, c)
        y = self.excitation(y).view(bs, c, 1)
        return x * y.expand_as(x)
    
class ResidualBlockA(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlockA, self).__init__()
        assert out_channels == in_channels 
        self.conv1 = nn.Sequential(
                        nn.BatchNorm1d(in_channels),
                        nn.SiLU(),
                        nn.Conv1d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1))
        self.conv2 = nn.Sequential(
                        nn.BatchNorm1d(out_channels),
                        nn.SiLU(),
                        nn.Conv1d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
                        SE_Block(out_channels))   
        self.out_channels = out_channels
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        out += residual
        return out

class ResidualBlockB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlockB, self).__init__()
        assert out_channels == 2 * in_channels
        self.conv1 = nn.Sequential(
                        nn.BatchNorm1d(in_channels),
                        nn.SiLU(),
                        nn.Conv1d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1))
        self.conv2 = nn.Sequential(
                        nn.BatchNorm1d(out_channels),
                        nn.SiLU(),
                        nn.Conv1d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
                        SE_Block(out_channels))
        self.factorized_reduction=nn.Conv1d(in_channels, out_channels, kernel_size = 1, stride = 1, padding = 0)
        self.out_channels = out_channels
        self.swish = SE_Block(in_channels)
    def forward(self, x):
        # print('residualB')
        # print('x shape',x.shape)
        residual = self.swish(x)
        residual = self.factorized_reduction(residual)
        out = self.conv1(x)
        out = self.conv2(out)
        out += residual
        # print('out shape',out.shape)
        return out

class Binary_Classifier(nn.Module):
    def __init__(self, blockA, blockB, inchannel):
        super(Binary_Classifier, self).__init__()
        self.inplanes = inchannel
        self.conv1 = nn.Sequential(
                        nn.Conv1d(self.inplanes, self.inplanes, kernel_size=3, stride=1, padding=1),
                        nn.ReLU())
        self.layer0 = self._make_layer(blockA, self.inplanes, self.inplanes, 3)
        self.layer1 = self._make_layer(blockB, self.inplanes, self.inplanes * 2, 1)
        self.layer2 = self._make_layer(blockA, self.inplanes * 2, self.inplanes * 2, 3)
        self.layer3 = self._make_layer(blockB, self.inplanes * 2, self.inplanes * 4, 1)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(self.inplanes * 4, 1)
        self.sigmoid = nn.Sigmoid()
        
    def _make_layer(self, block, in_channels, out_channels, blocks):
        layers = []
        for i in range(blocks):
            layers.append(block(in_channels, out_channels))
        return nn.Sequential(*layers)
    
    
    def forward(self, x):
        # x=self.linear1(x)
        # x=x.reshape(x.shape[0],self.inplanes,16,16)

        x = self.conv1(x)
        # print(x.shape)
        x = self.layer0(x)
        # print(x.shape)
        x = self.layer1(x)
        # print('layer1 out',x.shape)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.sigmoid(x)
        return x
        
class NCP_VAE(nn.Module):
    def __init__(self, sampler, bin_classifier):
        super(NCP_VAE, self).__init__()
        self.sampler = sampler
        self.bin_classifier = bin_classifier
        self.bin_classifier.eval()
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    @torch.no_grad()
    def get_reweight(self, z):
        x = self.sampler.model.decoder(z)["reconstruction"].detach()
        D = self.bin_classifier(x)
        r = D / (1 - D)
        return r

    @torch.no_grad()
    def sample_from_p(self, num_samples):
        """sample from gaussian mixture"""
        try:
            z = torch.tensor(self.sampler.gmm.sample(num_samples)[0]).type(torch.float).to(self.device)
        except Exception:
            z = torch.randn(num_samples, self.sampler.model.latent_dim).to(self.device)
        return z

    @torch.no_grad()
    def sample_from_q(self, input):
        z = self.sampler.model(input).z.detach().to(device)
        return z

    @torch.no_grad()
    def _sample_from_ncp(self, num_samples):
        z = self.sample_from_p(num_samples)
        r = self.get_reweight(z).reshape(-1,)
        r /= r.sum()
        idx = torch.multinomial(r, num_samples)
        return z[idx]

    @torch.no_grad()
    def sample(self, num_samples):
        z = self._sample_from_ncp(num_samples)
        return self.sampler.model.decoder(z)["reconstruction"].detach()
    
    @torch.no_grad()
    def get_nll_batch(self, data, n_samples=1):
        log_p_x = []

        for i in range(n_samples):

            encoder_output = self.sampler.model.encoder(data)
            mu, log_var = encoder_output.embedding, encoder_output.log_covariance

            std = torch.exp(0.5 * log_var)
            z, _ = self.sampler.model._sample_gauss(mu, std)
            r = self.get_reweight(z).reshape(-1,)
            r /= r.sum()
            idx = torch.multinomial(r, z.shape[0])
            z = z[idx]

            log_q_z_given_x = -0.5 * (
                log_var + (z - mu) ** 2 / torch.exp(log_var)
            ).sum(dim=-1, keepdim=True) # shape [B x 1]
            log_p_z = -0.5 * (z ** 2).sum(dim=-1, keepdim=True) # shape [B x 1]
            
            # print(f'log_q_z_given_x: {log_q_z_given_x.shape}')
            # print(f'log_p_z: {log_p_z.shape}')

            recon_x = self.sampler.model.decoder(z)["reconstruction"]

            if self.sampler.model.model_config.reconstruction_loss == "mse":

                log_p_x_given_z = -0.5 * torch.log(F.mse_loss(
                    recon_x.reshape(data.shape[0], -1),
                    data.reshape(data.shape[0], -1),
                    reduction="none",
                )).sum(dim=-1) - torch.tensor(
                    [np.prod(self.sampler.model.input_dim) / 2 * np.log(np.pi * np.e * 2)]
                ).to(
                    data.device
                )  # decoding distribution is assumed unit variance  N(mu, I)

            elif self.sampler.model.model_config.reconstruction_loss == "bce":

                log_p_x_given_z = -F.binary_cross_entropy(
                    recon_x.reshape(data.shape[0], -1),
                    data.reshape(data.shape[0], -1),
                    reduction="none",
                ).sum(dim=-1)

            log_p_x.append(
                    log_p_x_given_z + log_p_z - log_q_z_given_x
                )  # log(2*pi) simplifies

        log_p_x = torch.cat(log_p_x, dim=1) # [B x n_samples]
        # print(f'log_p_x: {log_p_x.shape}')
        
        return torch.logsumexp(log_p_x, dim=1) - np.log(n_samples)

In [None]:
def train_ncp(sampler, data_loader, epochs=10, lr=0.001):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    binary_classifier = Binary_Classifier(ResidualBlockA, ResidualBlockB, 1).to(device)
    optimizer = torch.optim.Adam(binary_classifier.parameters(), lr=lr)
    criterion = nn.BCELoss()
    for epoch in range(epochs):
        acc = []
        losses = []
        for batch in data_loader:
            num_samples = len(batch['data'])
            q = sampler.model(batch).recon_x
            p = sampler.sample(num_samples)
            input = torch.cat((q, p), dim=0)
            labels = torch.cat((torch.ones(num_samples), torch.zeros(num_samples))).to(device)
            out = binary_classifier(input).squeeze()
            loss = criterion(out, labels) 
            acc.append(np.mean(np.array(out.cpu() > 0.5) == np.array(labels.cpu())))
            optimizer.zero_grad()
            loss.backward()
            losses.append(loss.cpu().item())
            optimizer.step()
        print('Eopoch :{} | Acc : {:.3f} | Loss : {:.3f} '.format(epoch,np.mean(acc),np.mean(losses)))

    ncp_vae = NCP_VAE(sampler,binary_classifier)
    
    return ncp_vae

In [None]:
from torch.utils.data import DataLoader

labels = torch.ones(train_dataset.tuples.shape[0]).cuda()
dataset = BaseDataset(train_dataset.tuples.cuda(), labels)
train_loader = DataLoader(dataset=dataset, batch_size=1024, shuffle=False)

from pythae.samplers import NormalSampler

normal_samper = NormalSampler(model=model)
ncp_vae_sampler_normal = train_ncp(normal_samper, train_loader, lr=0.01)

In [None]:
count, qerrors = eval_power(ncp_vae_sampler_normal)

In [None]:
print(f'estimation failed times: {count}')
print('test results')
print(f"Median: {np.median(qerrors)}")
print(f"90th percentile: {np.percentile(qerrors, 90)}")
print(f"95th percentile: {np.percentile(qerrors, 95)}")
print(f"99th percentile: {np.percentile(qerrors, 99)}")
print(f"Max: {np.max(qerrors)}")
print(f"Mean: {np.mean(qerrors)}")

### CE based on Imputation

In [2]:
import torch
import numpy as np
import os
import data_tabular

device = "cuda" if torch.cuda.is_available() else "cpu"

def power():
    csv_file = os.path.join('../dataset/', 'household_power_consumption.txt')
    cols = ['Global_active_power','Global_reactive_power','Voltage','Global_intensity','Sub_metering_1','Sub_metering_2','Sub_metering_3']
    trX = data_tabular.CsvTable('power', csv_file, cols, sep=';', na_values=[' ', '?'], header=0, dtype=np.float64)  
    # print(trX.data.shape)    

    return trX

In [3]:
original_data = power()
# print(original_data.data.isna().any(axis=1))
input_bins = [c.DistributionSize() for c in original_data.columns]
# print(input_bins)
table = TableDataset(original_data)

In [None]:
# calss TableDataset's verification

# print(table[10000]['data'])
# print(table[10000]['data_one_hot'].shape)
# j = 0
# for i in range(len(input_bins)):
#     j += 0 if i == 0 else input_bins[int(i) - 1]
#     print(table[10000]['data_one_hot'][int(table[10000]['data'][i] + j)])
# print(table[10000]['masks'].shape)    
# print(table[10000]['masks'].sum())

In [4]:
from typing import Tuple, Union
from pydantic.dataclasses import dataclass
from pythae.config import BaseConfig
from typing_extensions import Literal

@dataclass
class MissIWAEConfig(BaseConfig):
    """MissIWAE model config class.

    Parameters:
        input_dim (tuple): The input_data dimension.
        latent_dim (int): The latent space dimension. Default: None.
        reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'
        number_samples (int): Number of samples to use on the Monte-Carlo estimation. Default: 10
    """

    reconstruction_loss: Literal["bce", "mse", "obsce"] = "mse"
    number_samples: int = 100
    input_dim: int = 1
    output_dim: int = 1
    latent_dim: int = 10
    hidden_dim: int = 256
    input_bins: Union[Tuple[int, ...], None] = None
    perc_miss: float = 0.
    embedding_dim: int = 64
    

### MIWAE_FC

In [5]:
from pythae.models.nn import BaseEncoder, BaseDecoder
from pythae.models.base.base_utils import ModelOutput
import torch.nn as nn
import torch
import torch.nn.functional as F


class ResBlock_FC(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super(ResBlock_FC, self).__init__()
        assert out_channels==in_channels 
        self.linear_layer1 = nn.Sequential(
                        nn.BatchNorm1d(in_channels),
                        nn.ReLU(inplace=True),
                        nn.Linear(in_channels, middle_channels, bias=False))
        self.linear_layer2 = nn.Sequential(
                        nn.BatchNorm1d(middle_channels),
                        nn.ReLU(inplace=True),
                        nn.Linear(middle_channels, out_channels, bias=False))   
        
    def forward(self, x):
        residual = x
        out = self.linear_layer1(x)
        out = self.linear_layer2(out)
        out += residual
        return out
    
class Encoder_FC_MissIWAE_Power(BaseEncoder):
    def __init__(self, args):
        BaseEncoder.__init__(self)

        self.input_dim = args.input_dim
        self.hidden_dim = args.hidden_dim
        self.latent_dim = args.latent_dim
        self.embedding_dim = args.embedding_dim
        self.register_buffer('position_ids', torch.arange(self.input_dim) / self.input_dim)

        self.input_layer = nn.Sequential(
            nn.Linear(self.input_dim, self.embedding_dim, bias=False),
            nn.Linear(self.embedding_dim, self.hidden_dim, bias=False),
        )
        self.residual_layers = nn.Sequential(
            *[ResBlock_FC(self.hidden_dim, int(self.hidden_dim / 4), self.hidden_dim) for _ in range(2)]
        )

        self.posterior = nn.Linear(self.hidden_dim, 2 * self.latent_dim)

    def forward(self, x: torch.Tensor):
        
        out = self.input_layer(x + self.position_ids) # (bs, hidden_dim)
        out = self.residual_layers(out) # (bs, hidden_dim)
        out = self.posterior(out) # (bs, latent_dim * 2)
        embedding, log_covariance = torch.split(out, self.latent_dim, dim=-1)
        output = ModelOutput(
            embedding=embedding,
            log_covariance=log_covariance
        )
        return output
    
class Decoder_FC_MissIWAE_Power(BaseDecoder):
    def __init__(self, args):
        BaseDecoder.__init__(self)
        
        self.output_dim = args.output_dim
        self.hidden_dim = args.hidden_dim
        self.latent_dim = args.latent_dim

        self.prior = nn.Linear(self.latent_dim, self.hidden_dim, bias=False)
        self.residual_layers = nn.Sequential(
            *[ResBlock_FC(self.hidden_dim, int(self.hidden_dim / 4), self.hidden_dim) for _ in range(2)]
        )
        self.output_layer = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Linear(self.hidden_dim, self.output_dim),
            # nn.Sigmoid()
        )

    def forward(self, z: torch.Tensor):
        out = self.prior(z)
        out = self.residual_layers(out)
        out = self.output_layer(out)
        output = ModelOutput(reconstruction=out)

        return output



### MIWAE_Conv

slower

In [None]:
from pythae.models import VAE
from pythae.models.nn import BaseEncoder, BaseDecoder
from pythae.models.base.base_utils import ModelOutput
from pythae.data.datasets import BaseDataset
import torch.nn as nn
from typing import Optional
import torch
import torch.nn.functional as F

class ResBlock_Conv(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super(ResBlock_Conv, self).__init__()
        assert out_channels==in_channels 
        self.conv1 = nn.Sequential(
                        nn.BatchNorm1d(in_channels),
                        nn.ReLU(),
                        nn.Conv1d(in_channels, middle_channels, kernel_size = 3, stride = 1, padding = 1, bias=False))
        self.conv2 = nn.Sequential(
                        nn.BatchNorm1d(middle_channels),
                        nn.ReLU(),
                        nn.Conv1d(middle_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias=False))   
        
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        out += residual
        return out

class Encoder_Conv_MissIWAE_Power(BaseEncoder):
    def __init__(self, args):
        BaseEncoder.__init__(self)

        self.n_channels = 1
        self.input_dim = args.input_dim
        self.hidden_dim = args.hidden_dim
        self.latent_dim = args.latent_dim

        self.input_layers = nn.Linear(self.input_dim, self.hidden_dim, bias=False)
        self.fc2conv = nn.Conv1d(self.n_channels, self.hidden_dim, self.hidden_dim, bias=False)
        self.residual_layers = nn.Sequential(
            *[ResBlock_Conv(self.hidden_dim, int(self.hidden_dim / 4), self.hidden_dim) for _ in range(2)]
        )
        self.posterior = nn.Conv1d(self.hidden_dim, self.latent_dim * 2, 1, 1, 0)

    def forward(self, x: torch.Tensor):
        """
        The Encoder_Conv

        Args:
            x (Tensor): (bs, n_features,)

        Returns:
            ModelOutput: An instance of ModelOutput containing all the relevant parameters

        """
        out = self.input_layers(x).unsqueeze(1) # (bs, 1, hidden_dim)
        out = self.fc2conv(out) # (bs, hidden_dim, 1)
        out = self.residual_layers(out) # (bs, hidden_dim, 1)
        out = self.posterior(out) # (bs, latent_dim * 2, 1)
        embedding, log_covariance = torch.split(out, self.latent_dim, dim=1)
        output = ModelOutput(
            embedding=embedding.reshape(x.shape[0], -1),
            log_covariance=log_covariance.reshape(x.shape[0], -1)
        )
        return output
    
class Decoder_Conv_MissIWAE_Power(BaseDecoder):
    def __init__(self, args):
        BaseDecoder.__init__(self)
        
        self.n_channels = 1
        self.output_dim = args.output_dim
        self.hidden_dim = args.hidden_dim
        self.latent_dim = args.latent_dim

        self.prior = nn.Conv1d(self.latent_dim, self.hidden_dim, 1, 1, 0, bias=False)
        self.residual_layers = nn.Sequential(
            *[ResBlock_Conv(self.hidden_dim, int(self.hidden_dim / 4), self.hidden_dim) for _ in range(2)]
        )
        self.conv2fc = nn.ConvTranspose1d(self.hidden_dim, self.n_channels, self.hidden_dim)
        self.output_layer = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Linear(self.hidden_dim, self.output_dim),
            # nn.Sigmoid(),
        )

    def forward(self, z: torch.Tensor):
        """
        The Decoder_Conv

        Args:
            z (Tensor): (bs * n_samples, latent_dim,)

        Returns:
            ModelOutput: An instance of ModelOutput containing all the relevant parameters

        """
        
        out = self.prior(z.reshape(z.shape[0], self.latent_dim, -1)) # (bs * n_samples, hidden_dim, 1)
        out = self.residual_layers(out) # (bs * n_samples, hidden_dim, 1)
        out = self.conv2fc(out) # (bs * n_samples, 1, hidden_dim)
        out = self.output_layer(out.reshape(z.shape[0], -1)) # (bs * n_samples, output_dim)
        output = ModelOutput(reconstruction=out)

        return output

### MissIWAE

In [None]:
import time
from pythae.models import VAE
from pythae.models.nn import BaseEncoder, BaseDecoder
from pythae.models.base.base_utils import ModelOutput
from pythae.data.datasets import BaseDataset
import torch.nn as nn
from typing import Optional
import torch
import torch.nn.functional as F

class MissIWAE(VAE):
    """
    MissIWAE is based on the Importance Weighted Autoencoder model, maximises a potentially tight lower bound of the log-likelihood of the observed data.

    Args:
        model_config (MissIWAEConfig): The IWAE configuration setting the main
            parameters of the model.

        encoder (BaseEncoder): An instance of BaseEncoder (inheriting from `torch.nn.Module` which
            plays the role of encoder. This argument allows you to use your own neural networks
            architectures if desired. If None is provided, a simple Multi Layer Preception
            (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.

        decoder (BaseDecoder): An instance of BaseDecoder (inheriting from `torch.nn.Module` which
            plays the role of decoder. This argument allows you to use your own neural networks
            architectures if desired. If None is provided, a simple Multi Layer Preception
            (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.

    .. note::
        
    """

    def __init__(
        self,
        model_config: MissIWAEConfig,
        encoder: Optional[BaseEncoder] = None,
        decoder: Optional[BaseDecoder] = None,
    ):

        VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder)

        self.model_name = "MissIWAE"
        self.n_samples = model_config.number_samples
        self.mask = None
        self.input_bins = model_config.input_bins
        # print(f'input_bins {self.input_bins}')

    def forward(self, inputs: BaseDataset, **kwargs):
        """
        The VAE model

        Args:
            inputs (BaseDataset): The training dataset with labels

        Returns:
            ModelOutput: An instance of ModelOutput containing all the relevant parameters

        """
        
        x = inputs["data"] # (bs, n_dims)
        # print(f'x {x.device}')
        # xb = self._encode_onehot(x, self.input_bins) # (bs, n_features)
        xb = inputs["data_one_hot"]
        
        # miss_pattern = torch.randperm(xb.numel())[:int(xb.numel() * self.model_config.perc_miss)]
        # x_coords, y_coords = torch.div(miss_pattern, xb.shape[1], rounding_mode='floor'), miss_pattern % xb.shape[1]
        # mask = torch.zeros(xb.shape, dtype=torch.bool, device='cpu') # (bs, n_features)
        # mask[x_coords, y_coords] = 1
        # xbhat = xb.clone().detach()
        mask = inputs["masks"]
        
        # mask[xb.bool()] = 0 # the actual values should be observed
        xb[mask.bool()] = 0.5 # in xbhat, the missing values are represented by 0.5
        
        # self.mask = mask.unsqueeze(-1) # (bs, n_features, 1)
        self.mask = mask.unsqueeze(-1).to(device=xb.device)
        
        encoder_output = self.encoder(xb)

        mu, log_var = encoder_output.embedding, encoder_output.log_covariance

        mu = mu.unsqueeze(1).repeat(1, self.n_samples, 1)
        log_var = log_var.unsqueeze(1).repeat(1, self.n_samples, 1)
        # print(f'mu : {mu.shape}')
        # print(f'log_var : {log_var.shape}')
        std = torch.exp(0.5 * log_var)

        z, _ = self._sample_gauss(mu, std)

        recon_xb = self.decoder(z.reshape(-1, self.latent_dim))[
            "reconstruction"
        ].reshape(x.shape[0], -1, self.n_samples) # (bs, n_features, n_samples)
        # print(f'recon_xb : {recon_xb.shape}')
        
        loss, recon_loss, kld = self.loss_function(recon_xb, xb, mu, log_var, z) # binary_cross_entropy_with_logits
        # loss, recon_loss, kld = self.loss_function(recon_xb, x, mu, log_var, z) # cross_entropy
        
        recon_x = torch.zeros((x.shape[0], self.n_samples, x.shape[1]), device=x.device) # (bs, n_samples, n_dims)
        # for i, coli_dom_size in enumerate(self.input_bins):
        #     # colis = recon_xb[:, :, start : start + coli_dom_size]
        #     probs_is = torch.softmax(self.mask * recon_xb[:, start : start + coli_dom_size, :], 1).mean(-1) # (bs, input_bin in n_features)
        #     # print(f'prodbs_is: {probs_is.shape}')
        #     recon_x[:, :, i] = torch.multinomial(probs_is, num_samples=self.n_samples, replacement=True)  # (bs, num_samples, i in n_dims)
        #     start += coli_dom_size
       
        output = ModelOutput(
            recon_loss=recon_loss,
            reg_loss=kld,
            loss=loss,
            recon_x=recon_x.reshape(x.shape[0], self.n_samples, -1)[:, 0, :].reshape_as(
                x
            ),
            z=z.reshape(x.shape[0], self.n_samples, -1)[:, 0, :].reshape(
                -1, self.latent_dim
            ),
        )

        return output

    def loss_function(self, recon_x, x, mu, log_var, z):
        """
        loss function

        Args:
            recon_x (tensor): (bs, n_features, n_samples)
            x (tensor): (bs, n_features) -> binary_cross_entropy_with_logits or (bs, n_dims) -> cross_entropy
            mu (tensor): (bs, n_samples, latent_dim)
            log_var (tensor): (bs, n_samples, latent_dim)
            z (tensor): (bs, n_samples, latent_dim)

        Returns:
            elbo, recon_loss, kld

        """

        if self.model_config.reconstruction_loss == "mse":

            recon_loss = (
                0.5
                * F.mse_loss(
                    recon_x,
                    x.reshape(recon_x.shape[0], -1)
                    .unsqueeze(1)
                    .repeat(1, self.n_samples, 1),
                    reduction="none",
                ).sum(dim=-1)
            )

        elif self.model_config.reconstruction_loss == "bce":

            recon_loss = F.binary_cross_entropy(
                recon_x,
                x.reshape(recon_x.shape[0], -1)
                .unsqueeze(1)
                .repeat(1, self.n_samples, 1),
                reduction="none",
            ).sum(dim=-1)
            
        elif self.model_config.reconstruction_loss == "obsce":
            
            x = x.reshape(recon_x.shape[0], -1).unsqueeze(-1).repeat(1, 1, self.n_samples) # (bs, n_dims, n_samples) or (bs, n_features, n_samples)
            if x.shape == recon_x.shape:
                # mask ->(bs, n_features, 1) * entropy -> (bs, n_features, n_samples)
                # softmax each attr's feature then binary_cross_entropy
                start = 0
                for i in range(len(self.input_bins)):
                    recon_x[:, start: start + self.input_bins[i], :] = self._gumbel_softmax(recon_x[:, start: start + self.input_bins[i], :], 1).float()
                    start += self.input_bins[i]
                    
                recon_loss = (
                    ~self.mask
                    * F.binary_cross_entropy(
                        recon_x,
                        x,
                        reduction="none",   
                    ).float() 
                ).sum(dim=1) # (bs, n_samples)
                
            else:
                recon_loss = torch.zeros(recon_x.size()[0], self.n_samples, device=recon_x.device) # (bs, n_samples)
                start = 0
                recon_x = ~self.mask * recon_x
                for i in range(len(self.input_bins)):
                    xb = self._gumbel_softmax(recon_x[:, start: start + self.input_bins[i], :], dim=1).float()
                    recon_loss += F.cross_entropy(
                       xb,
                        x[:, i, :].long(),
                        reduction="none",
                    )
                    start += self.input_bins[i]
            
        log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / (log_var.exp() + 10e-7))).float().sum(dim=-1) # (bs, n_samples)
        log_p_z = -0.5 * (z ** 2).float().sum(dim=-1) # (bs, n_samples)

        KLD = -(log_p_z - log_q_z)

        log_w = -(recon_loss + KLD).float() # (bs, n_samples)

        # log_w_minus_max = log_w - log_w.max(1, keepdim=True)[0]
        # w = log_w_minus_max.exp()
        # w_tilde = (w / w.sum(axis=1, keepdim=True)).detach()
        w_tilde = F.log_softmax(log_w, dim=1).exp().detach()

        return (
            -(w_tilde * log_w).sum(1).mean(),
            recon_loss.mean(),
            KLD.mean(),
        )

    def _sample_gauss(self, mu, std):
        # Reparametrization trick
        # Sample N(0, I)
        eps = torch.randn_like(std)
        return mu + eps * std, eps
    
    def _gumbel_softmax(self, logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> torch.Tensor:
        gumbels = (
            -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
        )  # ~Gumbel(0,1)
        gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
        y_soft = gumbels.softmax(dim)

        if hard:
            # Straight through.
            index = y_soft.max(dim, keepdim=True)[1]
            y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
            ret = y_hard - y_soft.detach() + y_soft
        else:
            # Reparametrization trick.
            ret = y_soft
        return ret

    def _encode_onehot(self, data, input_bins):
        '''
        Args:
            data: tensor.
        '''
        # print(f'data : {data.keys}')
        bs = data.size()[0]
        y_onehots = []
        data = data.long()
        # print(f'input_bins : {input_bins}')
        for i, coli_dom_size in enumerate(input_bins):
            if coli_dom_size <= 2:
                y_onehots.append(data[:, i].view(-1, 1).float())
            else:
                y_onehot = torch.zeros(bs, coli_dom_size, device=data.device)
                y_onehot.scatter_(1, data[:, i].view(-1, 1), 1)
                y_onehots.append(y_onehot)
        # [bs, sum(dist size)]
        return torch.cat(y_onehots, 1)    
    
    def probe(self, x, mask, n_samples, **kwargs):
        """
        probe

        Args:
            x (tensor): (bs, n_features) 
            mask (tensor): (bs, n_features)

        Returns:
            elbo, recon_loss, kld

        """

        # print(f'x : {x.device}')
        encoder_output = self.encoder(x)

        mu, log_var = encoder_output.embedding, encoder_output.log_covariance

        mu = mu.unsqueeze(1).repeat(1, n_samples, 1)
        log_var = log_var.unsqueeze(1).repeat(1, n_samples, 1)

        std = torch.exp(0.5 * log_var)

        z, _ = self._sample_gauss(mu, std)

        recon_x = self.decoder(z.reshape(-1, self.latent_dim))[
            "reconstruction"
        ].reshape(x.shape[0], n_samples, -1) # (bs, n_samples, n_features)
        
        start = 0
        for i in range(len(self.input_bins)):
            recon_x[:, start: start + self.input_bins[i], :] = F.softmax(recon_x[:, start: start + self.input_bins[i], :], 1).float()
            start += self.input_bins[i]
            
        recon_loss = (
            ~mask
            * F.binary_cross_entropy(
                recon_x,
                x.reshape(recon_x.shape[0], -1).unsqueeze(1).repeat(1, n_samples, 1),
                reduction="none",   
            ).float() 
        ).sum(dim=-1) # (bs, n_samples)
        
        # recon_loss = (
        #             ~mask
        #             * F.binary_cross_entropy_with_logits(
        #                 recon_x,
        #                 x.reshape(recon_x.shape[0], -1)
        #                 .unsqueeze(1)
        #                 .repeat(1, n_samples, 1),
        #                 reduction="none",
        #             )
        #         ).sum(dim=-1)
        log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / (log_var.exp() + 10e-8))).sum(dim=-1) # (bs, n_samples)
        log_p_z = -0.5 * (z ** 2).sum(dim=-1) # (bs, n_samples)
        
        imp_weights = F.softmax(recon_loss + log_p_z - log_q_z, 1).reshape(-1, recon_x.shape[0]) # (n_samples, bs)
        recon_x = recon_x.reshape(n_samples, recon_x.shape[0], -1) # (n_samples, bs, n_features)
        
        xm = torch.einsum('ki,kij->ij', imp_weights, recon_x) # (bs, n_features)
        # print(f'xm : {xm.shape}')
        recon_probe = torch.ones(recon_x.size()[0], 1, device=recon_x.device) # (bs, 1)
        start = 0
        for i in range(len(self.model_config.input_bins)):
            probs_i = F.softmax(xm[:, start : start + self.model_config.input_bins[i]], 1) * mask[:, start : start + self.model_config.input_bins[i]]
            probs_i = probs_i.sum(dim=1)
            # print(f'probs_i : {probs_i.shape}')
            recon_probe *= probs_i
            start += self.model_config.input_bins[i]
            
        return recon_probe.squeeze(-1)
        

In [13]:
model_config = MissIWAEConfig(
    input_dim = sum(input_bins),
    latent_dim = 16,
    output_dim = sum(input_bins),
    hidden_dim = 256,
    input_bins = tuple(input_bins),
    perc_miss = 0.5,
    reconstruction_loss = "obsce",
    number_samples = 50
    )

# print(f'input_bins : {tuple(input_bins)}')

encoder = Encoder_FC_MissIWAE_Power(model_config)
decoder = Decoder_FC_MissIWAE_Power(model_config)
'''
encoder = Encoder_Conv_MissIWAE_Power(model_config)
decoder = Decoder_Conv_MissIWAE_Power(model_config)
'''

device = "cuda" if torch.cuda.is_available() else "cpu"

model = MissIWAE(
    model_config = model_config,
    encoder = encoder,
    decoder = decoder
).cuda()

In [14]:
from pythae.trainers import BaseTrainerConfig
from pythae.pipelines import TrainingPipeline
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

def linear_warmup(warmup_iters):
    def f(iteration):
        return 1.0 if iteration > warmup_iters else iteration / warmup_iters
    return f

''''''
training_config = BaseTrainerConfig(
    output_dir = './saved_models/power_test/imputated_ce',
    learning_rate = 1.5e-3,
    per_device_train_batch_size = 1024,
    per_device_eval_batch_size = 1024,
    steps_saving = None,
    num_epochs = 10,
    train_dataloader_num_workers = 4,
    eval_dataloader_num_workers = 4,
    optimizer_cls = "AdamW",
    optimizer_params = {
        "betas" : (0.99, 0.999),
        },
    # scheduler_cls = "OneCycleLR",
    # scheduler_params = {
    #     "pct_start" : 0.2,
    #     "max_lr" : 1e-3,
    #     "total_steps" : 20,
    #     },
    # amp = True # binary_cross_entropy can not support now
)

pipeline = TrainingPipeline(
    model = model,
    training_config = training_config
)

from pythae.trainers.training_callbacks import WandbCallback

callbacks = []
wandb_cb = WandbCallback()
wandb_cb.setup(
    training_config = training_config,
    model_config = model_config,
    project_name = "vae-ce",
    entity_name = "spice-neu-edu-cn"
)
callbacks.append(wandb_cb)

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

# with torch.autograd.detect_anomaly():
#     pipeline(
#         train_data=table,
#         # eval_data=table,
#         callbacks=callbacks,
#     )
    
pipeline(
        train_data=table,
        # eval_data=table,
        callbacks=callbacks,
    )

'''
import torch.profiler as profiler
with profiler.profile(
    activities=[
        profiler.ProfilerActivity.CPU,
        profiler.ProfilerActivity.CUDA],  # 分析 CPU 和 CUDA 活动
    schedule=torch.profiler.schedule(
        wait=1,  # 前1步不采样
        warmup=1,  # 第2步作为热身，不计入结果
        active=3,  # 采集后面3步的性能数据
        repeat=2),  # 重复2轮
    on_trace_ready=torch.profiler.tensorboard_trace_handler('./runs'),  # 保存日志以供 TensorBoard 可视化
    record_shapes=True,  # 记录输入张量的形状
    profile_memory=True,  # 分析内存分配
    with_stack=True  # 记录操作的调用堆栈信息
) as profiler:
    pipeline(
        train_data=table,
        # eval_data=table,
    )
'''

Checking train dataset...
Using Base Trainer

! No eval dataset provided ! -> keeping best model on train.



ModelError: Error when calling forward method from model. Potential issues: 
 - Wrong model architecture -> check encoder, decoder and metric architecture if you provide yours 
 - The data input dimension provided is wrong -> when no encoder, decoder or metric provided, a network is built automatically but requires the shape of the flatten input data.
Exception raised: <class 'TypeError'> with message: empty_like(): argument 'input' (position 1) must be Tensor, not MissIWAE

#### result
    MVAE
    15 epoch
    --------------------------------------------------------------------------
    Train loss: 33.732
    --------------------------------------------------------------------------
    Training ended!

    2 layers 100 samples
    10 epoch 
    --------------------------------------------------------------------------
    Train loss: 33.5193
    --------------------------------------------------------------------------
    30 epoch 
    --------------------------------------------------------------------------
    Train loss: 33.3101
    --------------------------------------------------------------------------
    Training ended!

    2 layers 50 samples
    10 epoch
    --------------------------------------------------------------------------
    Train loss: 33.4454
    --------------------------------------------------------------------------
    15 epoch
    --------------------------------------------------------------------------
    Train loss: 33.2469
    --------------------------------------------------------------------------
    40 epoch
    --------------------------------------------------------------------------
    Train loss: 33.1056
    --------------------------------------------------------------------------

    4 layers 50 samples
    10 epoch
    --------------------------------------------------------------------------
    Train loss: 33.5834
    --------------------------------------------------------------------------
    15 epoch
    --------------------------------------------------------------------------
    Train loss: 33.3808
    --------------------------------------------------------------------------
    30 epoch 
    --------------------------------------------------------------------------
    Train loss: 33.1588
    --------------------------------------------------------------------------

    4 layers 100 samples
    30 epoch 
    --------------------------------------------------------------------------
    Train loss: 33.16
    --------------------------------------------------------------------------
    Training ended!


In [None]:
import json
import os
from typing import Any, Dict

from pydantic import ValidationError

from pythae.models.base.base_utils import CPU_Unpickler

def from_dict(config_dict: Dict[str, Any]) -> "BaseConfig":
        """Creates a :class:`~pythae.config.BaseConfig` instance from a dictionnary

        Args:
            config_dict (dict): The Python dictionnary containing all the parameters

        Returns:
            :class:`BaseConfig`: The created instance
        """
        try:
            config = MissIWAEConfig(**config_dict)
        except (ValidationError, TypeError) as e:
            raise e
        return config

def dict_from_json(json_path: Union[str, os.PathLike]) -> Dict[str, Any]:
        try:
            with open(json_path) as f:
                try:
                    config_dict = json.load(f)
                    return config_dict

                except (TypeError, json.JSONDecodeError) as e:
                    raise TypeError(
                        f"File {json_path} not loadable. Maybe not json ? \n"
                        f"Catch Exception {type(e)} with message: " + str(e)
                    ) from e

        except FileNotFoundError:
            raise FileNotFoundError(
                f"Config file not found. Please check path '{json_path}'"
            )

def from_json_file(json_path: str) -> "BaseConfig":
        """Creates a :class:`~pythae.config.BaseConfig` instance from a JSON config file

        Args:
            json_path (str): The path to the json file containing all the parameters

        Returns:
            :class:`BaseConfig`: The created instance
        """
        config_dict = dict_from_json(json_path)

        config_name = config_dict.pop("name")

        return from_dict(config_dict)

def load_model_config_from_folder(dir_path):
        file_list = os.listdir(dir_path)

        if "model_config.json" not in file_list:
            raise FileNotFoundError(
                f"Missing model config file ('model_config.json') in"
                f"{dir_path}... Cannot perform model building."
            )

        path_to_model_config = os.path.join(dir_path, "model_config.json")
        model_config = from_json_file(path_to_model_config)

        return model_config

def load_model_weights_from_folder(dir_path):
        file_list = os.listdir(dir_path)

        if "model.pt" not in file_list:
            raise FileNotFoundError(
                f"Missing model weights file ('model.pt') file in"
                f"{dir_path}... Cannot perform model building."
            )

        path_to_model_weights = os.path.join(dir_path, "model.pt")

        try:
            model_weights = torch.load(path_to_model_weights, map_location="cpu")

        except RuntimeError:
            RuntimeError(
                "Enable to load model weights. Ensure they are saves in a '.pt' format."
            )

        if "model_state_dict" not in model_weights.keys():
            raise KeyError(
                "Model state dict is not available in 'model.pt' file. Got keys:"
                f"{model_weights.keys()}"
            )

        model_weights = model_weights["model_state_dict"]

        return model_weights
    
def load_custom_encoder_from_folder(dir_path):

        file_list = os.listdir(dir_path)

        if "encoder.pkl" not in file_list:
            raise FileNotFoundError(
                f"Missing encoder pkl file ('encoder.pkl') in"
                f"{dir_path}... This file is needed to rebuild custom encoders."
                " Cannot perform model building."
            )

        else:
            with open(os.path.join(dir_path, "encoder.pkl"), "rb") as fp:
                encoder = CPU_Unpickler(fp).load()

        return encoder

def load_custom_decoder_from_folder(dir_path):

    file_list = os.listdir(dir_path)

    if "decoder.pkl" not in file_list:
        raise FileNotFoundError(
            f"Missing decoder pkl file ('decoder.pkl') in"
            f"{dir_path}... This file is needed to rebuild custom decoders."
            " Cannot perform model building."
        )

    else:
        with open(os.path.join(dir_path, "decoder.pkl"), "rb") as fp:
            decoder = CPU_Unpickler(fp).load()

    return decoder    

def load_from_folder(model, dir_path):
        """Class method to be used to load the model from a specific folder

        Args:
            dir_path (str): The path where the model should have been be saved.

        .. note::
            This function requires the folder to contain:

            - | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided

            **or**

            - | a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp.
                ``decoder.pkl``) if a custom encoder (resp. decoder) was provided
        """

        model_config = load_model_config_from_folder(dir_path)
        model_weights = load_model_weights_from_folder(dir_path)
        
        encoder = load_custom_encoder_from_folder(dir_path)
        decoder = load_custom_decoder_from_folder(dir_path)
        
        model = MissIWAE(model_config, encoder=encoder, decoder=decoder)
        model.load_state_dict(model_weights)

        return model

last_training = sorted(os.listdir('./saved_models/power_test/imputated_ce'))[-1]
model = load_from_folder(model, os.path.join('./saved_models/power_test/imputated_ce', last_training, 'final_model')).to(device)
# print(f'model: {model.device}')

In [9]:
# valid
from torch.utils.data import DataLoader

model.eval()
valid_loss = []
with torch.no_grad():
    for inputs in DataLoader(table, batch_size=1024, num_workers=4, pin_memory=True):
        x = inputs["data"].cuda() # (bs, n_dims)
        xb = inputs["data_one_hot"].cuda()
        mask = inputs["masks"].cuda()
        
        xb[mask.bool()] = 0.5 # in xbhat, the missing values are represented by 0.5
        
        # self.mask = mask.unsqueeze(-1) # (bs, n_features, 1)
        model.mask = mask.unsqueeze(-1).to(device=xb.device)
        
        encoder_output = model.encoder(xb)

        mu, log_var = encoder_output.embedding, encoder_output.log_covariance

        mu = mu.unsqueeze(1).repeat(1, model.n_samples, 1)
        log_var = log_var.unsqueeze(1).repeat(1, model.n_samples, 1)
        # print(f'mu : {mu.shape}')
        # print(f'log_var : {log_var.shape}')
        std = torch.exp(0.5 * log_var)

        z, _ = model._sample_gauss(mu, std)

        recon_xb = model.decoder(z.reshape(-1, model.latent_dim))[
            "reconstruction"
        ].reshape(x.shape[0], -1, model.n_samples) # (bs, n_features, n_samples)
        # print(f'recon_xb : {recon_xb.shape}')
        x = x.reshape(recon_xb.shape[0], -1).unsqueeze(-1).repeat(1, 1, model.n_samples)
        recon_loss = torch.zeros(recon_xb.size()[0], model.n_samples, device=recon_xb.device) # (bs, n_samples)
        start = 0
        for i in range(len(model.input_bins)):
            recon_loss += F.cross_entropy(
                recon_xb[:, start: start + model.input_bins[i], :],
                x[:, i, :].long(),
                reduction="none",
            )
            start += model.input_bins[i]
                
        log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / (log_var.exp() + 10e-7))).float().sum(dim=-1) # (bs, n_samples)
        log_p_z = -0.5 * (z ** 2).float().sum(dim=-1) # (bs, n_samples)

        KLD = -(log_p_z - log_q_z)

        log_w = -(recon_loss + KLD).float() # (bs, n_samples)
        w_tilde = F.log_softmax(log_w, dim=1).exp().detach()

        valid_loss.append(-(w_tilde * log_w).sum(1).mean())
    
print(f'valid_loss : {torch.stack(valid_loss).mean().item() / np.log(2)}')

valid_loss : 32.13945187641754


conv valid_loss : 32.83912683113089 \
absolute pe valid_loss : 32.13945187641754

In [10]:
import math
import numpy as np
from my_utils import Card, ErrorMetric, FillInUnqueriedColumns, GenerateQuery

OPS = {
    '>': np.greater,
    '<': np.less,
    '>=': np.greater_equal,
    '<=': np.less_equal,
    '=': np.equal
}

def eval_power_discrete(model):
    rng = np.random.RandomState(1234)
    count = 0
    n_rows = table.tuples.shape[0]
    qerrors = []

    for i in range(3000):
            
        cols, ops, vals = GenerateQuery(original_data.columns, rng, original_data.data)
        true_card = Card(original_data.data, cols, ops, vals)
        # print(cols, ops, vals)
        columns, operators, vals = FillInUnqueriedColumns(original_data, cols, ops, vals)
                
        ncols = len(original_data.columns)
        
        mask_i_list = [None] * ncols  # None means all valid.
        for i in range(ncols):
            
            # Column i.
            op = operators[i]
            if op is not None:
                # There exists a filter.
                mask_i = OPS[op](columns[i].all_distinct_values,
                                  vals[i]).astype(np.float32, copy=False)
            else:
                mask_i = np.ones(len(columns[i].all_distinct_values), dtype=np.float32)
                
            mask_i_list[i] = torch.as_tensor(mask_i, dtype=torch.bool, device=device).view(1, -1)
            # print(f'mask_i: {mask_i_list[i].shape}')
        
        mask = torch.cat(mask_i_list, dim=1)
        # print(f'mask: {mask.shape}')
        xobs = torch.zeros(mask.size(), dtype=torch.float32, device=device)
        # print(f'xobs: {xobs.shape}')
        xobs[mask] = 0.5
        # print(mask)
        # break
        probs = model.probe(xobs, mask, 10).detach().cpu().numpy().tolist()
        prob = probs[0]
        # print(f'prob: {prob}')
        
        
        est_card = max(prob * n_rows, 1)
        
        if est_card > n_rows:
            count += 1
            est_card = n_rows
            # print(f'prob {prob} true_card: {true_card}')
            
        qerror = ErrorMetric(est_card, true_card)
        qerrors.append(qerror)
        
        if i % 100 == 0:
            print(f'{i} queries done')
    
    return count, qerrors

In [11]:
model.eval()
count, qerrors = eval_power_discrete(model)

print(f'estimation failed times: {count}')
print('test results')
print(f"Median: {np.median(qerrors)}")
print(f"90th percentile: {np.percentile(qerrors, 90)}")
print(f"95th percentile: {np.percentile(qerrors, 95)}")
print(f"99th percentile: {np.percentile(qerrors, 99)}")
print(f"Max: {np.max(qerrors)}")
print(f"Mean: {np.mean(qerrors)}")

estimation failed times: 8
test results
Median: 26.057259944782736
90th percentile: 2506.5918770686994
95th percentile: 10210.681563767039
99th percentile: 98248.27999999905
Max: 776636.0
Mean: 4299.89007198345


cross_entropy
10 epoch
estimation failed times: 2
test results
Median: 23.151113000960258
90th percentile: 2567.3999999999996
95th percentile: 10652.93222600411
99th percentile: 109012.11886992674
Max: 535608.0
Mean: 4093.639562321854

binary_cross_entropy
10 epoch 
estimation failed times: 0
test results
Median: 28.51926446395437
90th percentile: 2609.357427105609
95th percentile: 11461.255020661672
99th percentile: 124040.84624133495
Max: 533867.357491531
Mean: 4704.4605649417435
binary_cross_entropy absolute pe
10 epoch
estimation failed times: 8
test results
Median: 26.057259944782736
90th percentile: 2506.5918770686994
95th percentile: 10210.681563767039
99th percentile: 98248.27999999905
Max: 776636.0
Mean: 4299.89007198345
