In [1]:
import os
os.chdir('../')

import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Tue Feb 13 15:48:39 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.154.05             Driver Version: 535.154.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off | 00000000:19:00.0 Off |                  Off |
| 42%   62C    P0              92W / 450W |     11MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        Off | 00000000:68:0

### Hyper-Parameters

In [3]:
from easydict import EasyDict
from diffusers import DDPMScheduler

hp = EasyDict()

# Data
hp.dataset = 'ffhq_256'
hp.data_root = '/home/scpark/data'
hp.test_eval = True
hp.image_size = 256
hp.image_channels = 3
hp.n_batch = 8

# Model
hp.custom_width_str = ""
hp.bottleneck_multiple = 0.25
hp.no_bias_above = 64
hp.num_mixtures = 10
hp.width = 512
hp.zdim = 16
hp.dec_blocks = "1x2,4m1,4x3,8m4,8x4,16m8,16x9,32m16,32x21,64m32,64x13,128m64,128x7,256m128"
hp.enc_blocks = "256x3,256d2,128x8,128d2,64x12,64d2,32x17,32d2,16x7,16d2,8x5,8d2,4x5,4d4,1x4"

# Train
hp.lr = 1e-4

# Diffusion
hp.scheduler = DDPMScheduler()
hp.diff_middle_width = 128
hp.diff_residual = True


### Model

In [4]:
from model.main.vdvae_latent import Model as VAE
from model.encoder.vdvae_encoder import Encoder
from model.decoder.vdvae_decoder import Decoder
from model.loss.dmol import Loss

from model.main.latent_diffusion import Model
from model.latent_diffusion.default_latent_diffusion import LatentDiffusion

from tensorboardX import SummaryWriter
from util.util import *

In [5]:
step = 0
device = 'cuda:0'

vae = VAE(Encoder(hp), Decoder(hp), Loss(hp)).to(device)
model = Model(LatentDiffusion(hp)).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=hp.lr)

for name, module in model.named_modules():
    if len(name.split('.')) <= 1 and len(name) > 0:
        print(name, get_size(module))

print('done')

latent_diffusion 75.9298095703125
done


### Load

In [6]:
save_dir = '/data/save/lse/train_latent/train02.13-1/'

!mkdir -p $save_dir
!ls -lt $save_dir

writer = SummaryWriter(save_dir)

if False:
    step, model, optimizer = load(save_dir, 60000, model, optimizer)

total 0


In [7]:
checkpoint_path = '/data/checkpoint/ffhq256-iter-1700000-model-ema.th'
# Checkpoint 파일 로드
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

model_state_dict = vae.state_dict()
for key in checkpoint.keys():
    if key.startswith('encoder'):
        model_key = 'encoder.' + key
        if model_key in model_state_dict:
            model_state_dict[model_key] = checkpoint[key]
        else:
            print(model_key)
    if key.startswith('decoder'):
        if key.startswith('decoder.out_net'):
            model_key = 'loss.' + key[8:]
        else:
            model_key = 'decoder.' + key
            
        if model_key in model_state_dict:
            model_state_dict[model_key] = checkpoint[key]
        else:
            print(model_key)
            
vae.load_state_dict(model_state_dict)
print('done')

done


### Dataset

In [8]:
from torch.utils.data import DataLoader
from data.vdvae_data import set_up_data

hp, data_train, data_valid_or_test, preprocess_fn = set_up_data(hp)
train_loader = DataLoader(data_train, batch_size=hp.n_batch, drop_last=True, pin_memory=True)
print(train_loader)

DOING TEST
<torch.utils.data.dataloader.DataLoader object at 0x7f8d7bb90830>


### Train

In [9]:
def get_latent(data_input, vae):
    vae.eval()
    data = {'x': data_input}
    with torch.no_grad():
        stats = vae.get_latent(data, get_latents=True)
    return stats['stats']

def train_step(stats, model, optimizer):
    model.train()
    model.zero_grad()
    data = {'stats': stats}
    data = model(data)
    loss = 0
    for key in data:
        if 'loss' in key:
            loss = loss + data[key]
    loss.backward()
    optimizer.step()
    return loss

In [None]:
import matplotlib.pyplot as plt
from IPython import display
import torchvision.transforms as transforms

resize = transforms.Resize((hp.image_size, hp.image_size))

while True:
    for x in train_loader:
        # Get Latents from pretrained-VAE
        x[0] = resize(x[0].permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        data_input, target = preprocess_fn(x)
        data_input = data_input.to(device)
        target = target.to(device)
        stats = get_latent(data_input, vae)
        loss = train_step(stats, model, optimizer)

        if step % 10 == 0:
            print(step, 'loss', loss.item())
            writer.add_scalar('loss', loss.item(), step)
            
        if step % 1000 == 0:
            display.clear_output()
            
        if step % 10000 == 0:
            save(save_dir, step, model, optimizer)
                
        step += 1

10 loss 0.36587944626808167


In [None]:
save(save_dir, step, model, optimizer)
print('done')