In [1]:
import torch
import torchvision
import torch.nn.functional as F

# Get Data
train_dataset = torchvision.datasets.MNIST(train=True, download=True, root='./')
rand = torch.randperm(train_dataset.data.shape[0])
X = train_dataset.data.reshape(-1,28*28)[rand].float()
Y = train_dataset.targets[rand]

w1 = (torch.randn(784,128) / 784**0.5).requires_grad_(True)
w2 = (torch.randn(128,64) / 128**0.5).requires_grad_(True)
w3 = (torch.randn(64,10) / 64**0.5).requires_grad_(True)
b3 = torch.zeros(10, requires_grad=True)

batch_size = 32

for epoch in range(3):
    lr = 0.1 if epoch < 1 else 0.01
    loop_count = 0
    for i in range(0, len(X), batch_size):
        # Forward
        Xb = X[i:i+batch_size] / 255
        Yb = Y[i:i+batch_size]

        l1 = Xb @ w1
        relu1 = F.relu(l1)
        l2 = relu1 @ w2
        relu2 = F.relu(l2)
        logits = relu2 @ w3 + b3

        # Loss
        loss = F.cross_entropy(logits, Yb)

        # Backward
        loss.backward()

        # Update
        with torch.no_grad():
            w1 -= lr * w1.grad
            w2 -= lr * w2.grad
            w3 -= lr * w3.grad
            b3 -= lr * b3.grad
            w1.grad.zero_()
            w2.grad.zero_()
            w3.grad.zero_()
            b3.grad.zero_()

        if loop_count % 10 == 0:
            print(loss.item())
        loop_count += 1


# Get Test Data
test_dataset = torchvision.datasets.MNIST(train=False, download=True, root='./')
rand = torch.randperm(test_dataset.data.shape[0])
X_test = test_dataset.data.reshape(-1,28*28)[rand].float()
Y_test = test_dataset.targets[rand]

correct = 0
total = 0

with torch.no_grad():
    for i in range(0, len(X_test), batch_size):
        Xb = X_test[i:i+batch_size] / 255
        Yb = Y_test[i:i+batch_size]

        l1 = Xb @ w1
        relu1 = F.relu(l1)
        l2 = relu1 @ w2
        relu2 = F.relu(l2)
        logits = relu2 @ w3 + b3

        prediction = logits.argmax(dim=1)
        correct += (prediction == Yb).sum().item()
        total += len(Yb)

accuracy = correct / total
print(f'Accuracy: {accuracy*100:.0f}%')

2.3675999641418457
2.1866016387939453
1.7654765844345093
1.5709749460220337
1.1464297771453857
0.9248833060264587
0.7113738656044006
0.9655783176422119
0.6777794361114502
0.48356083035469055
0.6293730735778809
0.4536769688129425
0.5627464652061462
0.7197120785713196
0.5772279500961304
0.3885633647441864
0.3402160108089447
0.3803543448448181
0.5404077768325806
0.5499785542488098
0.30291539430618286
0.6433289647102356
0.48161429166793823
0.3641708195209503
0.5442917346954346
0.5423845052719116
0.36179319024086
0.4866829216480255
0.5551265478134155
0.2520545721054077
0.15962845087051392
0.3887462019920349
0.718519926071167
0.24712133407592773
0.15473206341266632
0.5574661493301392
0.38798168301582336
0.5161125659942627
0.41792044043540955
0.09301386028528214
0.27888399362564087
0.2604405879974365
0.5430464744567871
0.3128535747528076
0.10670263320207596
0.3089912235736847
0.5799988508224487
0.48038095235824585
0.19319921731948853
0.36871960759162903
0.4818449318408966
0.25013482570648193


In [None]:
import torch
import torch.nn as nn
import torchvision

# Data (your fast way)
train_dataset = torchvision.datasets.MNIST(train=True, download=True, root='./')
rand = torch.randperm(train_dataset.data.shape[0])
X = (train_dataset.data.reshape(-1, 784)[rand].float() / 255)
Y = train_dataset.targets[rand]

# Model
model = nn.Sequential(nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 10))
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# Train
for epoch in range(3):
    if epoch == 1: optimizer.param_groups[0]['lr'] = 0.01
    for i in range(0, len(X), 32):
        loss = nn.functional.cross_entropy(model(X[i:i+32]), Y[i:i+32])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
