In [None]:
#|default_exp patches

# Compatibility Patches
> PyTorch version compatibility and fastai backport patches for fastxtend

In [None]:
#|export
from packaging.version import parse

import fastai
from fastai.torch_core import _rebuild_from_type
from fastai.data.core import TfmdDL
from fastai.callback.training import ProgressCallback

from fastxtend.imports import *

In [None]:
#|exporti
_torch_version = parse(torch.__version__)
_torch_20  = parse('2.0')
_torch_113 = parse('1.13')
_torch_112 = parse('1.12')

## fastai Backports -

In [None]:
#|exporti
# This has been upstreamed in fastai 2.7.11
if parse(fastai.__version__) < parse('2.7.11'):
    @patch
    def to(self:TfmdDL, device):
        self.device = device
        for tfm in self.after_batch.fs:
            # Check that tfm.to is callable as TabularPandas & transforms set tfm.to as an object
            if hasattr(tfm, 'to') and callable(tfm.to):
                tfm.to(device)
            else:
                for a in L(getattr(tfm, 'parameters', None)):
                    setattr(tfm, a, getattr(tfm, a).to(device))
        return self

## PyTorch 1.12 and 1.13 -

In [None]:
#|export
if parse(fastai.__version__) < parse('2.7.12'):
    @patch
    def clone(self:TensorBase, *, memory_format=None):
        cls = type(self)
        return self.as_subclass(Tensor).clone(memory_format=memory_format).as_subclass(cls)

    @patch
    def new_empty(self:TensorBase, size, *, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):
        cls = type(self)
        if _torch_version < _torch_113 and layout is None:
            layout = torch.strided
        if _torch_version < _torch_112:
            return super(TensorBase, self).new_empty(size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad)
        return self.as_subclass(Tensor).new_empty(size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad).as_subclass(cls)

    @patch
    def new_empty(self:TensorBase, *size, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):
        cls = type(self)
        if _torch_version < _torch_113 and layout is None:
            layout = torch.strided
        if _torch_version < _torch_112:
            return super(TensorBase, self).new_empty(*size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad)
        return self.as_subclass(Tensor).new_empty(*size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad).as_subclass(cls)

In [None]:
#|hide
from copy import deepcopy

In [None]:
#|hide
x = TensorBase(torch.rand(4,3,16,16))
x.test = 'test metadata'
y = deepcopy(x)
assert hasattr(y, 'test') and y.test == x.test

## PyTorch 2.0 Nightly -

In [None]:
#|exporti
if _torch_version >= _torch_20 and parse(fastai.__version__) < parse('2.7.12'):
    @patch
    def __reduce_ex__(self:TensorBase, proto):
        return super(TensorBase, self).__reduce_ex__(proto)

    @patch
    def after_batch(self:ProgressCallback):
        self.pbar.update(self.iter+1)
        if hasattr(self, 'smooth_loss'):
            self.pbar.comment = f'{self.smooth_loss.item():.4f}'