In [1]:
import os
os.chdir('../')

import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Tue Feb 13 20:05:47 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.154.05             Driver Version: 535.154.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off | 00000000:19:00.0 Off |                  Off |
| 32%   44C    P5              75W / 200W |     11MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        Off | 00000000:68:0

### Hyper-Parameters

In [48]:
from easydict import EasyDict
from diffusers import DDPMScheduler

hp = EasyDict()

# Data
hp.dataset = 'ffhq_256'
hp.data_root = '/home/scpark/data'
hp.test_eval = True
hp.image_size = 256
hp.image_channels = 3
hp.n_batch = 8

# Model
hp.custom_width_str = ""
hp.bottleneck_multiple = 0.25
hp.no_bias_above = 64
hp.num_mixtures = 10
hp.width = 512
hp.zdim = 16
hp.dec_blocks = "1x2,4m1,4x3,8m4,8x4,16m8,16x9,32m16,32x21,64m32,64x13,128m64,128x7,256m128"
hp.enc_blocks = "256x3,256d2,128x8,128d2,64x12,64d2,32x17,32d2,16x7,16d2,8x5,8d2,4x5,4d4,1x4"

# Train
hp.lr = 1e-4

# Diffusion
hp.scheduler = DDPMScheduler()
hp.scheduler.set_timesteps(10)
hp.diff_middle_width = 128
hp.diff_residual = True


### Model

In [49]:
from model.main.vdvae_latent import Model as VAE
from model.encoder.vdvae_encoder import Encoder
from model.decoder.vdvae_diffusion_decoder import Decoder
from model.loss.dmol import Loss

from model.main.latent_diffusion import Model as LD
from model.latent_diffusion.denorm_latent_diffusion import LatentDiffusion

In [50]:
device = 'cuda:0'
ld = LD(LatentDiffusion(hp)).to(device)
print('done')

done


In [51]:
draw_samples_list = []
for i in range(len(ld.latent_diffusion.backbones)):
    def draw_samples(pm, pv):
        latents = torch.randn_like(pm) * hp.scheduler.init_noise_sigma
        for t in hp.scheduler.timesteps:
            inp = hp.scheduler.scale_model_input(latents, t)
            t_tensor = torch.ones((len(inp),)).to(device) * t
            inp = inp * torch.exp(pv) + pm
            pred = ld.latent_diffusion.backbones[i](inp, t_tensor)
            pred = (pred - pm) / torch.exp(pv)
            latents = hp.scheduler.step(pred, t, latents)['prev_sample']
        return latents * torch.exp(pv) + pm
    draw_samples_list.append(draw_samples)
    
hp.draw_samples_list = draw_samples_list
print('done')

done


In [52]:
vae = VAE(Encoder(hp), Decoder(hp), Loss(hp)).to(device)
print('done')

done


### Load

In [53]:
checkpoint_path = '/data/save/lse/train_latent/train02.13-2/save_10000'
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
ld.load_state_dict(checkpoint['model_state_dict'])
print('done')

done


In [54]:
checkpoint_path = '/data/checkpoint/ffhq256-iter-1700000-model-ema.th'
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

model_state_dict = vae.state_dict()
for key in checkpoint.keys():
    if key.startswith('encoder'):
        model_key = 'encoder.' + key
        if model_key in model_state_dict:
            model_state_dict[model_key] = checkpoint[key]
        else:
            print(model_key)
    if key.startswith('decoder'):
        if key.startswith('decoder.out_net'):
            model_key = 'loss.' + key[8:]
        else:
            model_key = 'decoder.' + key
            
        if model_key in model_state_dict:
            model_state_dict[model_key] = checkpoint[key]
        else:
            print(model_key)
            
vae.load_state_dict(model_state_dict)
print('done')

done


In [55]:
vae.eval()
with torch.no_grad():
    samples = vae.sample(5)
    print(samples.shape)

66it [00:00, 82.00it/s] 

(5, 256, 256, 3)





In [None]:
def show_samples(samples):
    import matplotlib.pyplot as plt
    N = len(samples)
    plt.figure(figsize=[18, 4])
    for i in range(len(samples)):
        plt.subplot(1, N, i+1)
        plt.imshow(samples[i])
        plt.xticks([])
        plt.yticks([])
    plt.show()
    
show_samples(samples)