In [None]:
# Mout Google Drive
# https://towardsdatascience.com/google-drive-google-colab-github-dont-just-read-do-it-5554d5824228
from google.colab import drive
ROOT = "/content/drive"
drive.mount(ROOT)
# %pwd %ls
# run github settings
%run /content/drive/MyDrive/CNNStanford/pytorch/pytorch_sandbox/Colab_Helper.ipynb

In [None]:
MESSAGE = "clean file & gitignore again"
!git config --global user.email "ronyginosar@mail.huji.ac.il"
!git config --global user.name "ronyginosar"
!git add .

In [None]:
!git commit -m "{MESSAGE}"
!git push "{GIT_PATH}"

In [None]:
import torch
import sys
import torchvision
from torchvision.datasets import CIFAR10
from exercises.part3_nn_modules.ex1 import Ex1Net
import torchvision.transforms as transforms
import os
# from torchinfo import summary
# from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch.nn.functional as f
from datetime import datetime
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau


num_epochs = 100
base_lr = 1e-4
batch_size = 4
mini_batch_print = 1000  # print every x mini-batches
validation_split = 0.10 # use 10% of training data as a validation set
# constant for classes
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


def tensor_show(image_tensor, title=None, is_normalized=True, one_channel=False, is_show=False):
    if one_channel:
        image_tensor = image_tensor.mean(dim=0)
    if is_normalized:
        image_tensor = image_tensor / 2 + 0.5  # un-normalize
    if one_channel:
        plt.imshow(image_tensor.cpu(), cmap="Greys")
    else:
        plt.imshow(image_tensor.permute(1, 2, 0))
    if title:
        plt.title(title)
    if is_show:
        plt.show()


# tensorboard helper functions

def images_to_probs(data_output):
    """
    Generates predictions and corresponding probabilities from a trained
    network and a list of images
    https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html
    """
    # convert output probabilities to predicted class
    _, preds_tensor = torch.max(data_output, dim=1)
    preds = torch.squeeze(preds_tensor)
    return preds, [f.softmax(el, dim=0)[i].item() for i, el in zip(preds, data_output)]


def plot_classes_preds(data_output, images, labels):
    """
    Generates matplotlib Figure using a trained network, along with images
    and labels from a batch, that shows the network's top prediction along
    with its probability, alongside the actual label, coloring this
    information based on whether the prediction was correct or not.
    Uses the "images_to_probs" function.
    https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html
    """
    preds, probs = images_to_probs(data_output)
    # plot the images in the batch, along with predicted and true labels
    # fig = plt.figure(figsize=(12, 48))
    fig = plt.figure(figsize=(16, 8))
    for idx in np.arange(4):
        ax = fig.add_subplot(1, 4, idx + 1, xticks=[], yticks=[])
        tensor_show(images[idx], is_normalized=True, one_channel=True)
        ax.set_title("{0}, {1:.1f}%\n(label: {2})".format(
            classes[preds[idx]],
            probs[idx] * 100.0,
            classes[labels[idx]]),
            color=("green" if preds[idx] == labels[idx].item() else "red"))
    return fig


def print_batch_statistics(epoch, running_loss, writer, loader_size, data_output, data, labels, i=0, mini_batch=False, training=True):
    # TODO more metrics
    if mini_batch: div_factor = mini_batch_print
    else: div_factor = loader_size # for validation epoch, entire epoch size is the divider size

    if training:
        print(f'[{epoch + 1}, {i + 1:5d}] '
              f'loss: {running_loss / div_factor:.3f}')  # running loss in accordance to minibatchprint size
        # log the running loss
        writer.add_scalar('training loss',
                          running_loss / div_factor,  # running loss in accordance to minibatchprint size
                          epoch * loader_size + i)
        # log a plt Figure showing the model's predictions on a random mini-batch
        global_step = epoch * loader_size + i
        if i==0: global_step=5  # for validation epoch, mark the figure
        writer.add_figure('predictions vs. actuals',
                          plot_classes_preds(data_output, data, labels),
                          global_step=global_step)
        writer.add_image(f'images', torchvision.utils.make_grid(data), 0)
    else:
        print(f'validation epoch [{epoch + 1}] '
              f'validation loss: {running_loss / div_factor:.3f}')  # running loss in accordance to minibatchprint size
        # log the running loss
        writer.add_scalar('validation loss',
                          running_loss / div_factor,  # running loss in accordance to batch size
                          epoch * loader_size + i)
        # log a plt Figure showing the model's predictions on a validation batch
        writer.add_figure('validation predictions vs. actuals',
                          plot_classes_preds(data_output, data, labels),
                          global_step=epoch * loader_size)
        writer.add_image(f'validation images', torchvision.utils.make_grid(data), 0)  # {i}
        print(f"validation accuracy {accuracy(data_output, labels)}")


def accuracy(data_output, labels):
    # preds, _ = images_to_probs(data_output)
    # acc = torch.tensor(torch.sum(preds == labels).item() / len(preds))

    _, preds = torch.max(data_output, dim=1)
    acc = torch.tensor(torch.sum(preds == labels).item() / len(preds))
    # TODO - verify the print is right
    return acc

def train_epoch(net, train_loader, device, optimizer, calc_loss, epoch, writer):
    running_loss = 0.0
    net.train()  # set net to train mode
    for i, (data, labels) in enumerate(train_loader):
    # for (data, labels) in train_loader:
        # print(f 'mini epoch {i}')
        data = data.to(device)  # move to GPU if available
        labels = labels.to(device)  # move to GPU if available
        optimizer.zero_grad()  # zero gradients so won't accumulate
        output = net(data)
        loss = calc_loss(output, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item() * data.size(0) # loss.item() is not the loss for an entire mini-batch, it is devided by the size of the minibatch
        # print every x mini-batches
        if i % mini_batch_print == (mini_batch_print - 1):
            print_batch_statistics(epoch, running_loss, writer, len(train_loader), output, data, labels, i, mini_batch=True, training=True)
            running_loss = 0.0

    epoch_loss = running_loss / len(train_loader) # loss averaged across all examples for the current epoch
    return epoch_loss


def validation_epoch(net, validation_loader, device, calc_loss, epoch, writer):
    running_loss = 0.0
    net.eval()  # set net to train mode
    with torch.no_grad(): # we don't want to calculate gradients, just validation
        for (data, labels) in validation_loader:
            data = data.to(device)  # move to GPU if available
            labels = labels.to(device)  # move to GPU if available
            output = net(data)
            loss = calc_loss(output, labels)
            running_loss += loss.item() * data.size(0)
    print_batch_statistics(epoch, running_loss, writer, len(validation_loader), output, data, labels, training=False)
    epoch_loss = running_loss / len(validation_loader)   # loss averaged across all examples for the current epoch
    return epoch_loss


def load_checkpoint(path, net, optimizer, scheduler):
    cp = torch.load(path)
    net.load_state_dict(cp['model'])
    optimizer.load_state_dict(cp['optimizer'])
    scheduler.load_state_dict(cp['scheduler'])
    return cp['epoch']


def main(argv):
    # tensorboard
    current_run = f"{datetime.now():%Y.%m.%d_%H.%M}"
    # writer = SummaryWriter(os.path.join(os.getcwd(), "exercises", "part5_train", "runs", current_run))
    writer = SummaryWriter(os.path.join(os.getcwd(), "runs", current_run))

    # Specify a path for saving the model
    CHECKPOINT_PATH = os.path.join(os.getcwd(), "checkpoints", "latest.pt")
    resume = True
    start_epoch = 0
    # todo add load from init net

    # DS helper
    # A function/transform that takes in an PIL image and returns a transformed version
    normalize_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # Define dataset
    # CIFAR10 returns (image, target) where target is index of the target class.
    # to access one image: im, c = train_data[idx]
    # train_data, val_data = CIFAR10(root='./data', download=True, transform=transform)
    # or just training DS:
    # root_path = os.path.join(os.getcwd(), "exercises", "part5_train", "data")  # if needed: "exercises", "part5_train"
    root_path = os.path.join(os.getcwd(), "data")  # if needed: "exercises", "part5_train"
    train_data = CIFAR10(root=root_path, transform=normalize_transform, train=True)  # , download=True)
    # print(f' ROOT PRINT {root_path}')

    # validation-training split
    rand_seed = torch.random.get_rng_state()
    torch.manual_seed(43)  # momentary seed to get the same validation set each time
    val_size = int(len(train_data)*validation_split) # x% of training set
    train_size = int(len(train_data)*(1-validation_split))
    train_ds, val_ds = random_split(train_data, [train_size, val_size])
    torch.random.set_rng_state(rand_seed)

    # Define DataLoader for training
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                                               num_workers=4, pin_memory=True)  # improve speed when using GPU

    # Define DataLoader for validation
    validation_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size,
                                               num_workers=4, pin_memory=True)

    # Define the network, ex3.1
    net = Ex1Net(in_channels=[3, 32, 64, 128],
                 out_channels=[32, 64, 128, 256],
                 pools=['max', 'max', 'max', 'avg'],
                 num_classes=10)
    # print("built net")
    # print summary
    # print(summary(net,
    #               input_size=(data.shape),
    #               col_names=["input_size", "output_size", "num_params", "mult_adds"],
    #               col_width=15,
    #               depth=10))

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if (device=="cuda:0"):
        print("GPU Check: Using GPU CUDA")
    else:
        print("GPU Check: Using CPU")
    net = net.to(device)  # move entire net to GPU if available (instead of forcing it by .cuda())
    # TODO build net on GPU instead of moving

    # Define optimizer
    optimizer = torch.optim.Adam(net.parameters(), lr=base_lr)

    # LR scheduler
    lr_scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=20)

    # Define Loss
    calc_loss = torch.nn.CrossEntropyLoss()

    # todo save initial net upon flag and load it
    if resume and os.path.exists(CHECKPOINT_PATH):
        print("loading last epoch")
        last_epoch = load_checkpoint(CHECKPOINT_PATH, net, optimizer, lr_scheduler)
        start_epoch = last_epoch + 1

    # Actual Training
    for epoch in range(start_epoch, num_epochs):
        print("starting training")
        training_loss = train_epoch(net, train_loader, device, optimizer, calc_loss, epoch, writer)
        validation_loss = validation_epoch(net, validation_loader, device, calc_loss, epoch, writer)
        # adjust schedule according to validation loss
        lr_scheduler.step(validation_loss)
        # Save
        # todo function for save
        save_dict = { 'epoch': epoch,
                      'model': net.state_dict(),
                      'optimizer': optimizer.state_dict(),
                      'scheduler': lr_scheduler.state_dict(),
                      'loss': training_loss}
        torch.save(save_dict, CHECKPOINT_PATH)  # saves net.state_dict to path
        print(f'Finished epoch {epoch}, saved state')
        # todo Save & update the best model (i.e. when validation loss reaches a new minimum)

    writer.close()
    print('Finished Training')


if __name__ == "__main__":
    main(sys.argv[1:])

# %%
# # get some random training images
# dataiter = iter(train_loader)
# images, labels = next(dataiter)
# # show images
# tensor_show(torchvision.utils.make_grid(images), isNormalized = True)
