In [1]:
from torchrl.data import TensorDict
from torchrl.data.tensordict.tensordict import TensorDictBase
import functorch
from torch import nn
import torch
from copy import copy, deepcopy

_RESET_OLD_TENSORDICT = True

In [2]:
from functorch._src.vmap import _add_batch_dim, tree_unflatten, tree_flatten

In [3]:
class FunctionalModule(nn.Module):
    """
    This is the callable object returned by :func:`make_functional`.
    """

    def __init__(self, stateless_model):
        super(FunctionalModule, self).__init__()
        self.stateless_model = stateless_model

    @staticmethod
    def _create_from(model, disable_autograd_tracking=False):
        # TODO: We don't need to copy the model to create a stateless copy
        model_copy = deepcopy(model)
        param_tensordict = extract_weights(model_copy)
        if disable_autograd_tracking:
            tensordict_weights.apply(lambda x: x.requires_grad_(False), inplace=True)
        return FunctionalModule(model_copy), param_tensordict

    def forward(self, params, *args, **kwargs):
        # Temporarily load the state back onto self.stateless_model
        old_state = _swap_state(self.stateless_model, params, return_old_tensordict=_RESET_OLD_TENSORDICT)
        try:
            return self.stateless_model(*args, **kwargs)
        finally:
            # Remove the loaded state on self.stateless_model
            if _RESET_OLD_TENSORDICT:
                _swap_state(self.stateless_model, old_state)


In [4]:
def extract_weights(model):
    tensordict = TensorDict({}, [])
    for name, param in list(model.named_parameters(recurse=False)):
        setattr(model, name, None)
        tensordict[name] = param
    for name, module in model.named_children():
        module_tensordict = extract_weights(module)
        if module_tensordict is not None:
            tensordict[name] = module_tensordict
    if len(tensordict.keys()):
        return tensordict
    else:
        return None

def _swap_state(model, tensordict, return_old_tensordict=False):
#     if return_old_tensordict:
#         old_tensordict = tensordict.clone(recursive=False)
#         old_tensordict.batch_size = []
    
    if return_old_tensordict:
        old_tensordict = TensorDict({}, [], device=tensordict._device_safe())

    for key, value in list(tensordict.items()):
        if isinstance(value, TensorDictBase):
            _swap_state(getattr(model, key), value)
        else:
            if return_old_tensordict:
                old_attr = getattr(model, key)
                if old_attr is None:
                    old_attr = torch.tensor([]).view(*value.shape, 0)
            delattr(model, key)
            setattr(model, key, value)
            if return_old_tensordict:
                old_tensordict.set(key, old_attr)
    if return_old_tensordict:
        return old_tensordict

In [5]:
model = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3), nn.Sequential(nn.Linear(3, 4)))
print(model)

Sequential(
  (0): Linear(in_features=1, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=3, bias=True)
  (2): Sequential(
    (0): Linear(in_features=3, out_features=4, bias=True)
  )
)


In [6]:
tensordict_weights = extract_weights(model)
print(tensordict_weights)

TensorDict(
    fields={
        0: TensorDict(
            fields={
                bias: Tensor(torch.Size([2]), dtype=torch.float32),
                weight: Tensor(torch.Size([2, 1]), dtype=torch.float32)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        1: TensorDict(
            fields={
                bias: Tensor(torch.Size([3]), dtype=torch.float32),
                weight: Tensor(torch.Size([3, 2]), dtype=torch.float32)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        2: TensorDict(
            fields={
                0: TensorDict(
                    fields={
                        bias: Tensor(torch.Size([4]), dtype=torch.float32),
                        weight: Tensor(torch.Size([4, 3]), dtype=torch.float32)},
                    batch_size=torch.Size([]),
                    device=cpu,
                    is_shared=False)},
            batch_size=torch.Size([])

In [7]:
# accessing weights
tensordict_weights["0", "bias"]

Parameter containing:
tensor([-0.3050,  0.3137], requires_grad=True)

In [8]:
tensordict_weights["0"]["bias"]

Parameter containing:
tensor([-0.3050,  0.3137], requires_grad=True)

In [9]:
# flatten - unflatten
tensordict_weights_flatten = tensordict_weights.flatten_keys(separator=".", inplace=False)
print(tensordict_weights_flatten)

TensorDict(
    fields={
        0.bias: Tensor(torch.Size([2]), dtype=torch.float32),
        0.weight: Tensor(torch.Size([2, 1]), dtype=torch.float32),
        1.bias: Tensor(torch.Size([3]), dtype=torch.float32),
        1.weight: Tensor(torch.Size([3, 2]), dtype=torch.float32),
        2.0.bias: Tensor(torch.Size([4]), dtype=torch.float32),
        2.0.weight: Tensor(torch.Size([4, 3]), dtype=torch.float32)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)


In [10]:
tensordict_weights_unflatten = tensordict_weights_flatten.unflatten_keys(separator=".", inplace=False)
print(tensordict_weights_unflatten)

TensorDict(
    fields={
        0: TensorDict(
            fields={
                bias: Tensor(torch.Size([2]), dtype=torch.float32),
                weight: Tensor(torch.Size([2, 1]), dtype=torch.float32)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        1: TensorDict(
            fields={
                bias: Tensor(torch.Size([3]), dtype=torch.float32),
                weight: Tensor(torch.Size([3, 2]), dtype=torch.float32)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        2: TensorDict(
            fields={
                0: TensorDict(
                    fields={
                        bias: Tensor(torch.Size([4]), dtype=torch.float32),
                        weight: Tensor(torch.Size([4, 3]), dtype=torch.float32)},
                    batch_size=torch.Size([]),
                    device=cpu,
                    is_shared=False)},
            batch_size=torch.Size([])

In [11]:
# BatchedTensor
t = TensorDict({"a": torch.randn(3, 1), "b": TensorDict({"c": torch.randn(3, 1)}, [])}, [])
t = t.apply(lambda x: _add_batch_dim(x, 0, 0))
t["b", "c"]

BatchedTensor(lvl=0, bdim=0, value=
    tensor([[-0.5027],
            [-0.5578],
            [ 0.0191]])
)

In [12]:
# requires_grad to False
tensordict_weights.apply(lambda x: x.requires_grad_(False), inplace=True)
tensordict_weights["0", "bias"]

Parameter containing:
tensor([-0.3050,  0.3137])

In [13]:
model = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3), nn.Sequential(nn.Linear(3, 4)))

fmodel, params = FunctionalModule._create_from(model)
params

TensorDict(
    fields={
        0: TensorDict(
            fields={
                bias: Tensor(torch.Size([2]), dtype=torch.float32),
                weight: Tensor(torch.Size([2, 1]), dtype=torch.float32)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        1: TensorDict(
            fields={
                bias: Tensor(torch.Size([3]), dtype=torch.float32),
                weight: Tensor(torch.Size([3, 2]), dtype=torch.float32)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        2: TensorDict(
            fields={
                0: TensorDict(
                    fields={
                        bias: Tensor(torch.Size([4]), dtype=torch.float32),
                        weight: Tensor(torch.Size([4, 3]), dtype=torch.float32)},
                    batch_size=torch.Size([]),
                    device=cpu,
                    is_shared=False)},
            batch_size=torch.Size([])

In [14]:
fmodel(params, torch.randn(1))

tensor([ 1.1301,  0.8452,  0.6819, -0.7968], grad_fn=<AddBackward0>)

In [15]:
fmodel(params, torch.randn(1, 1))

tensor([[ 0.8535,  1.2261,  1.0852, -0.6828]], grad_fn=<AddmmBackward0>)

In [16]:
functorch.vmap(torch.add, (0, 0))(torch.ones(10, 1), torch.ones(10, 1)).shape

torch.Size([10, 1])

In [17]:
x = torch.randn(10, 1, 1)
functorch.vmap(fmodel, (None, 0))(params, x)  # works

tensor([[[ 0.8134,  1.2813,  1.1436, -0.6662]],

        [[ 0.3356,  1.9392,  1.8399, -0.4693]],

        [[ 0.8799,  1.1897,  1.0466, -0.6937]],

        [[ 0.7426,  1.3787,  1.2467, -0.6371]],

        [[ 0.1707,  2.1663,  2.0803, -0.4013]],

        [[ 0.7977,  1.3030,  1.1665, -0.6597]],

        [[ 0.3470,  1.9236,  1.8234, -0.4740]],

        [[ 0.5987,  1.5769,  1.4564, -0.5778]],

        [[ 0.4835,  1.7356,  1.6245, -0.5302]],

        [[ 1.1662,  0.7954,  0.6293, -0.8117]]], grad_fn=<AddBackward0>)

In [18]:
functorch.vmap(fmodel, (0, 0))(params.expand(10), x)  # works

tensor([[[ 0.8134,  1.2813,  1.1436, -0.6662]],

        [[ 0.3356,  1.9392,  1.8399, -0.4693]],

        [[ 0.8799,  1.1897,  1.0466, -0.6937]],

        [[ 0.7426,  1.3787,  1.2467, -0.6371]],

        [[ 0.1707,  2.1663,  2.0803, -0.4013]],

        [[ 0.7977,  1.3030,  1.1665, -0.6597]],

        [[ 0.3470,  1.9236,  1.8234, -0.4740]],

        [[ 0.5987,  1.5769,  1.4564, -0.5778]],

        [[ 0.4835,  1.7356,  1.6245, -0.5302]],

        [[ 1.1662,  0.7954,  0.6293, -0.8117]]], grad_fn=<AddBackward0>)

In [19]:
# benchmarking
from functorch._src.make_functional import FunctionalModule as FunctionalModule_orig

model = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3), nn.Sequential(nn.Linear(3, 4)))
%timeit FunctionalModule_orig._create_from(model)
%timeit FunctionalModule._create_from(model)

460 µs ± 26.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.99 ms ± 132 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
module_orig, params_orig = FunctionalModule_orig._create_from(model)
module, params = FunctionalModule._create_from(model)

# fair comparison
_RESET_OLD_TENSORDICT = True
x = torch.randn(1)
%timeit module_orig(params_orig, x)
%timeit module(params, x)

192 µs ± 7.79 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
134 µs ± 5.36 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [21]:
# unfair comparison -- does not swap back the params
_RESET_OLD_TENSORDICT = False
x = torch.randn(1)
%timeit module_orig(params_orig, x)
%timeit module(params, x)

197 µs ± 12 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
106 µs ± 1.85 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
