In [None]:
#|default_exp patches

# Compatibility Patches
> PyTorch version compatibility and fastai backport patches for fastxtend

In [None]:
#|export
from packaging.version import parse

import fastai

from fastxtend.imports import *

In [None]:
#|exporti
_torch_version = parse(torch.__version__)
_torch_20  = parse('2.0')
_torch_113 = parse('1.13')
_torch_112 = parse('1.12')

## fastai Backports -

In [None]:
#|exporti
# This should be upstreamed in fastai 2.7.14
if parse(fastai.__version__) < parse('2.7.14'):
    from fastai.data.core import TfmdDL, DataLoader

    def device_get(self) -> torch.device|None:
        return self._device

    def device_set(self, device:int|str|torch.device|None):
        self._device, *_ = torch._C._nn._parse_to(device=device)
        if hasattr(self, 'after_batch') and hasattr(self.after_batch, 'fs'):
            for tfm in self.after_batch.fs:
                # Check that tfm.to is callable as TabularPandas & transforms set tfm.to as an object
                if hasattr(tfm, 'to') and callable(tfm.to):
                    tfm.to(device)
                else:
                    for a in L(getattr(tfm, 'parameters', None)):
                        if hasattr(getattr(tfm, a), 'to'):
                            setattr(tfm, a, getattr(tfm, a).to(device))

    # apply property patch to DataLoader
    setattr(DataLoader, 'device', property(device_get, device_set))

    @patch
    def to(self:TfmdDL, device):
        self.device = device

## PyTorch 1.12 and 1.13 -

In [None]:
#|export
if parse(fastai.__version__) < parse('2.7.12'):
    @patch
    def clone(self:TensorBase, *, memory_format=None):
        cls = type(self)
        return self.as_subclass(Tensor).clone(memory_format=memory_format).as_subclass(cls)

    @patch
    def new_empty(self:TensorBase, size, *, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):
        cls = type(self)
        if _torch_version < _torch_113 and layout is None:
            layout = torch.strided
        if _torch_version < _torch_112:
            return super(TensorBase, self).new_empty(size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad)
        return self.as_subclass(Tensor).new_empty(size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad).as_subclass(cls)

    @patch
    def new_empty(self:TensorBase, *size, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):
        cls = type(self)
        if _torch_version < _torch_113 and layout is None:
            layout = torch.strided
        if _torch_version < _torch_112:
            return super(TensorBase, self).new_empty(*size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad)
        return self.as_subclass(Tensor).new_empty(*size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad).as_subclass(cls)

In [None]:
#|hide
from copy import deepcopy

In [None]:
#|hide
x = TensorBase(torch.rand(4,3,16,16))
x.test = 'test metadata'
y = deepcopy(x)
assert hasattr(y, 'test') and y.test == x.test

## PyTorch 2.0 Nightly -

In [None]:
#|exporti
if _torch_version >= _torch_20 and parse(fastai.__version__) < parse('2.7.12'):
    from fastai.callback.training import ProgressCallback

    @patch
    def __reduce_ex__(self:TensorBase, proto):
        return super(TensorBase, self).__reduce_ex__(proto)

    @patch
    def after_batch(self:ProgressCallback):
        self.pbar.update(self.iter+1)
        if hasattr(self, 'smooth_loss'):
            self.pbar.comment = f'{self.smooth_loss.item():.4f}'

## Hugging Face MutableMapping Objects

In [None]:
#|exporti
# This was upstreamed in fastai 2.7.13
if parse(fastai.__version__) < parse('2.7.13'):
    from collections.abc import MutableMapping
    from fastcore.dispatch import retain_type
    from fastai.basics import defaults
    from fastai.learner import Learner

    def apply(func, x, *args, **kwargs):
        "Apply `func` recursively to `x`, passing on args"
        if is_listy(x):
            return type(x)([apply(func, o, *args, **kwargs) for o in x])
        if isinstance(x, (dict, MutableMapping)):
            return {k: apply(func, v, *args, **kwargs) for k,v in x.items()}
        res = func(x, *args, **kwargs)
        return res if x is None else retain_type(res, x)

    def to_device(b, device=None, non_blocking=False):
        "Recursively put `b` on `device`."
        if defaults.use_cuda==False:
            device='cpu'
        elif device is None:
            device=default_device()
        def _inner(o):
            if isinstance(o,Tensor):
                return o.to(device, non_blocking=non_blocking)
            return o
        return apply(_inner, b)

    @patch
    def _set_device(self:Learner, b):
        model_device = next(self.model.parameters()).device
        dls_device = getattr(self.dls, 'device', default_device())
        if model_device == dls_device:
            return to_device(b, dls_device)
        else:
            return to_device(b, model_device)