In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
num_epochs = 5
train_batch_size = 20
test_batch_size = 10
learning_rate = 0.001
classes = [ 
    'T-shirt/top', 
    'Trouser', 
    'Pullover', 
    'Dress', 
    'Coat',
    'Sandal', 
    'Shirt', 
    'Sneaker',
    'Bag',
    'Ankle boot'
]

In [4]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))])

train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)

test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

print(len(train_dataset))
print(len(test_dataset))

60000
10000


In [5]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [6]:
def train(model, device, train_loader, optimizer, criterion, epochs, interval=500):
    for epoch in range(epochs):
        print(f'Starting #{epoch}')
        for batch_idx, (images, labels) in tqdm(enumerate(train_loader)):
            images = images.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()

            if batch_idx % interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(images), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))
        print(f'Finishing Epoch #{epoch}')

In [7]:
def test(model, device, test_loader):
    correct_imgs = 0
    matrix = []
    for _ in range(10):
        matrix.append([0] * 10)
    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct_imgs += predictions.eq(labels.view_as(predictions)).sum().item()
            for j, (label, pred) in enumerate(list(zip(labels.tolist(), predictions.tolist()))):   
                matrix[pred][label] += 1
        
    return correct_imgs / len(test_loader.dataset), matrix

In [8]:
model = ConvNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)

In [9]:
train(model, device, train_loader, optimizer, criterion, 5, 100)

2it [00:00, 16.74it/s]

Starting #0


104it [00:04, 24.24it/s]



203it [00:08, 24.15it/s]



305it [00:12, 24.32it/s]



404it [00:16, 27.10it/s]



503it [00:20, 24.37it/s]



605it [00:24, 24.45it/s]



704it [00:28, 24.45it/s]



806it [00:33, 24.86it/s]



905it [00:37, 27.23it/s]



1004it [00:40, 24.73it/s]



1103it [00:44, 24.55it/s]



1205it [00:49, 24.46it/s]



1304it [00:53, 27.31it/s]



1403it [00:57, 23.57it/s]



1505it [01:01, 24.39it/s]



1604it [01:05, 24.20it/s]



1703it [01:09, 24.30it/s]



1805it [01:13, 24.23it/s]



1904it [01:17, 24.34it/s]



2003it [01:22, 24.17it/s]



2105it [01:26, 24.15it/s]



2204it [01:30, 24.34it/s]



2303it [01:34, 24.53it/s]



2405it [01:38, 24.12it/s]



2504it [01:42, 24.16it/s]



2603it [01:46, 24.23it/s]



2705it [01:51, 23.62it/s]



2804it [01:55, 24.48it/s]



2903it [01:59, 24.49it/s]



3000it [02:03, 24.32it/s]
3it [00:00, 21.46it/s]

Finishing Epoch #0
Starting #1


105it [00:04, 24.56it/s]



204it [00:08, 24.69it/s]



303it [00:12, 23.23it/s]



404it [00:17, 24.66it/s]



503it [00:21, 24.62it/s]



605it [00:25, 24.62it/s]



704it [00:29, 27.52it/s]



806it [00:32, 27.13it/s]



905it [00:36, 27.68it/s]



1004it [00:40, 27.44it/s]



1106it [00:43, 27.54it/s]



1205it [00:47, 27.48it/s]



1304it [00:51, 24.54it/s]



1406it [00:55, 24.75it/s]



1505it [00:59, 24.61it/s]



1604it [01:03, 24.45it/s]



1703it [01:07, 24.72it/s]



1805it [01:11, 24.45it/s]



1904it [01:16, 24.68it/s]



2006it [01:20, 24.83it/s]



2105it [01:24, 24.06it/s]



2204it [01:28, 24.56it/s]



2306it [01:32, 24.58it/s]



2405it [01:36, 24.66it/s]



2504it [01:40, 24.79it/s]



2606it [01:44, 24.64it/s]



2705it [01:48, 24.82it/s]



2804it [01:52, 25.92it/s]



2903it [01:56, 23.50it/s]



3000it [02:00, 24.85it/s]
3it [00:00, 21.68it/s]

Finishing Epoch #1
Starting #2


105it [00:04, 24.34it/s]



204it [00:08, 24.41it/s]



303it [00:12, 24.20it/s]



405it [00:16, 24.07it/s]



504it [00:20, 22.79it/s]



603it [00:24, 24.56it/s]



705it [00:29, 24.72it/s]



804it [00:33, 24.58it/s]



906it [00:37, 24.66it/s]



1005it [00:41, 24.53it/s]



1104it [00:45, 23.61it/s]



1203it [00:49, 24.56it/s]



1306it [00:53, 24.06it/s]



1405it [00:58, 24.30it/s]



1504it [01:02, 27.10it/s]



1606it [01:06, 24.53it/s]



1705it [01:10, 24.67it/s]



1804it [01:14, 24.64it/s]



1906it [01:18, 24.68it/s]



2005it [01:22, 24.74it/s]



2104it [01:26, 24.21it/s]



2206it [01:30, 24.58it/s]



2305it [01:34, 24.64it/s]



2404it [01:38, 24.65it/s]



2506it [01:43, 24.75it/s]



2602it [01:46, 24.67it/s]



2704it [01:51, 24.46it/s]



2805it [01:55, 24.53it/s]



2903it [02:00, 24.41it/s]



3000it [02:04, 24.18it/s]
3it [00:00, 26.25it/s]

Finishing Epoch #2
Starting #3


105it [00:03, 27.55it/s]



204it [00:07, 27.28it/s]



303it [00:11, 24.34it/s]



405it [00:15, 24.50it/s]



504it [00:19, 24.43it/s]



603it [00:23, 24.70it/s]



705it [00:27, 24.62it/s]



804it [00:31, 24.50it/s]



906it [00:35, 24.44it/s]



1005it [00:40, 24.57it/s]



1104it [00:44, 24.62it/s]



1206it [00:48, 24.56it/s]



1305it [00:52, 24.49it/s]



1404it [00:56, 23.18it/s]



1503it [01:01, 15.27it/s]



1604it [01:06, 18.29it/s]



1703it [01:11, 24.44it/s]



1805it [01:15, 23.96it/s]



1904it [01:19, 24.68it/s]



2006it [01:23, 24.59it/s]



2105it [01:27, 24.58it/s]



2204it [01:31, 24.80it/s]



2306it [01:35, 24.50it/s]



2405it [01:39, 24.60it/s]



2504it [01:43, 24.67it/s]



2603it [01:47, 24.56it/s]



2705it [01:51, 24.74it/s]



2804it [01:55, 24.74it/s]



2906it [02:00, 24.54it/s]



3000it [02:03, 24.20it/s]
3it [00:00, 22.53it/s]

Finishing Epoch #3
Starting #4


105it [00:04, 24.69it/s]



204it [00:08, 27.48it/s]



306it [00:11, 25.73it/s]



405it [00:15, 24.59it/s]



504it [00:20, 24.62it/s]



606it [00:24, 24.69it/s]



705it [00:28, 24.65it/s]



804it [00:32, 24.68it/s]



903it [00:36, 24.77it/s]



1005it [00:40, 24.56it/s]



1104it [00:44, 27.47it/s]



1206it [00:48, 24.81it/s]



1305it [00:52, 24.51it/s]



1404it [00:56, 24.45it/s]



1503it [01:00, 24.46it/s]



1605it [01:04, 24.58it/s]



1704it [01:08, 24.68it/s]



1806it [01:12, 24.58it/s]



1905it [01:16, 24.75it/s]



2004it [01:20, 24.70it/s]



2103it [01:25, 24.54it/s]



2205it [01:29, 24.69it/s]



2304it [01:33, 24.47it/s]



2406it [01:37, 27.18it/s]



2505it [01:40, 27.52it/s]



2604it [01:44, 23.97it/s]



2704it [01:49, 23.99it/s]



2805it [01:54, 22.68it/s]



2904it [01:58, 27.70it/s]



3000it [02:01, 24.63it/s]

Finishing Epoch #4





In [10]:
test(model, device, test_loader)

100%|██████████| 1000/1000 [00:08<00:00, 116.57it/s]


(0.8722,
 [[872, 3, 22, 27, 1, 1, 195, 0, 4, 0],
  [0, 963, 0, 6, 2, 0, 1, 0, 2, 0],
  [17, 2, 819, 9, 101, 0, 101, 0, 4, 0],
  [35, 23, 12, 912, 42, 3, 32, 0, 5, 0],
  [6, 4, 75, 21, 772, 0, 73, 0, 3, 0],
  [2, 0, 1, 0, 0, 937, 0, 12, 1, 5],
  [55, 3, 65, 22, 78, 0, 583, 0, 8, 1],
  [0, 0, 0, 0, 0, 39, 0, 947, 5, 44],
  [13, 2, 6, 3, 4, 1, 15, 1, 968, 1],
  [0, 0, 0, 0, 0, 19, 0, 40, 0, 949]])