In [1]:
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import time

In [2]:
batch_size = 32
learning_rate = 1e-3
num_epoches = 50

In [3]:
train_dataset = datasets.MNIST(root='./data', train = True, transform=transforms.ToTensor(), download = True)
test_dataset = datasets.MNIST(root='./data', train = False, transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle=False)

In [4]:
class FCNet(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super(FCNet, self).__init__()
        self.layer1 = nn.Linear(in_dim, n_hidden_1)
        self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
        self.layer3 = nn.Linear(n_hidden_2, out_dim)
    
    def forward(self, x):
#         print(x.size())
        x = x.view(x.size(0), -1)
        out_1 = self.layer1(x)
        out_2 = self.layer2(out_1)
        out_3 = self.layer3(out_2)
        return out_3

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FCNet(28*28, 300, 100, 10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = learning_rate, weight_decay = 1e-5)

In [7]:
for epoch in range(num_epoches):
#     print('epoch {}'.format(epoch + 1))
#     print('*' * 10)
    for i, data in enumerate(train_loader, 1):
        img, label = data
        img = img.view(img.size(0), -1)
        img = Variable(img).to(device)
        label = Variable(label).to(device)
        
#         print('batch {}, img size'.format(i) + str(img.size()))
        out = model(img)
        loss = criterion(out, label)
        _, pred = torch.max(out, 1)
        num_correct = (pred == label).sum()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    eval_loss = 0
    eval_acc = 0
    for data in test_loader:
        img, label = data
        image = img.view(img.size(0), -1)
        img = Variable(img).to(device)
        label = Variable(label).to(device)
        out = model(img)
        loss = criterion(out, label)
        eval_loss += loss.data.item() * label.size(0)
        _, pred = torch.max(out, 1)
        num_correct = (pred==label).sum()
        eval_acc += num_correct.data.item()

    print('[{}/{}] Test Loss: {:.6f}, Test Acc: {:.6f}'.format(epoch+1, num_epoches, eval_loss / len(test_dataset), eval_acc / len(test_dataset)))


[1/50] Test Loss: 0.295943, Test Acc: 0.917900
[2/50] Test Loss: 0.299297, Test Acc: 0.917700
[3/50] Test Loss: 0.283107, Test Acc: 0.919500
[4/50] Test Loss: 0.297940, Test Acc: 0.915800
[5/50] Test Loss: 0.299987, Test Acc: 0.915700


KeyboardInterrupt: 