<a href="https://colab.research.google.com/github/piriram/AlgoAzazat/blob/main/%08%EB%A8%B8%EC%8B%A0%EB%9F%AC%EB%8B%9D/ML1_MNIST_DecisionTree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from torch.utils.data import DataLoader
from torchvision import transforms, datasets 
import numpy as np
from sklearn import svm
from sklearn.metrics import accuracy_score 
from sklearn.tree import DecisionTreeClassifier 
from scipy.stats import randint
from sklearn.model_selection import GridSearchCV

In [2]:

# 학습 데이터와 테스트 데이터를 위한 전처리 방식 정의
mnist_train_transform = transforms.Compose([transforms.ToTensor()])
mnist_test_transform = transforms.Compose([transforms.ToTensor()])

# MNIST 데이터셋을 다운로드하고 전처리 방식 적용하여 데이터셋을 만듦
trainset_mnist = datasets.MNIST(root='./data', train=True, download=True, transform=mnist_train_transform)
testset_mnist = datasets.MNIST(root='./data', train=False, download=True, transform=mnist_test_transform)

# 데이터셋을 DataLoader에 넣어 데이터를 batch 단위로 불러옴
MNIST_train = DataLoader(trainset_mnist, batch_size=32, shuffle=True, num_workers=2)
MNIST_test = DataLoader(testset_mnist, batch_size=32, shuffle=False, num_workers=2)

# MNIST 학습 데이터셋 이미지와 라벨 정보를 리스트에 담음
MNIST_train_images = []
MNIST_train_labels = []
for batch in MNIST_train:
    images, labels = batch
    images_flat = images.view(images.shape[0], -1)
    MNIST_train_images.append(images_flat.numpy())
    MNIST_train_labels.append(labels.numpy())

# MNIST 학습 데이터셋 이미지 리스트를 numpy 배열로 변환
MNIST_train_images = np.vstack(MNIST_train_images)
MNIST_train_labels = np.concatenate(MNIST_train_labels)

# MNIST 테스트 데이터셋 이미지와 라벨 정보를 리스트에 담음
MNIST_test_images = []
MNIST_test_labels = []
for batch in MNIST_test:
    images, labels = batch
    images_flat = images.view(images.shape[0], -1)
    MNIST_test_images.append(images_flat.numpy())
    MNIST_test_labels.append(labels.numpy())

# MNIST 테스트 데이터셋 이미지 리스트를 numpy 배열로 변환
MNIST_test_images = np.vstack(MNIST_test_images)
MNIST_test_labels = np.concatenate(MNIST_test_labels)



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 71880924.77it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 86278984.21it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 26370161.89it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2097613.83it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [3]:
# 의사 결정 트리 모델 생성 함수
def decisionTree(max_depth):
    dt_model = DecisionTreeClassifier(max_depth=max_depth)

    # GridSearchCV를 사용해서 적절한 하이퍼 파라미터를 검색(교차 유효성 검사 k를 5로 설정)
    params_grid = {
        'min_samples_split': [2, 5, 10],
        'min_samples_leaf': [1, 2, 4],
        'max_leaf_nodes': [5, 10, None]
    }
    grid_search = GridSearchCV(dt_model, params_grid, cv=5, n_jobs=-1)
    grid_search.fit(MNIST_train_images, MNIST_train_labels)


    # 교육 데이터셋에 대한 정확도 계산
    train_score = grid_search.score(MNIST_train_images, MNIST_train_labels)
    print("Training set accuracy: ", train_score)

    # 테스트 데이터셋에 대한 정확도 계산
    test_score = grid_search.score(MNIST_test_images, MNIST_test_labels)
    print("Test set accuracy: ", test_score)

In [4]:
decisionTree(3)

Training set accuracy:  0.49151666666666666
Test set accuracy:  0.4953


In [5]:
decisionTree(6)

Training set accuracy:  0.73825
Test set accuracy:  0.7415


In [6]:
decisionTree(9)

Training set accuracy:  0.8661
Test set accuracy:  0.8502


In [7]:
decisionTree(12)

Training set accuracy:  0.9491666666666667
Test set accuracy:  0.8798
