# reentrancy-demo

![](https://i.imgur.com/NYowq6j.png)

<div style="background-color: #fff; box-shadow: 0 2px 4px 0 rgba(0,0,0,0.2); padding: 30px; padding-top: 24px; margin: 0px 40px">
This Jupyter notebook is the code compliment to the blog posts <b><a href="https://spell.run/blog/automated-machine-failure-recovery-Xp3TEhEAACUAYwPM">Automating GPU machine failure recovery in Google Compute Engine</a></b>, <b><a href="https://spell.run/blog/reducing-gpu-model-training-costs-using-spot-XqtgJBAAACMAR6h8">Reducing GPU model training costs by 66% using spot instances</a></b>, and <b><a href="https://spell.ml/blog/spot-interruptions-XzQ5jRIAACIAK3h2">Making model training scripts robust to spot interruptions</a></b>.
</div>

<br/>

## prerequisites

You will need:
* An account on Spell.
* A copy of the [250 Segmented Bob Ross Images](https://www.kaggle.com/residentmario/segmented-bob-ross-images) dataset from Kaggle. Download this dataset, unzip it, and upload it to SpellFS by running e.g. `spell upload ~/Downloads/segmented-bob-ross-images`. The files should land in the `uploads/segmented-bob-ross-images` directory in Spell.
* The `pytorch` and `spell` Python packages installed in your local environment. Alternatively, you can launch this notebook from a Spell workspace by running the following CLI command (requires having the `spell` package installed):

```python
spell jupyter \
    --lab \
    --github-url https://github.com/spellrun/spell-examples.git \
    spot-demo-workspace
```

## a non-reentrant training script

A script or pipeline is said to be **reentrant** if it can safely be rerun after terminating or failing midway through execution.

To demonstrate what this means in the context of machine learning training, take a look at the following `pytorch` training script. This script trains a [UNet](https://arxiv.org/abs/1505.04597) image segmentation model on the [250 Segmented Bob Ross Images](https://www.kaggle.com/residentmario/segmented-bob-ross-images) dataset from Kaggle.

In [1]:
%%writefile unet.py
import torch
from torch import nn

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_1_1 = nn.Conv2d(3, 64, 3)
        torch.nn.init.kaiming_normal_(self.conv_1_1.weight)
        self.relu_1_2 = nn.ReLU()
        self.norm_1_3 = nn.BatchNorm2d(64)
        self.conv_1_4 = nn.Conv2d(64, 64, 3)
        torch.nn.init.kaiming_normal_(self.conv_1_4.weight)
        self.relu_1_5 = nn.ReLU()
        self.norm_1_6 = nn.BatchNorm2d(64)
        self.pool_1_7 = nn.MaxPool2d(2)
        
        self.conv_2_1 = nn.Conv2d(64, 128, 3)
        torch.nn.init.kaiming_normal_(self.conv_2_1.weight)        
        self.relu_2_2 = nn.ReLU()
        self.norm_2_3 = nn.BatchNorm2d(128)
        self.conv_2_4 = nn.Conv2d(128, 128, 3)
        torch.nn.init.kaiming_normal_(self.conv_2_4.weight)        
        self.relu_2_5 = nn.ReLU()
        self.norm_2_6 = nn.BatchNorm2d(128)
        self.pool_2_7 = nn.MaxPool2d(2)
        
        self.conv_3_1 = nn.Conv2d(128, 256, 3)
        torch.nn.init.kaiming_normal_(self.conv_3_1.weight)
        self.relu_3_2 = nn.ReLU()
        self.norm_3_3 = nn.BatchNorm2d(256)
        self.conv_3_4 = nn.Conv2d(256, 256, 3)
        torch.nn.init.kaiming_normal_(self.conv_3_4.weight)
        self.relu_3_5 = nn.ReLU()
        self.norm_3_6 = nn.BatchNorm2d(256)
        self.pool_3_7 = nn.MaxPool2d(2)
        
        self.conv_4_1 = nn.Conv2d(256, 512, 3)
        torch.nn.init.kaiming_normal_(self.conv_4_1.weight)
        self.relu_4_2 = nn.ReLU()
        self.norm_4_3 = nn.BatchNorm2d(512)
        self.conv_4_4 = nn.Conv2d(512, 512, 3)
        torch.nn.init.kaiming_normal_(self.conv_4_4.weight)
        self.relu_4_5 = nn.ReLU()
        self.norm_4_6 = nn.BatchNorm2d(512)
        
        # deconv is the '2D transposed convolution operator'
        self.deconv_5_1 = nn.ConvTranspose2d(512, 256, (2, 2), 2)
        # 61x61 -> 48x48 crop
        self.c_crop_5_2 = lambda x: x[:, :, 6:54, 6:54]
        self.concat_5_3 = lambda x, y: torch.cat((x, y), dim=1)
        self.conv_5_4 = nn.Conv2d(512, 256, 3)
        torch.nn.init.kaiming_normal_(self.conv_5_4.weight)        
        self.relu_5_5 = nn.ReLU()
        self.norm_5_6 = nn.BatchNorm2d(256)
        self.conv_5_7 = nn.Conv2d(256, 256, 3)
        torch.nn.init.kaiming_normal_(self.conv_5_7.weight)
        self.relu_5_8 = nn.ReLU()
        self.norm_5_9 = nn.BatchNorm2d(256)
        
        self.deconv_6_1 = nn.ConvTranspose2d(256, 128, (2, 2), 2)
        # 121x121 -> 88x88 crop
        self.c_crop_6_2 = lambda x: x[:, :, 17:105, 17:105]
        self.concat_6_3 = lambda x, y: torch.cat((x, y), dim=1)
        self.conv_6_4 = nn.Conv2d(256, 128, 3)
        torch.nn.init.kaiming_normal_(self.conv_6_4.weight)
        self.relu_6_5 = nn.ReLU()
        self.norm_6_6 = nn.BatchNorm2d(128)
        self.conv_6_7 = nn.Conv2d(128, 128, 3)
        torch.nn.init.kaiming_normal_(self.conv_6_7.weight)
        self.relu_6_8 = nn.ReLU()
        self.norm_6_9 = nn.BatchNorm2d(128)
        
        self.deconv_7_1 = nn.ConvTranspose2d(128, 64, (2, 2), 2)
        # 252x252 -> 168x168 crop
        self.c_crop_7_2 = lambda x: x[:, :, 44:212, 44:212]
        self.concat_7_3 = lambda x, y: torch.cat((x, y), dim=1)
        self.conv_7_4 = nn.Conv2d(128, 64, 3)
        torch.nn.init.kaiming_normal_(self.conv_7_4.weight)
        self.relu_7_5 = nn.ReLU()
        self.norm_7_6 = nn.BatchNorm2d(64)
        self.conv_7_7 = nn.Conv2d(64, 64, 3)
        torch.nn.init.kaiming_normal_(self.conv_7_7.weight)        
        self.relu_7_8 = nn.ReLU()
        self.norm_7_9 = nn.BatchNorm2d(64)
        
        # 1x1 conv ~= fc; n_classes = 9
        self.conv_8_1 = nn.Conv2d(64, 9, 1)

    def forward(self, x):
        x = self.conv_1_1(x)
        x = self.relu_1_2(x)
        x = self.norm_1_3(x)
        x = self.conv_1_4(x)
        x = self.relu_1_5(x)
        x_residual_1 = self.norm_1_6(x)
        x = self.pool_1_7(x_residual_1)
        
        x = self.conv_2_1(x)
        x = self.relu_2_2(x)
        x = self.norm_2_3(x)
        x = self.conv_2_4(x)
        x = self.relu_2_5(x)
        x_residual_2 = self.norm_2_6(x)
        x = self.pool_2_7(x_residual_2)
        
        x = self.conv_3_1(x)
        x = self.relu_3_2(x)
        x = self.norm_3_3(x)
        x = self.conv_3_4(x)
        x = self.relu_3_5(x)
        x_residual_3 = self.norm_3_6(x)
        x = self.pool_3_7(x_residual_3)
        
        x = self.conv_4_1(x)
        x = self.relu_4_2(x)
        x = self.norm_4_3(x)        
        x = self.conv_4_4(x)
        x = self.relu_4_5(x)
        x = self.norm_4_6(x)
        
        x = self.deconv_5_1(x)
        x = self.concat_5_3(self.c_crop_5_2(x_residual_3), x)
        x = self.conv_5_4(x)
        x = self.relu_5_5(x)
        x = self.norm_5_6(x)
        x = self.conv_5_7(x)
        x = self.relu_5_8(x)
        x = self.norm_5_9(x)
        
        x = self.deconv_6_1(x)
        x = self.concat_6_3(self.c_crop_6_2(x_residual_2), x)
        x = self.conv_6_4(x)
        x = self.relu_6_5(x)
        x = self.norm_6_6(x)
        x = self.conv_6_7(x)
        x = self.relu_6_8(x)
        x = self.norm_6_9(x)
        
        x = self.deconv_7_1(x)
        x = self.concat_7_3(self.c_crop_7_2(x_residual_1), x)
        x = self.conv_7_4(x)
        x = self.relu_7_5(x)
        x = self.norm_7_6(x)
        x = self.conv_7_7(x)
        x = self.relu_7_8(x)
        x = self.norm_7_9(x)
        
        x = self.conv_8_1(x)
        return x

Overwriting unet.py


In [2]:
%%writefile train.py
import torch
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import numpy as np
from pathlib import Path

from unet import UNet

NUM_EPOCHS = 50

class BobRossSegmentedImagesDataset(Dataset):
    def __init__(self, dataroot):
        super().__init__()
        self.dataroot = dataroot
        self.imgs = list((self.dataroot / 'train' / 'images').rglob('*.png'))
        self.segs = list((self.dataroot / 'train' / 'labels').rglob('*.png'))
        self.transform = transforms.Compose([
            transforms.Resize((164, 164)),
            transforms.Pad(46, padding_mode='reflect'),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(
                            mean=(0.459387, 0.46603974, 0.4336706),
                            std=(0.06098535, 0.05802868, 0.08737113)
            )
        ])
        self.color_key = {
            3 : 0,
            5: 1,
            10: 2,
            14: 3,
            17: 4,
            18: 5,
            22: 6,
            27: 7,
            61: 8
        }
        assert len(self.imgs) == len(self.segs)
        
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, i):
        def translate(x):
            return self.color_key[x]
        translate = np.vectorize(translate)
        
        img = Image.open(self.imgs[i])
        img = self.transform(img)
        
        seg = Image.open(self.segs[i])
        seg = seg.resize((256, 256), Image.NEAREST)
        
        seg = translate(np.array(seg)).astype('int64')
        
        # Additionally, the original UNet implementation outputs a segmentation map
        # for a subset of the overall image, not the image as a whole! With this input
        # size the segmentation map targeted is a (164, 164) center crop.
        seg = seg[46:210, 46:210]
        
        return img, seg

dataroot = Path('/mnt/segmented-bob-ross-images/')
dataset = BobRossSegmentedImagesDataset(dataroot)
dataloader = DataLoader(dataset, shuffle=True, batch_size=8)

model = UNet()
model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.5)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=32)

for epoch in range(NUM_EPOCHS):
    losses = []

    for i, (batch, segmap) in enumerate(dataloader):
        optimizer.zero_grad()
        
        batch = batch.cuda()
        segmap = segmap.cuda()

        output = model(batch)
        loss = criterion(output, segmap)
        loss.backward()
        optimizer.step()
        scheduler.step()

        curr_loss = loss.item()
        losses.append(curr_loss)

    print(f'Finished epoch {epoch}.')

torch.save(model.state_dict(), '50_net.pth')

Overwriting train.py


**This training script is not reentrant.** Saving the model to disk is the very last thing this training script does. If the run executing this training script gets interrupted, all training progress the model made prior to termination is lost forever!

It's unsafe to assume that the machine running the script will actually succeed in executing it from start to finish:

* The machine might experience a hardware failure, forcing early termination.
* If the machine is an on-demand Compute Engine instance on GCP, and it has been live for six hours or longer, it may be randomly forced into the `REPAIRING` state and self-terminate.
* If the machine is a spot instance on AWS or a preemptible instance on GCP, it may get reclaimed by the vendor.

If the model is small, it's cheap enough to simply launch a new training job from scratch when the old one fails.

If the model is large and complex, things are different. State-of-the-art user-facing deep learning models often take hundreds (if not thousands) of dollars to train once. In these kinds of scenarios losing all of your progress is prohibitively expensive and potentially a huge source of delay for your project.

## reentrant training script
To protect against unexpected failure and/or take advantage of spot instance cost savings, make your training script reentrant:

In [3]:
%%writefile train_reentrant.py
import torch
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import numpy as np
from pathlib import Path

from unet import UNet

NUM_EPOCHS = 50

# Instead of always starting the zeroeth epoch, check if the user passed a checkpoint.
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--from-checkpoint', type=str, dest='checkpoint', default='')
args = parser.parse_args()
if args.checkpoint:
    first_remaining_epoch = int(args.checkpoint.split('_')[0]) + 1
    EPOCHS = range(first_remaining_epoch, NUM_EPOCHS)
else:
    EPOCHS = range(NUM_EPOCHS)

class BobRossSegmentedImagesDataset(Dataset):
    def __init__(self, dataroot):
        super().__init__()
        self.dataroot = dataroot
        self.imgs = list((self.dataroot / 'train' / 'images').rglob('*.png'))
        self.segs = list((self.dataroot / 'train' / 'labels').rglob('*.png'))
        self.transform = transforms.Compose([
            transforms.Resize((164, 164)),
            transforms.Pad(46, padding_mode='reflect'),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(
                            mean=(0.459387, 0.46603974, 0.4336706),
                            std=(0.06098535, 0.05802868, 0.08737113)
            )
        ])
        self.color_key = {
            3 : 0,
            5: 1,
            10: 2,
            14: 3,
            17: 4,
            18: 5,
            22: 6,
            27: 7,
            61: 8
        }
        assert len(self.imgs) == len(self.segs)
        
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, i):
        def translate(x):
            return self.color_key[x]
        translate = np.vectorize(translate)
        
        img = Image.open(self.imgs[i])
        img = self.transform(img)
        
        seg = Image.open(self.segs[i])
        seg = seg.resize((256, 256), Image.NEAREST)
        
        seg = translate(np.array(seg)).astype('int64')
        
        # Additionally, the original UNet implementation outputs a segmentation map
        # for a subset of the overall image, not the image as a whole! With this input
        # size the segmentation map targeted is a (164, 164) center crop.
        seg = seg[46:210, 46:210]
        
        return img, seg

dataroot = Path('/mnt/segmented-bob-ross-images/')
dataset = BobRossSegmentedImagesDataset(dataroot)
dataloader = DataLoader(dataset, shuffle=True, batch_size=8)

# Instead of always initializing an empty model, initialize from the checkpoints
# file if one is available.
model = UNet()
model.cuda()
if args.checkpoint:
    model.load_state_dict(torch.load(f'/mnt/checkpoints/{args.checkpoint}'))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.5)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=32)

for epoch in EPOCHS:
    losses = []

    for i, (batch, segmap) in enumerate(dataloader):
        optimizer.zero_grad()
        
        batch = batch.cuda()
        segmap = segmap.cuda()

        output = model(batch)
        loss = criterion(output, segmap)
        loss.backward()
        optimizer.step()
        scheduler.step()

        curr_loss = loss.item()
        losses.append(curr_loss)

    print(f'Finished epoch {epoch}.')

    # Save the model checkpoints file every 5 epochs
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'{epoch}_net.pth')
        print(f'Saved model to {epoch}_net.pth.')

Overwriting train_reentrant.py


This updated script uses the [PyTorch model saving and loading facilities](https://pytorch.org/tutorials/beginner/saving_loading_models.html) to checkpoint the file every five epochs. The script is now parameterizable with a `--from-checkpoint` flag which, if set, specifies the path to the checkpoints file that model training is to be restarted from.

It's easy to use this updated training script hand-in-hand with the Spell API to ensure that you can restart right where you left off. Here's how. First, run:

In [7]:
!spell run --machine-type V100 \
    --github-url 'https://github.com/spellrun/spell-examples.git' \
    --mount uploads/segmented-bob-ross-images:/mnt/segmented-bob-ross-images \
    --pip Pillow \
    "python spot/train_reentrant.py"

[0m💫 Casting spell #307…
[0m✨ Stop viewing logs with ^C
[0m[K[0m[?25h[0m✨ Machine_Requested… done
[0m[K[0m[?25h[0m✨ Building… done tagged registry.spell:80/remote_content_307:d1958944e……0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m
[0m[K[0m[?25h[0m✨ Mounting… done
[0m✨ [0mRun is running
[0mFinished epoch 0.
[0mSaved model to 0_net.pth.
[0mFinished epoch 1.
[0mFinished epoch 2.
[0mFinished epoch 3.
[0mFinished epoch 4.
[0mFinished epoch 5.
[0mSaved model to 5_net.pth.
[0mFinished epoch 6.
[0mFinished epoch 7.
[0mFinished epoch 8.
[0mFinished epoch 9.
[0mFinished epoch 10.
[0mSaved model to 10_net.pth.
[0mFinished epoch 11.
[0mFinished epoch 12.
[0mFinished epoch 13.
[0mFinis

The Spell workspace will not log you in by default. If you are running this notebook from inside of a Spell workspace you will need to run the following command, replacing `YOUR_EMAIL` with your Spell email and `YOUR_PASSWORD` with your Spell password:

In [None]:
# !spell login --identity YOUR_EMAIL --password YOUR_PASSWORD

After running this model training for a while, simulate an early termination by [stopping](https://spell.run/docs/run_overview#interrupting-a-run) the run (replace `RUN_ID` here with the `RUN_ID` that got assigned to this run):

In [8]:
!spell stop 307

[0mStopping run 307. Use 'spell logs -f 307' to view logs while the job finishes.
[0m[0m

Check the run outputs that landed in SpellFS using `spell ls`:

In [9]:
!spell ls runs/307

[0m30854380 Apr 27 23:50   0_net.pth
[0m30854380 Apr 27 23:50   10_net.pth
[0m30854380 Apr 27 23:50   15_net.pth
[0m30854380 Apr 27 23:50   20_net.pth
[0m30854380 Apr 27 23:50   25_net.pth
[0m30854380 Apr 27 23:50   30_net.pth
[0m30854380 Apr 27 23:50   35_net.pth
[0m30854380 Apr 27 23:50   5_net.pth
[0m[0m

And now to restart where you left off you simply run:

In [15]:
!spell run --machine-type V100 \
    --github-url 'https://github.com/spellrun/spell-examples.git' \
    --mount uploads/segmented-bob-ross-images:/mnt/segmented-bob-ross-images \
    --mount runs/307:/mnt/checkpoints/ \
    --pip Pillow \
    "python spot/train_reentrant.py --from-checkpoint '35_net.pth'"

[0m💫 Casting spell #309…
[0m✨ Stop viewing logs with ^C
[0m[K[0m[?25h[0m✨ Machine_Requested… done
[0m[K[0m[?25h[0m✨ Building… done tagged registry.spell:80/remote_content_309:bfde0d448……[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m[0m
[0m[K[0m[?25h[0m✨ Mounting… done
[0m✨ [0mRun is running
[0m^C

[0m✨ Your run is still running remotely.
[0m✨ Use 'spell kill 309' to terminate your run
[0m✨ Use 'spell logs 309' to view logs again
[0m[K[0m[?25h[0m[0m

Which will train the model to completion.

## resumable

A **resumable** training script goes one step further than a reentrant one.

Spell runs executed on spot instances may be configured with a `--auto-resume` flag set. In the event the machine is reclaimed by the cloud provider, this flag instructs Spell to queue a new Spell run with the same run command and a copy of the previous run's disk image. Assuming your training script is resumable, this will make your training job robust to cloud interrupts. Almost as good as an on-demand instance at less than half the price!

Here's an example. Note the addition of the `--resume` flag, which automatically finds the most recent checkpoint file and resumes from that one.

In [22]:
%%writefile train_resumable.py
import torch
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import numpy as np
from pathlib import Path
import re
import os

from unet import UNet

NUM_EPOCHS = 50

# Instead of always starting the zeroeth epoch, check if the user passed a checkpoint.
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--from-checkpoint', type=str, dest='checkpoint', default='')
parser.add_argument('--resume', action='store_true')
args = parser.parse_args()

class BobRossSegmentedImagesDataset(Dataset):
    def __init__(self, dataroot):
        super().__init__()
        self.dataroot = dataroot
        self.imgs = list((self.dataroot / 'train' / 'images').rglob('*.png'))
        self.segs = list((self.dataroot / 'train' / 'labels').rglob('*.png'))
        self.transform = transforms.Compose([
            transforms.Resize((164, 164)),
            transforms.Pad(46, padding_mode='reflect'),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(
                            mean=(0.459387, 0.46603974, 0.4336706),
                            std=(0.06098535, 0.05802868, 0.08737113)
            )
        ])
        self.color_key = {
            3 : 0,
            5: 1,
            10: 2,
            14: 3,
            17: 4,
            18: 5,
            22: 6,
            27: 7,
            61: 8
        }
        assert len(self.imgs) == len(self.segs)
        
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, i):
        def translate(x):
            return self.color_key[x]
        translate = np.vectorize(translate)
        
        img = Image.open(self.imgs[i])
        img = self.transform(img)
        
        seg = Image.open(self.segs[i])
        seg = seg.resize((256, 256), Image.NEAREST)
        
        seg = translate(np.array(seg)).astype('int64')
        
        # Additionally, the original UNet implementation outputs a segmentation map
        # for a subset of the overall image, not the image as a whole! With this input
        # size the segmentation map targeted is a (164, 164) center crop.
        seg = seg[46:210, 46:210]
        
        return img, seg

dataroot = Path('/mnt/segmented-bob-ross-images/')
dataset = BobRossSegmentedImagesDataset(dataroot)
dataloader = DataLoader(dataset, shuffle=True, batch_size=8)

# Instead of always initializing an empty model, initialize from the checkpoints
# file if one is available.
model = UNet()
model.cuda()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.5)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=32)

if args.resume:
    if not os.path.exists("/spell/checkpoints/") or len(os.listdir("/spell/checkpoints/")) == 0:
        EPOCHS = range(NUM_EPOCHS)
    else:
        checkpoint_epoch = max(
            [int(re.findall("[0-9]{1,2}", fp)[0]) for fp in os.listdir("/spell/checkpoints/")]
        )
        model.load_state_dict(torch.load(f'/spell/checkpoints/{checkpoint_epoch}_net.pth'))
        first_remaining_epoch = checkpoint_epoch + 1
        EPOCHS = range(first_remaining_epoch, NUM_EPOCHS)
elif args.checkpoint:
    first_remaining_epoch = int(args.checkpoint.split('_')[0]) + 1
    EPOCHS = range(first_remaining_epoch, NUM_EPOCHS)
    model.load_state_dict(torch.load(f'/spell/checkpoints/{args.checkpoint}'))
else:
    EPOCHS = range(NUM_EPOCHS)

for epoch in EPOCHS:
    losses = []

    for i, (batch, segmap) in enumerate(dataloader):
        optimizer.zero_grad()
        
        batch = batch.cuda()
        segmap = segmap.cuda()

        output = model(batch)
        loss = criterion(output, segmap)
        loss.backward()
        optimizer.step()
        scheduler.step()

        curr_loss = loss.item()
        losses.append(curr_loss)

    print(f'Finished epoch {epoch}.')

    # Save the model checkpoints file every 5 epochs
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'/spell/checkpoints/{epoch}_net.pth')
        print(f'Saved model to {epoch}_net.pth.')

Overwriting train_resumable.py


In [21]:
!spell run --machine-type v100-spot \
    --github-url 'https://github.com/spellrun/spell-examples.git' \
    --github-ref 'train-resumable' \
    --mount uploads/segmented-bob-ross-images:/mnt/segmented-bob-ross-images \
    --auto-resume \
    "python spot/train_resumable.py --resume"

[0m💫 Casting spell #150…
[0m✨ Stop viewing logs with ^C
[1m[36m🌟[0m Machine_Requested… Run created -- waiting for a v100-spot machine.[0m[0m^C
[0m
[0m✨ Your run is still running remotely.
[0m✨ Use 'spell kill 150' to terminate your run
[0m✨ Use 'spell logs 150' to view logs again
[0m[K[0m[?25h[0m[0m