In [2]:
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.init
from torch.utils.data import DataLoader
import torch.optim as optim

In [3]:
device='cuda' if torch.cuda.is_available() else 'cpu'
#랜덤 시드 고정
torch.manual_seed(777)
#GPU 사용 가능일 경우 랜덤 시드 고정
if device=='cuda':
    torch.cuda.manual_seed_all(777)

In [4]:
learning_rate=0.001
epochs=30
batch_size=256

In [5]:
cifar10_train=dsets.CIFAR10(root='data/cifar10_data',
                        train=True,
                        transform=transforms.ToTensor(),
                        download=True)
cifar10_test=dsets.CIFAR10(root='data/cifar10_data',
                        train=False,
                        transform=transforms.ToTensor(),
                        download=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar10_data\cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [03:22<00:00, 841008.86it/s] 


Extracting data/cifar10_data\cifar-10-python.tar.gz to data/cifar10_data
Files already downloaded and verified


In [6]:
print(cifar10_train)
print(cifar10_test)

Dataset CIFAR10
    Number of datapoints: 50000
    Root location: data/cifar10_data
    Split: Train
    StandardTransform
Transform: ToTensor()
Dataset CIFAR10
    Number of datapoints: 10000
    Root location: data/cifar10_data
    Split: Test
    StandardTransform
Transform: ToTensor()


In [7]:
data_loader=DataLoader(dataset=cifar10_train,
                       batch_size=batch_size,
                       shuffle=True,
                       drop_last=True)
test_loader=DataLoader(dataset=cifar10_test,
                       batch_size=batch_size,
                       shuffle=True,
                       drop_last=True)

In [8]:
for X, Y in data_loader:
  print(X.size())
  print(Y.size())
  break

torch.Size([256, 3, 32, 32])
torch.Size([256])


In [9]:
for X, Y in test_loader:
  print(X.size())
  print(Y.size())
  break

torch.Size([256, 3, 32, 32])
torch.Size([256])


In [10]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        # Conv2d(3, 32) : 3은 input 채널수, 32는 output 채널수
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #(32, 16, 16)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #(64, 8, 8)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #(128, 4, 4)
        )

        self.fc1 = nn.Linear(4*4*128, 128, bias=True)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        # nn.init.uniform_(self.fc1.weight) : 초기화 생략 가능

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(out.size(0), -1)  #[256, 128, 4, 4] => [256, 128*4*4]
        out = self.fc1(out)
        out = self.fc2(out)
        out = self.fc3(out)
        return out



In [11]:
model = CNN().to(device)
crit = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [12]:
cnt_train_batch = len(data_loader)
cnt_test_batch = len(test_loader)
print(cnt_train_batch)
print(cnt_test_batch)

195
39


In [None]:
for epoch in range(epochs):
    avg_cost =0

    for X,Y in data_loader:
        X= X.to(device)
        Y= Y.to(device)

        y_hat =model(X)        
        cost = crit(y_hat, Y)

        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

        avg_cost += cost/cnt_train_batch

    print("epoch : {}, cost : {}".format(epoch,avg_cost))

In [None]:
with torch.no_grad():
    accuracy = 0
    for X, Y in test_loader:
        X = X.to(device)
        Y = Y.to(device)
        pred = model(X)
        accuracy = (pred.argmax(dim=1) == Y).sum()
        accuracy += accuracy
        print(pred.argmax(dim=1))
        print(Y)
        print(accuracy)
    print(accuracy/(cnt_test_batch*batch_size))  
