# Title

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
from random import randint
from typing import Optional, Union

import jax
import jax.numpy as np
import torch
from numpy.random import randn
from torch import Tensor, nn

In [None]:
?tf.RaggedTensor.from_nested_value_rowids

In [None]:
class RecursiveNorm(nn.Module):
    def forward(self, x: Union[list, Tensor]) -> Tensor:
        if isinstance(x, list):
            return sum(self.forward(y) for y in x)
        return torch.linalg.norm(x)


module = RecursiveNorm()

In [None]:
def recursive_norm(x: Union[list[np.ndarray], np.ndarray]) -> np.ndarray:
    if isinstance(x, list):
        return sum(recursive_norm(y) for y in x)
    return np.linalg.norm(x)


jitted_recursive_norm = jax.jit(recursive_norm)

In [None]:
jax.pmap(recursive_norm)(data)

In [None]:
max_length: int = 16
max_rank: int = 5
n = randint(1, max_rank + 1)
shape = tuple(randint(1, max_length + 1) for _ in range(n))
shape

In [None]:
def random_tensor(max__length: int = 16, max_rank: int = 5):
    n = randint(1, max_rank)
    shape = tuple(randint(1, max_length) for _ in range(n))
    return np.array(randn(*shape))


def random_nested(max_length=5, max_depth=3, cur_depth=0):
    length = randint(1, 5)
    if randint(0, 1) and cur_depth < max_depth:
        # nest
        return [random_nested(cur_depth=cur_depth + 1) for _ in range(length)]
    return [random_tensor() for _ in range(length)]


def to_torch(x, device: Optional[torch.device] = None):
    if isinstance(x, list):
        return [to_torch(y) for y in x]
    return torch.tensor(x.to_py())


tensor = random_tensor()
tensor.shape

In [None]:
data = random_nested()
torch_data = to_torch(data)

In [None]:
%%timeit
recursive_norm(data)

In [None]:
%%timeit
jitted_recursive_norm(data)

In [None]:
%%timeit
module(torch_data)