# WeightWatcher PyTorch Demo

In [None]:
import numpy as np
import torch
import torchvision.models
import torch.nn as nn
from torchvision import datasets, transforms

import ww

In [None]:
%load_ext autoreload
%autoreload 2

## VGG Architecture

In [None]:
# VGG16 with BatchNorm
model = torchvision.models.vgg16_bn()

In [None]:
# Draw graph
dg = ww.build_pytorch_graph(model, torch.zeros([1, 3, 224, 224]))
dot = ww.draw_graph(dg)
dot

## ResNet Architecture

In [None]:
# Resnet50
model = torchvision.models.resnet50()

In [None]:
# Draw graph
dg = ww.build_pytorch_graph(model, torch.zeros([1, 3, 224, 224]))
ww.draw_graph(dg)

## Visualize Training Progress

In [None]:
# CIFAR10 Dataset
t = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.CIFAR10('datasets', train=True, download=True, transform=t)
test_dataset = datasets.CIFAR10('datasets', train=False, download=True, transform=t)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=50, shuffle=True)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=50, shuffle=True)

ww.show("train_dataset.data", train_dataset.train_data)
ww.show("train_dataset.labels", train_dataset.train_labels)
ww.show("test_dataset.data", test_dataset.test_data)
ww.show("test_dataset.labels", test_dataset.test_labels)

In [None]:
# TODO: Moving model to GPU breaks grph drawing. Investigate
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = "cpu"

# Simple Convolutional Network
class CifarModel(nn.Module):
    def __init__(self):
        super(CifarModel, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            
            nn.AdaptiveMaxPool2d(1)
        )
        self.classifier = nn.Sequential(
            nn.Linear(32, 32),
#             nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Linear(32, 10))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

model = CifarModel().to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

In [None]:
dg = ww.build_pytorch_graph(model, torch.zeros([1, 3, 32, 32]))
dot = ww.draw_graph(dg)
dot

In [None]:
step = 0
w = ww.Watcher()

# Visual customizations
w.legend={"loss": "Training Loss",
          "accuracy": "Training Accuracy"}

In [None]:
# Training loop
for epoch in range(2):
    train_iter = iter(train_loader)
    for inputs, labels in train_iter:
        optimizer.zero_grad()
        inputs = inputs.to(device)
        labels = labels.to(device)

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        if step and step % 100 == 0:
            # Compute accuracy
            pred_labels = np.argmax(outputs.detach().numpy(), 1)
            accuracy = np.mean(pred_labels == labels.detach().numpy())
            
            w.step(step, loss=loss, accuracy=accuracy)
            with w:
                w.plot(["loss"])
                w.plot(["accuracy"])
        step += 1
