In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets

In [None]:
train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [None]:
len(train_dataset)

60000

In [None]:
test_dataset = dsets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

len(test_dataset)

10000

In [None]:
batch_size = 100

In [None]:
n_iters = 3000

In [None]:
num_epochs = n_iters / (len(train_dataset) / batch_size)
num_epochs = int(num_epochs)
num_epochs

5

In [None]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

In [None]:
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

In [None]:
class LogisticRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        out = self.linear(x)
        return out

In [None]:
input_dim = 28*28
output_dim = 10

In [None]:
model = LogisticRegressionModel(input_dim, output_dim)

In [None]:
criterion = nn.CrossEntropyLoss()  

In [None]:
learning_rate = 0.001

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) 

In [None]:
iter = 0
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        
        images = images.view(-1, 28*28).requires_grad_()
        labels = labels

        optimizer.zero_grad()

        outputs = model(images)

        loss = criterion(outputs, labels)

        loss.backward()

        optimizer.step()

        iter += 1

        if iter % 500 == 0:
                 
            correct = 0
            total = 0
          
            for images, labels in test_loader:
              
                images = images.view(-1, 28*28).requires_grad_()

              
                outputs = model(images)

                _, predicted = torch.max(outputs.data, 1)

                total += labels.size(0)

                correct += (predicted == labels).sum()

            accuracy = 100 * correct / total

            
            print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.item(), accuracy))

Iteration: 500. Loss: 1.8331000804901123. Accuracy: 68.1500015258789
Iteration: 1000. Loss: 1.6044321060180664. Accuracy: 75.56999969482422
Iteration: 1500. Loss: 1.2541276216506958. Accuracy: 78.87999725341797
Iteration: 2000. Loss: 1.175275444984436. Accuracy: 80.83000183105469
Iteration: 2500. Loss: 1.1568013429641724. Accuracy: 82.01000213623047
Iteration: 3000. Loss: 0.9194238781929016. Accuracy: 82.7699966430664


In [None]:
iter_test = 0
for images, labels in test_loader:
    iter_test += 1
    images = images.view(-1, 28*28).requires_grad_()
    outputs = model(images)
    if iter_test == 1:
        print('OUTPUTS')
        print(outputs)
    _, predicted = torch.max(outputs.data, 1)

OUTPUTS
tensor([[-4.4865e-01, -1.1810e+00, -3.9384e-01, -2.0149e-01,  1.4982e-01,
         -4.0730e-01, -1.0246e+00,  2.8202e+00, -3.7285e-01,  7.7710e-01],
        [ 5.5879e-01,  1.5292e-01,  1.4814e+00,  1.0624e+00, -1.9591e+00,
          8.2732e-01,  1.2868e+00, -1.7932e+00,  1.3278e-01, -1.6495e+00],
        [-9.8499e-01,  2.3290e+00,  1.5072e-01,  1.6640e-01, -6.8971e-01,
         -2.7085e-01, -3.0313e-01, -2.1233e-01,  2.8485e-01, -3.3337e-01],
        [ 2.9013e+00, -2.5296e+00, -8.2553e-02, -3.1091e-01, -1.1311e+00,
          6.5927e-01,  9.7033e-01,  1.5075e-01, -6.9668e-01, -5.1884e-01],
        [-2.1097e-01, -2.3518e+00,  3.3307e-01, -7.2100e-01,  1.7499e+00,
         -4.1194e-01,  1.5404e-01,  4.2941e-01, -6.0739e-02,  7.6049e-01],
        [-1.4048e+00,  2.8358e+00,  1.1883e-01,  2.7256e-01, -7.3437e-01,
         -4.2539e-01, -7.6897e-01, -4.0427e-02,  5.4660e-01, -2.1985e-01],
        [-1.2367e+00, -1.1731e+00, -5.9871e-01,  2.7296e-01,  1.6158e+00,
          5.0010e-01, -6

In [None]:
iter_test = 0
for images, labels in test_loader:
    iter_test += 1
    images = images.view(-1, 28*28).requires_grad_()
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    if iter_test == 1:
        print('PREDICTION')
        print(predicted[0])

PREDICTION
tensor(7)


In [None]:
correct = 0
total = 0
iter_test = 0
for images, labels in test_loader:
    iter_test += 1
    images = images.view(-1, 28*28).requires_grad_()
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)

  
    total += labels.size(0)

    correct += (predicted == labels).sum()

accuracy = 100 * (correct.item() / total)

print(accuracy)

82.77
