# CH05.1. **전이 학습(Transfer Learning)**

## 00. **작업 환경 설정**

#### 00.0. **사전 변수 설정**

In [11]:
SEED_NUM = 2025
BATCH_SIZE = 32
EPOCH_NUM = 25
USE_PRETRAIN_YN = 'N'
MODEL_PTH = '../../model/cifaResNet.pt'

#### 00.1. **라이브러리 호출 및 옵션 설정**

In [12]:
#(1) Import libraries
import os
import random
import tqdm
import collections
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import sklearn
import torch
import torchvision
import torchinfo

#(2) Set options
os.environ['PYTHONHASHSEED'] = str(SEED_NUM)
random.seed(a=SEED_NUM)
np.random.seed(seed=SEED_NUM)
torch.use_deterministic_algorithms(mode=True)
torch.manual_seed(seed=SEED_NUM)
torch.mps.manual_seed(seed=SEED_NUM)

#(3) Define device(hardware)
if torch.backends.mps.is_available() :
    device = torch.device(device='mps')
else :
    device = torch.device(device='cpu')
print(f'>> Device : {device}')

>> Device : mps


#### 00.2. **사용자정의함수 정의**

In [13]:
#(1) Define `show_img()` function
def show_img(df:torchvision.datasets, index:int) :
    img = df[index][0]
    target = df[index][1]
    img = img / 2 + 0.5    
    img = np.transpose(a=img.numpy(), axes=(1, 2, 0)) # axes 파라미터는 축 순서 변경
    plt.imshow(X=img) 
    plt.xlabel(xlabel=f'Target : {target}({df.classes[target]})')
    plt.show()

#(2) Define `compute_metrics()` function
def compute_metrics(model:torch.nn.Module, loader:torch.utils.data.DataLoader) :
    _preds = []
    _targets = []
    model.eval()
    with torch.no_grad() : 
        for inputs, targets in loader :
            inputs = inputs.to(device=device) 
            targets = targets.to(device=device)
            preds = model(x=inputs)
            preds = torch.argmax(input=preds, dim=1)
            _preds.extend(preds.cpu().numpy())
            _targets.extend(targets.cpu().numpy())
    model.train()
    accuracy = sklearn.metrics.accuracy_score(y_true=_targets, y_pred=_preds)
    precision = sklearn.metrics.precision_score(y_true=_targets, y_pred=_preds, average='weighted')
    recall = sklearn.metrics.recall_score(y_true=_targets, y_pred=_preds, average='weighted')
    f1 = sklearn.metrics.f1_score(y_true=_targets, y_pred=_preds, average='weighted')
    output = pd.DataFrame(data={
        'metricName' : ['accuracy', 'precision', 'recall', 'f1'], 
        'value'      : [accuracy, precision, recall, f1] 
    })
    return output

#### 00.3. **클래스 정의**

In [14]:
pass

<b></b>

## 01. **데이터셋 전처리 및 로드**

#### 01.1. **이미지 전처리 파이프라인 정의**

In [15]:
img_tf = torchvision.transforms.Compose(
    transforms=[
        torchvision.transforms.RandomCrop(size=32, padding=4),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ]
)

#### 01.2. **데이터셋 로드**

In [16]:
cifa_train = torchvision.datasets.CIFAR10(root='../../data', train=True, download=True, transform=img_tf)
cifa_test = torchvision.datasets.CIFAR10(root='../../data', train=False, download=True, transform=img_tf)

#### 01.3. **EDA**

In [17]:
#(1) Print sample of train
# len(cifa_train)

In [18]:
#(2) Print image shape 
# cifa_train[0][0].shape

In [19]:
#(3) Print frequency of target class
# target_freq = collections.Counter()
# for i in range(len(cifa_train)):
#     input, target = cifa_train[i]
#     if isinstance(target, torch.Tensor) :
#         target = target.item()
#     target_freq[target] += 1
# pd.DataFrame(data=list(target_freq.items()), columns=['class', 'count']).sort_values(by='class')

In [20]:
#(4) Display image
# show_img(df=cifa_train, index=5)

#### 01.4. **데이터로더 변환**

In [21]:
cifa_train_loader = torch.utils.data.DataLoader(dataset=cifa_train, batch_size=BATCH_SIZE, shuffle=True)
cifa_test_loader = torch.utils.data.DataLoader(dataset=cifa_test, batch_size=BATCH_SIZE, shuffle=True)

<b></b>

## 02. **모델 구축 및 학습**

#### 02.1. **모델 정의**

In [22]:
#(2) Define `model`
model = torchvision.models.resnet18(weights='DEFAULT')

#(2) Customize layer
model.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.fc = torch.nn.Linear(in_features=512, out_features=10)

#(3)
model = model.to(dtype=torch.float32, device=device)

#(4) Display `model`
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

#### **(`PLUS`)** 모델 동결(Model Freezing)하기 : `param.requires_grade=False`

In [23]:
pass

In [24]:
#(4)
dummy = torch.randn(size=[BATCH_SIZE]+list(cifa_train[0][0].squeeze(dim=0).shape)).to(device=device)
torchinfo.summary(model=model, input_data=dummy)

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [32, 10]                  --
├─Conv2d: 1-1                            [32, 64, 32, 32]          1,728
├─BatchNorm2d: 1-2                       [32, 64, 32, 32]          128
├─ReLU: 1-3                              [32, 64, 32, 32]          --
├─MaxPool2d: 1-4                         [32, 64, 16, 16]          --
├─Sequential: 1-5                        [32, 64, 16, 16]          --
│    └─BasicBlock: 2-1                   [32, 64, 16, 16]          --
│    │    └─Conv2d: 3-1                  [32, 64, 16, 16]          36,864
│    │    └─BatchNorm2d: 3-2             [32, 64, 16, 16]          128
│    │    └─ReLU: 3-3                    [32, 64, 16, 16]          --
│    │    └─Conv2d: 3-4                  [32, 64, 16, 16]          36,864
│    │    └─BatchNorm2d: 3-5             [32, 64, 16, 16]          128
│    │    └─ReLU: 3-6                    [32, 64, 16, 16]          --
│

In [25]:
#(5) Define loss function
criterion = torch.nn.CrossEntropyLoss()

#(6) Define optimizer(optimization method)
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-2, weight_decay=1e-7)

#(7) Define Scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=30, gamma=0.1)

#### 02.2. **학습 전 변수 정의**

In [26]:
batch_cnt = len(cifa_train_loader)
if USE_PRETRAIN_YN == 'Y' :
    checkpoint = torch.load(f=MODEL_PTH)
    model.load_state_dict(state_dict=checkpoint['model'])
    optimizer.load_state_dict(state_dict=checkpoint['optimizer'])
    epoch = checkpoint['epoch']
    loss_hist = checkpoint['loss_hist']
    best_loss = loss_hist[-1]
else :
    epoch = 0
    loss_hist = []
    best_loss = float('inf')
print(f">> Epoch={epoch}, Train Loss={best_loss}")

>> Epoch=0, Train Loss=inf


#### 02.3. **모델 학습**

In [27]:
progress_bar = tqdm.trange(epoch, EPOCH_NUM)
for epoch in progress_bar : 
    running_loss = 0.0
    model.train()
    for inputs, targets in cifa_train_loader :
        inputs = inputs.to(device=device)
        targets = targets.to(device=device)
        optimizer.zero_grad() 
        preds = model(x=inputs)
        loss = criterion(input=preds, target=targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    running_loss_avg = running_loss / batch_cnt
    loss_hist.append(running_loss_avg)
    if running_loss_avg < best_loss :
        best_loss = running_loss_avg
        torch.save(
            obj={
                'epoch'     : epoch,
                'loss_hist' : loss_hist,
                'model'     : model.state_dict(),
                'optimizer' : optimizer.state_dict()
            }, 
            f=MODEL_PTH
        )
    scheduler.step()
    progress_bar.set_postfix(ordered_dict={'epoch':epoch+1, 'loss':running_loss_avg}) 

  8%|▊         | 2/25 [03:53<44:46, 116.81s/it, epoch=2, loss=1.52]


KeyboardInterrupt: 

<b></b>

## 03. **모델 평가**

#### 03.1. **최적 성능 모델 로드**

In [None]:
checkpoint = torch.load(f=MODEL_PTH)
model.load_state_dict(state_dict=checkpoint['model'])
print(f'>> Epoch : {checkpoint["epoch"]}, Loss : {checkpoint["loss_hist"][-1]}')

#### 03.2. **학습 손실(Traing Loss) 확인**

In [None]:
#(1) Check metrics
compute_metrics(model=model, loader=cifa_train_loader)

In [None]:
#(2) Plot traing loss
plt.figure(figsize=(12, 6))
plt.title(label='Training Loss')
plt.xlabel(xlabel='epoch')
plt.ylabel(ylabel='loss')
plt.plot(loss_hist)
plt.show()

#### 03.3. **성능 평가**

In [None]:
compute_metrics(model=model, loader=cifa_test_loader)