In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
from typing import *

import numpy as np
import torch
from torch import Tensor, jit

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
?torch.jit.annotate

In [None]:
@jit.script
def aggregate_and(
    x: Tensor,
    dim: list[int],
    keepdim: bool = False,
) -> Tensor:
    r"""Compute logical ``AND`` across dim."""
    dims = torch.jit.annotate(List[int], [])

    if dim is None:
        dims = list(range(x.ndim))
    elif isinstance(dim, int):
        dims = [dim]
    else:
        dims = dim

    if isinstance(dims, tuple):
        if len(dims) == 0:
            return x
        if keepdim:
            for d in dims:
                x = torch.all(x, dim=d, keepdim=keepdim)
        else:
            for i, d in enumerate(dims):
                x = torch.all(x, dim=d - i, keepdim=keepdim)
        return x

    if keepdim:
        for d in dims:
            x = torch.all(x, dim=d, keepdim=keepdim)
    else:
        for i, d in enumerate(dims):
            x = torch.all(x, dim=d - i, keepdim=keepdim)

    return x

In [None]:
m = torch.randn(3, 4, 5) > 0.1
aggregate_and(m, dim=[])

In [None]:
import torch
from torch import BoolTensor, Tensor, jit
from torch.nn.utils.rnn import pack_sequence, pad_sequence

In [None]:
tensors = [torch.randn(n + 1) for n in range(3)]
lengths = torch.tensor([len(t) for t in tensors])
tensors

In [None]:
batch = pad_sequence(tensors, batch_first=True)

In [None]:
from tsdm.utils.data import aggregate_and

m = torch.randn(3, 4, 5) > 0.1
aggregate_and(m, dim=())

In [None]:
@jit.script
def torch_being_dumb(im: Union[None, int, list[int], tuple] = None) -> int:
    return len(im)

In [None]:
torch_being_dumb(())

In [None]:
len(())

In [None]:
raise

In [None]:
unpad_sequence(batch, batch_first=True)

In [None]:
[x[:l] for x, l in zip(batch_pad_packed, lengths)]

In [None]:
batch_pad_packed

In [None]:
?unpack_sequence

In [None]:
tensors = [torch.randn(abs(n - 3), 3) for n in range(6)]

for i, t in enumerate(tensors):
    if len(t) > 0:
        tensors[i][0] = float("nan")
tensors

In [None]:
padded_seq = pad_sequence(tensors, batch_first=True, padding_value=float("nan"))
padded_seq.swapaxes(-1, -2)

In [None]:
@torch.jit.script
def unpad_sequence(
    padded_seq: Tensor,
    batch_first: bool = False,
    lengths: Optional[Tensor] = None,
    padding_value: float = 0.0,
) -> list[Tensor]:
    r"""Reverse operation of `torch.nn.utils.rnn.pad_sequence`."""
    padded_seq: Tensor = padded_seq.swapaxes(0, 1) if not batch_first else padded_seq
    padding: Tensor = torch.tensor(
        padding_value, dtype=padded_seq.dtype, device=padded_seq.device
    )

    if lengths is not None:
        return [x[0:l] for x, l in zip(padded_seq, lengths)]

    # infer lengths from mask
    if torch.isnan(padding):
        mask = torch.isnan(padded_seq)
    else:
        mask = padded_seq == padding_value

    # all features are masked
    dims = list(range(2, padded_seq.ndim))
    agg = aggregate_and(mask, dim=dims)
    # count, starting from the back, until the first observation occurs.
    inferred_lengths = (~cumulative_and(agg.flip(dims=(1,)), dim=1)).sum(dim=1)

    return [x[0:l] for x, l in zip(padded_seq, inferred_lengths)]

In [None]:
from torch.nn.utils.rnn import pack_sequence, pad_sequence

from tsdm.utils.data import unpack_sequence, unpad_sequence

In [None]:
tensors = [torch.randn(1 + abs(n - 3), 3) for n in range(6)]

In [None]:
packed = pack_sequence(tensors, enforce_sorted=False)
unpacked = unpack_sequence(packed)

In [None]:
def unpack_sequence(batch: PackedSequence) -> list[Tensor]:
    r"""Reverse operation of pack_sequence."""
    batch_pad_packed, lengths = pad_packed_sequence(batch, batch_first=True)
    torch.swapaxes(batch_pad_packed, 1, 2)
    return [x[:l] for x, l in zip(batch_pad_packed, lengths)]

In [None]:
tensors

In [None]:
unpad_sequence(padded_seq, batch_first=True, padding_value=float("nan"))

In [None]:
b = torch.randn(4, 5, 6) > 0.1
b.all(dim=(1, 2))

In [None]:
import jax.numpy as jnp

jnp.add.aggregate

In [None]:
set(dir(jnp.add))

In [None]:
from torch import BoolTensor

In [None]:
@jit.script
def get_longest(x: Tensor, value: Tensor, reverse: bool = False) -> Tensor:
    """take as long as equal to value"""

    y = torch.flip(x, dims=(0,)) if reverse else x

    if torch.isnan(value):
        i = 0
        for el in y:
            if not torch.isnan(el).all():
                break
            i += 1
        return x[:i]

    i = 0
    while (y[i] == value).all():
        i += 1
    return x[:i]

In [None]:
@jit.script
def aggregate_and(
    x: BoolTensor,
    dim: Union[None, int, list[int]] = None,
    keepdim: bool = False,
) -> BoolTensor:
    r"""Compute logical ``AND`` across dim."""

    if dim is None:
        dims = list(range(x.ndim))
    elif isinstance(dim, int):
        dims = [dim]
    else:
        dims = dim

    if keepdim:
        for d in dims:
            x = torch.all(x, dim=d, keepdim=keepdim)
    else:
        for i, d in enumerate(dims):
            x = torch.all(x, dim=d - i, keepdim=keepdim)

    return x

In [None]:
x = torch.isnan(padded_seq)
aggregate_and(x)