In [None]:
#|default_exp callback.compiler

In [None]:
#|exporti
# Contains code from:
# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai
# PyTorch - PyTorch BSD-style license - Copyright (c) 2013-present PyTorch contributors

# PyTorch Compile
> Callbacks and patches to integrate `torch.compile` into fastai

The `CompilerCallback` and `DynamoExplainCallback` provide an easy to use `torch.compile` integration for fastai.

For more information on `torch.compile` please read *[PyTorch's getting started](https://pytorch.org/docs/master/compile/get-started.html)* guide. For troubleshooting `torch.compile` refer to this [PyTorch Nightly guide](https://pytorch.org/docs/master/compile/index.html#troubleshooting-and-gotchas).

This module is not imported via any fastxtend all imports. You must import it separately after importing fastai and fastxtend as it modifies model saving, loading, and training:

```python
from fastxtend.callback import compiler
# or
from fastxtend.callback.compiler import *
```

To use, create a `fastai.learner.Learner` with a `torch.compile` compatible model and call [`compile`](#learner.compile)  on the `Learner` or pass `CompilerCallback` to the `Learner` of fit method callbacks.
```python
Learner(...).compile()
# or
Learner(..., cbs=CompilerCallback())
```

In [None]:
#|export
from __future__ import annotations
from typing import Dict

from copy import deepcopy
from enum import Enum
from pathlib import Path
import pickle
import warnings

from packaging.version import parse

import torch._dynamo as dynamo
from torch.serialization import FILE_LIKE

from fastai.learner import Learner, save_model, join_path_file, _cast_tensor
from fastai.callback import schedule
from fastai.callback.core import Callback, TrainEvalCallback, CancelFitException

import fastai
if parse(fastai.__version__) < parse('2.7.13'):
    from fastxtend.callback.amp import MixedPrecision
else:
    from fastai.callback.fp16 import MixedPrecision

try:
    from fastxtend.ffcv.loader import Loader
    FFCV = True
except ImportError:
    FFCV = False

from fastxtend.imports import *

In [None]:
#|hide
from fastai.vision.learner import vision_learner

In [None]:
#|hide
warnings.simplefilter('ignore')

In [None]:
#|exporti
_min_torch_20 = ismin_torch('2.0')
_only_torch_20 = ismin_torch('2.0') and notmax_torch('2.1.0')
_min_torch_21 = ismin_torch('2.1.0')

if not _min_torch_20:
    warn('Imported `fastxtend.callback.compiler`, which requires a minimum of PyTorch 2.0 to work.')

In [None]:
#|exporti
def is_compiled(model:nn.Module):
    "Check whether a `nn.Module` model has been compiled"
    return (hasattr(model, '_compiled_call_impl') and getattr(model, '_compiled_call_impl') is not None) \
            or isinstance(model, dynamo.OptimizedModule)

In [None]:
#|export
class CompileMode(str, Enum):
    "All valid `torch.compile` modes for tab-completion and typo-proofing"
    default        = 'default'
    reduceoverhead = 'reduce-overhead'
    maxautotune    = 'max-autotune'

Currently, the 'reduce-overhead' mode doesn't appear to work with all models, and 'max-autotune' shouldn't be used per *[Compile troubleshooting and gotchas](https://pytorch.org/docs/master/compile/index.html#troubleshooting-and-gotchas)*.

In [None]:
#|export
class MatMulPrecision(str, Enum):
    "All valid `matmul_precision` modes for tab-completion and typo-proofing"
    highest = 'highest'
    high    = 'high'
    medium  = 'medium'

## CompilerCallback -

In [None]:
#|export
class CompilerCallback(Callback):
    "A callback for using `torch.compile` (beta) with fastai"
    order = TrainEvalCallback.order + 1 # Compiling needs to occur on the GPU, but before distributed training starts

    def __init__(self,
        fullgraph:bool=False, # Prevent breaking model into subgraphs
        dynamic:bool=False, # Use dynamic shape tracing
        backend:str|Callable='inductor', # `torch.compile` backend to use
        mode:str|CompileMode|None=None, # `torch.compile` mode to use
        options:Dict[str, Union[str,int,bool]]|None=None, # Extra options to pass to compile backend
        matmul_precision:str|MatMulPrecision='high', # Set Ampere and newer matmul precision
        recompile:bool=False, # Force a compiled model to recompile. Use when freezing/unfreezing a compiled model.
        verbose:bool=False, # Verbose output
    ):
        mode = CompileMode(mode).value if mode is not None else mode
        matmul_precision = MatMulPrecision(matmul_precision).value
        if mode is not None and options is not None:
            raise ValueError(f"Both {mode=} or {options=} cannot be set at the same time.")
        store_attr(but='recompile')
        self._recompile = recompile

    def before_fit(self):
        if not _min_torch_20:
            self.run = False
            warn("Attempting to use `CompilerCallback` without PyTorch 2.0 or greater. Disabling.")
            return

        if torch.cuda.get_device_capability() >= (8, 0) and torch.get_float32_matmul_precision() != self.matmul_precision:
            if self.verbose and self.matmul_precision != 'highest':
                print(f"Your GPU has modern tensor cores, automatically enabling by setting `torch.set_float32_matmul_precision('{self.matmul_precision}')`")
            torch.set_float32_matmul_precision(self.matmul_precision)

        if hasattr(self.learn, 'progressive_resize') and _only_torch_20:
            warn("Using `ProgressiveResize` and `torch.compile` at the same time with PyTorch 2.0 will result in a new compile every size change.")
        msg = ""
        if self.dynamic and _only_torch_20:
            msg += "Using PyTorch 2.0 `compile` with dynamic shapes is under active development and might fail. Upgrade to PyTorch 2.1.\n"
        if self.mode == 'max-autotune':
            msg += "Using `torch.compile` with `mode='max-autotune'` is under active development and might fail\n"
        if msg != "":
            msg += "See https://pytorch.org/docs/master/compile/index.html#troubleshooting-and-gotchas for more details"
            warn(msg)
        if self.mode == 'reduce-overhead':
            warn("Using `torch.compile` & fastai with `mode='reduce-overhead'` currently doesn't appear to train.")

        if self._recompile and is_compiled(self.learn.model):
            if self.verbose:
                print("Recompiling model")
            self._reset_compiled()
        self._recompile = False

        if not is_compiled(self.learn.model):
            if not isinstance(self.learn.model, nn.Module):
                warn("Model is not ")
            if hasattr(self.learn.model, 'compile'):
                self.learn.model.compile(fullgraph=self.fullgraph, dynamic=self.dynamic,
                                         backend=self.backend, mode=self.mode, options=self.options)
            else:
                compiled_model = torch.compile(self.learn.model, fullgraph=self.fullgraph,
                                               dynamic=self.dynamic, backend=self.backend,
                                               mode=self.mode, options=self.options)
                self.learn.model = compiled_model._orig_mod
                self.learn.model._orig_forward = self.learn.model.forward
                self.learn.model.forward = compiled_model.dynamo_ctx(self.learn.model.forward)
                self.learn.model._compiled_call_impl = True
        else:
            warn("Model is already compiled. To recompile pass `recomple=True` to CompilerCallback.")

    def _reset_compiled(self):
        if is_compiled(self.learn.model):
            dynamo.reset()
            self.learn.model._compiled_call_impl = None
            if hasattr(self.learn.model, '_orig_forward'):
                self.learn.model.forward = self.learn.model._orig_forward
            if isinstance(self.learn.model, dynamo.OptimizedModule):
                self.learn.model = self.learn.model._orig_mod

Using `torch.compile` with `mode='max-autotune'` is under active development and might fail. See *[Compile troubleshooting and gotchas](https://pytorch.org/docs/master/compile/index.html#troubleshooting-and-gotchas)* for more details.

By default, <code>CompilerCallback</code> will set matmul ops to use TensorFloat32 for supported GPUs, which is the recommended setting for `torch.compile`. Set `matmul_precision='highest'` to turn off or `matmul_precision='medium'` to enable `bfloat16` mode.

fastxtend provides the [`compile`](#learner.compile) convenience method for easily enabling `torch.compile`. Or you can pass <code>CompilerCallback</code> to the `cbs` argument of the `fastai.learner.Learner` or a fit method.

```python
learn = Learner(..., cbs=CompilerCallback())
learn.fine_tune(1)
```

## DynamoExplainCallback -

In [None]:
#|export
class DynamoExplainCallback(Callback):
    "A callback to automate finding graph breaks with PyTorch Compile's Dynamo Explain"
    order = MixedPrecision.order+1 # DynamoExplain occurs on the GPU before any training starts

    def __init__(self,
        print_results:bool=True, # Print enabled `torch._dynamo.explain` output(s)
        out_guards:bool=False, # Print the `out_guards` output
        ops_per_graph:bool=False, # Print the `ops_per_graph` output
        break_reasons:bool=False, # Print the `break_reasons` output
    ):
        self.print_results = print_results
        self.print_out_guards = out_guards
        self.print_ops_per_graph = ops_per_graph
        self.print_break_reasons = break_reasons

    def before_fit(self):
        if not _min_torch_20:
            self.run = False
            warn("Attempting to use `DynamoExplainCallback` without PyTorch 2.0 or greater. Canceling training.")
            raise CancelFitException()

        self.explanation, self.out_guards, self.graphs, self.ops_per_graph, self.break_reasons, self.explanation_verbose = '','','','','',''
        states = get_random_states()
        try:
            if FFCV and isinstance(self.dls.train, Loader) and self.dls.train.async_tfms:
                # With `async_tfms`, `Loader` needs to initialize all `Loader.batches_ahead` Cuda streams
                # for the training dataloader. Since FFCV doesn't support seeded transforms and the reset
                # random state only seeds the dataset order, this shouldn't effect training outcome.
                b = self.dls.train.one_batch(batches_ahead=True)
            elif hasattr(self.dls.valid, 'one_batch'):
                b = self.dls.valid.one_batch()
            else:
                b = next(iter(self.dls.valid))
                model_device = next(self.model.parameters()).device
                b = to_device(b, model_device)
            i = getattr(self.dls, 'n_inp', 1 if len(b)==1 else len(b)-1)
            self.learn.xb, self.learn.yb = b[:i], b[i:]

            if hasattr(self.learn, 'mixed_precision'):
                self.learn.mixed_precision.autocast.__enter__()

            if _only_torch_20:
                self.explanation, self.out_guards, self.graphs, self.ops_per_graph, self.break_reasons, self.explanation_verbose \
                    = dynamo.explain(self.learn.model, *_cast_tensor(self.learn.xb))
            else:
                self.explain_output = dynamo.explain(self.learn.model)(*_cast_tensor(self.learn.xb))

            if hasattr(self.learn, 'mixed_precision'):
                self.learn.mixed_precision.autocast.__exit__(None, None, None)

            self.learn.opt.zero_grad()
        finally:
            set_random_states(**states)

        if self.print_results:
            print('\nDynamo Explain Report\n')
            if _min_torch_21:
                print_copy = deepcopy(self.explain_output)
                if not self.print_ops_per_graph:
                    print_copy.ops_per_graph = None
                if not self.print_out_guards:
                    print_copy.out_guards = None
                print(print_copy)
                print_copy = None
            else:
                output = "Explanation:\n"
                output += f"  {self.explanation}\n"
                output += "Break Reasons:\n"
                for idx, break_reason in enumerate(self.break_reasons):
                    output += f"  Break Reason {idx+1}:\n"
                    output += f"    Reason: {break_reason.reason}\n"
                    output += "    User Stack:\n"
                    for frame_summary in break_reason.user_stack:
                        output += f"      {frame_summary}\n"

                if self.ops_per_graph is not None and self.print_ops_per_graph:
                    output += "Ops per Graph:\n"
                    for idx, ops in enumerate(self.ops_per_graph):
                        output += f"  Ops {idx+1}:\n"
                        for op in ops:
                            output += f"    {op}\n"

                if self.out_guards is not None and self.print_out_guards:
                    output += "Out Guards:\n"
                    for i, guard in enumerate(self.out_guards[0]):
                        output += f"  Guard {i+1}:\n"
                        output += f"    {str(guard)}"

                print(output)

            print('\n')

        raise CancelFitException()

`DynamoExplainCallback` automates finding graph breaks using `torch._dynamo.explain` per the *[Identifying the cause of a graph break](https://pytorch.org/docs/stable/torch.compiler_faq.html#identifying-the-cause-of-a-graph-break)* section in the *[PyTorch Compile FAQ](https://pytorch.org/docs/stable/torch.compiler_faq.html)*. <code>DynamoExplainCallback</code> uses one batch from the validation dataloader[^Loader] to generate the `_dynamo.explain` report(s) and then cancels training.

To use, pass <code>DynamoExplainCallback</code> to the `cbs` argument of the `fastai.learner.Learner` or fit method.

```python
learn = Learner(..., cbs=DynamoExplainCallback())
learn.fit(1)
```

By default, <code>DynamoExplainCallback</code> prints the basic explanation output from `_dynamo.explain`, with arguments to enable printing `out_guards` and/or `ops_per_graph`.

All `_dynamo.explain` outputs are stored as attributes in the callback for later reference. For example, to view the out_guards after running `Learner` with <code>DynamoExplainCallback</code>:

```python
# PyTorch 2.0
print(learn.dynamo_explain.out_guards)

# PyTorch 2.1
print(learn.dynamo_explain.explain_output.out_guards)
```

[^Loader]: Unless using the FFCV `Loader`, then it uses the training dataloader. This doesn't effect seeded training as FFCV dataloaders do not seed transforms, only dataset order.

## Convenience Method

fastxtend adds a convenience method to `fastai.learner.Learner` to easily enable `torch.compile`.

In [None]:
#|export
@patch
def compile(self:Learner,
    fullgraph:bool=False, # Prevent breaking model into subgraphs
    dynamic:bool=False, # Use dynamic shape tracing. Sets to `False` if PyTorch < 2.1
    backend:str|Callable='inductor', # `torch.compile` backend to use
    mode:str|CompileMode|None=None, # `torch.compile` mode to use
    options:Dict[str, Union[str,int,bool]]|None=None, # Extra options to pass to compile backend
    matmul_precision:str|MatMulPrecision='high', # Set Ampere and newer matmul precision
    recompile:bool=False, # Force a compiled model to recompile. Use when freezing/unfreezing a compiled model.
    verbose:bool=False, # Verbose output
):
    "Set `Learner` to compile model using `torch.compile` via `CompilerCallback`"
    return self.add_cb(CompilerCallback(fullgraph=fullgraph, dynamic=dynamic if _min_torch_21 else False,
                                        backend=backend, mode=mode, options=options,
                                        matmul_precision=matmul_precision,
                                        recompile=recompile, verbose=verbose))

`compile` only sets `dynamic` if using PyTorch 2.1 or later, for PyTorch 2.0 it's hardcoded to `False`. You can override this by directly setting via `CompilerCallback`.

To use, call the `compile` method after initalizing a `fastai.learner.Learner`.

```python
learn = Learner(...).compile()
learn.fine_tune(1)
```

## Compatability Patches

These patches integrate `torch.compile` with fastai exporting, loading, freezing, unfreezing, and fine tuning.

### Exporting and Loading

In [None]:
#|export
@patch
def export(self:Learner,
    fname:FILE_LIKE='export.pkl', # Learner export file name, path, bytes, or IO
    pickle_module:Any=pickle, # Module used for pickling metadata and objects
    pickle_protocol:int=2 # Pickle protocol used
):
    "Export the content of `self` without the items and the optimizer state for inference"
    if rank_distrib(): return # don't export if child proc
    self._end_cleanup()
    old_dbunch = self.dls
    self.dls = self.dls.new_empty()
    state = self.opt.state_dict() if self.opt is not None else None
    self.opt = None
    compiled_forward = None
    # torch.compiled models currently cannot be pickled.
    if _only_torch_20 and is_compiled(self.model):
        compiled_forward = self.model.forward
        self.model.forward = self.model._orig_forward
        delattr(self.model, '_orig_forward')
    with warnings.catch_warnings():
        # To avoid the warning that come from PyTorch about model not being checked
        warnings.simplefilter("ignore")
        torch.save(self, self.path/fname, pickle_module=pickle_module, pickle_protocol=pickle_protocol)
    if _min_torch_20 and compiled_forward is not None:
        self.model._orig_forward = self.model.forward
        self.model.forward = compiled_forward
    self.create_opt()
    if state is not None: self.opt.load_state_dict(state)
    self.dls = old_dbunch

In [None]:
#|export
def load_learner(
    fname:FILE_LIKE, # File name, path, bytes, or IO
    cpu:bool=True, # Load model to CPU
    pickle_module=pickle # Module used for unpickling metadata and objects
):
    "Load a `Learner` object in `fname`, by default putting it on the `cpu`"
    distrib_barrier()
    map_loc = 'cpu' if cpu else default_device()
    try: res = torch.load(fname, map_location=map_loc, pickle_module=pickle_module)
    except AttributeError as e:
        e.args = [f"Custom classes or functions exported with your `Learner` not available in namespace.\Re-declare/import before loading:\n\t{e.args[0]}"]
        raise
    if cpu:
        res.dls.cpu()
        if hasattr(res, 'channels_last'):
            res = res.to_contiguous(to_fp32=True)
        elif hasattr(res, 'mixed_precision'):
            res = res.to_fp32()
        elif hasattr(res, 'non_native_mixed_precision'):
            res = res.to_non_native_fp32()
        if hasattr(res, 'compiler'):
            res = res.remove_cb(CompilerCallback)
    return res

By default, `load_learner` will remove the `CompilerCallback`.

### Freezing and Unfreezing

In [None]:
#|export
@patch
def freeze_to(self:Learner, n:int):
    "Freeze parameter groups up to `n`"
    if self.opt is None:
        self.create_opt()
    self.opt.freeze_to(n)
    self.opt.clear_state()
    if _min_torch_20 and is_compiled(self.model):
        if hasattr(self, 'compiler'):
            self.compiler._recompile = True
        else:
            warn("Freezing or unfreezing a compiled model isn't supported."\
                 "\nThe model must be recompiled to take effect."\
                 "\nPass `CompilerCallback(..., recompile=True)` to `Learner.cbs`"\
                 "\nor call `torch._dynamo.reset() and recompile model.")

Freezing and unfreezing models works, but they need to be recompiled after. `freeze_to` will set `CompilerCallback` to recompile the model or warn users they need to manually recompile.

### Training

In [None]:
#|export
@patch
@delegates(Learner.fit_one_cycle)
def fine_tune(self:Learner,
    epochs:int, # Number of unfrozen epochs to train
    base_lr:float=2e-3, # Base learning rate, model head unfrozen learning rate
    freeze_epochs:int=1, # Number of frozen epochs to train
    lr_mult:Numeric=100, # Model stem unfrozen learning rate: `base_lr/lr_mult`
    pct_start:float=0.3, # Start unfrozen learning rate cosine annealing
    div:Numeric=5.0, # Initial unfrozen learning rate: `base_lr/div`
    compile_frozen:bool=False, # Compile model during frozen finetuning if `CompilerCallback` is used
    **kwargs
):
    "Fine tune with `Learner.freeze` for `freeze_epochs`, then with `Learner.unfreeze` for `epochs`, using discriminative LR."
    self.freeze()
    if _min_torch_20 and hasattr(self, 'compiler') and not compile_frozen:
        self.compiler.run = is_compiled(self.model)
    self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
    base_lr /= 2
    self.unfreeze()
    self.fit_one_cycle(epochs, slice(base_lr/lr_mult, base_lr), pct_start=pct_start, div=div, **kwargs)

By default, `fine_tune` will not compile the `freeze_epochs`, but this can be overridden by passing `freeze_compile=True`. If the model is already compiled, this will have no effect.

## Testing -

In [None]:
#|hide
#|cuda
from packaging.version import parse
import tempfile

import fastai

from fastcore.basics import num_cpus

if parse(fastai.__version__) < parse('2.7.11'):
    from fastxtend.callback.channelslast import *
else:
    from fastai.callback.channelslast import *
from fastai.data.external import URLs, untar_data
from fastai.data.block import DataBlock, CategoryBlock
from fastai.data.transforms import GrandparentSplitter, get_image_files, parent_label, Normalize
from fastai.vision.augment import Resize, aug_transforms
from fastai.vision.core import imagenet_stats
from fastai.vision.data import ImageBlock
from fastai.vision.models import resnet34

from fastxtend.metrics import Accuracy
from fastxtend.optimizer.fused import adam
from fastxtend.utils import *

In [None]:
#|hide
#|cuda
warnings.simplefilter("ignore")

In [None]:
#|hide
#|cuda
imagenette = untar_data(URLs.IMAGENETTE_160)

with less_random():
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                       splitter=GrandparentSplitter(valid_name='val'),
                       get_items=get_image_files, get_y=parent_label,
                       item_tfms=Resize(128),
                       batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)])
    dls = dblock.dataloaders(imagenette, bs=128, num_workers=num_cpus(), pin_memory=True)

    learn = Learner(dls, resnet34(num_classes=dls.c), opt_func=adam(foreach=True),
                    metrics=Accuracy()).to_channelslast().compile(fullgraph=True)
    learn.fit_one_cycle(3, 1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,1.858504,2.472007,0.326624,00:38
1,1.4438,1.235949,0.59949,00:07
2,1.099457,0.968236,0.686879,00:07


In [None]:
#|hide
#|cuda
free_gpu_memory(learn, dls)

In [None]:
#|hide
#|cuda
imagenette = untar_data(URLs.IMAGENETTE_160)

with less_random():
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                       splitter=GrandparentSplitter(valid_name='val'),
                       get_items=get_image_files, get_y=parent_label,
                       item_tfms=Resize(128),
                       batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)])
    dls = dblock.dataloaders(imagenette, bs=128, num_workers=num_cpus(), pin_memory=True)

    learn = vision_learner(dls, resnet34, opt_func=adam(foreach=True),
                           metrics=Accuracy()).to_channelslast()
    learn.fine_tune(3)

epoch,train_loss,valid_loss,accuracy,time
0,0.75506,0.175099,0.945223,00:06


epoch,train_loss,valid_loss,accuracy,time
0,0.248197,0.140395,0.957452,00:07
1,0.1766,0.125655,0.960764,00:07
2,0.111107,0.104562,0.967389,00:06


In [None]:
#|hide
#|cuda
free_gpu_memory(learn, dls)

In [None]:
#|hide
#|cuda
imagenette = untar_data(URLs.IMAGENETTE_160)

with less_random():
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                       splitter=GrandparentSplitter(valid_name='val'),
                       get_items=get_image_files, get_y=parent_label,
                       item_tfms=Resize(128),
                       batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)])
    dls = dblock.dataloaders(imagenette, bs=128, num_workers=num_cpus(), pin_memory=True)

    learn = vision_learner(dls, resnet34, opt_func=adam(foreach=True),
                           metrics=Accuracy()).to_channelslast().compile()
    learn.fine_tune(3, compile_frozen=True)

epoch,train_loss,valid_loss,accuracy,time
0,0.759313,0.172038,0.945987,00:40


epoch,train_loss,valid_loss,accuracy,time
0,0.24912,0.127831,0.959236,00:36
1,0.171835,0.124405,0.962038,00:07
2,0.108882,0.104857,0.967389,00:07


In [None]:
#|hide
#|cuda
learn.validate()

(#2) [0.10485660284757614,0.9673885107040405]

In [None]:
#|hide
#|cuda
tmp_de = tempfile.TemporaryDirectory(dir=learn.path)
tmp_pe = Path(Path(tmp_de.name).stem)
tmp_ds = tempfile.TemporaryDirectory(dir=learn.path/learn.model_dir)
tmp_ps = Path(Path(tmp_ds.name).stem)

In [None]:
#|hide
#|cuda
learn.save(tmp_ps/'test')

Path('models/tmpbpa6ct7y/test.pth')

In [None]:
#|hide
#|cuda
learn.export(tmp_pe/'export.pkl')

In [None]:
#|hide
#|cuda
saved = vision_learner(dls, resnet34, opt_func=adam(foreach=True),
                       metrics=Accuracy()).to_channelslast()
saved.load(tmp_ps/'test')
saved.validate()

(#2) [0.10484044998884201,0.9673885107040405]

In [None]:
#|hide
#|cuda
exported = load_learner(tmp_pe/'export.pkl')
exported.dls = learn.dls

In [None]:
#|hide
#|cuda
tmp_ds.cleanup()
tmp_de.cleanup()

In [None]:
#|hide
#|cuda
exported.validate()

(#2) [0.10487388074398041,0.9673885107040405]

In [None]:
#|hide
#|cuda
free_gpu_memory(learn, dls)

In [None]:
#|hide
#|cuda
imagenette = untar_data(URLs.IMAGENETTE_160)

with less_random():
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                        splitter=GrandparentSplitter(valid_name='val'),
                        get_items=get_image_files, get_y=parent_label,
                        item_tfms=Resize(128),
                        batch_tfms=[Normalize.from_stats(*imagenet_stats)])
    dls = dblock.dataloaders(imagenette, bs=64, num_workers=num_cpus(), pin_memory=True)

    learn = vision_learner(dls, resnet34, opt_func=adam(foreach=True), metrics=Accuracy(),
                           cbs=DynamoExplainCallback()).to_channelslast()
    learn.fit(1)


Dynamo Explain Report

Graph Count: 1
Graph Break Count: 0
Op Count: 17
Break Reasons:
Compile Times: TorchDynamo compilation metrics:
Function                          Runtimes (s)
------------------------------  --------------
_compile                                0.8507
OutputGraph.call_user_compiler          0.0013



