In [1]:
import os
os.chdir('../')

import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="4"
device = 'cuda:0'

Thu Jan 18 21:57:16 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.223.02   Driver Version: 470.223.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  Off  | 00000000:1B:00.0 Off |                    0 |
| N/A   33C    P0    65W / 300W |   8433MiB / 80994MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  Off  | 00000000:1C:00.0 Off |                    0 |
| N/A   56C    P0    77W / 300W |   1035MiB / 80994MiB |      0%      Default |
|       

### Model Parts

In [3]:
from model.main.enc_prior_latent_dec import Model
from model.encoder.conv2d_encoder import Encoder
from model.prior.normal_prior import Prior
from model.latent.mmd_latent import Latent
from model.decoder.conv2d_decoder import Decoder

from tensorboardX import SummaryWriter
from util.util import *

### Model Init.

In [4]:
from easydict import EasyDict
hp = EasyDict()
hp.size = 64
hp.in_dim = 3
hp.out_dim = 3
hp.z_dim = 128
hp.h_dims = [32, 64, 128, 256, 512]
hp.activation = F.sigmoid
hp.opts = {'pz_scale': 1,
           'mmd_kernel': 'IMQ', # 'IMQ', 'RBF'
            'pz': 'normal', # 'normal', 'sphere', 'uniform'
            'zdim': hp.z_dim
           }

In [5]:
model_list = []
for i in range(10):
    model = Model(Encoder(**hp), Prior(**hp), Latent(**hp), Decoder(**hp))
    model = model.to(device)
    model.eval()
    model_list.append(model)
    
print('done')

done


### Dataset

In [6]:
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CelebA

class MyCelebA(CelebA):
    """
    A work-around to address issues with pytorch's celebA dataset class.
    
    Download and Extract
    URL : https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing
    """
    
    def _check_integrity(self) -> bool:
        return True

root = '/data'
test_transforms = transforms.Compose([transforms.CenterCrop(148),
                                      transforms.Resize(hp.size),
                                      transforms.ToTensor(),])
test_dataset = MyCelebA(root, split='test', transform=test_transforms, download=False)
test_loader = DataLoader(test_dataset, batch_size=10000, shuffle=False)
print('done')

done


In [7]:
def preprocess(batch):
    x, t = batch
    data = {}
    data['x'] = x.to(device)
    data['t'] = t.to(device)
    return data

batch = next(iter(test_loader))
data = preprocess(batch)
print(data.keys())

dict_keys(['x', 't'])


### Load

In [8]:
from tqdm import tqdm_notebook as tqdm

save_path = '/data/scpark/save/lse/train_celeba/train01.17-3/save_200000'

# warm start
models_state_dict = torch.load(save_path, map_location=torch.device('cpu'))['models_state_dict']
for i, model in tqdm(enumerate(model_list)):
    model.load_state_dict(models_state_dict[i], strict=True)

0it [00:00, ?it/s]

### Reconstruction Loss

In [9]:
z_list = []
recon_loss_list = []
for model in tqdm(model_list):
    with torch.no_grad():
        data = model(data, M=1)
        z_list.append(data['z'])
        recon_loss_list.append(data['recon_loss'].item())
        
print('Reconstruction Loss :', np.mean(recon_loss_list))

  0%|          | 0/10 [00:00<?, ?it/s]

Reconstruction Loss : 0.005149052850902081


### Get Scale

In [10]:
scale_list = []
for z in z_list:
    scale = z.std(dim=0)
    scale_list.append(scale)
    
print(scale_list)

[tensor([1.0104, 0.9947, 1.0068, 1.0042, 1.0190, 0.9891, 1.0047, 0.9822, 1.0136,
        0.9970, 1.0094, 1.0129, 1.0228, 0.9911, 1.0169, 0.9928, 1.0032, 1.0241,
        1.0259, 0.9939, 1.0301, 1.0086, 0.9857, 0.9859, 1.0025, 0.9890, 0.9874,
        1.0045, 1.0076, 0.9998, 1.0142, 0.9968, 1.0246, 1.0138, 1.0193, 1.0155,
        0.9972, 0.9952, 1.0061, 0.9956, 1.0300, 1.0124, 1.0090, 1.0003, 1.0031,
        1.0156, 1.0193, 1.0094, 0.9923, 1.0110, 0.9946, 1.0020, 1.0072, 1.0061,
        0.9932, 1.0104, 1.0042, 1.0055, 1.0262, 1.0018, 1.0087, 1.0048, 1.0108,
        1.0033, 0.9853, 1.0039, 0.9889, 1.0126, 0.9858, 1.0125, 1.0067, 0.9789,
        0.9958, 1.0108, 0.9860, 1.0299, 1.0263, 1.0136, 1.0083, 0.9905, 1.0047,
        1.0054, 0.9986, 1.0213, 0.9788, 1.0350, 1.0064, 1.0167, 1.0090, 0.9885,
        0.9987, 1.0080, 1.0027, 0.9971, 1.0044, 1.0016, 1.0231, 0.9937, 1.0004,
        1.0120, 0.9966, 0.9812, 0.9858, 1.0095, 1.0044, 1.0013, 0.9988, 0.9904,
        0.9969, 1.0144, 1.0231, 1.0082,

### MMD Test

In [11]:
from util.mmd_penalty import mmd_penalty

opts = {'pz_scale': 1,
        'mmd_kernel': 'RBF', # 'IMQ', 'RBF'
        'pz': 'normal', # 'normal', 'sphere', 'uniform' 
        'zdim': hp.z_dim
       }

mmd_losses = []
for z in tqdm(z_list):
    prior = model_list[0].prior.sample(len(z), hp.z_dim)
    mmd_loss = mmd_penalty(prior, z, opts)
    mmd_losses.append(mmd_loss.item())
    
print('MMD Loss :', np.mean(mmd_losses))    

  0%|          | 0/10 [00:00<?, ?it/s]

MMD Loss : 0.0004493951797485352


In [12]:
from util.mmd_penalty import mmd_penalty

opts = {'pz_scale': 1,
        'mmd_kernel': 'RBF', # 'IMQ', 'RBF'
        'pz': 'normal', # 'normal', 'sphere', 'uniform' 
        'zdim': hp.z_dim
       }

mmd_losses = []
for z, scale in tqdm(zip(z_list, scale_list)):
    prior = model_list[0].prior.sample(len(z), hp.z_dim) * scale
    mmd_loss = mmd_penalty(prior, z, opts)
    mmd_losses.append(mmd_loss.item())
    
print('Corrected MMD Loss :', np.mean(mmd_losses))    

0it [00:00, ?it/s]

Corrected MMD Loss : 0.0004256129264831543


### Cross NLL Test

In [None]:
from util.loglikelihood import get_optimum_log_sigma, get_cross_nll

cross_nll_list = []
for _ in tqdm(range(10)):
    cross_nlls = []
    for z, model in zip(z_list, model_list):
        p_samples1 = model.prior.sample(len(z), hp.z_dim)
        p_samples2 = model.prior.sample(len(z), hp.z_dim)
        log_sigma = get_optimum_log_sigma(p_samples1, p_samples2, min_log_sigma=-5, max_log_sigma=5)
        p_samples1 = model.prior.sample(len(z), hp.z_dim)
        cross_nll = get_cross_nll(p_samples1, z, log_sigma)
        cross_nlls.append(cross_nll)
    cross_nll = np.mean(cross_nlls)
    cross_nll_list.append(cross_nll)
    
print('Cross NLL :', np.mean(cross_nll_list))

  0%|          | 0/10 [00:00<?, ?it/s]

In [None]:
from util.loglikelihood import get_optimum_log_sigma, get_cross_nll

cross_nll_list = []
for _ in tqdm(range(10)):
    cross_nlls = []
    for z, model, scale in zip(z_list, model_list, scale_list):
        p_samples1 = model.prior.sample(len(z), hp.z_dim) * scale
        p_samples2 = model.prior.sample(len(z), hp.z_dim) * scale
        log_sigma = get_optimum_log_sigma(p_samples1, p_samples2, min_log_sigma=-5, max_log_sigma=5)
        p_samples1 = model.prior.sample(len(z), hp.z_dim) * scale
        cross_nll = get_cross_nll(p_samples1, z, log_sigma)
        cross_nlls.append(cross_nll)
    cross_nll = np.mean(cross_nlls)
    cross_nll_list.append(cross_nll)
    
print('Corrected Cross NLL :', np.mean(cross_nll_list))

In [None]:
print('done')