# CH05.1. **Variational Auto Encoder(VAE)**

## 00. **작업 환경 설정**

#### 00.0. **사전 변수 설정**

In [1]:
SEED_NUM = 2025
BATCH_SIZE = 512
EPOCH_NUM = 500
USE_CHECKPOINT_YN = 'Y'
MODEL_PTH = '../../model/AutoEncoder.pt'
LOGGER_PTH = '../../log/AutoEncoderTrainLogger.pt'

#### 00.1. **라이브러리 호출 및 옵션 설정**

In [2]:
#(1) Import libraries
import os
import ipywidgets
import random
import tqdm
import collections
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import torch
import torchvision
import torchinfo

#(2) Set up options
os.environ['PYTHONHASHSEED'] = str(SEED_NUM)
random.seed(a=SEED_NUM)
np.random.seed(seed=SEED_NUM)
torch.use_deterministic_algorithms(mode=True)
torch.manual_seed(seed=SEED_NUM)
torch.mps.manual_seed(seed=SEED_NUM)

#(3) Set up device
if torch.backends.mps.is_available() :
    device = torch.device(device='mps')
else :
    device = torch.device(device='cpu')
print(f'>> Device : {device}')

#(4) Set up HTML tag
# display(ipywidgets.HTML(data=
# '''
# <style> 
#     .white-play button {
#         background-color: white !important; 
#         color: black !important;
#     } 
# </style>
# '''
# ))

>> Device : mps


#### 00.2. **사용자정의함수 정의**

In [3]:
def show_img(df:torchvision.datasets, index:int) -> plt.figure :
    img = df[index][0]
    target = df[index][1]
    img = (img/2+0.5).numpy() # -1 ~ 1 normalization 
    channel_cnt = img.shape[0]
    if channel_cnt == 3 :
        img = np.transpose(a=img, axes=(1, 2, 0))
        plt.imshow(X=img) 
    elif channel_cnt == 1 : 
        img = np.squeeze(a=img, axis=0)
        plt.imshow(X=img, cmap='gray')
    else : 
        pass 
    plt.xlabel(xlabel=f'Target : {target}({df.classes[target]})')
    plt.show()

#### 00.3. **클래스 정의**

In [None]:
#(1) Define `Flatten` class
class Flatten(torch.nn.Module) :
    def forward(self, x:torch.Tensor) -> torch.Tensor :
        '''
            (batch size, channel size , (image) height, image width) -> (batch size, channel size * image width * image height) 
                                                                     -> (batch size, channel size * (image size)**2) 
        '''
        batch_size = x.shape[0]
        x = x.reshape(shape=(batch_size, -1))
        return x 

#(2) Define `Unflatten` class
class Unflatten(torch.nn.Module) :
    def __init__(self, channel_num:int) :
        super().__init__()
        self.channel_num = channel_num
    def forward(self, x:torch.Tensor) -> torch.Tensor :
        ''' 
            (batch size, channel size * (image) height * width) -> (batch size, channel size, height, width) 
        '''
        shape = x.shape
        img_size = int((shape[1]//self.channel_num)**0.5) 
        x = x.reshape(shape=(shape[0], self.channel_num, img_size, img_size))
        return  x 
    
#(3) Define `MyConvAutoEncoder` class
class MyConvAutoEncoder(torch.nn.Module) :
    def __init__(self, input_shape:list, channel_num:int, class_num:int, device:torch.device) :
        super().__init__()
        self.input_shape = input_shape
        self.channel_num = channel_num
        self.device = device
        flattened_size = self._compute_flatten_size()
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=channel_num, kernel_size=3, stride=2), 
            torch.nn.ReLU(), 
            torch.nn.Conv2d(in_channels=channel_num, out_channels=2*channel_num, kernel_size=3, stride=2),
            torch.nn.ReLU(), 
            torch.nn.Conv2d(in_channels=2*channel_num, out_channels=4*channel_num, kernel_size=3, stride=1),
            torch.nn.ReLU(),
            Flatten(),
            torch.nn.Linear(in_features=flattened_size, out_features=class_num), 
            torch.nn.ReLU()
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(10, 1024),
            torch.nn.ReLU(),
            Unflatten(channel_num=4*channel_num),
            torch.nn.ConvTranspose2d(in_channels=4*channel_num, out_channels=2*channel_num, kernel_size=3, stride=1),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(in_channels=2*channel_num, out_channels=channel_num, kernel_size=3, stride=2),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(in_channels=channel_num, out_channels=1, kernel_size=3, stride=2, output_padding=1)
        )
        self.to(device=device)
    def _compute_flatten_size(self) -> int :
        with torch.no_grad() :
            dummy_data = torch.zeros(size=(1, 1, self.input_shape[1], self.input_shape[2])).to(device=self.device)
            dummy_fn = torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=1, out_channels=self.channel_num, kernel_size=3, stride=2), 
                torch.nn.ReLU(), 
                torch.nn.Conv2d(in_channels=self.channel_num, out_channels=2*self.channel_num, kernel_size=3, stride=2),
                torch.nn.ReLU(), 
                torch.nn.Conv2d(in_channels=2*self.channel_num, out_channels=4*self.channel_num, kernel_size=3, stride=1),
                torch.nn.ReLU()
            ).to(device=device)
            output = dummy_fn(dummy_data)
        output = output.reshape(shape=(1, -1)).shape[1]
        return output
    def forward(self, x:torch.Tensor) -> torch.Tensor :
        x = x.to(device=device)
        x = self.encoder(x)
        x = self.decoder(x)
        return x

#(3) Define `TrainLogger` class
class TrainLogger : 
    def __init__(self) :
        self.train_log = {
            'epoch'  : [],
            'inputs' : [],
            'preds'  : []
        }
    def log(self, epoch:int, inputs:torch.Tensor, preds:torch.Tensor, path:str) :
        self.train_log['epoch'].append(epoch)
        self.train_log['inputs'].append(inputs)
        self.train_log['preds'].append(preds)
        torch.save(obj={'train_log':self.train_log}, f=path)
    def move_device(self, device:str) :
        for i in range(len(self.train_log['inputs'])) :
            if (device == 'cpu') :
                self.train_log['inputs'][i] = self.train_log['inputs'][i].detach().cpu().numpy()
                self.train_log['preds'][i] = self.train_log['preds'][i].detach().cpu().numpy()
            else :
                self.train_log['inputs'][i] = self.train_log['inputs'][i].to(device=device)
                self.train_log['preds'][i] = self.train_log['preds'][i].to(device=device)

#(4) Define `Visualizer` class
class Visualizer :    
    def __init__(self, train_log:dict, fig_size:tuple=(8, 8)) :
        self.train_log = train_log
        self.fig_size = fig_size
        self.epoch_min = min(self.train_log['epoch'])
        self.epoch_max = max(self.train_log['epoch'])
        self.epoch_num = len(train_log['epoch']) - 1
        self.sample_num = train_log['inputs'][0].shape[0] - 1
        self.widget_output = ipywidgets.Output(
            layout=ipywidgets.Layout(
                width='auto', 
                height='auto', 
                margin='0px', 
                padding='0px'
            )
        )
        self.epoch_play = ipywidgets.Play(
            min=self.epoch_min,
            max=self.epoch_max,
            step=1,
            value=0,
            interval=250,
            description='Epoch Play',
            disabled=False
        )
        self.epoch_play.add_class(className='white-play')
        self.epoch_slider = ipywidgets.IntSlider(
            min=self.epoch_min,
            max=self.epoch_max,
            step=1,
            value=0,
            description='Epoch',
            continuous_update=True,
            orientation='horizontal',
            readout=True,
            readout_format='d'
        )
        self.sample_slider = ipywidgets.IntSlider(
            min=0,
            max=self.sample_num,
            step=1,
            value=0,
            description='Sample',
            continuous_update=True,
            orientation='horizontal',
            readout=True,
            readout_format='d'
        )
        ipywidgets.jslink(attr1=(self.epoch_play, 'value'), attr2=(self.epoch_slider, 'value'))
        self.epoch_slider.observe(handler=self.on_epoch_change, names='value')
        self.sample_slider.observe(handler=self.on_sample_change, names='value')
        with self.widget_output:
            self.fig, (self.ax1, self.ax2) = plt.subplots(nrows=1, ncols=2, figsize=self.fig_size)
            try:
                self.fig.canvas.header_visible = False
                self.fig.canvas.toolbar_visible = False
            except:
                pass
            plt.show()
        self.update_view()
    def on_epoch_change(self, change:dict) :
        self.update_view()
    def on_sample_change(self, change:dict) :
        self.update_view()
    def update_view(self) :
        with self.widget_output:
            self.ax1.clear()
            self.ax2.clear()
            ep_value = self.epoch_slider.value
            ep_idx = self.train_log['epoch'].index(ep_value)
            sp_idx = self.sample_slider.value
            input_img = self.train_log['inputs'][ep_idx][sp_idx].squeeze()
            pred_img = self.train_log['preds'][ep_idx][sp_idx].squeeze()
            self.ax1.imshow(X=input_img, cmap='gray')
            self.ax1.set_title(label='Target', fontdict={'fontsize': 12})
            self.ax1.set_aspect(aspect='auto')
            self.ax1.axis('off')
            self.ax2.imshow(X=pred_img, cmap='gray')
            self.ax2.set_title(label='Prediction', fontdict={'fontsize': 12})
            self.ax2.set_aspect(aspect='auto')
            self.ax2.axis('off')
            self.fig.canvas.draw_idle()
    def plot_compare(self) -> ipywidgets.widgets :
        controls_box = ipywidgets.VBox(
            children=[
                self.epoch_slider,
                self.sample_slider,
                self.epoch_play
            ],
            layout=ipywidgets.Layout(
                align_items='center',
                margin='0% 0% 15% -5%'
            )
        )
        ui = ipywidgets.HBox(
            children=[self.widget_output, controls_box],
            layout=ipywidgets.Layout(
                justify_content='flex-start',
                align_items='center',
                width='auto',
                margin='0px',
                padding='0px'
            )
        )
        display(ui)

<b></b>