[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/shhommychon/mnist_generative_practice/blob/master/PretrainedClassfier.ipynb)

# MNIST 분류기 사전학습

## 환경설정

- PyTorch 버젼 다운그레이드 코드로 인해 세션이 자동으로 한번 꺼질 예정입니다. 다시 실행하면 됩니다.

In [1]:
# 현 CUDA 버젼 확인
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Aug_15_22:02:13_PDT_2023
Cuda compilation tools, release 12.2, V12.2.140
Build cuda_12.2.r12.2/compiler.33191640_0


In [2]:
# 파이썬 버젼 고정
import torch

# 현 PyTorch 버젼 확인
if torch.__version__.split('+')[0] != "1.11.0":
    print(f"Current PyTorch version is {torch.__version__}, downgrading to 1.11.0")

    # PyTorch 제거
    !pip uninstall torch torchvision -y

    # PyTorch 1.11.0 설치
    !pip install torch==1.11.0 torchvision==0.12.0 -f https://download.pytorch.org/whl/cu111/torch_stable.html

    # 런타임 강제 재시작
    import os
    os.kill(os.getpid(), 9)
else:
    print("PyTorch version is already 1.11.0")

PyTorch version is already 1.11.0


In [3]:
!git clone https://github.com/shhommychon/mnist_generative_practice.git

import sys
sys.path = ["./mnist_generative_practice"] + sys.path

Cloning into 'mnist_generative_practice'...
remote: Enumerating objects: 16, done.[K
remote: Counting objects: 100% (16/16), done.[K
remote: Compressing objects: 100% (11/11), done.[K
remote: Total 16 (delta 4), reused 15 (delta 3), pack-reused 0[K
Receiving objects: 100% (16/16), 7.91 KiB | 7.91 MiB/s, done.
Resolving deltas: 100% (4/4), done.


## 실습 코드
- 원본 코드: [devnson/mnist_pytorch](https://github.com/devnson/mnist_pytorch)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import transforms
import torchvision.datasets as datasets

import numpy as np

### MNIST 데이터셋 다운로드 및 불러오기

- 다운로드 코드 원본: [devnson/mnist_pytorch](https://github.com/devnson/mnist_pytorch?tab=readme-ov-file#downloading-datasets)
- 불러오기 코드 원본: [devnson/mnist_pytorch](https://github.com/devnson/mnist_pytorch?tab=readme-ov-file#splitting-datasets)


In [5]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True,
                      transform=transforms.Compose([transforms.ToTensor()]))
mnist_testset = datasets.MNIST(root='./data', train=False, download=True,
                      transform=transforms.Compose([transforms.ToTensor()]))

mnist_valset, mnist_testset = torch.utils.data.random_split(
                                mnist_testset,
                                [
                                    int(0.9 * len(mnist_testset)),
                                    int(0.1 * len(mnist_testset))
                                ]
                            )
train_dataloader = torch.utils.data.DataLoader(mnist_trainset, batch_size=64, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(mnist_valset, batch_size=32, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(mnist_testset, batch_size=32, shuffle=False)

print("Training dataset size: ", len(mnist_trainset))
print("Validation dataset size: ", len(mnist_valset))
print("Testing dataset size: ", len(mnist_testset))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Training dataset size:  60000
Validation dataset size:  9000
Testing dataset size:  1000


### 모델 및 손실함수/옵티마이저 등 설정

In [6]:
from models.classifier import ConvFeatExtractor, SoftDropout, ClassifierHead, MNISTClassifier

model = MNISTClassifier()

if (torch.cuda.is_available()): model.cuda()

model

MNISTClassifier(
  (feature_extractor): ConvFeatExtractor(
    (conv_1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv_2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (max_pool2d): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (linear): Linear(in_features=3136, out_features=128, bias=True)
    (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (relu): ReLU()
  )
  (classifier): ClassifierHead(
    (dropout): SoftDropout(
      (dropouts): ModuleList(
        (0): Dropout(p=0.5, inplace=False)
        (1): Dropout(p=0.5, inplace=False)
        (2): Dropout(p=0.5, inplace=False)
        (3): Dropout(p=0.5, inplace=False)
        (4): Dropout(p=0.5, inplace=False)
      )
    )
    (relu): ReLU()
    (linear): Linear(in_features=128, out_features=10, bias=True)
  )
)

In [7]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

max_epochs = 100
best_val_loss = 100
max_patience = 3

### 모델 학습 코드

- 모델 학습 코드 원본: [devnson/mnist_pytorch](https://github.com/devnson/mnist_pytorch?tab=readme-ov-file#training)

In [8]:
patience = max_patience
for epoch in range(max_epochs):
    total_train_loss = 0
    total_val_loss = 0

    model.train()

    # training
    for iter, (image, label) in enumerate(train_dataloader):

        if (torch.cuda.is_available()):
            image = image.cuda()
            label = label.cuda()


        # optimizing gradients to zero before training

        optimizer.zero_grad()

        pred = model(image)

        this_train_loss = loss(pred, label)
        total_train_loss += this_train_loss.item()

        this_train_loss.backward()
        optimizer.step()

    total_train_loss = total_train_loss / (iter + 1)

    # validation
    model.eval()
    correct = 0
    for iter, (image, label) in enumerate(val_dataloader):

        if (torch.cuda.is_available()):
            image = image.cuda()
            label = label.cuda()

        pred = model(image)

        this_val_loss = loss(pred, label)
        total_val_loss += this_val_loss.item()

        pred = F.softmax(pred, dim=1)
        correct += (label == pred.argmax(dim=1)).sum().item()

    total_val_loss = total_val_loss / (iter + 1)
    accuracy = correct / len(mnist_valset)

    print(f"\nEpoch: {epoch+1}/{max_epochs}, Train Loss: {total_train_loss:.8f}, Val Loss: {total_val_loss:.8f}, Val Accuracy: {accuracy:.8f}")

    if total_val_loss < best_val_loss:
        best_val_loss = total_val_loss
        print(f"\tSaving the model state dictionary for Epoch: {epoch+1} with Validation loss: {total_val_loss:.8f}")
        torch.save(model.feature_extractor.state_dict(), "mnist_feature_extractor.dth")
        torch.save(model.classifier.state_dict(), "mnist_classifier.dth")
        patience = max_patience
    else:
        patience -= 1
        print(f"\tLoss not decreased. Will wait for {patience} more epochs...")

    if patience <= 0: break



Epoch: 1/100, Train Loss: 2.28912675, Val Loss: 2.23135520, Val Accuracy: 0.29022222
	Saving the model state dictionary for Epoch: 1 with Validation loss: 2.23135520

Epoch: 2/100, Train Loss: 2.23171102, Val Loss: 2.06879649, Val Accuracy: 0.93411111
	Saving the model state dictionary for Epoch: 2 with Validation loss: 2.06879649

Epoch: 3/100, Train Loss: 2.13738816, Val Loss: 1.85302809, Val Accuracy: 0.97822222
	Saving the model state dictionary for Epoch: 3 with Validation loss: 1.85302809

Epoch: 4/100, Train Loss: 2.01617039, Val Loss: 1.61801096, Val Accuracy: 0.98377778
	Saving the model state dictionary for Epoch: 4 with Validation loss: 1.61801096

Epoch: 5/100, Train Loss: 1.87387433, Val Loss: 1.31445243, Val Accuracy: 0.98322222
	Saving the model state dictionary for Epoch: 5 with Validation loss: 1.31445243

Epoch: 6/100, Train Loss: 1.71667977, Val Loss: 1.06748210, Val Accuracy: 0.98433333
	Saving the model state dictionary for Epoch: 6 with Validation loss: 1.0674821

### 사전학습 모델 불러오기

- 위 모델 학습 코드를 실행하기 귀찮을 경우에만 이 셀을 실행

In [8]:
!apt install megatools

# mnist_feature_extractor.dth
!megadl 'https://mega.nz/file/DgUQyDyB#7Gyq_9kzCz8FcGZV659VD1Cq1_36wimGVOG2Eram3P8'

# mnist_classifier.dth
!megadl 'https://mega.nz/file/H49S3bTI#qsonzlkV3JMniTbyzV77BB9VLhwmh1OJLTgxuD4PEMM'

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following NEW packages will be installed:
  megatools
0 upgraded, 1 newly installed, 0 to remove and 31 not upgraded.
Need to get 207 kB of archives.
After this operation, 898 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 megatools amd64 1.10.3-1build1 [207 kB]
Fetched 207 kB in 0s (661 kB/s)
Selecting previously unselected package megatools.
(Reading database ... 121671 files and directories currently installed.)
Preparing to unpack .../megatools_1.10.3-1build1_amd64.deb ...
Unpacking megatools (1.10.3-1build1) ...
Setting up megatools (1.10.3-1build1) ...
Processing triggers for man-db (2.10.2-1) ...
[0KDownloaded mnist_feature_extractor.dth
[0KDownloaded mnist_classifier.dth


### 사전학습 모델 테스트

- 테스트 코드 원본: [devnson/mnist_pytorch](https://github.com/devnson/mnist_pytorch?tab=readme-ov-file#testing-model)

In [9]:
model.feature_extractor.load_state_dict(torch.load("mnist_feature_extractor.dth"))
model.classifier.load_state_dict(torch.load("mnist_classifier.dth"))
model.eval()

results = list()
correct = 0
for iter, (image, label) in enumerate(test_dataloader):

    if (torch.cuda.is_available()):
        image = image.cuda()
        label = label.cuda()

    pred = model(image)

    pred = F.softmax(pred, dim=1)
    correct += (label == pred.argmax(dim=1)).sum().item()

test_accuracy = correct / len(mnist_testset)
print(f"Test accuracy {test_accuracy:.8f}")

Test accuracy 0.98900000
