In [None]:
#|default_exp callback.casttotensor

In [None]:
#|hide
from nbdev.showdoc import *

In [None]:
#|export
from packaging.version import parse

from fastcore.dispatch import cast

import fastai
from fastai.callback.core import Callback
from fastai.basics import defaults

from fastxtend.imports import *

# Cast To Tensor Backport
> A callback to cast model inputs to `Tensor` as a workaroud for a PyTorch performance bug

For use in fastai 2.6.x or older. Import globally:

```python
from fastxtend.vision.all import *
```

or individually:

```python
from fastxtend.callback import casttotensor
```

In [None]:
#|exporti
def _cast_tensor(x): 
    if isinstance(x, tuple): return tuple(_cast_tensor(x_) for x_ in x)
    else: return cast(x, Tensor) if isinstance(x,torch.Tensor) else x

In [None]:
#|export
class CastToTensorBackport(Callback):
    "Cast Subclassed Tensors to `Tensor`"
    order=9 # Right before MixedPrecision

    def before_batch(self):
        self.learn.xb,self.learn.yb = _cast_tensor(self.learn.xb),_cast_tensor(self.learn.yb)

Workaround for bug in PyTorch where subclassed tensors, such as `TensorBase`, train up to ~20% slower than `Tensor` when passed to a model. Added to `Learner` by default if using fastai 2.6.x or older. 

CastToTensorBackport is identical to the CastToTensor callback releasing with fastai 2.7.0.

CastToTensorBackport's order is right before `MixedPrecision` so callbacks which make use of fastai's tensor subclasses still can use them.

If inputs are not a subclassed tensor or tuple of tensors, you may need to cast inputs in `Learner.xb` and `Learner.yb` to `Tensor` via your own callback or in the dataloader before `Learner` performs the forward pass.

If the CastToTensorBackport workaround interferes with custom code, it can be removed:

```python
learn = Learner(...)
learn.remove_cb(CastToTensorBackport)
```

You should verify your inputs are of type `Tensor` or implement a cast to `Tensor` via a custom callback or dataloader if CastToTensor is removed.

In [None]:
#|export
if parse(fastai.__version__) < parse('2.7.0') and CastToTensorBackport not in defaults.callbacks: 
    defaults.callbacks.append(CastToTensorBackport)