# Train image model

> Train a convnet using fastai (on colab).

## References

https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

https://dev.to/tkeyo/export-fastai-resnet-models-to-onnx-2gj7

## Running this notebook in colab

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pete88b/expoco/blob/main/11e_viseme_image_train_model.ipynb)

### Change runtime type to use GPU

### Run the following cell, then restart the runtime

In [2]:
! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab

## After runtime restart

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import datetime, json, torchvision
from fastai.vision.all import *
from zipfile import ZipFile
path = Path('/content/data')
path.mkdir(exist_ok=True)

## Copy your `data.zip` and `metadata.json` to your google drive

Update the path in the following cell to the point to your data

In [None]:
source_path = Path('/content/drive/MyDrive/Colab Notebooks/datasets/expoco/viseme_image_dataset_20220202')
zip_file_path = source_path/'data.zip'
assert zip_file_path.is_file(), f'{zip_file_path} not found'

In [None]:
z = ZipFile(zip_file_path)
z.extractall(path)

## Onnx helper functions

In [None]:
def now():
    "Return a timestamp string that can be used in file or directory names"
    return datetime.utcnow().strftime('%Y%m%d_%H%M%S')

In [None]:
def onnx_export():
    torch_model = learn.model.cpu().eval() # by the time this fn is called, learn will exist
    model = nn.Sequential(
        torchvision.transforms.Normalize(**stats),
        torch_model,
        nn.Softmax(dim=1)
    )
    batch_size = 2
    # Input to the model
    x = torch.randn(batch_size, 3, 256, 256, requires_grad=True)
    torch_out = model(x)

    model_id = now()
    output_path = path/f'model_{model_id}'
    print('output_path', output_path)
    output_path.mkdir()
    file_name = output_path/'resnet_3_256_256.onnx'

    # Export the model
    torch.onnx.export(model,                     # model being run
                      x,                         # model input (or a tuple for multiple inputs)
                      file_name,                 # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=10,          # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['input'],   # the model's input names
                      output_names = ['output'], # the model's output names
                      dynamic_axes = {'input': {0: 'batch_size'},    # variable length axes
                                      'output': {0: 'batch_size'}})
    print('Exported to', file_name)
    return file_name, x, torch_out

## Read `stats` from `metadata.json`

In [1]:
with open(source_path/'metadata.json') as f:
    metadata = json.load(f)
stats = metadata['stats']
stats

In [12]:
def stats_bgr_to_rgb(stats):
    "Convert stats for cv2 style to fastai/pytorch style"
    def _permute(a): return [a[2], a[1], a[0]]
    return {k:_permute(v) for k,v in stats.items()}

## Train a resnet with fastai

### TODO: Create a separate set of validation (and test) data

random sample of images for validation is too easy - validation images will all have an image in the training data that is very similar

In [None]:
data_block = DataBlock(blocks=(ImageBlock(cls=PILImage), CategoryBlock),
                       get_items=get_image_files,
                       splitter=RandomSplitter(),
                       get_y=parent_label,
                       batch_tfms=[Normalize.from_stats(**stats_bgr_to_rgb(stats))] + aug_transforms())
dls=data_block.dataloaders(path, bs=256)
dls.vocab

In [None]:
dls.show_batch(max_n=8,figsize=(14,4))

In [None]:
learn=cnn_learner(dls, resnet18, metrics=[accuracy], wd=1e-3)

In [None]:
learn.lr_find()

In [None]:
learn.freeze()
base_lr = 1e-3 # <- this lr should be fine but you might want to change as per lr_find recommendation
learn.fit_one_cycle(10, base_lr)

In [None]:
learn.recorder.plot_loss()

In [None]:
pre_unfreeze_onnx_path, x, torch_out = onnx_export()

In [None]:
learn.unfreeze()
learn.fit_one_cycle(5, slice(base_lr/100, base_lr))

In [None]:
learn.recorder.plot_loss()

In [None]:
onnx_path, x, torch_out = onnx_export() # this is probably the one we want to keep

## Download resnet_3_256_256.onnx to your machine ...

... before the colab session times out

I also download the notebook as a record of how the model was trained - not the most robust experiment tracking (o:

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
interp.print_classification_report()

In [None]:
interp.plot_confusion_matrix()

## Check the onnx model

In [None]:
!pip install onnx onnxruntime

In [None]:
import onnx, onnxruntime
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

In the cell below; `x` and `torch_out` were retured by a previous call to `onnx_export()`

In [None]:
ort_session = onnxruntime.InferenceSession(str(onnx_path))

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

## Terminate your colab session (o: just not before your model download has finished