# MNIST Transfer Learning

In [None]:
import torch
import torchvision.transforms as transforms
from torchvision import models, datasets
from torch import nn, optim
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### 사용 가능한 모델 리스트

In [None]:
all_models = models.list_models()
all_models

['alexnet',
 'convnext_base',
 'convnext_large',
 'convnext_small',
 'convnext_tiny',
 'deeplabv3_mobilenet_v3_large',
 'deeplabv3_resnet101',
 'deeplabv3_resnet50',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'efficientnet_b0',
 'efficientnet_b1',
 'efficientnet_b2',
 'efficientnet_b3',
 'efficientnet_b4',
 'efficientnet_b5',
 'efficientnet_b6',
 'efficientnet_b7',
 'efficientnet_v2_l',
 'efficientnet_v2_m',
 'efficientnet_v2_s',
 'fasterrcnn_mobilenet_v3_large_320_fpn',
 'fasterrcnn_mobilenet_v3_large_fpn',
 'fasterrcnn_resnet50_fpn',
 'fasterrcnn_resnet50_fpn_v2',
 'fcn_resnet101',
 'fcn_resnet50',
 'fcos_resnet50_fpn',
 'googlenet',
 'inception_v3',
 'keypointrcnn_resnet50_fpn',
 'lraspp_mobilenet_v3_large',
 'maskrcnn_resnet50_fpn',
 'maskrcnn_resnet50_fpn_v2',
 'maxvit_t',
 'mc3_18',
 'mnasnet0_5',
 'mnasnet0_75',
 'mnasnet1_0',
 'mnasnet1_3',
 'mobilenet_v2',
 'mobilenet_v3_large',
 'mobilenet_v3_small',
 'mvit_v1_b',
 'mvit_v2_s',
 'quantized_googlenet',
 '

In [None]:
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

### Frozen

In [None]:
# 모든 레이어의 학습 중단
for param in model.parameters():
    param.requires_grad = False # layer frozen

# 마지막 fc layer 재학습으로 수정
for param in model.fc.parameters():
    param.requires_grad = True

In [None]:
for idx, (name, module) in enumerate(model.named_modules()):
    print(f"{idx + 1}. {name} : {module}")

1.  : ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), strid

### 분류 구조 변경

In [None]:
model.fc

Linear(in_features=2048, out_features=1000, bias=True)

In [None]:
model.fc.in_features

2048

In [None]:
model.fc = (
    nn.Sequential(
        nn.Linear(in_features=model.fc.in_features, out_features=10),
        nn.LogSoftmax(dim=1)
    )
)
model.to(device)
for idx, (name, module) in enumerate(model.named_modules()):
    print(f"{idx + 1}. {name} : {module}")

1.  : ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), strid

In [None]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize(size=(224,224)),
    transforms.ToTensor()
])

train_dataset = datasets.MNIST(root="dataset",
                               train=True,
                               download=True,
                               transform=transform)
test_dataset = datasets.MNIST(root="dataset",
                               train=False,
                               download=True,
                               transform=transform)

In [None]:
train_idx, valid_idx = train_test_split(
    range(len(train_dataset)),
    stratify=train_dataset.targets, # 균등 분포
    test_size=0.2
)
train_dataset = Subset(dataset=train_dataset, indices=train_idx)
validation_dataset = Subset(dataset=train_dataset, indices=valid_idx)
print(f"{len(train_dataset)} {len(validation_dataset)} {len(test_dataset)}")

48000 12000 10000


### minibatch

In [None]:
batch_size = 512
train_batches = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
validation_batches = DataLoader(dataset=validation_dataset, batch_size=batch_size, shuffle=True)
test_batches = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
from tqdm import tqdm

optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_function = nn.NLLLoss()
num_epochs = 1

losses = []

model.train()
for epoch in range(num_epochs):
    total_loss = 0
    for inputs, labels in tqdm(train_batches, desc=f"epoch {epoch + 1}"):

        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        loss = loss_function(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    losses.append(total_loss)
    print(f"Epoch {epoch + 1} / {num_epochs}, Loss : {total_loss / len(train_batches)}")

epoch 1: 100%|██████████| 94/94 [4:07:14<00:00, 157.81s/it]

Epoch 1 / 1, Loss : 1.461938461090656





### evaluation

In [15]:
model.eval()
total_loss, acc = 0, 0.
wrong_inputs, wrong_preds_indices, actual_preds_indices = [], [], []
with torch.no_grad():
    total_corrects = 0
    for inputs, labels in tqdm(test_batches, desc=f"eval"):

        inputs = inputs.to(device)
        labels = inputs.to(device)

        y_test_pred = model(inputs)
        test_loss = loss_function(y_test_pred, labels)
        total_loss += test_loss.item()

        pred_labels = torch.argmax(y_test_pred, dim=1)
        corrects = pred_labels == labels
        sum = corrects.sum().item()
        total_corrects += sum

        wrong_idx = pred_labels.ne(labels).nonzero()[:, 0].cpu().numpy().tolist()
        for index in wrong_idx:
            wrong_inputs.append(inputs[index].cpu()) # 잘못 예측한 X
            wrong_preds_indices.append(pred_labels[index].cpu()) # 잘못 예측한 Y
            actual_preds_indices.append(labels[index].cpu()) # 실제 Y

    acc = total_corrects / len(train_dataset)
total_loss = total_loss / len(test_batches)


In [None]:
import matplotlib.pyplot as plt

fig, axe_loss = plt.subplots()

axe_loss.plot(range(len(losses)), losses)
axe_loss.set_xlabel("Loss")

print(f"Accuracy: {acc}")

for idx in range(10):
    plt.subplot(10, 10, idx + 1)
    plt.axis("off")
    plt.imshow(wrong_inputs[idx][0,:,:].numpy(), cmap="gray")
    plt.title(f"""
    Pred : {str(wrong_preds_indices[idx])}
    Real : {str(wrong_preds_indices)}
    """)

plt.show()
