## CNN PyTorch Example
Some adaptation from https://www.kaggle.com/juiyangchang/cnn-with-pytorch-0-995-accuracy.

### Import modules

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torch.optim import lr_scheduler
from torch.autograd import Variable
import math

### Data preparation

In [2]:
data_path = '/Users/rebeccawillison/Documents/research/stat-9340/hw5/mnist0_train_b.txt'
mnist0 = np.array(pd.read_fwf(data_path, header = None)).transpose()
data_path = '/Users/rebeccawillison/Documents/research/stat-9340/hw5/mnist9_train_b.txt'
mnist9 = np.array(pd.read_fwf(data_path, header = None)).transpose()

In [3]:
# subset training dataset & standardize
train_X = np.c_[mnist0[:,0:1000], mnist9[:,0:1000]].transpose()
train_Y = np.append(np.full(1000, 0), np.full(1000, 1))
train = pd.DataFrame(np.insert(train_X, 1, train_Y, axis=1)).rename(columns={0:'label'})

# create test dataset & standardize
test_X = np.c_[mnist0[:,2500:3000], mnist9[:,2500:3000]].transpose()
test_Y = np.append(np.full(500, 0), np.full(500, 1))
test = pd.DataFrame(test_X)

train_file = '/Users/rebeccawillison/Documents/research/stat-9340/hw5/mnist_train.csv'
train.to_csv(train_file, index = False)
test_file = '/Users/rebeccawillison/Documents/research/stat-9340/hw5/mnist_test.csv'
test.to_csv(test_file, index = False)

### Data loading functions

In [9]:
class MNIST_data(Dataset):
    def __init__(self, file_path, 
                 transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), 
                     transforms.Normalize(mean=(0.5,), std=(0.5,))])
                ):

        df = pd.read_csv(file_path)

        if len(df.columns) == n_pixels:
            # test data
            self.X = df.values.reshape((-1,28,28)).astype(np.uint8)[:,:,:,None]
            self.y = None
        else:
            # training data
            self.X = df.iloc[:,1:].values.reshape((-1,28,28)).astype(np.uint8)[:,:,:,None]
            self.y = torch.from_numpy(df.iloc[:,0].values.astype(int))

        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        if self.y is not None:
            return self.transform(self.X[idx]), self.y[idx]
        else:
            return self.transform(self.X[idx])

In [10]:
batch_size = 64
n_pixels = 28*28

train_dataset = MNIST_data(train_file, transform = transforms.Compose(
                            [transforms.ToTensor(), 
                             transforms.Normalize(mean=(0.5,), std=(0.5,))]))
test_dataset = MNIST_data(test_file)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size = 1000,
                                          shuffle=False)

### Define model

In [11]:
class myCNN(nn.Module):
    def __init__(self):
        super(myCNN, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(32 * 7 * 7, 30),
            nn.Sigmoid(),
            nn.Linear(30, 2),
            nn.Sigmoid()
        )
        
        for m in self.features.children():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        
        for m in self.classifier.children():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

model = myCNN()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

### Functions for training and model evaluation

In [12]:
def train(epoch):
    model.train()
    exp_lr_scheduler.step()

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

def evaluate(data_loader):
    model.eval()
    loss = 0
    correct = 0
    
    for data, target in data_loader:
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        loss += F.cross_entropy(output, target, size_average=False).data[0]
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
        
    loss /= len(data_loader.dataset)
    
def prediction(data_loader):
    model.eval()
    test_pred = torch.LongTensor()
    
    for i, data in enumerate(data_loader):
        data = Variable(data)
        output = model(data)
        
    return output

### Train model

In [13]:
n_epochs = 10
for epoch in range(n_epochs):
    train(epoch)

### Get test predictions and check accuracy

In [14]:
preds = prediction(test_loader)

In [15]:
test_pred = preds[:,1].detach().numpy()
test_pred[test_pred > np.mean(test_pred)] = 1
test_pred[test_pred != 1] = 0
print('Accuracy:', sum(test_pred == test_Y)/1000)

Accuracy: 0.715
