In [None]:
#|default_exp callback.compiler

In [None]:
#|exporti
# Contains code from:
# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai

# PyTorch Compile [beta]
> Experimental callbacks and patches to integrate `torch.compile` into fastai

The `CompilerCallback` and `DynamoExplainCallback` are experiments to provide an easy to use `torch.compile` integration for fastai.

`torch.compile` with the default inductor backend [allows 30% to 2x speedups and 10% memory compression](https://github.com/pytorch/pytorch/issues/93794) for both training and inference.

For more information on `torch.compile` please read *[PyTorch's getting started](https://pytorch.org/docs/master/compile/get-started.html)* guide. For troubleshooting `torch.compile` refer to this [PyTorch Nightly guide](https://pytorch.org/docs/master/compile/index.html#troubleshooting-and-gotchas).

This module is not imported via any fastxtend all imports. You must import it separately after importing fastai and fastxtend:

```python
from fastxtend.callback import compiler
# or
from fastxtend.callback.compiler import *
```

In [None]:
#|export
from __future__ import annotations
from typing import Dict

from enum import Enum
from pathlib import Path
import pickle
import warnings

from packaging.version import parse

import torch._dynamo as dynamo
from torch.serialization import FILE_LIKE

from fastai.learner import Learner, save_model, join_path_file, _cast_tensor
from fastai.callback import schedule
from fastai.callback.core import Callback, TrainEvalCallback, CancelFitException
from fastai.callback.fp16 import MixedPrecision

try:
    from fastxtend.ffcv.loader import Loader
    FFCV = True
except ImportError:
    FFCV = False

from fastxtend.imports import *

In [None]:
#|hide
from fastai.vision.learner import vision_learner

In [None]:
#|exporti
_torch_version = parse(torch.__version__)
_torch_20 = parse('2.0')
_torch_21 = parse('2.1')

if _torch_version < _torch_20:
    warn('Imported `fastxtend.callback.compiler`, which requires a minimum of PyTorch 2.0 to work.')

In [None]:
#|export
class CompileMode(str, Enum):
    "All valid `torch.compile` modes for tab-completion and typo-proofing"
    default        = 'default'
    reduceoverhead = 'reduce-overhead'
    maxautotune    = 'max-autotune'

Currently, the 'reduce-overhead' mode doesn't appear to train, instead the loss stagnates, and 'max-autotune' shouldn't be used per *[Compile troubleshooting and gotchas](https://pytorch.org/docs/master/compile/index.html#troubleshooting-and-gotchas)*.

In [None]:
#|export
class MatMulPrecision(str, Enum):
    "All valid `matmul_precision` modes for tab-completion and typo-proofing"
    highest = 'highest'
    high    = 'high'
    medium  = 'medium'

## CompilerCallback -

In [None]:
#|export
class CompilerCallback(Callback):
    "An experimental callback for `torch.compile` (beta) and fastai"
    order = TrainEvalCallback.order + 1 # Compiling needs to occur on the GPU, but before distributed training starts

    def __init__(self,
        fullgraph:bool=False, # Prevent breaking model into subgraphs
        dynamic:bool=False, # Use dynamic shape tracing
        backend:str|Callable='inductor', # `torch.compile` backend to use
        mode:str|CompileMode|None=None, # `torch.compile` mode to use
        options:Dict[str, Union[str,int,bool]]|None=None, # Extra options to pass to compile backend
        matmul_precision:str|MatMulPrecision='high', # Set Ampere and newer TF32 matmul precision
        recompile:bool=False, # Force a compiled model to recompile. Use when freezing/unfreezing a compiled model.
        verbose:bool=True, # Verbose output
    ):
        if isinstance(mode, CompileMode):
            mode = mode.value
        if isinstance(mode, MatMulPrecision):
            matmul_precision = matmul_precision.value
        if mode is not None and options is not None:
            raise ValueError(f"Both {mode=} or {options=} cannot be set at the same time.")
        store_attr(but='recompile')
        self._recompile = recompile

    def before_fit(self):
        if _torch_version < _torch_20:
            self.run = False
            warn("Attempting to use `CompilerCallback` without PyTorch 2.0 or greater. Disabling.")
            return

        if torch.cuda.get_device_capability() >= (8, 0) and torch.get_float32_matmul_precision() != self.matmul_precision:
            if self.verbose and self.matmul_precision!='highest':
                print(f"Your GPU has modern tensor cores, automatically enabling by setting `torch.set_float32_matmul_precision('{self.matmul_precision}')`")
            torch.set_float32_matmul_precision(self.matmul_precision)

        if hasattr(self.learn, 'progressive_resize') and _torch_version < _torch_21:
            warn("Using `ProgressiveResize` and `torch.compile` at the same time will result in a new compile every size change.")
        msg = ""
        if self.dynamic:
            msg += "Using `torch.compile` with dynamic shapes is under active development and might fail\n"
        if self.mode == 'max-autotune':
            msg += "Using `torch.compile` with `mode='max-autotune'` is under active development and might fail\n"
        if msg != "":
            msg += "See https://pytorch.org/docs/master/compile/index.html#troubleshooting-and-gotchas for more details"
            warn(msg)
        if self.mode == 'reduce-overhead':
            warn("Using `torch.compile` & fastai with `mode='reduce-overhead'` currently doesn't appear to train.")

        if self._recompile and isinstance(self.learn.model, dynamo.OptimizedModule):
            if self.verbose:
                print("Recompiling model")
            dynamo.reset()
            self.learn.model = self.learn.model._orig_mod
        self._recompile = False

        if not isinstance(self.learn.model, dynamo.OptimizedModule):
            self.learn.model = torch.compile(self.learn.model, fullgraph=self.fullgraph,
                                             dynamic=self.dynamic, backend=self.backend,
                                             mode=self.mode, options=self.options)

Using `torch.compile` with dynamic shapes and `mode='max-autotune'` is under active development and might fail. See *[Compile troubleshooting and gotchas](https://pytorch.org/docs/master/compile/index.html#troubleshooting-and-gotchas)* for more details.

By default, <code>CompilerCallback</code> will set matmul ops to use TensorFloat32 for supported GPUs, which is the recommended setting for `torch.compile`. Set `matmul_precision='highest'` to turn off or `matmul_precision='medium'` to enable `bfloat16` mode.

fastxtend provides the [compile](#learner.compile) convenience method for easily enabling `torch.compile`. Or you can pass <code>CompilerCallback</code> to the `cbs` argument of the `fastai.learner.Learner` or a fit method.

```python
learn = Learner(..., cbs=CompilerCallback())
learn.fine_tune(1)
```

## DynamoExplainCallback -

In [None]:
#|export
class DynamoExplainCallback(Callback):
    "An experimental callback to find graph breaks with `torch.compile` (beta)"
    order = MixedPrecision.order+1 # DynamoExplain occurs on the GPU before any training starts

    def __init__(self,
        print_results:bool=True, # Print enabled `torch._dynamo.explain` output(s)
        explanation:bool=True, # Print the `explanation` output
        out_guards:bool=False, # Print the `out_guards` output
        graphs:bool=False, # Print the `graphs` output
        ops_per_graph:bool=False, # Print the `ops_per_graph` output
        break_reasons:bool=False, # Print the `break_reasons` output
        explanation_verbose:bool=False, # Print the `explanation_verbose` output
    ):
        self.print_results = print_results
        self.print_explanation = explanation
        self.print_out_guards = out_guards
        self.print_graphs = graphs
        self.print_ops_per_graph = ops_per_graph
        self.print_break_reasons = break_reasons
        self.print_explanation_verbose = explanation_verbose

    def before_fit(self):
        if _torch_version < _torch_20:
            self.run = False
            warn("Attempting to use `DynamoExplainCallback` without PyTorch 2.0 or greater. Canceling training.")
            raise CancelFitException()

        self.explanation, self.out_guards, self.graphs, self.ops_per_graph, self.break_reasons, self.explanation_verbose = '','','','','',''
        states = get_random_states()
        try:
            if FFCV and isinstance(self.dls.train, Loader) and self.dls.train.async_tfms:
                # With `async_tfms`, `Loader` needs to initialize all `Loader.batches_ahead` Cuda streams
                # for the training dataloader. Since FFCV doesn't support seeded transforms and the reset
                # random state only seeds the dataset order, this shouldn't effect training outcome.
                b = self.dls.train.one_batch(batches_ahead=True)
            else:
                b = self.dls.valid.one_batch()
            i = getattr(self.dls, 'n_inp', 1 if len(b)==1 else len(b)-1)
            self.learn.xb, self.learn.yb = b[:i], b[i:]

            if hasattr(self.learn, 'mixed_precision'):
                self.learn.mixed_precision.autocast.__enter__()

            self.explanation, self.out_guards, self.graphs, self.ops_per_graph, self.break_reasons, self.explanation_verbose \
                = dynamo.explain(self.learn.model, *_cast_tensor(self.learn.xb))

            if hasattr(self.learn, 'mixed_precision'):
                self.learn.mixed_precision.autocast.__exit__(None, None, None)

            self.learn.opt.zero_grad()
        finally:
            set_random_states(**states)

        if self.print_results:
            print('\nDynamo Explain Report')
            if self.print_explanation:
                print('\nExplanation:\n')
                print(self.explanation)
            if self.print_out_guards:
                print('\nOut Guards:\n')
                print(self.out_guards)
            if self.print_graphs:
                print('\nGraphs:\n')
                print(self.graphs)
            if self.print_ops_per_graph:
                print('\nOperations per Graph:\n')
                print(self.ops_per_graph)
            if self.print_break_reasons:
                print('\nBreak Reasons:\n')
                print(self.break_reasons)
            if self.print_explanation_verbose:
                print('\nVerbose Explanation:\n')
                print(self.explanation_verbose)
            print('\n')

        raise CancelFitException()

`DynamoExplainCallback` automates finding graph breaks using `torch._dynamo.explain` per the *[Identifying the cause of a graph break](https://pytorch.org/docs/stable/dynamo/faq.html#identifying-the-cause-of-a-graph-break)* section in the *[PyTorch Compile FAQ](https://pytorch.org/docs/stable/dynamo/faq.html)*. <code>DynamoExplainCallback</code> uses one batch from the validation dataloader[^Loader] to generate the `_dynamo.explain` report(s) and then cancels training.

To use, pass <code>DynamoExplainCallback</code> to the `cbs` argument of the `fastai.learner.Learner` or fit method.

```python
learn = Learner(..., cbs=DynamoExplainCallback())
learn.fit(1)
```

By default, <code>DynamoExplainCallback</code> prints the `explanation` output from `_dynamo.explain`, with arguments to enable printing `out_guards`, `graphs`, `ops_per_graph`, `break_reasons`, and/or `explanation_verbose`.

All `_dynamo.explain` outputs are stored as attributes in the callback for later reference. For example, to view the verbose explanation after running `Learner` with <code>DynamoExplainCallback</code>:

```python
print(learn.dynamo_explain.explanation_verbose)
```

[^Loader]: Unless using the FFCV `Loader`, then it uses the training dataloader. This doesn't effect seeded training as FFCV dataloaders do not seed transforms, only dataset order.

## Convenience Method

fastxtend adds a convenience method to `Learner` to easily enable `torch.compile`.

In [None]:
#|export
@patch
def compile(self:Learner,
    fullgraph:bool=False, # Prevent breaking model into subgraphs
    backend:str|Callable='inductor', # `torch.compile` backend to use
    mode:str|CompileMode|None=None, # `torch.compile` mode to use
    options:Dict[str, Union[str,int,bool]]|None=None, # Extra options to pass to compile backend
    matmul_precision:str|MatMulPrecision='high', # Set Ampere and newer TF32 matmul precision
    recompile:bool=False, # Force a compiled model to recompile. Use when freezing/unfreezing a compiled model.
    verbose:bool=True, # Verbose output
):
    "Set `Learner` to compile model using `torch.compile`."
    return self.add_cb(CompilerCallback(fullgraph=fullgraph, backend=backend,
                                        mode=mode, options=options,
                                        matmul_precision=matmul_precision,
                                        recompile=recompile, verbose=verbose))

`compile` does not expose `dynamic` since it's recommended not to be used with PyTorch 2.0. You can set it directly via `CompilerCallback`.

To use, call the `compile` method after initalizing a `fastai.learner.Learner`.

```python
learn = Learner(...).compile()
learn.fine_tune(1)
```

## Compatability Patches

These patches integrate `torch.compile` with fastai saving, loading, freezing, unfreezing, and fine tuning.

### Saving and Exporting

In [None]:
#|export
@patch
@delegates(save_model)
def save(self:Learner,
    file:FILE_LIKE, # Save file name, path, bytes, or IO
    save_compiled:bool=False, # Save compiled model
    **kwargs
):
    "Save model and optimizer state (if `with_opt`) to `self.path/self.model_dir/file`"
    file = join_path_file(file, self.path/self.model_dir, ext='.pth')
    if _torch_version >= _torch_20 and isinstance(self.model, dynamo.OptimizedModule) and not save_compiled:
        save_model(file, self.model._orig_mod, getattr(self,'opt',None), **kwargs)
    else:
        save_model(file, self.model, getattr(self,'opt',None), **kwargs)
    return file

Saving a compiled model is supported, but for maximum compatiblity is turned off by default. Set `save_compiled=True` to save a compiled model.

In [None]:
#|export
@patch
def export(self:Learner,
    fname:FILE_LIKE='export.pkl', # Learner export file name, path, bytes, or IO
    pickle_module:Any=pickle, # Module used for pickling metadata and objects
    pickle_protocol:int=2 # Pickle protocol used
):
    "Export the content of `self` without the items and the optimizer state for inference"
    if rank_distrib(): return # don't export if child proc
    self._end_cleanup()
    old_dbunch = self.dls
    self.dls = self.dls.new_empty()
    state = self.opt.state_dict() if self.opt is not None else None
    self.opt = None
    # torch.compiled models currently cannot be pickled.
    if _torch_version >= _torch_20 and isinstance(self.model, dynamo.OptimizedModule):
        self.model = self.model._orig_mod
    with warnings.catch_warnings():
        # To avoid the warning that come from PyTorch about model not being checked
        warnings.simplefilter("ignore")
        torch.save(self, self.path/fname, pickle_module=pickle_module, pickle_protocol=pickle_protocol)
    self.create_opt()
    if state is not None: self.opt.load_state_dict(state)
    self.dls = old_dbunch

As of PyTorch 2.0 and 2.1 Nightly, compiled models cannot be pickled, so `export` sets `Learner.model` as the non-compiled model.

In [None]:
#|export
def load_learner(
    fname:FILE_LIKE, # File name, path, bytes, or IO
    cpu:bool=True, # Load model to CPU
    pickle_module=pickle # Module used for unpickling metadata and objects
):
    "Load a `Learner` object in `fname`, by default putting it on the `cpu`"
    distrib_barrier()
    map_loc = 'cpu' if cpu else default_device()
    try: res = torch.load(fname, map_location=map_loc, pickle_module=pickle_module)
    except AttributeError as e:
        e.args = [f"Custom classes or functions exported with your `Learner` not available in namespace.\Re-declare/import before loading:\n\t{e.args[0]}"]
        raise
    if cpu:
        res.dls.cpu()
        if hasattr(res, 'channels_last'): res = res.to_contiguous(to_fp32=True)
        elif hasattr(res, 'mixed_precision'): res = res.to_fp32()
        elif hasattr(res, 'non_native_mixed_precision'): res = res.to_non_native_fp32()
        if hasattr(res, 'compiler'): res = res.remove_cb(CompilerCallback)
    return res

By default, `load_learner` will remove the `CompilerCallback`.

### Freezing and Unfreezing

In [None]:
#|export
@patch
def freeze_to(self:Learner, n:int):
    "Freeze parameter groups up to `n`"
    if self.opt is None:
        self.create_opt()
    self.opt.freeze_to(n)
    self.opt.clear_state()
    if _torch_version >= _torch_20 and isinstance(self.model, dynamo.OptimizedModule):
        if hasattr(self, 'compiler'):
            self.compiler._recompile = True
        else:
            warn("Freezing or unfreezing a compiled model isn't supported."\
                 "\nThe model must be recompiled to take effect."\
                 "\nPass `CompilerCallback(..., recompile=True)` to `Learner.cbs`"\
                 "\nor call `torch._dynamo.reset() and recompile model.")

Freezing and unfreezing models works, but they need to be recompiled after. `freeze_to` will set `CompilerCallback` to recompile the model or warn users they need to manually recompile.

### Training

In [None]:
#|export
@patch
@delegates(Learner.fit_one_cycle)
def fine_tune(self:Learner,
    epochs:int, # Number of unfrozen epochs to train
    base_lr:float=2e-3, # Base learning rate, model head unfrozen learning rate
    freeze_epochs:int=1, # Number of frozen epochs to train
    lr_mult:Numeric=100, # Model stem unfrozen learning rate: `base_lr/lr_mult`
    pct_start:float=0.3, # Start unfrozen learning rate cosine annealing
    div:Numeric=5.0, # Initial unfrozen learning rate: `base_lr/div`
    freeze_compile:bool=False, # pct_start for unfrozen fit_one_cycle
    **kwargs
):
    "Fine tune with `Learner.freeze` for `freeze_epochs`, then with `Learner.unfreeze` for `epochs`, using discriminative LR."
    self.freeze()
    if _torch_version >= _torch_20 and hasattr(self, 'compiler') and not freeze_compile:
        self.compiler.run = isinstance(self.model, dynamo.OptimizedModule)
    self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
    base_lr /= 2
    self.unfreeze()
    self.fit_one_cycle(epochs, slice(base_lr/lr_mult, base_lr), pct_start=pct_start, div=div, **kwargs)

By default, `fine_tune` will not compile the `freeze_epochs`, but this can be overridden by passing `freeze_compile=True`. If the model is already compiled, this will have no effect.

## Testing -

In [None]:
#|hide
#|slow
from packaging.version import parse
import fastai

from fastcore.basics import num_cpus

if parse(fastai.__version__) < parse('2.7.11'):
    from fastxtend.callback.channelslast import *
else:
    from fastai.callback.channelslast import *
from fastai.data.external import URLs, untar_data
from fastai.data.block import DataBlock, CategoryBlock
from fastai.data.transforms import GrandparentSplitter, get_image_files, parent_label, Normalize
from fastai.vision.augment import Resize, aug_transforms
from fastai.vision.core import imagenet_stats
from fastai.vision.data import ImageBlock
from fastai.vision.models import resnet34

from fastxtend.metrics import Accuracy
from fastxtend.optimizer.fused import adam
from fastxtend.utils import *

In [None]:
#|hide
#|slow
warnings.simplefilter("ignore")

In [None]:
#|hide
#|slow
#|cuda
imagenette = untar_data(URLs.IMAGENETTE_160)

with less_random():
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                        splitter=GrandparentSplitter(valid_name='val'),
                        get_items=get_image_files, get_y=parent_label,
                        item_tfms=Resize(128),
                        batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)])
    dls = dblock.dataloaders(imagenette, bs=128, num_workers=num_cpus(), pin_memory=True)

    learn = Learner(dls, resnet34(num_classes=dls.c), opt_func=adam(foreach=True),
                    metrics=Accuracy()).to_channelslast().compile()
    learn.fit_one_cycle(5, 1e-3)

Your GPU has modern tensor cores, automatically enabling by setting `torch.set_float32_matmul_precision('high')`


epoch,train_loss,valid_loss,accuracy,time
0,1.880623,2.615195,0.249427,00:27
1,1.511329,1.463864,0.526369,00:06
2,1.208427,1.33783,0.569427,00:06
3,0.993035,0.886476,0.717197,00:06
4,0.818099,0.817872,0.735032,00:06


In [None]:
#|hide
#|slow
#|cuda
free_gpu_memory(learn, dls)

In [None]:
#|hide
#|slow
#|cuda
imagenette = untar_data(URLs.IMAGENETTE_160)

with less_random():
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                        splitter=GrandparentSplitter(valid_name='val'),
                        get_items=get_image_files, get_y=parent_label,
                        item_tfms=Resize(128),
                        batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)])
    dls = dblock.dataloaders(imagenette, bs=128, num_workers=num_cpus(), pin_memory=True)

    learn = vision_learner(dls, resnet34, opt_func=adam(foreach=True),
                           metrics=Accuracy()).to_channelslast()
    learn.fine_tune(5)

epoch,train_loss,valid_loss,accuracy,time
0,0.754579,0.174962,0.945987,00:06


epoch,train_loss,valid_loss,accuracy,time
0,0.245339,0.129721,0.962293,00:06
1,0.183843,0.153452,0.951847,00:06
2,0.127661,0.117478,0.960509,00:06
3,0.08312,0.106381,0.96586,00:06
4,0.055219,0.101295,0.969936,00:06


In [None]:
#|hide
#|slow
#|cuda
free_gpu_memory(learn, dls)

In [None]:
#|hide
#|slow
#|cuda
imagenette = untar_data(URLs.IMAGENETTE_160)

with less_random():
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                        splitter=GrandparentSplitter(valid_name='val'),
                        get_items=get_image_files, get_y=parent_label,
                        item_tfms=Resize(128),
                        batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)])
    dls = dblock.dataloaders(imagenette, bs=128, num_workers=num_cpus(), pin_memory=True)

    learn = vision_learner(dls, resnet34, opt_func=adam(foreach=True),
                    metrics=Accuracy()).to_channelslast().compile()
    learn.fine_tune(5)

epoch,train_loss,valid_loss,accuracy,time
0,0.754579,0.174962,0.945987,00:06


epoch,train_loss,valid_loss,accuracy,time
0,0.241853,0.124922,0.96,00:26
1,0.190665,0.142393,0.95414,00:06
2,0.12772,0.114811,0.963567,00:06
3,0.082946,0.100213,0.967643,00:06
4,0.057766,0.098956,0.967389,00:06


In [None]:
#|hide
#|slow
#|cuda
learn.validate()

(#2) [0.09895562380552292,0.9673885107040405]

In [None]:
#|hide
#|slow
#|cuda
learn.save('test')

Path('models/test.pth')

In [None]:
#|hide
#|slow
#|cuda
learn.export()

In [None]:
#|hide
#|slow
#|cuda
saved = vision_learner(dls, resnet34, opt_func=adam(foreach=True),
                       metrics=Accuracy()).to_channelslast() 
saved.load('test')
saved.validate()

(#2) [0.09896746277809143,0.9673885107040405]

In [None]:
#|hide
#|slow
#|cuda
exported = load_learner('export.pkl')
exported.dls = learn.dls

In [None]:
#|hide
#|slow
#|cuda
exported.validate()

(#2) [0.098935067653656,0.9673885107040405]

In [None]:
#|hide
#|slow
#|cuda
imagenette = untar_data(URLs.IMAGENETTE_160)

with less_random():
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                        splitter=GrandparentSplitter(valid_name='val'),
                        get_items=get_image_files, get_y=parent_label,
                        item_tfms=Resize(128),
                        batch_tfms=[Normalize.from_stats(*imagenet_stats)])
    dls = dblock.dataloaders(imagenette, bs=64, num_workers=num_cpus(), pin_memory=True)

    learn = vision_learner(dls, resnet34, opt_func=adam(foreach=True), metrics=Accuracy(),
                           cbs=DynamoExplainCallback()).to_channelslast()
    learn.fit(1)





Dynamo Explain Report


Explanation:


Dynamo produced 1 graphs with 0 graph break and 17 ops


In [None]:
#|hide
#|slow
#|cuda
print(learn.dynamo_explain.ops_per_graph)

[[<built-in function iadd>, <built-in function iadd>, <built-in function iadd>, <built-in function iadd>, <built-in function iadd>, <built-in function iadd>, <built-in function iadd>, <built-in function iadd>, <built-in function iadd>, <built-in function iadd>, <built-in function iadd>, <built-in function iadd>, <built-in function iadd>, <built-in function iadd>, <built-in function iadd>, <built-in function iadd>, <built-in method cat of type object at 0x7fb991162540>]]
