In [1]:
import os
os.chdir('../')

import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="1"
device = 'cuda:0'

Thu Jan 18 21:57:05 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.223.02   Driver Version: 470.223.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  Off  | 00000000:1B:00.0 Off |                    0 |
| N/A   33C    P0    64W / 300W |   8433MiB / 80994MiB |      9%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  Off  | 00000000:1C:00.0 Off |                    0 |
| N/A   59C    P0    81W / 300W |     35MiB / 80994MiB |      0%      Default |
|       

### Model Parts

In [3]:
from model.main.enc_prior_latent_dec import Model
from model.encoder.conv2d_vae_encoder import Encoder
from model.prior.normal_prior import Prior
from model.latent.vae_latent import Latent
from model.decoder.conv2d_decoder import Decoder

from tensorboardX import SummaryWriter
from util.util import *

### Model Init.

In [4]:
from easydict import EasyDict
hp = EasyDict()
hp.size = 64
hp.in_dim = 3
hp.out_dim = 3
hp.z_dim = 128
hp.h_dims = [32, 64, 128, 256, 512]
hp.activation = F.sigmoid

In [5]:
model_list = []
for i in range(10):
    model = Model(Encoder(**hp), Prior(**hp), Latent(**hp), Decoder(**hp))
    model = model.to(device)
    model.eval()
    model_list.append(model)
    
print('done')

done


### Dataset

In [6]:
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CelebA

class MyCelebA(CelebA):
    """
    A work-around to address issues with pytorch's celebA dataset class.
    
    Download and Extract
    URL : https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing
    """
    
    def _check_integrity(self) -> bool:
        return True

root = '/data'
test_transforms = transforms.Compose([transforms.CenterCrop(148),
                                      transforms.Resize(hp.size),
                                      transforms.ToTensor(),])
test_dataset = MyCelebA(root, split='test', transform=test_transforms, download=False)
test_loader = DataLoader(test_dataset, batch_size=10000, shuffle=False)
print('done')

done


In [7]:
def preprocess(batch):
    x, t = batch
    data = {}
    data['x'] = x.to(device)
    data['t'] = t.to(device)
    return data

batch = next(iter(test_loader))
data = preprocess(batch)
print(data.keys())

dict_keys(['x', 't'])


### Load

In [8]:
from tqdm import tqdm_notebook as tqdm

save_path = '/data/scpark/save/lse/train_celeba/train01.17-1/save_200000'

# warm start
models_state_dict = torch.load(save_path, map_location=torch.device('cpu'))['models_state_dict']
for i, model in tqdm(enumerate(model_list)):
    model.load_state_dict(models_state_dict[i], strict=True)

0it [00:00, ?it/s]

### Reconstruction Loss

In [9]:
z_list = []
recon_loss_list = []
for model in tqdm(model_list):
    with torch.no_grad():
        data = model(data, M=1)
        z_list.append(data['z'])
        recon_loss_list.append(data['recon_loss'].item())
        
print('Reconstruction Loss :', np.mean(recon_loss_list))

  0%|          | 0/10 [00:00<?, ?it/s]

Reconstruction Loss : 0.012373057287186384


### Get Scale

In [10]:
scale_list = []
for z in z_list:
    scale = z.std(dim=0)
    scale_list.append(scale)
    
print(scale_list)

[tensor([1.0011, 0.9907, 0.9867, 0.9950, 0.9837, 0.9944, 0.9817, 1.0122, 1.0078,
        0.9851, 1.0019, 1.0087, 0.9962, 0.9823, 0.9962, 0.9939, 1.0086, 1.0008,
        1.0012, 0.9907, 1.0029, 1.0190, 0.9961, 1.0019, 1.0067, 1.0039, 1.0200,
        0.9795, 0.9987, 1.0116, 1.0119, 0.9905, 0.9900, 0.9937, 1.0076, 0.9825,
        0.9758, 0.9968, 0.9930, 0.9958, 0.9932, 0.9960, 0.9924, 0.9989, 0.9985,
        0.9947, 0.9992, 1.0009, 0.9975, 1.0031, 1.0009, 0.9958, 0.9960, 0.9994,
        0.9831, 0.9939, 0.9972, 1.0049, 0.9892, 0.9999, 0.9943, 0.9979, 0.9897,
        1.0056, 1.0082, 0.9925, 1.0011, 0.9903, 0.9977, 0.9919, 1.0030, 1.0036,
        1.0085, 1.0198, 0.9876, 1.0211, 0.9996, 1.0069, 0.9987, 0.9954, 0.9969,
        0.9938, 0.9950, 1.0076, 0.9845, 1.0011, 0.9963, 0.9872, 0.9947, 1.0031,
        1.0144, 1.0071, 1.0031, 0.9884, 1.0004, 0.9964, 0.9924, 0.9869, 0.9956,
        0.9915, 0.9922, 0.9910, 0.9673, 0.9998, 1.0032, 1.0195, 1.0044, 0.9941,
        0.9923, 0.9775, 0.9933, 0.9990,

### MMD Test

In [11]:
from util.mmd_penalty import mmd_penalty

opts = {'pz_scale': 1,
        'mmd_kernel': 'RBF', # 'IMQ', 'RBF'
        'pz': 'normal', # 'normal', 'sphere', 'uniform' 
        'zdim': hp.z_dim
       }

mmd_losses = []
for z in tqdm(z_list):
    prior = model_list[0].prior.sample(len(z), hp.z_dim)
    mmd_loss = mmd_penalty(prior, z, opts)
    mmd_losses.append(mmd_loss.item())
    
print('MMD Loss :', np.mean(mmd_losses))    

  0%|          | 0/10 [00:00<?, ?it/s]

MMD Loss : 8.291006088256836e-05


In [12]:
from util.mmd_penalty import mmd_penalty

opts = {'pz_scale': 1,
        'mmd_kernel': 'RBF', # 'IMQ', 'RBF'
        'pz': 'normal', # 'normal', 'sphere', 'uniform' 
        'zdim': hp.z_dim
       }

mmd_losses = []
for z, scale in tqdm(zip(z_list, scale_list)):
    prior = model_list[0].prior.sample(len(z), hp.z_dim) * scale
    mmd_loss = mmd_penalty(prior, z, opts)
    mmd_losses.append(mmd_loss.item())
    
print('Corrected MMD Loss :', np.mean(mmd_losses))    

0it [00:00, ?it/s]

Corrected MMD Loss : 9.118318557739258e-05


### Cross NLL Test

In [13]:
from util.loglikelihood import get_optimum_log_sigma, get_cross_nll

cross_nll_list = []
for _ in tqdm(range(10)):
    cross_nlls = []
    for z, model in zip(z_list, model_list):
        p_samples1 = model.prior.sample(len(z), hp.z_dim)
        p_samples2 = model.prior.sample(len(z), hp.z_dim)
        log_sigma = get_optimum_log_sigma(p_samples1, p_samples2, min_log_sigma=-5, max_log_sigma=5)
        p_samples1 = model.prior.sample(len(z), hp.z_dim)
        cross_nll = get_cross_nll(p_samples1, z, log_sigma)
        cross_nlls.append(cross_nll)
    cross_nll = np.mean(cross_nlls)
    cross_nll_list.append(cross_nll)
    
print('Cross NLL :', np.mean(cross_nll_list))

  0%|          | 0/10 [00:00<?, ?it/s]

Cross NLL : 23.634900131225585


In [14]:
from util.loglikelihood import get_optimum_log_sigma, get_cross_nll

cross_nll_list = []
for _ in tqdm(range(10)):
    cross_nlls = []
    for z, model, scale in zip(z_list, model_list, scale_list):
        p_samples1 = model.prior.sample(len(z), hp.z_dim) * scale
        p_samples2 = model.prior.sample(len(z), hp.z_dim) * scale
        log_sigma = get_optimum_log_sigma(p_samples1, p_samples2, min_log_sigma=-5, max_log_sigma=5)
        p_samples1 = model.prior.sample(len(z), hp.z_dim) * scale
        cross_nll = get_cross_nll(p_samples1, z, log_sigma)
        cross_nlls.append(cross_nll)
    cross_nll = np.mean(cross_nlls)
    cross_nll_list.append(cross_nll)
    
print('Corrected Cross NLL :', np.mean(cross_nll_list))

  0%|          | 0/10 [00:00<?, ?it/s]

Corrected Cross NLL : 23.57019254684448
