# How to write training scripts
**So you don't write a separate query script each time you miss a logging statement.**








































## Let's import PyTorch and define the Model

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

class Net(nn.Module):
    def __init__(self):
        torch.manual_seed(1217)
        super(Net, self).__init__()
        self.fc1 = nn.Linear(196, 10)
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.pool(x)
        x = x.view(4, -1)
        x = F.relu(self.fc1(x))
        return x

## Let's now load the data and initialize the model, optimizer, and loss criterion

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.MNIST(root='./mnist', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, num_workers=2)

testset = torchvision.datasets.MNIST(root='./mnist', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, num_workers=2)

In [4]:
def eval(net):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))

## Train the network
1. Comment out code
2. Copy code

In [5]:
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for epoch in range(2):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
    torch.save(net.state_dict(), f'./mnist_net_{epoch}.pth')
    torch.save(optimizer.state_dict(), f'./mnist_opt_{epoch}.pth')
    print("ONCE PER EPOCH LOG STMT")
    eval(net)

print('Finished Training')


[1,  2000] loss: 1.278
[1,  4000] loss: 1.067
[1,  6000] loss: 0.980
[1,  8000] loss: 0.816
[1, 10000] loss: 0.765
[1, 12000] loss: 0.763
[1, 14000] loss: 0.755
Accuracy of the network on the 10000 test images: 80 %
[2,  2000] loss: 0.576
[2,  4000] loss: 0.596
[2,  6000] loss: 0.541
[2,  8000] loss: 0.567
[2, 10000] loss: 0.545
[2, 12000] loss: 0.550
[2, 14000] loss: 0.546
Accuracy of the network on the 10000 test images: 81 %
Finished Training


# The pattern
* What's the diff?

* Demo
    1. Training mode
    2. Skip mode: Log something outside the training loop 
    2. Parallel mode: Log something inside the taining loop

In [11]:
# CLI Flags
TRAINING = False
PROBING_TRAINING = True

LO = 1
HI = 2

In [12]:
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

if LO:                                                                                      #                           
    net.load_state_dict(torch.load(f'./mnist_net_{LO - 1}.pth'))                            #
    optimizer.load_state_dict(torch.load(f'./mnist_opt_{LO - 1}.pth'))                      # 

for epoch in range(LO, HI):                                                                 #
    running_loss = 0.0
    if TRAINING or PROBING_TRAINING:                                                        #
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print(f"ALSO EVERY STEP {epoch}:{i+1}")
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0
        TRAINING and torch.save(net.state_dict(), f'./mnist_net_{epoch}.pth')               #
        TRAINING and torch.save(optimizer.state_dict(), f'./mnist_opt_{epoch}.pth')         #
    else:                                                                                   #
        net.load_state_dict(torch.load(f'./mnist_net_{epoch}.pth'))                         #            
        optimizer.load_state_dict(torch.load(f'./mnist_opt_{epoch}.pth'))                   #
    print(f"ONCE PER EPOCH (Confusion Matrix) {epoch}")
    eval(net)

print('Finished Training')

ALSO EVERY STEP 1:2000
[2,  2000] loss: 0.576
ALSO EVERY STEP 1:4000
[2,  4000] loss: 0.596
ALSO EVERY STEP 1:6000
[2,  6000] loss: 0.541
ALSO EVERY STEP 1:8000
[2,  8000] loss: 0.567
ALSO EVERY STEP 1:10000
[2, 10000] loss: 0.545
ALSO EVERY STEP 1:12000
[2, 12000] loss: 0.550
ALSO EVERY STEP 1:14000
[2, 14000] loss: 0.546
ONCE PER EPOCH (Confusion Matrix) 1
Accuracy of the network on the 10000 test images: 81 %
Finished Training


# The Hindsight Logging&trade; Programming Pattern 
* Skip Retraining when possible
    - Use memoization: observe physical-logical equivalence
* Parallelize Retraining otherwise
    - Enable resuming from a checkpoint
    - Work Partitioning: Control the epoch sub-range from the command-line

# Beyond the Pattern

## Optimizing and Dynamically Controlling the cost of checkpointing

### Background checkpointing
![Background Materialization](img/backmat-simple.png)

### Adaptive Checkpointing
![Adaptive Checkpointing](img/adaptivity_zoomed.png)

## Re-execution is embarrassingly parallel given checkpoints
![Parallel Replay](img/initializations.png)

# If the pattern seems like too much trouble, we can instrument the code for you automatically

![Autoinstrument](img/changeset_example.png)

In [1]:
from IPython.display import IFrame
width = 600
ratio = 8/6
IFrame("https://arxiv.org/pdf/2006.07357.pdf", width=width, height=ratio*width)