In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                    help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                    help='how many batches to wait before logging training status')
args = parser.parse_args(args=[])
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

In [3]:
import jax.numpy as np
from jax.scipy.special import logsumexp
from jax.experimental import optimizers
import numpy as onp
from jax import random
from jax import jit

import torch
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader

import tensorflow_datasets as tfds

key = random.PRNGKey(0)
batch_size = 11

mnist_data, info = tfds.load(name="mnist", batch_size=-1, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

#Formatting on training set
train_images, train_labels = train_data['image'], train_data['label']
train_images = torch.Tensor(train_images).permute(0,3,1,2)

#Formatting on testing set
test_images, test_labels = test_data['image'], test_data['label']
test_images = torch.Tensor(test_images).permute(0,3,1,2)

# Full train set
train_data = TensorDataset(train_images, torch.Tensor(train_labels).type(torch.LongTensor))
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# Full test set
test_data = TensorDataset(test_images, torch.Tensor(test_labels).type(torch.LongTensor))
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

transform = transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])



In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        """
        define a cnn with two conv layers (experiment with the number of channels and kernal sizes) 
        followed by doing dropout, then two fully connected layers.
        """
        self.conv1 = nn.Sequential(nn.Conv2d(1, 10, 5, 1, 2),
                                   nn.ReLU(),
                                   nn.MaxPool2d(kernel_size=2, stride=2),
                                   )
        self.conv2 = nn.Sequential(nn.Conv2d(10, 20, 5, 1, 2),
                                   nn.ReLU(),
                                   nn.MaxPool2d(kernel_size=2, stride=2),
                                   nn.Dropout2d(p=0.25),
                                   nn.Flatten(),
                                   )
        self.lin1 = nn.Sequential(nn.Linear(20 * 7 * 7, 50),
                                  nn.ReLU(),
                                  nn.Dropout(p=0.25),
                                  )
        self.lin2 = nn.Linear(50,10)

    def forward(self, x):
        """
	define the forward pass of the cnn with a relu activation function for each hidden layer
        and dropout after the first fully connected layer.
        """
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.lin1(x)
        x = self.lin2(x)
        return F.log_softmax(x, dim=1)
    


model = Net()

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.1)

In [7]:
def train(epoch):
    model.train()
    correct = 0
    train_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        train_loss += loss.item()
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
        loss.backward()
        optimizer.step()
        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

    train_loss /= len(train_loader)
    print('\nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        train_loss, correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))


In [11]:
n_epoch = 1
def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        output = model(data)
        test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

for epoch in range(1, n_epoch + 1):
    train(epoch)
    test()


Train set: Average loss: 0.2169, Accuracy: 56477/60000 (94%)


Test set: Average loss: 0.1071, Accuracy: 9715/10000 (97%)

