In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, datasets

In [2]:
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")

In [3]:
DEVICE

device(type='cpu')

In [4]:
EPOCHS = 30
BATCH_SIZE = 64

In [12]:
transform = transforms.Compose([
    transforms.ToTensor()
])

trainset = datasets.FashionMNIST(
    root = './data/',
    train = True,
    download = True,
    transform = transform,
)

testset = datasets.FashionMNIST(
    root = './data/',
    train = False,
    download = True,
    transform = transform,
)
batch_size = 16
train_loader = torch.utils.data.DataLoader(
    dataset = trainset,
    batch_size = batch_size,
)
test_loader = torch.utils.data.DataLoader(
    dataset = testset,
    batch_size = batch_size,
)

In [5]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [6]:
model = Net().to(DEVICE)
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [7]:
def train(model, train_loader, optimizer):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        

In [16]:
def evaluate(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
        
    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    return test_loss, test_accuracy

In [17]:
for epoch in range(1, EPOCHS + 1):
    train(model, train_loader, optimizer)
    test_loss, test_acc = evaluate(model, test_loader)
    
    print(f"[{epoch} Test Loss: {test_loss:.4f}, Accuracy: {test_acc:.4f}]")

[1 Test Loss: 0.4555, Accuracy: 83.7700]
[2 Test Loss: 0.4274, Accuracy: 84.6000]
[3 Test Loss: 0.4085, Accuracy: 85.4400]
[4 Test Loss: 0.3932, Accuracy: 86.0400]
[5 Test Loss: 0.3810, Accuracy: 86.3200]
[6 Test Loss: 0.3726, Accuracy: 86.5200]
[7 Test Loss: 0.3650, Accuracy: 86.8600]
[8 Test Loss: 0.3594, Accuracy: 87.1800]
[9 Test Loss: 0.3553, Accuracy: 87.3200]
[10 Test Loss: 0.3521, Accuracy: 87.4400]
[11 Test Loss: 0.3489, Accuracy: 87.7200]
[12 Test Loss: 0.3435, Accuracy: 87.9000]
[13 Test Loss: 0.3414, Accuracy: 87.8200]
[14 Test Loss: 0.3402, Accuracy: 88.0600]
[15 Test Loss: 0.3371, Accuracy: 88.1600]
[16 Test Loss: 0.3362, Accuracy: 88.3200]
[17 Test Loss: 0.3366, Accuracy: 88.2900]
[18 Test Loss: 0.3337, Accuracy: 88.5700]
[19 Test Loss: 0.3338, Accuracy: 88.6700]
[20 Test Loss: 0.3354, Accuracy: 88.5800]
[21 Test Loss: 0.3340, Accuracy: 88.7400]
[22 Test Loss: 0.3341, Accuracy: 88.7000]
[23 Test Loss: 0.3357, Accuracy: 88.5100]
[24 Test Loss: 0.3342, Accuracy: 88.6700]
[

In [26]:
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(DEVICE), target.to(DEVICE)
        output = model(data)
#         print(output)
        pred = output.max(1, keepdim=True)[1]
#         print(output.max(1)[1])
        print(target)
        print(target.view_as(pred))
        break

tensor([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 7, 3, 4, 1])
tensor([[9],
        [2],
        [1],
        [1],
        [6],
        [1],
        [4],
        [6],
        [5],
        [7],
        [4],
        [5],
        [7],
        [3],
        [4],
        [1]])
