# Notebook Title

Brief Explanation

## If necessary install pytorch lightning

In [None]:
# !pip install pytorch-lightning

## If you want to use TPU install xla

In [None]:
# VERSION = "20200325"  #@param ["1.5" , "20200325", "nightly"]
# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# !python pytorch-xla-env-setup.py --version $VERSION

## Imports

In [1]:
import torch
from torch import nn
from torch.optim import lr_scheduler

import pytorch_lightning as pl

import torchvision
import torchvision.models as models
from torchvision import transforms

from torch.utils.data import DataLoader
from torch.nn import functional as F
from pytorch_lightning.metrics import Accuracy, Recall, Precision, ROC, AUC

from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
import numpy as np

## Configs

In [3]:
TENSORBOARD_DIRECTORY = "logs/"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Create Pytorch Ligtning Model

In [None]:
class Net(pl.LightningModule):
    def __init__(self, batch_size=64):
        super(Net, self).__init__()
        
        self.batch_size = batch_size
        
        self.criterion = ...
        self.metrics = {"accuracy": Accuracy(), "recall": Recall()}
        
    
    def forward(self, x):
        raise NotImplementedError()
    
    
    def prepare_data(self):
#         transform = transforms.Compose(
#             [transforms.Resize(224),
#             transforms.ToTensor(),
#             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
#             )

        
        self.trainset = ...

        self.valset = ...
        
    def train_dataloader(self):
        return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=4)
    
    
    def val_dataloader(self):
        return DataLoader(self.valset, batch_size=self.batch_size, num_workers=4)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)

        scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
        
        return [optimizer], [scheduler]
    
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        
        loss = self.criterion(logits, y)

        tensorboard_logs = {'train_loss': loss}
        
        return {'loss': loss, 'log': tensorboard_logs}

    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        
        loss = self.criterion(logits, y)
        
        metrics_dict = {f"val_{name}": metric(logits, y) for name, metric in self.metrics.items()}
        
        return {**{"val_loss": loss}, **metrics_dict}
        
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()

        tensorboard_logs = {name: torch.stack([x[f"val_{name}"] for x in outputs]).mean()
                                for name, metric in self.metrics.items()}
        
        tensorboard_logs["val_loss"] = avg_loss

        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

## Lightning Trainer

In [None]:
logger = TensorBoardLogger(TENSORBOARD_DIRECTORY, name="logger_name")

In [None]:
net = Net(batch_size=1024)

In [None]:
trainer = pl.Trainer(max_epochs=30,
                     logger=logger,
                     gpus=1,
                     early_stop_callback=True)

trainer.fit(net)

## Add visualizations to Tensorboard

In [None]:
writer = SummaryWriter(TENSORBOARD_DIRECTORY)

valloader = net.val_dataloader()
inputs, labels = next(iter(valloader))
inputs, labels = inputs.to(device), labels.to(device)


In [4]:
# grid = torchvision.utils.make_grid(inputs[:25])
# writer.add_image('images', grid, 0)

writer.add_graph(net, inputs)
writer.close()

NameError: name 'writer' is not defined

In [None]:
%load_ext tensorboard

In [None]:
# %reload_ext tensorboard
%tensorboard --logdir logs

## Show model metrics

In [None]:
net.eval()

final_metrics = {}

for name, metric in net.metrics.items():
        final_metrics[f"{name}"]= torch.stack([metric(net(x), y) for x, y in valloader]).mean()
        
final_metrics

## Show Example Predictions

In [None]:
outputs = net(inputs)
predictions = torch.max(outputs, 1)[1]

In [None]:
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress','Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

In [None]:
plt.figure(figsize=(20,20))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    img = inputs[i]

    img = img / 2 + 0.5     # unnormalize
    npimg = img.cpu().numpy()

    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.xlabel(f"Real: {class_names[labels[i]]}, Pred: {class_names[predictions[i]]}")
plt.show()