In [1]:
import os
os.chdir('../')

import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="1"

Fri Jan 19 18:59:58 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.223.02   Driver Version: 470.223.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  Off  | 00000000:1B:00.0 Off |                    0 |
| N/A   45C    P0    60W / 300W |     35MiB / 80994MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  Off  | 00000000:1C:00.0 Off |                    0 |
| N/A   49C    P0    48W / 300W |     35MiB / 80994MiB |      0%      Default |
|       

### Model Parts

In [3]:
from model.main.enc_prior_latent_quantizer_dec import Model
from model.encoder.dalle_encoder import Encoder
from model.prior.multi_randn_prior import Prior
from model.latent.dalle_multi_lse_latent import Latent
from model.quantizer.dalle_nearest_multi_quantizer import Quantizer
from model.decoder.dalle_decoder import Decoder

from tensorboardX import SummaryWriter
from util.util import *

### Model Init.

In [4]:
from easydict import EasyDict
hp = EasyDict()
hp.img_size = 256
hp.n_blk_per_group = 2
hp.n_hid = 128
hp.n_latents = 13
hp.dim_per_latent = 16
hp.z_dim = hp.n_latents * hp.dim_per_latent
hp.n_prior_embeddings = 2
hp.init_log_sigma = -1
hp.const_sigma = True

In [5]:
step = 0
device = 'cuda:0'

model_list = []
optimizer_list = []
for i in range(1):
    model = Model(Encoder(**hp), Prior(**hp), Latent(**hp), Quantizer(**hp), Decoder(**hp))
    model = model.to(device)
    model_list.append(model)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    optimizer_list.append(optimizer)

for name, module in model.named_modules():
    if len(name.split('.')) <= 2 and len(name) > 0:
        print(name, get_size(module))

print('done')

encoder 36.15240478515625
encoder.encoder 36.15240478515625
prior 0.0015869140625
latent 3.814697265625e-06
quantizer 0.0
decoder 41.2302360534668
decoder.decoder 41.2302360534668
done


### Load

In [6]:
save_dir = '/data/scpark/save/lse/train_dalle/train01.19-2/'

!mkdir -p $save_dir
!ls -lt $save_dir

writer = SummaryWriter(save_dir)

if False:
    step, model_list, optimizer_list = load_model_list(save_dir, 2195, model_list, optimizer_list)

total 239092
-rw-rw-r-- 1 scpark scpark    329444  1월 19 18:59 events.out.tfevents.1705658112.GPUSVR11
-rw-rw-r-- 1 scpark scpark 243615153  1월 19 18:55 save_0
-rw-rw-r-- 1 scpark scpark        40  1월 19 18:54 events.out.tfevents.1705658034.GPUSVR11
-rw-rw-r-- 1 scpark scpark        40  1월 19 18:53 events.out.tfevents.1705657979.GPUSVR11
-rw-rw-r-- 1 scpark scpark     69764  1월 19 18:52 events.out.tfevents.1705657898.GPUSVR11
-rw-rw-r-- 1 scpark scpark    517316  1월 19 18:51 events.out.tfevents.1705657434.GPUSVR11
-rw-rw-r-- 1 scpark scpark    222884  1월 19 18:43 events.out.tfevents.1705657076.GPUSVR11
-rw-rw-r-- 1 scpark scpark     16351  1월 19 18:37 events.out.tfevents.1705657007.GPUSVR11
-rw-rw-r-- 1 scpark scpark        40  1월 19 18:36 events.out.tfevents.1705656986.GPUSVR11
-rw-rw-r-- 1 scpark scpark        40  1월 19 18:36 events.out.tfevents.1705656751.GPUSVR11
-rw-rw-r-- 1 scpark scpark        40  1월 19 18:32 events.out.tfevents.1705656674.GPUSVR11
-rw-rw-r-- 1 scpar

### Dataset

In [7]:
from torch.utils.data import DataLoader, Dataset
from data.imagenet_dataset import ImagenetDataset

train_data_dir = '/data/imagenet/val'
train_dataset = ImagenetDataset(root_dir=train_data_dir, img_size=hp.img_size)
train_loader = DataLoader(train_dataset, batch_size=16,
                              shuffle=True, num_workers=1)
print('done')

done


In [8]:
def preprocess(batch):
    x, t = batch
    data = {}
    data['x'] = x.to(device)
    data['t'] = t.to(device)
    return data

### Train

In [9]:
def plot(x):
    _x = x.permute(0, 2, 3, 1).data.cpu().numpy()
    plt.figure(figsize=[18, 4])
    for i in range(10):
        plt.subplot(1, 10, i+1)
        plt.imshow(_x[i])
        plt.xticks([])
        plt.yticks([])
    plt.show()

In [10]:
from IPython import display

while True:
    for batch in train_loader:
        print(step)
        
        loss_dict = {}
        for model, optimizer in zip(model_list, optimizer_list):
            data = preprocess(batch)

            # Forward
            model.train()
            model.zero_grad()
            data = model(data)

            # Backward
            loss = 0
            for key in data.keys():
                if 'lse_loss' in key:
                    loss = loss + data[key]
                    if key in loss_dict:
                        loss_dict[key].append(data[key].item())
                    else:
                        loss_dict[key] = [data[key].item()]
                elif 'loss' in key:
                    loss = loss + data[key]
                    if key in loss_dict:
                        loss_dict[key].append(data[key].item())
                    else:
                        loss_dict[key] = [data[key].item()]
                    
            loss.backward()
            optimizer.step()
        
        for key in loss_dict:
            writer.add_scalar(key, np.mean(loss_dict[key]), step)
            print(key, np.mean(loss_dict[key]))
        
        if step % 1000 == 0:
            import matplotlib.pyplot as plt
            display.clear_output()
            
            x = data['x']
            y = data['y']
            
            plot(x)
            plot(y)
                        
            min_indices = data['min_indices'].reshape(hp.n_latents, -1)
            ratios = []
            for i in range(hp.n_latents):
                ratio = sum(min_indices[i]) / len(min_indices[i])
                ratios.append(ratio.item())

            plt.figure(figsize=[18, 4])
            plt.bar([i for i in range(len(ratios))], ratios)
            plt.grid()
            plt.show()
            
        if step % 10000 == 0:
            save_model_list(save_dir, step, model_list, optimizer_list)
                
        step += 1
        

0


KeyboardInterrupt: 

In [None]:
save_model_list(save_dir, step, model_list, optimizer_list)
print('done')