In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

In [2]:
data = np.load("../data/mnist_rot_train.npz")
X_tr = torch.from_numpy(data["X"]).float()
N, D = X_tr.shape
X_tr = X_tr.view(N, 1, 28, 28)

y_tr = torch.from_numpy(data["labels"])
a_tr = torch.from_numpy(data["angles"])

data = np.load("../data/mnist_rot_validation.npz")
X_val = torch.from_numpy(data["X"]).float()
N, D = X_val.shape
X_val = X_val.view(N, 1, 28, 28)

y_val = torch.from_numpy(data["labels"])
a_val = torch.from_numpy(data["angles"])

print(X_tr.shape, y_tr.shape, a_tr.shape)

torch.Size([60000, 1, 28, 28]) torch.Size([60000]) torch.Size([60000])


In [7]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(5, 5), padding=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=(3, 3), padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(stride=2, kernel_size=2)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(32 * 14 * 14, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 10)
        )

    def forward(self, X):
        h1 = self.conv1(X)
        h1 = torch.flatten(h1, start_dim=1)

        h2 = self.fc1(h1)

        return h2

    def predict(self, X):
        y_pred = self.forward(X)
        y_pred = torch.argmax(y_pred, dim=1)

        return y_pred

In [8]:
def train(model, epochs=5):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
    dataset = TensorDataset(X_tr[:10000], y_tr[:10000], a_tr[:10000])
    train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

    for i in range(epochs):
        loss = None
        for i, (X_batch, y_batch, a_batch) in enumerate(train_loader):
            y_pred = model.forward(X_batch)

            loss = F.cross_entropy(y_pred, y_batch, reduction='sum')

            if i % 200 == 0:
                print(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

In [9]:
model = Net()

train(model, epochs=1)

534.8815307617188


In [10]:
def get_class_err(y_pred, y):
    return (y_pred == y).float().mean()

# y_pred = model.predict(X_tr)
# print(y_pred.shape)
# print('Train acc', get_class_err(y_pred, y_tr))

y_pred = model.predict(X_val)
print('Val acc', get_class_err(y_pred, y_val))

Val acc tensor(0.2216)
