# CH03.3. **Transpose CNN**

## 00. **작업 환경 설정**

#### 00.0. **사전 변수 설정**

In [None]:
SEED_NUM = 2025
BATCH_SIZE = 512
EPOCH_NUM = 500
USE_CHECKPOINT_YN = 'Y'
MODEL_PTH = '../../model/TransposeCNN.pt'

#### 00.1. **라이브러리 호출 및 옵션 설정**

In [None]:
#(1) Import libraries
import os
import ipywidgets
import random
import tqdm
import collections
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import torch
import torchvision
import torchinfo

#(2) Set up options
os.environ['PYTHONHASHSEED'] = str(SEED_NUM)
random.seed(a=SEED_NUM)
np.random.seed(seed=SEED_NUM)
torch.use_deterministic_algorithms(mode=True)
torch.manual_seed(seed=SEED_NUM)
torch.mps.manual_seed(seed=SEED_NUM)

#(3) Set up device
if torch.backends.mps.is_available() :
    device = torch.device(device='mps')
else :
    device = torch.device(device='cpu')
print(f'>> Device : {device}')

#(4) Set up HTML tag
# display(ipywidgets.HTML(data=
# '''
# <style> 
#     .white-play button {
#         background-color: white !important; 
#         color: black !important;
#     } 
# </style>
# '''
# ))

#### 00.2. **사용자정의함수 정의**

In [None]:
def show_img(df:torchvision.datasets, index:int) -> plt.figure :
    img = df[index][0]
    target = df[index][1]
    img = (img/2+0.5).numpy() # -1 ~ 1 normalization 
    channel_cnt = img.shape[0]
    if channel_cnt == 3 :
        img = np.transpose(a=img, axes=(1, 2, 0))
        plt.imshow(X=img) 
    elif channel_cnt == 1 : 
        img = np.squeeze(a=img, axis=0)
        plt.imshow(X=img, cmap='gray')
    else : 
        pass 
    plt.xlabel(xlabel=f'Target : {target}({df.classes[target]})')
    plt.show()

#### 00.3. **클래스 정의**

In [None]:
#

<b></b>

## 01. **데이터셋 전처리 및 로드**

#### 01.1. **이미지 전처리 파이프라인 정의**

In [None]:
img_tf = torchvision.transforms.Compose(
    transforms=[
        torchvision.transforms.ToTensor()
    ]
)

#### 01.2. **데이터셋 로드**

In [None]:
mnist_train = torchvision.datasets.MNIST(root='../../data', train=True, download=True, transform=img_tf)
mnist_test = torchvision.datasets.MNIST(root='../../data', train=False, download=True, transform=img_tf)

#### 01.3. **EDA**

In [None]:
#(1) Print sample of train
len(mnist_train)

In [None]:
#(2) Display image
show_img(df=mnist_train, index=5)

In [None]:
#(3) Check `input_size`
input_shape = list(mnist_train[0][0].shape)

#(4) Print `input_size`
input_shape

In [None]:
#(5) Print frequency of target class
target_freq = collections.Counter()
for i in range(len(mnist_train)):
    input, target = mnist_train[i]
    if isinstance(target, torch.Tensor) :
        target = target.item()
    target_freq[target] += 1
pd.DataFrame(data=list(target_freq.items()), columns=['class', 'count']).sort_values(by='class')

#### 01.4. **데이터로더 변환**

In [None]:
mnist_train_loader = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=BATCH_SIZE, shuffle=False)
mnist_test_loader = torch.utils.data.DataLoader(dataset=mnist_test, batch_size=BATCH_SIZE, shuffle=True)

<b></b>

## 02. **모델 구축 및 학습**

#### 02.1. **하이퍼 파라미터 정의**

In [None]:
channel_num = 16

#### 02.1. **모델 정의**

In [None]:
#(1) Define `model`
model = MyConvAutoEncoder(input_shape=input_shape, channel_num=channel_num, class_num=10, device=device).to(dtype=torch.float32)

#(2) Display `model`
torchinfo.summary(
    model=model, 
    input_size=[BATCH_SIZE]+input_shape,
    device=device
)

In [None]:
#(4) Define loss function
criterion = torch.nn.MSELoss()

#(5) Define optimizer(optimization method)
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3, weight_decay=1e-7)

#(6) Define Scheduler
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=30, gamma=0.1)

#(7) Define logger
logger = TrainLogger()

#### 02.2. **모델 체크포인트 로드**

In [None]:
init_epoch = 0 
train_cost_hist = []
best_train_cost = float('inf')
if USE_CHECKPOINT_YN == 'Y' :
    try :
        checkpoint = torch.load(f=MODEL_PTH, map_location=device)
        model.load_state_dict(state_dict=checkpoint['model'])
        optimizer.load_state_dict(state_dict=checkpoint['optimizer'])
        init_epoch = checkpoint['best_epoch']
        train_cost_hist = checkpoint['train_cost_hist']
        best_train_cost = train_cost_hist[-1]
    except Exception :
        pass
print(f'>> Last Epoch={init_epoch}, Last Train Loss={best_train_cost}')

#### 02.3. **모델 학습**

In [None]:
best_epoch = init_epoch
batch_len = len(mnist_train_loader)
progress_bar = tqdm.trange(init_epoch+1, EPOCH_NUM+1)
for epoch in progress_bar : 
    train_cost = 0.0
    model.train()
    for inputs, targets in mnist_train_loader :
        optimizer.zero_grad() 
        preds = model(x=inputs)
        loss = criterion(input=preds, target=inputs.to(device=device))
        loss.backward()
        optimizer.step()
        train_cost += loss.item()
    train_cost = train_cost / batch_len
    train_cost_hist.append(train_cost)
    logger.log(
        epoch=epoch, 
        inputs=inputs,
        preds=preds,
        path=LOGGER_PTH
    )
    if train_cost < best_train_cost :
        best_epoch = epoch
        best_train_cost = train_cost
        torch.save(
            obj={
                'model'           : model.state_dict(),
                'optimizer'       : optimizer.state_dict(),
                'best_epoch'      : epoch,
                'train_cost_hist' : train_cost_hist,
            }, 
            f=MODEL_PTH
        )
    progress_bar.set_postfix(ordered_dict={
        'last_epoch'      : epoch, 
        'last_train_cost' : train_cost, 
        'best_epoch'      : best_epoch, 
        'best_train_cost' : best_train_cost
    }) 

<b></b>

## 03. **모델 평가**

#### 03.1. **최적 성능 모델 로드**

In [None]:
checkpoint = torch.load(f=MODEL_PTH, map_location=device)
model.load_state_dict(state_dict=checkpoint['model'])
print(f'>> Best Epoch : {np.argmin(a=checkpoint["train_cost_hist"])+1}, Best Train Loss : {np.min(a=checkpoint["train_cost_hist"])}')

#### 04.2. **과소 적합 확인**

In [None]:
plt.figure(figsize=(12, 6))
plt.xlabel(xlabel='epoch')
plt.ylabel(ylabel='loss')
plt.plot(train_cost_hist, label='Training Loss')
plt.axvline(x=np.argmin(a=checkpoint["train_cost_hist"]), color='grey', linestyle='--', linewidth=0.6, label=f'Best Training Loss')
plt.legend(loc='upper right')
plt.show()

#### 03.3. **(에포크 별) 학습 과정 확인**

In [None]:
#(0) Set up interactive mode
# %matplotlib notebook
%matplotlib widget  

#(1) Move device
logger.move_device(device='cpu')

In [None]:
#(2) Define `visualizer`
viz = Visualizer(train_log=logger.train_log, fig_size=(8, 4))

#(3) Set up interactive mode
viz.plot_compare()

#### 03.4. **일반화 성능 확인**

In [None]:
pass