# FiftyOne Walkthrough: Model Training with a Custom Model

This walkthrough provides examples of how FiftyOne can be used during a model training procedure when your model definition code is custom.  It covers the following concepts:
* Integrating your existing model-training loop with FiftyOne
* Adding predictions from your model to your FiFtyOne dataset
* Visualizing aspects of your dataset based on your newly trained model

This walkthrough is self-contained and uses a custom model definition provided within.

## Preliminaries

This code requires a Torch installation.  So, install it if necessary in your shell / virtual environment.

```
pip install torch
pip install torchvision
```
XXX QUESTION: How self-contained should these be?  This requirement is already included in the README.

Let's set up some basic variables and structures we need for the walkthrough.  Nothing in here is specific to FiftyOne; you probably don't need to understand it in detail.

In [None]:
import time

# Settings; defaults are fine if you have a GPU.  Otherwise, you'll want to
# reduce some values just to get the gist of the walkthrough
settings = {}
# These settings are for a powerful GPU with more than 6GBs Memory
#settings['batch_size'] = 512
# Use all of the samples for training
#settings['take'] = None
# These will work on GPU's with 4GB RAM
# You may need to lower further to run the walkthrough
settings['batch_size'] = 36
# Use 10000 samples from the total training set
settings['take'] = 10000  
# 24 gets us to a good point in this setup
settings['epochs'] = 24
# Where to save the model
settings['model_path'] = './model.pth'

localtime = lambda: time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())

# Dataset Parameters
# Set up the label map for the walkthrough.  We will be working with CIFAR-10.
##  Dataset Setup
cifar10_mean, cifar10_std = [
    (125.31, 122.95, 113.87), 
    (62.99, 62.09, 66.70), 
]

cifar10_map = "airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck".split(', ')
cifar10_rev = {name: index for index, name in enumerate(cifar10_map)}
N_labels = 10

## Setup the data

Let's use the CIFAR-10 dataset for simplicity.  And, let's get it from the FiftyOne zoo, without loss of generality. 

In [None]:
import fiftyone as fo
import fiftyone.zoo as foz

dataset = foz.load_zoo_dataset("cifar10")
print(dataset)

That was pretty easy!  There are 60,000 samples loaded and we see tags of `train` and `test`.  Let's just make sure that things line up properly: there should be 50,000 samples for training and 10,000 for testing (or validation, as we will use them below).

In [None]:
train_view = dataset.view().match_tag("train")
print("train samples: %d" % len(train_view))
valid_view = dataset.view().match_tag("test")
print("validation samples: %d" % len(valid_view))

### Cache the data for training

Next, as you may be aware, FiftyOne's `Dataset` object points to the actual data on disk and does not load it or cache it in any way, allowing FiftyOne to be fast and lightweight while giving you maximum flexibility.  So, since we are going to be training a model and this example data is small, let's cache all of the data in memory.

In [None]:
import imageio 
import numpy as np

# Produces train_set and valid_set that are lists of tuples: (image, label)

if train_view is None:
    raise ValueError(
        "train expects 'train_view' in the global namespace. See README.md"
    )
if valid_view is None:
    raise ValueError(
        "train expects 'valid_view' in the global namespace. See README.md"
    )

def update_progress(progress):
    # progress is [0,1]
    t = 51
    i = int(progress*t)
    r = t-i
    print("\r[%s%s] %.1f%%" % ("#"*i, " "*r,  progress*100), end="")
    
    
if settings['take']:
    train_view = train_view.take(settings['take'])
    print(f"using a subset of the data for the model training")
    print(f"updated train set: {len(train_view)} samples")

    
print("Training images")
_train_images = []
_train_labels = []
for index, sample in enumerate(train_view.iter_samples()):
    image = np.array(imageio.imread(sample.filepath))
    label = cifar10_rev[sample["ground_truth"].label]
    _train_images.append(image)
    _train_labels.append(label)

    if index % 100 == 0:
        update_progress(index / len(train_view))
update_progress(1)
print()


print("Validation images")
_valid_images = []
_valid_labels = []
for index, sample in enumerate(valid_view.iter_samples()):
    image = np.array(imageio.imread(sample.filepath))
    label = cifar10_rev[sample["ground_truth"].label]
    _valid_images.append(image)
    _valid_labels.append(label)

    if index % 100 == 0:
        update_progress(index / len(valid_view))
update_progress(1)
print()

## Train the Model

Using the model provided in `./simple_resnet.py`, let's now train a model and save it to disk.  This code uses Torch as the ML backend and implements a small resnet model.  We will train the model using the cached images from the FiftyOne dataset we made above.  

In [None]:
from __future__ import print_function
from functools import partial
from simple_resnet import *

class DataLoader():
    def __init__(self, dataset, batch_size, shuffle, set_random_choices=False, num_workers=0, drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.set_random_choices = set_random_choices
        self.dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=shuffle, drop_last=drop_last
        )

    def __iter__(self):
        if self.set_random_choices:
            self.dataset.set_random_choices()
        return ({'input': x.to(device).half(), 'target': y.to(device).long()} for (x,y) in self.dataloader)

    def __len__(self):
        return len(self.dataloader)

# Set up the dataset structures for our model code and preprocess the data.
whole_dataset = {
    'train': {
        'data': np.asarray(_train_images),
        'targets': np.asarray(_train_labels)
    },
    'valid': {
        'data': np.asarray(_valid_images),
        'targets': np.asarray(_valid_labels)
    }
}

print("Preprocessing training data")
transforms = [
    partial(normalise, mean=np.array(cifar10_mean, dtype=np.float32), std=np.array(cifar10_std, dtype=np.float32)),
    partial(transpose, source='NHWC', target='NCHW'),
]
train_set = list(zip(*preprocess(whole_dataset['train'], [partial(pad, border=4)] + transforms).values()))
valid_set = list(zip(*preprocess(whole_dataset['valid'], transforms).values()))
print(f"Finished preprocessing")

print(f"train set: {len(train_set)} samples")
print(f"valid set: {len(valid_set)} samples")


# Set up the variables for training the model. 
lr_schedule = PiecewiseLinear([0, 5, settings['epochs']], [0, 0.4, 0])
train_transforms = [Crop(32, 32), FlipLR(), Cutout(8, 8)]
total_N = len(train_set)

print(f'Starting the model training at {localtime()}')

model = Network(simple_resnet()).to(device).half()
logs, state = Table(), {MODEL: model, LOSS: x_ent_loss}

valid_batches = DataLoader(valid_set, settings['batch_size'], shuffle=False, drop_last=False)

train_batches = DataLoader(
        Transform(train_set, train_transforms),
        settings['batch_size'], shuffle=True, set_random_choices=True, drop_last=True
)
lr = lambda step: lr_schedule(step/len(train_batches))/settings['batch_size']
opts = [
    SGD(trainable_params(model).values(),
    {'lr': lr, 'weight_decay': Const(5e-4*settings['batch_size']), 'momentum': Const(0.9)})
]
state[OPTS] = opts

for epoch in range(settings['epochs']):
    logs.append(union({'epoch': epoch+1}, train_epoch(state, Timer(torch.cuda.synchronize), train_batches, valid_batches)))
logs.df().query(f'epoch=={settings["epochs"]}')[['train_acc', 'valid_acc']].describe()

if settings['model_path']:
    torch.save(model.state_dict(),settings['model_path'])

# Wrap-Up

And, we're done!  If you used the GPU setup, you should see the `valid acc` field in the training log above reporting about 94% accuracy; if you used the lesser setup, your results will vary, but it should be around 85%. Well, this is really only the beginning of the journey. Here, we trained a toy model on the CIFAR-10 dataset with FiftyOne involved in the process.  

XXX TODO Link to other walkthroughs in a sensible manner.  
With this model, we could, for example, now begin to explore the training and validation dataset for uniqueness and possible label mistakes.

In the sister-walkthrough, XXX LINK, we use Torch network definitions to directly train a model on a larger dataset.  

XXX QUESTION: Should we incorporate some other ending here, such as visualization of predictions on the validation set, or the worst K predictions on the validation set?  Something to show the next stage value of the tool?