In [None]:
# Mout Google Drive
# https://towardsdatascience.com/google-drive-google-colab-github-dont-just-read-do-it-5554d5824228
from google.colab import drive
ROOT = "/content/drive"
drive.mount(ROOT)
# %pwd %ls
# run github settings
%run /content/drive/MyDrive/CNNStanford/pytorch/pytorch_sandbox/Colab_Helper.ipynb

In [None]:
MESSAGE = "clean file & gitignore again"
!git config --global user.email "ronyginosar@mail.huji.ac.il"
!git config --global user.name "ronyginosar"
!git add .

In [None]:
!git commit -m "{MESSAGE}"
!git push "{GIT_PATH}"

In [None]:
import torch
import sys
import os
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from exercises.part5_train.ex1_train_cifar import tensor_show
from exercises.part3_nn_modules.ex1 import Ex1Net
import torchvision
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import torch.nn.functional as F


batch_size = 4
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
showExample = False
pr_curve = False  # TODO combine into general code


def add_pr_curve_tensorboard(writer, class_index, test_probs, test_label, global_step=0):
    '''
    Takes in a "class_index" from 0 to 9 and plots the corresponding
    precision-recall curve
    https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html#assessing-trained-models-with-tensorboard
    '''
    tensorboard_truth = test_label == class_index
    tensorboard_probs = test_probs[:, class_index]

    writer.add_pr_curve(classes[class_index],
                        tensorboard_truth,
                        tensorboard_probs,
                        global_step=global_step)
    writer.close()


def get_num_correct(preds, labels):
    # elegant alternative to counting correct
    # https://towardsdatascience.com/a-complete-guide-to-using-tensorboard-with-pytorch-53cb2301e8c3
    # number of correct labels after training of the model and applying the trained model to the test set
    # was: preds.argmax(dim=1).eq(labels).sum().item()
    # “argmax “ gets the index corresponding to the highest value in a tensor
    # dim 1 corresponds to the batch of labels (dim 0 is images)
    highest_predictions = preds.argmax(dim=1)
    # ”eq” compares the predicted labels to the True labels in the batch and returns 1 if matched and 0 if unmatched
    correct_predictions = highest_predictions.eq(labels)
    # sum of the 1’s to get total number of correct predictions
    correct_predictions_count = correct_predictions.sum().item()
    # “item” converts the one dimensional tensor of correct_predictions to a
    # floating point value so that it can be appended to a list(total_correct) for plotting in TensorBoard
    return correct_predictions_count


def main(argv):
    # TODO add Torchmetrics
    # skeleton from part 6 in ppt
    # https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#test-the-network-on-the-test-data
    CHECKPOINT_PATH = os.path.join(os.getcwd(), "checkpoints", "latest.pt")
    current_run = f"{datetime.now():%Y.%m.%d_%H.%M}"
    writer = SummaryWriter(os.path.join(os.getcwd(), "evals", current_run))


    model = Ex1Net(in_channels=[3, 32, 64, 128],
                   out_channels=[32, 64, 128, 256],
                   pools=['max', 'max', 'max', 'avg'],
                   num_classes=10)
    checkpoint = torch.load(CHECKPOINT_PATH)
    model.load_state_dict(checkpoint['model'])
    # must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference
    model.eval()
    # print(model.state_dict())

    normalize_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    root_path = os.path.join(os.getcwd(), "data")  # if needed: "exercises", "part5_train"
    testset = CIFAR10(root=root_path, train=False, download=True, transform=normalize_transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    if showExample:
        # see output on example from test set
        print("On Example data from test set:")
        dataiter = iter(testloader)
        images, labels = next(dataiter)
        # print images
        # tensor_show(torchvision.utils.make_grid(images), is_normalized=True, one_channel=False, is_show=True)
        print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))
        # print predictions
        outputs = model(images)
        _, predicted = torch.max(outputs, dim=1)
        print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'
                                      for j in range(4)))

    print("Prediction on entire test set:")
    # correct = 0, total = 0
    correct_pred = {classname: 0 for classname in classes}
    total_pred = {classname: 0 for classname in classes}
    # since we're not training, we don't need to calculate the gradients for our outputs
    if pr_curve: class_probs = [], class_label = []
    with torch.no_grad():
        for (images, labels) in testloader:
            # calculate outputs by running images through the network
            outputs = model(images)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs, 1)
            if pr_curve: class_probs_batch = [F.softmax(el, dim=0) for el in outputs]

            # total += labels.size(0)
            # correct += (predicted == labels).sum().item()

            # collect the correct predictions for each class
            for label, prediction in zip(labels, predicted):
                if label == prediction:
                    correct_pred[classes[label]] += 1
                total_pred[classes[label]] += 1
                if pr_curve:
                    class_probs.append(class_probs_batch)
                    class_label.append(labels)
                # or
                # total += labels.size(0)
                # correct += (predicted == labels).sum().item()
                # print(class_label)

    if pr_curve:
        test_probs = torch.cat([torch.stack(p) for p in class_probs])
        test_label = torch.cat(class_label)

    # print accuracy for each class
    for classname, correct_count in correct_pred.items():
        accuracy = 100 * float(correct_count) / total_pred[classname]
        print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

    # print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
    print(f'Accuracy of the network on the overall test images: '
          f'{100 * sum(correct_pred.values()) // sum(total_pred.values())} %')

    # print precision-recall curve for each class
    if pr_curve:
        for classidx in range(len(classes)):
            add_pr_curve_tensorboard(writer, classidx, test_probs, test_label)

if __name__ == "__main__":
    main(sys.argv[1:])
