In [1]:
import os
os.chdir('../')

import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="1"

Mon Jan 22 11:02:39 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.223.02   Driver Version: 470.223.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  Off  | 00000000:1B:00.0 Off |                    0 |
| N/A   27C    P0    41W / 300W |     35MiB / 80994MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  Off  | 00000000:1C:00.0 Off |                    0 |
| N/A   49C    P0    71W / 300W |     35MiB / 80994MiB |      0%      Default |
|       

### Model Parts

In [3]:
from model.main.enc_prior_latent_quantizer_dec import Model
from model.encoder.net_64_encoder import Encoder
from model.prior.dalle_randn_prior import Prior
from model.latent.dalle_lse_latent import Latent
from model.quantizer.dalle_nearest_quantizer import Quantizer
from model.decoder.net_64_decoder import Decoder

from tensorboardX import SummaryWriter
from util.util import *

### Model Init.

In [4]:
from easydict import EasyDict
hp = EasyDict()
hp.img_size = 64
hp.n_resblocks = 6
hp.z_dim = 64
hp.n_prior_embeddings = 512
hp.init_log_sigma = -3
hp.const_sigma = True
hp.quantize = False
hp.prior_mu = 0.999

In [5]:
step = 0
device = 'cuda:0'

model_list = []
optimizer_list = []
for i in range(1):
    model = Model(Encoder(**hp), Prior(**hp), Latent(**hp), Quantizer(**hp), Decoder(**hp))
    model = model.to(device)
    model_list.append(model)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    optimizer_list.append(optimizer)

for name, module in model.named_modules():
    if len(name.split('.')) <= 2 and len(name) > 0:
        print(name, get_size(module))

print('done')

encoder 1.2191162109375
encoder.encoder 1.2191162109375
prior 0.126953125
latent 3.814697265625e-06
quantizer 0.0
decoder 1.0611686706542969
decoder.decoder 1.0611686706542969
done


### Load

In [6]:
save_dir = '/data/scpark/save/lse/train_dalle/train01.19-7_0.999_1e-3/'

!mkdir -p $save_dir
!ls -lt $save_dir

writer = SummaryWriter(save_dir)

if True:
    step, model_list, optimizer_list = load_model_list(save_dir, 88126, model_list, optimizer_list)

total 85816
-rw-rw-r-- 1 scpark scpark 13081272  1월 22 11:02 events.out.tfevents.1705863177.GPUSVR11
-rw-rw-r-- 1 scpark scpark       40  1월 22 11:02 events.out.tfevents.1705888948.GPUSVR11
-rw-rw-r-- 1 scpark scpark  7477768  1월 22 11:02 save_88126
-rw-rw-r-- 1 scpark scpark  7477768  1월 22 10:23 save_80000
-rw-rw-r-- 1 scpark scpark  7477768  1월 22 09:34 save_70000
-rw-rw-r-- 1 scpark scpark  7477768  1월 22 08:45 save_60000
-rw-rw-r-- 1 scpark scpark  7477768  1월 22 07:56 save_50000
-rw-rw-r-- 1 scpark scpark  7477768  1월 22 07:08 save_40000
-rw-rw-r-- 1 scpark scpark  7477768  1월 22 06:19 save_30000
-rw-rw-r-- 1 scpark scpark  7477768  1월 22 05:30 save_20000
-rw-rw-r-- 1 scpark scpark  7477768  1월 22 04:41 save_10000
-rw-rw-r-- 1 scpark scpark  7470716  1월 22 03:53 save_0


### Dataset

In [7]:
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CelebA

class MyCelebA(CelebA):
    """
    A work-around to address issues with pytorch's celebA dataset class.
    
    Download and Extract
    URL : https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing
    """
    
    def _check_integrity(self) -> bool:
        return True

root = '/data'
train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
                                       transforms.CenterCrop(148),
                                       transforms.Resize(hp.img_size),
                                       transforms.ToTensor(),])
train_dataset = MyCelebA(root, split='train', transform=train_transforms, download=False)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
print('done')

done


In [8]:
def preprocess(batch):
    x, t = batch
    data = {}
    data['x'] = x.to(device)
    data['t'] = t.to(device)
    return data

### Train

In [9]:
def plot(x):
    _x = x.permute(0, 2, 3, 1).data.cpu().numpy()
    plt.figure(figsize=[18, 4])
    for i in range(8):
        plt.subplot(1, 8, i+1)
        plt.imshow(_x[i])
        plt.xticks([])
        plt.yticks([])
    plt.show()

In [10]:
def linear(start_value, end_value, current_index, start_index, end_index):
    if current_index > end_index:
        return end_value
    if current_index < start_index:
        return start_value

    grad = (end_value - start_value) / (end_index - start_index)
    y = start_value + grad * (current_index - start_index)

    return y

In [None]:
from IPython import display

while True:
    for batch in train_loader:
        print(step)
        
        loss_dict = {}
        for model, optimizer in zip(model_list, optimizer_list):
            data = preprocess(batch)

            # Forward
            model.train()
            model.zero_grad()
            data = model(data, latent_temp=1)

            # Backward
            loss = 0
            for key in data.keys():
                if 'lse_loss' in key:
                    loss = loss + data[key] * 1e-5
                    if key in loss_dict:
                        loss_dict[key].append(data[key].item())
                    else:
                        loss_dict[key] = [data[key].item()]
                elif 'commit_loss' in key:
                    loss = loss + data[key] * 1e-1
                    if key in loss_dict:
                        loss_dict[key].append(data[key].item())
                    else:
                        loss_dict[key] = [data[key].item()]        
                elif 'loss' in key:
                    loss = loss + data[key]
                    if key in loss_dict:
                        loss_dict[key].append(data[key].item())
                    else:
                        loss_dict[key] = [data[key].item()]
                    
            loss.backward()
            optimizer.step()
            
        for key in loss_dict:
            writer.add_scalar(key, np.mean(loss_dict[key]), step)
            print(key, np.mean(loss_dict[key]))
        
        if step % 1000 == 0:
            import matplotlib.pyplot as plt
            display.clear_output()
            
            model.eval()
            with torch.no_grad():
                data = model(data, latent_temp=1, quantize=True)
                
            writer.add_scalar('eval_recon_loss', data['recon_loss'].item(), step)    

            x = data['x']
            y = data['y']
            
            plot(x)
            plot(y)
            
            from sklearn.decomposition import PCA
            pca = PCA(n_components=2)
            e = model.prior.prior.data.cpu().numpy()
            pca.fit(e)
            e_pca = pca.transform(e)
            z_pca = pca.transform(data['z'].permute(0, 2, 3, 1).reshape(-1, hp.z_dim).data.cpu().numpy())
            plt.figure(figsize=[10, 10])
            plt.scatter(e_pca[:, 0], e_pca[:, 1], marker='x', alpha=1.0, color='black')
            plt.scatter(z_pca[:, 0], z_pca[:, 1], marker='o', alpha=0.01, color='blue')
            plt.grid()
            plt.show() 
            
        if step % 10000 == 0:
            save_model_list(save_dir, step, model_list, optimizer_list)
                
        step += 1
        

88126
lse_loss -213.44845581054688
commit_loss 0.016221964731812477
recon_loss 0.0012915924889966846
88127
lse_loss -213.30722045898438
commit_loss 0.016213838011026382
recon_loss 0.0013434942811727524
88128
lse_loss -212.75125122070312
commit_loss 0.01701168157160282
recon_loss 0.0013340013101696968
88129
lse_loss -211.82382202148438
commit_loss 0.015155800618231297
recon_loss 0.001267680199816823
88130
lse_loss -211.04074096679688
commit_loss 0.015470776706933975
recon_loss 0.001349212951026857
88131
lse_loss -210.17965698242188
commit_loss 0.015260379761457443
recon_loss 0.001265909755602479
88132
lse_loss -209.31072998046875
commit_loss 0.015889441594481468
recon_loss 0.0013327670749276876
88133
lse_loss -208.41049194335938
commit_loss 0.015515090897679329
recon_loss 0.0013296551769599319
88134
lse_loss -207.74813842773438
commit_loss 0.018875829875469208
recon_loss 0.0013072984293103218
88135
lse_loss -207.46151733398438
commit_loss 0.01860397681593895
recon_loss 0.001289308536797

88208
lse_loss -213.21719360351562
commit_loss 0.012899987399578094
recon_loss 0.0011487010633572936
88209
lse_loss -213.17291259765625
commit_loss 0.012993435375392437
recon_loss 0.0012655354803428054
88210
lse_loss -212.09136962890625
commit_loss 0.01400159765034914
recon_loss 0.0011811068980023265
88211
lse_loss -212.80865478515625
commit_loss 0.013401884585618973
recon_loss 0.0012145726941525936
88212
lse_loss -212.29547119140625
commit_loss 0.014046686701476574
recon_loss 0.001222753431648016
88213
lse_loss -213.04095458984375
commit_loss 0.01250036247074604
recon_loss 0.0012891639489680529
88214
lse_loss -212.55941772460938
commit_loss 0.013353190384805202
recon_loss 0.0012902416056022048
88215
lse_loss -212.89620971679688
commit_loss 0.013566195033490658
recon_loss 0.0012385081499814987
88216
lse_loss -212.72760009765625
commit_loss 0.013011109083890915
recon_loss 0.0012284382246434689
88217
lse_loss -213.1553955078125
commit_loss 0.013622385449707508
recon_loss 0.00120116770267

In [None]:
save_model_list(save_dir, step, model_list, optimizer_list)
print('done')