## MNIST example (also CIFAR example)
taken from: https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/convolutional_neural_network/main.py#L35-L56

and from cs231n assignment 2

In [None]:
import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np

# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### PyTorch Tensors: Flatten Function
A PyTorch Tensor is conceptionally similar to a numpy array: it is an n-dimensional grid of numbers, and like numpy PyTorch provides many functions to efficiently operate on Tensors. As a simple example, we provide a `flatten` function below which reshapes image data for use in a fully-connected neural network.

Recall that image data is typically stored in a Tensor of shape N x C x H x W, where:

* N is the number of datapoints
* C is the number of channels
* H is the height of the intermediate feature map in pixels
* W is the height of the intermediate feature map in pixels

This is the right way to represent the data when we are doing something like a 2D convolution, that needs spatial understanding of where the intermediate features are relative to each other. When we use fully connected affine layers to process the image, however, we want each datapoint to be represented by a single vector -- it's no longer useful to segregate the different channels, rows, and columns of the data. So, we use a "flatten" operation to collapse the `C x H x W` values per representation into a single long vector. The flatten function below first reads in the N, C, H, and W values from a given batch of data, and then returns a "view" of that data. "View" is analogous to numpy's "reshape" method: it reshapes x's dimensions to be N x ??, where ?? is allowed to be anything (in this case, it will be C x H x W, but we don't need to specify that explicitly). 

In [None]:
def flatten(x):
    N = x.shape[0]        # read in N, C, H, W
    return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image
    
def test_flatten():
    x = torch.arange(12).view(2, 1, 3, 2)
    print('Before flattening: ', x)
    print('After flattening: ', flatten(x))

test_flatten()

## Define loaders, CNN

In [None]:
# Hyper parameters
num_epochs = 5
num_classes = 10
batch_size = 100
learning_rate = 0.001

# MNIST dataset




#train_dataset = torchvision.datasets.CIFAR10(root='../../data/',  #for CIFAR10
train_dataset = torchvision.datasets.MNIST(root='../../data/',
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)

#test_dataset = torchvision.datasets.CIFAR10(root='../../data/',  #for CIFAR10
test_dataset = torchvision.datasets.MNIST(root='../../data/',
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False)




# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            #nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2),  #for CIFAR10
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        
        #self.fc = nn.Linear(7*7*32, num_classes)  # for CIFAR10
        self.fc = nn.Linear(7*7*32, num_classes)  
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        #print (out.shape)                    # shape after last conv layer before flatting - (100, 32, 7, 7)
        
        # two (essentially the same) ways to flatten data
        out = flatten(out)
        out = out.reshape(out.size(0), -1)   # out.size(0) = 100, Namely, number of images N
                                             # flatting input so as to get proper shape to match linear (fc) layer
            
        #print (out.shape)                    # new flat shape is: (100, 1568)
        out = self.fc(out)                   # calling fc layer
        #print (out.shape)                    # shape after fc layer is (100, 10)
        return out



## get number of trainable parameters

In [None]:
model = ConvNet(num_classes).to(device)
def get_train_params_num(model):
    """
    This fucntion returns the number of trainable parameters of neural network model
    You may want to call it after you create your model to see how many parameteres the model has
    Args:
        model - neural net to examine. NOTE: this is an instantiation of the PolyNet class, not the class itself 
    """
    
    model_parameters = filter(lambda p: p.requires_grad==True, model.parameters())
    params_num = sum([np.prod(p.size()) for p in model_parameters])
    return params_num

In [None]:
get_train_params_num(model)

## Train model

In [None]:
model = ConvNet(num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

# Test the model
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)  # predicted is the index where maximum is obtained
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

### Playing around to get the feel for the data and network op.

In [None]:
from matplotlib import pyplot as plt
import PIL.Image as pil
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        print (labels.shape)
        print (labels)
        print()
        outputs = model(images)
        print(outputs.shape)
        print (outputs[0])
        print (torch.max(outputs[0]))
        print()
        _, predicted = torch.max(outputs.data, 1)
        print (predicted.shape)
        print (predicted)
        print()
        
        print (images.shape)
        #display mnist image
        pixels = images[2].reshape((28, 28))
        plt.imshow(pixels, cmap='gray')
        plt.show()
        break