In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np
import matplotlib.pyplot as plt

### Load Data

In [7]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = torch.tensor(X_train, dtype=torch.float).cuda()
X_test = torch.tensor(X_test, dtype=torch.float).cuda()
y_train = torch.tensor(y_train, dtype=torch.long).cuda()
y_test = torch.tensor(y_test, dtype=torch.long).cuda()

### Prepare Data

In [8]:
# Dataset
train_ds = TensorDataset(X_train, y_train)
test_ds = TensorDataset(X_test, y_test)
# Data Loader
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=100)

### Create Model

In [9]:
class ModelA(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 10)
        
    def forward(self, xb):
        xb = xb.view(-1, 784)
        xb = F.relu(self.fc1(xb))
        xb = self.fc2(xb)
        return xb.view(-1, xb.size(1))

### Train Model

In [10]:
def realise(p):
    return p.argmax(axis=1)
def accuracy(y1, y2):
    return (realise(y1) == y2).float().mean()

In [14]:
%%time
# Create Model
model = ModelA().cuda()
# Select Optimiser
opt = optim.RMSprop(model.parameters(), lr=0.001)
# Select Loss Function
loss_func = F.cross_entropy
# Train
epochs = 5
losses = []
val_losses = []
for epoch in range(epochs):
    # Train
    model.train()
    for xb, yb in train_dl:
        # Backprop
        loss = loss_func(model(xb), yb)
        loss.backward()
        opt.step()
        opt.zero_grad()
    # Validate
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for xb, yb in test_dl:
            loss = loss_func(model(xb), yb)
            val_loss += loss * len(xb)
        val_loss /= len(test_ds)
    # Statistic
    print(epoch, val_loss.item())
    losses.append(loss.item())
    val_losses.append(val_loss.item())
print('Final Accuracy:', accuracy(model(X_test), y_test).item())

0 0.20993229746818542
1 0.21988829970359802
2 0.16058087348937988
3 0.22476844489574432
4 0.21238425374031067
Final Accuracy: 0.9603999853134155
Wall time: 30.5 s
