# pytorch_mnist.ipynb
# WESmith 06/26/23
## Review of MNIST processing with pytorch.
## see https://nextjournal.com/gkoehler/pytorch-mnist

## 06/30/23 added capability to look at an embedding space in a modified network
## 07/02/23 added ability to look at embedding space of any desired network layer

In [None]:
data_dir = 'data'

In [None]:
import torch
import torch.nn            as nn
import torch.nn.functional as F
import torch.optim         as optim
import torchvision
import numpy               as np
import matplotlib.pyplot   as plt
from sklearn.manifold import TSNE
import os

In [None]:
data_dir         = 'data'
model_path       = 'results/model.pth'
optimizer_path   = 'results/optimizer.pth'
batch_size_train = 64
batch_size_test  = 1000
learning_rate    = 0.01
momentum         = 0.5
log_interval     = 10  # WS (was 10)

random_seed      = 42
torch.backends.cudnn.enabled = False  # disable nondeterministic algorithms
_ = torch.manual_seed(random_seed)

## SET UP DATA

In [None]:
train_data = torchvision.datasets.MNIST(data_dir, train=True, download=True,
                                         transform=torchvision.transforms.Compose([
                                             torchvision.transforms.ToTensor(),
                                             torchvision.transforms.Normalize(
                                             (0.1307,), (0.3081))]))  # predefined global mean, std

In [None]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size_train, shuffle=True)

In [None]:
test_data  = torchvision.datasets.MNIST(data_dir, train=False, download=True,
                                         transform=torchvision.transforms.Compose([
                                             torchvision.transforms.ToTensor(),
                                             torchvision.transforms.Normalize(
                                             (0.1307,), (0.3081))]))  # predefined global mean, std

In [None]:
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size_test, shuffle=True)

In [None]:
train_data.data.shape, train_data.targets.shape, test_data.data.shape, test_data.targets.shape

In [None]:
# get the full test set for scatterplots below
test_loader_full = torch.utils.data.DataLoader(test_data, batch_size=10000, shuffle=True)

In [None]:
_, (test_data_full, test_targ_full) = next(enumerate(test_loader_full))

In [None]:
test_data_full.shape, test_targ_full.shape

## CLASSES AND FUNCTIONS

In [None]:
# original model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d( 1, 10, kernel_size=5) #  1 channel coming in,  10 channels out
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5) # 10 channels coming in, 20 channels out
        self.conv2_drop = nn.Dropout2d()
        self.fc1   = nn.Linear(320, 50) # fully connected 4 x 4 x 20 = 320 inputs, 50 outputs
        self.fc2   = nn.Linear( 50, 10) # fully connected, 50 inputs, 10 outputs, 1 for each integer
        
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)  # self.training: Boolean, sets training or eval mode
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)  # WS added dim=1 to avoid throwing warnings

In [None]:
# WS model that bottlenecks at 2 neurons, to see if a 2D embedding space can be visualized
class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.conv1 = nn.Conv2d( 1, 10, kernel_size=5) #  1 channel coming in,  10 channels out
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5) # 10 channels coming in, 20 channels out
        self.conv2_drop = nn.Dropout2d()
        self.fc1   = nn.Linear(320, 50) # fully connected 4 x 4 x 20 = 320 inputs, 50 outputs
        self.emb   = nn.Linear(50, 2)   # WS added to try to get a 2D embedding space
        self.fc2   = nn.Linear(2, 10)   # WS modified to 2 inputs from 50, 10 outputs, 1 for each integer
        
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.emb(x)  # WS mod
        x = F.dropout(x, training=self.training)  # self.training: Boolean, sets training or eval mode
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)  # WS added dim=1 to avoid throwing warnings

In [None]:
def train_over_epochs(n_epochs, model_path, optimizer_path, log_interval=100):
    global train_losses
    global train_counter
    global test_losses
    train_losses  = []
    train_counter = []
    test_losses   = []
    test_counter  = [i*len(train_loader.dataset) for i in range(1, n_epochs + 1)]
    for epoch in range(1, n_epochs + 1):
        train(epoch, model_path, optimizer_path, log_interval=log_interval)
        test()

    fig = plt.figure(figsize=(14, 6))
    plt.plot(train_counter, train_losses, color='blue')
    plt.scatter(test_counter, test_losses, color='red')
    plt.grid()
    plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
    plt.xlabel('number of training examples seen')
    plt.ylabel('negative log likelihood loss')

In [None]:
def train(epoch, model_path, optimizer_path, log_interval=100):
    global train_losses
    global train_counter
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = network(data)
        loss   = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            ll = len(train_loader.dataset)
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                  epoch, batch_idx * len(data), ll,
                  100. * batch_idx / len(train_loader), loss.item()))
            train_losses.append(loss.item())
            train_counter.append((batch_idx * 64) + ((epoch - 1) * ll))
            torch.save(network.state_dict(),   model_path)     #'results/model.pth')
            torch.save(optimizer.state_dict(), optimizer_path) #'results/optimizer.pth')

In [None]:
def test():
    global test_losses
    network.eval()
    test_loss = 0
    correct   = 0
    with torch.no_grad():
        for data, target in test_loader:
            output     = network(data)
            # nll: negative log likelihood, size_average is deprecated, 'reduction' default is 'mean'
            test_loss += F.nll_loss(output, target, reduction='sum').item() #size_average=False).item()
            pred       = output.data.max(1, keepdim=True)[1]
            correct   += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print(f'test_losses inside test(): {test_losses}')
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
          test_loss, correct, len(test_loader.dataset),
          100. * correct / len(test_loader.dataset)))

In [None]:
def get_intermediate(model, name, input_tensor):
    '''
    access the output of an intermediate layer of a network
    model: name of the network (str)
    name:  name of the intermediate layer (str)
           note: all layer names of a model can be seen by typing the model's name in a cell
                 and executing that cell
    '''
    intermed = {}
    def get_inter(name):
        nonlocal intermed
        def hook(model, input, output):
            intermed[name] = output.detach()
        return hook
    
    try:
        # WS is there a simpler way to isolate a layer?
        layer = list(model.modules())[0].__dict__['_modules'][name]
    except:
        print(f'layer {layer} not found')
        return None
    
    hook_handle   = layer.register_forward_hook(get_inter(name))
    output        = model(input_tensor)
    hook_handle.remove()
    
    dd = intermed[name]
    print(f'{name}: {layer}, with shape {list(dd.shape)} and size {dd.numel()}')
    
    return dd

In [None]:
def embedding_scatterplot(network, layer, data, targ, perplexity=10, 
                          figpath=None, verbose=1, wid=12, hei=12):
    out = get_intermediate(network, layer, data)
    norm_vectors = out/out.mean(dim=1, keepdim=True)
    # compress the space to 2D for viewing
    # NOTE: init='random' works better than init='PCA' (which produces almost no clusters)
    pts = TSNE(n_components=2, perplexity=perplexity, 
               learning_rate='auto', init='random', 
               verbose=verbose).fit_transform(norm_vectors)
    
    targs = targ.numpy()
    class_colors = {0:'r', 1:'g', 2:'b', 3:'c', 4:'m', 5:'y', 6:'w', 
                    7:'#BBBBBB', 8:'#FFAAAA', 9:'#AAFFAA'}

    legend_handles = []
    legend_labels  = []
    for label, color in class_colors.items():
        legend_handles.append(plt.Line2D([], [], marker='o', color='k', 
                                         markerfacecolor=color, markersize=10))
        legend_labels.append(label)

    fig, ax = plt.subplots(figsize=(wid, hei))
    fig.patch.set_facecolor('#333333')
    ax.set_facecolor('black')
    ax.scatter(pts[:,0], pts[:,1], s=10, c=[class_colors[label] for label in targs])
    ax.grid()
    legend = ax.legend(legend_handles, legend_labels, facecolor='black')
    for text in legend.get_texts():
        text.set_color('white')
    ax.set_title(f'TSNE compression for layer {layer}', color='w', fontsize=16)
    if figpath:
        plt.savefig(figpath)
    plt.show()

In [None]:
def show_digits(truth, predictions, data, idx, nr=5, nc=5, scal = 2.5):
    # data = test_data_full
    wid = int(nc * scal)
    hei = int(nr * scal)
    fig = plt.figure(figsize=(wid, hei))
    num = nr * nc
    for i in range(num):
        plt.subplot(nr, nc, i+1)
        plt.tight_layout()
        j = idx[i]
        true = truth[j]
        pred = predictions.data.max(1, keepdim=True)[1][j].item()
        cc = 'c' if true == pred else 'r' # mark errors with red text
        plt.imshow(data[j][0], cmap='gray', interpolation='none')
        plt.title(f'{j}, truth:{true}, pred:{pred}', c=cc)
        plt.xticks([])
        plt.yticks([])

In [None]:
# WS create a precision matrix showing confusion errors
def precision(loader):  # loader is train_loader or test_loader
    dd = torch.zeros(10, 10) # row is truth, col is prediction
    network.eval()
    with torch.no_grad():
        for data, target in loader:
            output = network(data)
            for k in range(len(output)):
                pred  = output[k].argmax().item()
                truth = target[k].item()
                dd[truth, pred] += 1
    count = 0
    for k in range(10):
        count += dd[k, k]
    print(f'accuracy = {100. * count / dd.sum():.4f}')
    return dd.numpy()

In [None]:
def get_misclassified(output, target):
    '''
    get all mis-classified test images
    '''
    # get the predict vector over the batch: 
    #[0] are the max values, [1] is the 'argmax'
    pred  = output.data.max(1, keepdim=True)[1]
    # get truth in same vector format as pred
    truth = target.data.view_as(pred)
    # get a boolean vector of True/False 
    # where pred/truth agree/disagree
    validate = pred.eq(truth).squeeze()
    # return the indices of the batch samples with the False entries
    return torch.nonzero(validate==False)

## DEFINE NETWORK, LOAD PREVIOUS TRAINING AND/OR DO NEW TRAINING

In [None]:
network   = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum)

In [None]:
test() # before training, about 8% accuracy

In [None]:
# see if previous training exists, if so, load, otherwise train
if os.path.isfile(model_path) and os.path.isfile(optimizer_path):
    network_state_dict = torch.load(model_path)
    network.load_state_dict(network_state_dict)
    optimizer_state_dict = torch.load(optimizer_path)
    optimizer.load_state_dict(optimizer_state_dict)
else:
    n_epochs = 6
    # train_over_epochs() saves model and optimizer state at each iteration
    train_over_epochs(n_epochs, model_path, optimizer_path, log_interval=100)

In [None]:
test()

In [None]:
# optional additional training
n_epochs = 1
train_over_epochs(n_epochs, model_path, optimizer_path, log_interval=100)

## EXAMINE INTERMEDIATE NETWORK LAYERS

In [None]:
# see all network layers
network

In [None]:
upper     = 5000 # 10000 is max of test data
layer     = 'fc1'
acc       = '98%'  # previously measured accuracy with this test data
figpath   = f'mnist_layer_{layer}_acc_{acc}.png'

In [None]:
embedding_scatterplot(network, layer, test_data_full[0:upper], test_targ_full[0:upper], figpath=figpath)

In [None]:
layer     = 'fc2'
acc       = '98%'
figpath   = f'mnist_layer_{layer}_acc_{acc}.png'

In [None]:
embedding_scatterplot(network, layer, test_data_full[0:upper], test_targ_full[0:upper], figpath=figpath)

## LOOK AT TRUTH AND PREDICTIONS

In [None]:
with torch.no_grad():
    predictions = network(test_data_full)
truth = test_targ_full # convenience variable for what follows

In [None]:
predictions.shape, test_targ_full.shape, truth.shape

In [None]:
nr, nc = (5, 6)
num = nr * nc
idx = torch.randint(0, test_data_full.shape[0], (num,))
show_digits(truth, predictions, test_data_full, idx, nr=nr, nc=nc)

## GET ALL INCORRECT CLASSIFICATIONS FROM TEST DATA 
## AND DISPLAY SAMPLES

In [None]:
nr, nc = (5, 6)
indices = get_misclassified(predictions, truth)
num = nr * nc
random_indices = torch.randperm(len(indices))[:num]
# Select random elements from the misclassified list
idx = [indices[i].item() for i in random_indices]
show_digits(truth, predictions, test_data_full, idx, nr=nr, nc=nc)

In [None]:
# WS note: Net2() with 2-neuron bottleneck performance is down to 93% accuracy after 6 epochs,
# compared to 98% accuracy with original network

In [None]:
with torch.no_grad():
    output = network(example_data)

In [None]:
len(output)

In [None]:
for k in range(5):
    print(output[k].argmax().item(), example_targets[k].item())

In [None]:
test_matrix = precision(test_loader)

In [None]:
train_matrix = precision(train_loader)

In [None]:
plt.imshow(np.log(test_matrix + 1))
plt.show()

In [None]:
plt.imshow(np.log(train_matrix + 1))
plt.show()

In [None]:
#for k0, k1 in test_loader:
    #print(k0.shape, k1.shape)

## MODIFY THE NETWORK TO INCLUDE A 'BOTTLENECK' 2D EMBEDDING LAYER
## THAT CAN BE PLOTTED DIRECTLY WITHOUT TSNE

In [None]:
network   = Net2()
optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum)

In [None]:
# WS note: Net2() with 2-neuron bottleneck performance is down to 93% accuracy after 6 epochs,
# compared to 98% accuracy with original network

In [None]:
network

In [None]:
example_data[0].shape

In [None]:
network(example_data[0:3])

In [None]:
example_targets[0:3]

In [None]:
network