In [1]:
import os
os.chdir('../')

import warnings
warnings.filterwarnings('ignore')

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="3"

Thu Jan 18 02:32:27 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.223.02   Driver Version: 470.223.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  Off  | 00000000:1B:00.0 Off |                    0 |
| N/A   36C    P0    64W / 300W |   8351MiB / 80994MiB |      7%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  Off  | 00000000:1C:00.0 Off |                    0 |
| N/A   57C    P0   201W / 300W |  13598MiB / 80994MiB |     57%      Default |
|       

### Model Parts

In [4]:
from model.main.enc_prior_latent_dec import Model
from model.encoder.conv2d_encoder import Encoder
from model.prior.normal_prior import Prior
from model.latent.bmin_latent import Latent
from model.decoder.conv2d_decoder import Decoder

from tensorboardX import SummaryWriter
from util.util import *

### Model Init.

In [5]:
from easydict import EasyDict
hp = EasyDict()
hp.size = 64
hp.in_dim = 3
hp.out_dim = 3
hp.z_dim = 128
hp.h_dims = [32, 64, 128, 256, 512]
hp.M = 1024
hp.N = 256
#hp.z_activation = F.tanh
hp.activation = F.sigmoid

In [7]:
step = 0
device = 'cuda:0'

model_list = []
optimizer_list = []
for i in range(1):
    model = Model(Encoder(**hp), Prior(**hp), Latent(**hp), Decoder(**hp))
    model = model.to(device)
    model_list.append(model)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    optimizer_list.append(optimizer)

for name, module in model.named_modules():
    if len(name.split('.')) <= 2 and len(name) > 0:
        print(name, get_size(module))

print('done')

encoder 6.99169921875
encoder.convs 5.9912109375
encoder.linear 1.00048828125
prior 0.0
latent 0.0
decoder 7.028697967529297
decoder.linear 1.0078125
decoder.convs 5.9820556640625
decoder.out_conv 0.038829803466796875
done


### Load

In [8]:
save_dir = '/data/scpark/save/lse/train_celeba/train01.17-12/'

!mkdir -p $save_dir
!ls -lt $save_dir

writer = SummaryWriter(save_dir)

if False:
    step, model_list, optimizer_list = load_model_list(save_dir, 1505, model_list, optimizer_list)

total 0


### Dataset

In [9]:
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CelebA

class MyCelebA(CelebA):
    """
    A work-around to address issues with pytorch's celebA dataset class.
    
    Download and Extract
    URL : https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing
    """
    
    def _check_integrity(self) -> bool:
        return True

root = '/data'
train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
                                       transforms.CenterCrop(148),
                                       transforms.Resize(hp.size),
                                       transforms.ToTensor(),])
train_dataset = MyCelebA(root, split='train', transform=train_transforms, download=False)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

test_transforms = transforms.Compose([transforms.CenterCrop(148),
                                      transforms.Resize(hp.size),
                                      transforms.ToTensor(),])
test_dataset = MyCelebA(root, split='test', transform=test_transforms, download=False)
test_loader = DataLoader(test_dataset, batch_size=2048, shuffle=False)
print('done')

done


### Preprocess

In [10]:
def preprocess(batch):
    x, t = batch
    data = {}
    data['x'] = x.to(device)
    data['t'] = t.to(device)
    return data

### Train

In [11]:
def plot(x):
    _x = x.permute(0, 2, 3, 1).data.cpu().numpy()
    plt.figure(figsize=[18, 4])
    for i in range(10):
        plt.subplot(1, 10, i+1)
        plt.imshow(_x[i])
        plt.xticks([])
        plt.yticks([])
    plt.show()

In [None]:
from IPython import display

while True:
    for batch in train_loader:
        print(step)
        
        loss_dict = {}
        for model, optimizer in zip(model_list, optimizer_list):
            data = preprocess(batch)

            # Forward
            model.train()
            model.zero_grad()
            data = model(data, M=1024)

            # Backward
            loss = 0
            for key in data.keys():
                if 'lse_loss' in key:
                    loss = loss + data[key] * 0.1
                    if key in loss_dict:
                        loss_dict[key].append(data[key].item())
                    else:
                        loss_dict[key] = [data[key].item()]
                elif 'loss' in key:
                    loss = loss + data[key]
                    if key in loss_dict:
                        loss_dict[key].append(data[key].item())
                    else:
                        loss_dict[key] = [data[key].item()]
                    
            loss.backward()
            optimizer.step()
        
        for key in loss_dict:
            writer.add_scalar(key, np.mean(loss_dict[key]), step)
            print(key, np.mean(loss_dict[key]))
        
        if step % 1000 == 0:
            import matplotlib.pyplot as plt
            display.clear_output()
            
            z = data['z_copy'].data.cpu().numpy()
            plt.figure(figsize=[18, 15])
            for i in range(64):
                plt.subplot(8, 8, i+1)
                plt.scatter(z[:, 2*i], z[:, 2*i+1])
            plt.show()
        
            batch = next(iter(test_loader))
            data = preprocess(batch)
            
            model = model_list[0]
            model.eval()
            with torch.no_grad():
                data = model(data, M=1024)
            
            plot(data['x'])
            plot(data['y'])
            
            with torch.no_grad():
                z = model.prior.sample(10, hp.z_dim).to(device)
                y = model.sample(z)
                plot(y)
                
        if step % 10000 == 0:
            save_model_list(save_dir, step, model_list, optimizer_list)
                
        step += 1
        

113089
lse_loss 0.00024950093938969076
recon_loss 0.010060719214379787
113090
lse_loss 0.00027185940416529775
recon_loss 0.009587593376636505
113091
lse_loss 0.0002602505264803767
recon_loss 0.009427541866898537
113092
lse_loss 0.00026115309447050095
recon_loss 0.00987616740167141
113093
lse_loss 0.000252348167123273
recon_loss 0.009494731202721596
113094
lse_loss 0.00025119082420133054
recon_loss 0.009296436794102192
113095
lse_loss 0.000257954525295645
recon_loss 0.011202624067664146
113096
lse_loss 0.0002452055923640728
recon_loss 0.010217422619462013
113097
lse_loss 0.00025760853895917535
recon_loss 0.009042668156325817
113098
lse_loss 0.00026778984465636313
recon_loss 0.009636569768190384
113099
lse_loss 0.00025651627220213413
recon_loss 0.010485238395631313
113100
lse_loss 0.0002583444584161043
recon_loss 0.008787731640040874
113101
lse_loss 0.0002585042384453118
recon_loss 0.00931206438690424
113102
lse_loss 0.00025318667758256197
recon_loss 0.009977494366466999
113103
lse_loss 

113206
lse_loss 0.00025561387883499265
recon_loss 0.009794743731617928
113207
lse_loss 0.00024327629944309592
recon_loss 0.010961906984448433
113208
lse_loss 0.0002590271760709584
recon_loss 0.009724203497171402
113209
lse_loss 0.0002502911083865911
recon_loss 0.009902078658342361
113210
lse_loss 0.00025542627554386854
recon_loss 0.00998697616159916
113211
lse_loss 0.00026681608869694173
recon_loss 0.01137529592961073
113212
lse_loss 0.0002544757444411516
recon_loss 0.010058268904685974
113213
lse_loss 0.00025337975239381194
recon_loss 0.008647429756820202
113214
lse_loss 0.0002573329256847501
recon_loss 0.00938035175204277
113215
lse_loss 0.00025354616809636354
recon_loss 0.010048635303974152
113216
lse_loss 0.00026577431708574295
recon_loss 0.009658120572566986
113217
lse_loss 0.0002657126751728356
recon_loss 0.009779475629329681
113218
lse_loss 0.0002612458192743361
recon_loss 0.01102085318416357
113219
lse_loss 0.0002504382864572108
recon_loss 0.01030926313251257
113220
lse_loss 0.

113323
lse_loss 0.000259547057794407
recon_loss 0.011298294179141521
113324
lse_loss 0.00024244819360319525
recon_loss 0.009767304174602032
113325
lse_loss 0.0002617475111037493
recon_loss 0.01039495225995779
113326
lse_loss 0.0002489265170879662
recon_loss 0.009001867845654488
113327
lse_loss 0.0002595152473077178
recon_loss 0.009248926304280758
113328
lse_loss 0.0002523949951864779
recon_loss 0.009741079993546009
113329
lse_loss 0.00024724911781959236
recon_loss 0.01039237529039383
113330
lse_loss 0.0002569050993770361
recon_loss 0.008747952058911324
113331
lse_loss 0.000253664911724627
recon_loss 0.008501436561346054
113332
lse_loss 0.00024167938681785017
recon_loss 0.010250138118863106
113333
lse_loss 0.00024664137163199484
recon_loss 0.009644731879234314
113334
lse_loss 0.0002476838417351246
recon_loss 0.010289745405316353
113335
lse_loss 0.00024530920200049877
recon_loss 0.009272970259189606
113336
lse_loss 0.0002439187082927674
recon_loss 0.009569372981786728
113337
lse_loss 0.0

In [14]:
save_model_list(save_dir, step, model_list, optimizer_list)
print('done')

done
