In [28]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn as nn
import tqdm

In [4]:
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
device

'mps'

In [5]:
learning_rate = 0.001
training_epochs = 15
batch_size = 128

In [8]:
mnist_train = dsets.MNIST(root='MNIST_data/', # 다운로드 경로 지정
                          train=True, # True를 지정하면 훈련 데이터로 다운로드
                          transform=transforms.ToTensor(), # 텐서로 변환
                          download=True)

mnist_test = dsets.MNIST(root='MNIST_data/', # 다운로드 경로 지정
                         train=False, # False를 지정하면 테스트 데이터로 다운로드
                         transform=transforms.ToTensor(), # 텐서로 변환
                         download=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:08<00:00, 1139092.16it/s]


Extracting MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz to MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 134793.66it/s]


Extracting MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:05<00:00, 282331.80it/s]


Extracting MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 866799.93it/s]

Extracting MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw






In [23]:
dataloader = torch.utils.data.DataLoader(dataset=mnist_train,
										batch_size=batch_size,
										shuffle=True,
										drop_last=True)

In [30]:
class CNN(nn.Module):
	def __init__(self):
		super(CNN, self).__init__()
		# conv2d, maxpool2d, relu,linear
		self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
		self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
		self.relu = nn.ReLU()
		self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
		self.linear = nn.Linear(7 * 7 * 64, 10)
	
	def forward(self, x):
		# Conv Block -> 14 * 14 * 32
		x = self.conv1(x)
		x = self.relu(x)
		x = self.pool(x)
		# Conv Block -> 7 * 7 * 64
		x = self.conv2(x)
		x = self.relu(x)
		x = self.pool(x)
		# Fully Connected Layer (128, 64, 7, 7) -> (128, 3136)
		x = x.view(-1, 7 * 7 * 64)
		x = self.linear(x)

		return x


In [37]:
model = CNN()
print(model)

CNN(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu): ReLU()
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (linear): Linear(in_features=3136, out_features=10, bias=True)
)


In [38]:
model.to(device)

CNN(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu): ReLU()
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (linear): Linear(in_features=3136, out_features=10, bias=True)
)

In [39]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [40]:
epochs = 10

for epoch in range(epochs):
	avg_loss = 0

	for X, y in tqdm.tqdm(dataloader):
		X = X.to(device)
		y = y.to(device)

		y_pred = model(X)

		loss = criterion(y_pred, y)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		avg_loss += loss / len(dataloader)
	print('[Epoch: {:>4}] cost = {:>.9}'.format(epoch + 1, avg_loss))


100%|██████████| 468/468 [00:10<00:00, 44.97it/s]


[Epoch:    1] cost = 0.249171093


100%|██████████| 468/468 [00:07<00:00, 59.07it/s]


[Epoch:    2] cost = 0.0659177005


100%|██████████| 468/468 [00:08<00:00, 53.78it/s]


[Epoch:    3] cost = 0.0477367043


100%|██████████| 468/468 [00:08<00:00, 53.89it/s]


[Epoch:    4] cost = 0.0390753821


100%|██████████| 468/468 [00:08<00:00, 55.53it/s]


[Epoch:    5] cost = 0.0332936272


100%|██████████| 468/468 [00:07<00:00, 61.61it/s]


[Epoch:    6] cost = 0.0281144045


100%|██████████| 468/468 [00:07<00:00, 62.23it/s]


[Epoch:    7] cost = 0.0242308788


100%|██████████| 468/468 [00:07<00:00, 62.22it/s]


[Epoch:    8] cost = 0.0218355972


100%|██████████| 468/468 [00:07<00:00, 61.11it/s]


[Epoch:    9] cost = 0.0185687207


100%|██████████| 468/468 [00:07<00:00, 60.82it/s]

[Epoch:   10] cost = 0.0151100401





In [41]:
mnist_test.test_data.shape # 10000, 1, 28, 28



torch.Size([10000, 28, 28])

In [42]:
with torch.no_grad(): # 가중치를 업데이트 하지 않겠다 dropout 같은 가중치에 영향을 미치는 것들은 다 제외
	# 테스트 데이터를 모델에 입력하기 위한 준비
	X_test = mnist_test.test_data.view(len(mnist_test), 1, 28, 28).float().to(device)
	Y_test = mnist_test.test_labels.to(device)

	pred = model(X_test)

	correct_pred = torch.argmax(pred, 1) == Y_test

	acc = correct_pred.float().mean()

	print(f'Accuracy: {acc.item()}')
	



Accuracy: 0.9847999811172485
