In [None]:
import numpy as np

U = 1000
I = 100
R = np.random.choice([0, 1], (U, I))
dU = R.sum(axis=1)
dI = R.sum(axis=0)

DU_invsqrt = np.diag(1 / np.sqrt(dU))
DI_invsqrt = np.diag(1 / np.sqrt(dI))

R_tilde = DU_invsqrt @ R @ DI_invsqrt
svd = np.linalg.svd(R_tilde)
svd.S

In [None]:
DI.shape

In [None]:
import numpy as np

In [None]:
x0 = y0[..., 0]
v0 = y0[..., 1]

# mod out period
period = 4 / abs(v0)
t = np.mod.outer(t, period)
t = np.moveaxis(t, 0, -1)  # move time axis to back

next_wall = np.sign(v0)  # the next wall the ball will hit
time_to_wall = (next_wall - x0) / v0

In [None]:
np.moveaxis(t, 0, -1).shape

In [None]:
next_wall = np.sign(v0)  # the next wall the ball will hit
time_to_wall = (next_wall - x0) / v0

In [None]:
t.shape

In [None]:
v0 * t

In [None]:
t.shape

In [None]:
def solve_ivp(t, *, y0: np.ndarray) -> np.ndarray:
    """Solve the initial value problem.

    Signature: ``[(N,), (..., 2)] -> (..., N)``
    """
    x0 = y0[..., 0]
    v0 = y0[..., 1]

    # mod out period
    period = 4 / abs(v0)
    half_period = 2 / abs(v0)
    t = np.mod.outer(t, period)

    next_wall = np.sign(v0)  # the next wall the ball will hit

    t1 = (next_wall - x0) / v0
    t2 = t1 + half_period
    t3 = t2 + half_period

    x = np.select(
        [
            t <= t1,
            (t > t1) & (t <= t2),
            t > t2,
        ],
        [
            x0 + v0 * t,
            next_wall - v0 * (t - t1),
            -next_wall + v0 * (t - t2),
        ],
    )
    assert x.min() >= -1 and x.max() <= +1

    # move time axis to the back
    x = np.moveaxis(x, 0, -1)
    return x

In [None]:
x0 = y0[..., 0]
v0 = y0[..., 1]

# mod out period
period = 4 / abs(v0)
half_period = 2 / abs(v0)
t = np.mod.outer(t, period)

next_wall = np.sign(v0)  # the next wall the ball will hit

t1 = (next_wall - x0) / v0
t2 = t1 + half_period
t3 = t2 + half_period

x = np.select(
    [
        t <= t1,
        (t > t1) & (t <= t2),
        t > t2,
    ],
    [
        x0 + v0 * t,
        next_wall - v0 * (t - t1),
        -next_wall + v0 * (t - t2),
    ],
)
assert x.min() >= -1 and x.max() <= +1

In [None]:
t2.min(), t2.max()

In [None]:
abs(t2) < half_period

In [None]:
v0 * half_period

In [None]:
x.min()

In [None]:
(v0 * half_period).max()

In [None]:
x.max()

In [None]:
from scipy.stats import truncnorm

In [None]:
truncnorm

In [None]:
y0

In [None]:
import re

In [None]:
pattern = re.compile('(?<=[:\n]\n)\s*"""')

pattern.findall(file)

In [None]:
file = r'''
#!/usr/bin/env python
"""Check whether attributes in annotations shadow directly imported symbols.

Example:
    >>> import collections.abc as abc
    >>> from collections.abc import Sequence
    >>>
    >>> def foo(x: abc.Sequence) -> abc.Sequence:
    >>>     return x

    Would raise an error because `pd.DataFrame` shadows directly imported `DataFrame`.
"""

__all__ = [
    "get_pure_attributes",
    "get_full_attribute_parent",
    "get_imported_symbols",
    "get_imported_attributes",
    "check_file",
    "main",
]


import argparse
import ast
import logging
import sys
from ast import AST, Attribute, Name
from collections.abc import Iterator
from pathlib import Path
from typing import TypeGuard

from assorted_hooks.utils import get_python_files

__logger__ = logging.getLogger(__name__)


def is_pure_attribute(node: AST, /) -> TypeGuard[Attribute]:
    """Check whether a node is a pure attribute."""
    return isinstance(node, Attribute) and (
        isinstance(node.value, Name) or is_pure_attribute(node.value)
    )


def get_pure_attributes(tree: AST, /) -> Iterator[Attribute]:
    """Get all nodes that consist only of attributes."""
    for node in ast.walk(tree):
        if is_pure_attribute(node):
            yield node


def get_full_attribute_parent(node: Attribute | Name, /) -> tuple[Name, str]:
    """Get the parent of an attribute node."""
    if isinstance(node, Attribute):
        if not isinstance(node.value, Attribute | Name):
            raise ValueError(
                f"Expected Attribute or Name, got {type(node.value)} {vars(node.value)=}"
            )
        parent, string = get_full_attribute_parent(node.value)
        return parent, f"{string}.{node.attr}"

    if not isinstance(node, Name):
        raise ValueError(f"Expected ast.Name, got {type(node)=}  {vars(node.value)=}")

    return node, node.id


def get_imported_symbols(tree: AST, /) -> dict[str, str]:
    """Get all imported symbols."""
    imported_symbols = {}

    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            for alias in node.names:
                imported_symbols[alias.asname or alias.name] = alias.name
        elif isinstance(node, ast.ImportFrom):
            module_name = node.module
            if module_name is not None:
                for alias in node.names:
                    full_name = f"{module_name}.{alias.name}"
                    imported_symbols[alias.asname or alias.name] = full_name

    return imported_symbols


def get_imported_attributes(tree: AST, /) -> Iterator[tuple[Attribute, Name, str]]:
    """Finds attributes that can be replaced by directly imported symbols."""
    imported_symbols = get_imported_symbols(tree)

    for node in get_pure_attributes(tree):
        if node.attr in imported_symbols:
            # parent = get_full_attribute_string(node)
            parent, string = get_full_attribute_parent(node)

            head, tail = string.split(".", maxsplit=1)
            assert head == parent.id

            # e.g. DataFrame -> pandas.DataFrame
            matched_symbol = imported_symbols[node.attr]
            is_match = matched_symbol == string

            # need to check if parent is imported as well to catch pd.DataFrame
            if parent.id in imported_symbols:
                parent_alias = imported_symbols[parent.id]  # e.g. pd -> pandas
                is_match |= matched_symbol == f"{parent_alias}.{tail}"

            if is_match:
                yield node, parent, string


def check_file(file_path: Path, /, *, debug: bool = False) -> bool:
    """Finds shadowed attributes in a file."""
    passed = True

    # Your code here
    with open(file_path, "r", encoding="utf8") as file:
        tree = ast.parse(file.read())

    # find all violations
    for node, _, string in get_imported_attributes(tree):
        passed = False
        print(
            f"{file_path!s}:{node.lineno!s}"
            f" use directly imported {node.attr!r} instead of {string!r}"
        )

    if not passed and debug:
        imported_symbols = get_imported_symbols(tree)
        pad = " " * 4
        max_key_len = max(map(len, imported_symbols), default=0)
        print(pad, "Imported symbols:")
        for key, value in imported_symbols.items():
            print(2 * pad, f"{key:{max_key_len}} -> {value}")

    return passed


def main() -> None:
    """Main function."""
    parser = argparse.ArgumentParser(
        description="Checks that Bar is used instead of foo.Bar if both foo and Bar are imported.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "files",
        nargs="+",
        type=str,
        help="One or multiple files, folders or file patterns.",
    )
    parser.add_argument(
        "--debug",
        action=argparse.BooleanOptionalAction,
        type=bool,
        default=False,
        help="Print debug information.",
    )
    args = parser.parse_args()

    if args.debug:
        logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
        __logger__.debug("args: %s", vars(args))

    # find all files
    files: list[Path] = get_python_files(args.files)

    # apply script to all files
    passed = True
    for file in files:
        __logger__.debug('Checking "%s:0"', file)
        try:
            passed &= check_file(file, debug=args.debug)
        except Exception as exc:
            raise RuntimeError(f"{file!s}: Checking file failed!") from exc

    if not passed:
        sys.exit(1)


if __name__ == "__main__":
    main()
'''

In [None]:
truncnorm.rvs(-1, +1, loc=y0, scale=0.1)

In [None]:
t = np.linspace(-10, 20, 2000)
y0 = np.random.uniform(-0.5, 0.7, size=(5, 2))
x = solve_ivp(t, y0=y0)

import matplotlib.pyplot as plt

plt.plot(t, x[0], t, x[1], t, x[3])

In [None]:
x.shape

In [None]:
t.shape

In [None]:
import argparse
import os
from pathlib import Path

import numpy as np
from tqdm.auto import tqdm


def generate_sequence(low=-1, high=1.0, vel=None, num_steps=300):
    y = np.random.uniform(low=low, high=high)
    if vel is None:
        vel = np.random.uniform(low=0.05, high=0.5) * np.random.choice([-1, 1])
    noise_scale = 0.05
    points = [y + noise_scale * np.random.randn(1)]
    step_size = 0.1
    for i in range(num_steps - 1):
        y = y + vel * step_size
        points.append(y + noise_scale * np.random.randn(1))
        if y <= low or y >= high:
            vel = -vel
    return np.stack(points)


def generate_sequences(num_samples, vels=None, num_steps=300):
    all_target = []
    for _ in tqdm(range(num_samples)):
        chosen_v = np.random.choice(vels) if vels is not None else None
        y = generate_sequence(vel=chosen_v, num_steps=num_steps)
        all_target.append(y)
    all_target = np.stack(all_target)
    return all_target


def generate_dataset(
    seed=42,
    vels=None,
    dataset_path=None,
    n_train=5000,
    n_val=500,
    n_test=500,
    n_timesteps=300,
    file_prefix="",
):
    if dataset_path is None:
        dataset_path = "./bouncing_ball/"
    os.makedirs(dataset_path, exist_ok=True)

    np.random.seed(seed=seed)

    obs_train = generate_sequences(n_train, vels=vels, num_steps=n_timesteps)
    obs_val = generate_sequences(n_val, vels=vels, num_steps=n_timesteps)
    obs_test = generate_sequences(n_test, vels=vels, num_steps=n_timesteps)

    np.savez(os.path.join(dataset_path, f"{file_prefix}train.npz"), target=obs_train)
    np.savez(os.path.join(dataset_path, f"{file_prefix}val.npz"), target=obs_val)
    np.savez(os.path.join(dataset_path, f"{file_prefix}test.npz"), target=obs_test)


# if __name__ == "__main__":
#     parser = argparse.ArgumentParser()
#     parser.add_argument(
#         "--num_vels",
#         type=int,
#         default=0,
#         help="Number of fixed velocities, 0 indicates ranodm velocity for every sample",
#         choices=[0, 1, 2, 5],
#     )
#     args = parser.parse_args()

#     data_path = str(Path(__file__).resolve().parent / "bouncing_ball")
#     print(f"Saving dataset to: {data_path}.")
#     vels = None
#     if args.num_vels == 0:
#         vels = None
#         print("Generating dataset with random velocties...")
#         generate_dataset(file_prefix="rv_", vels=vels, dataset_path=data_path)
#     elif args.num_vels == 1:
#         vels = [0.2]
#         print("Generating dataset with 1 veloctity...")
#         generate_dataset(file_prefix="1fv_", vels=vels, dataset_path=data_path)
#     elif args.num_vels == 2:
#         vels = [0.2, 0.4]
#         print("Generating dataset with 2 velocties...")
#         generate_dataset(file_prefix="2fv_", vels=vels, dataset_path=data_path)
#     elif args.num_vels == 5:
#         vels = [0.1, 0.2, 0.3, 0.4, 0.5]
#         print("Generating dataset with 5 velocties...")
#         generate_dataset(file_prefix="5fv_", vels=vels, dataset_path=data_path)

In [None]:
sequences = generate_sequences(5000, None, 300)

In [None]:
sequences[0].max()

In [None]:
class Foo: ...

In [None]:
Foo.__class__ = int

In [None]:
import torch
from torch import jit, nn

In [None]:
slc = slice(1, 10, 0)

In [None]:
[1:3] = [ , , ]

In [None]:
import numpy as np

In [None]:
np.random.uniform()

In [None]:
gen = np.random.default_rng()

In [None]:
dir(gen)

In [None]:
from abc import ABC, ABCMeta, abstractmethod
from typing import Protocol


class Foo:
    @abstractmethod
    def bar(self) -> str:
        print(self.__class__)

In [None]:
issubclass(Protocol, ABCMeta)

In [None]:
stats.norm.rvs(size=())

In [None]:
class Foo()

In [None]:
class Bar(Foo): ...

In [None]:
Bar()

In [None]:
l = list(range(10))
l[len(l) + 1] = -1

In [None]:
from scipy import stats

In [None]:
stats.norm(0, 1).stats(moments="mvsk")

In [None]:
from typing_extensions import is_protocol

In [None]:
import typing
from collections import abc

In [None]:
is_protocol(abc.Sized)

In [None]:
l = list(range(10))
l[len(l) + 7 :] = [-1]  # works

In [None]:
l

In [None]:
items = [0, 1, 2, 3]
items[-17:2], items[-17:0], items[-17:-1]

In [None]:
items = [0, 1, 2, 3]
items[-17:-16] = [-1]
items

In [None]:
r = range(slc.start, slc.stop, 1)

In [None]:
l = list(range(10))
l[1:3] = [-1, -1, -1]  # lhs: len=2, rhs: len=3
print(l)  # [0, -1, -1, -1, 3, 4, 5, 6, 7, 8, 9]
len(l)

In [None]:
l = list(range(10))
l[1:3] = [-1, -1, -1]  # lhs: len=2, rhs: len=3

In [None]:
l = list(range(10))
l[1:3:2] = [-1, -1, -1]  # errors

In [None]:
l = list(range(10))
l[1:5:2] = [-1, -1, -1]  # errors

In [None]:
l = list(range(10))
l[1:6:2] = [-1, -1, -1]  # ✔

In [None]:
l[1:3]

In [None]:
list(r)

In [None]:
list(r)

In [None]:
from collections.abc import Sequence

In [None]:
set(dir(Sequence)) - set(dir(m))

In [None]:
m = nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 4))

In [None]:
jit.script(m).extend

In [None]:
import tsdm

In [None]:
import logging
from typing import NamedTuple

import numpy
import pyarrow as pa
import torch
from numpy import ndarray
from numpy.typing import NDArray
from pandas import DataFrame, Index, Series
from torch import Tensor
from typing_extensions import get_args, get_protocol_members, get_type_hints

from tsdm.types.protocols import Array, NTuple, SupportsShape

In [None]:
get_protocol_members(Array)

In [None]:
import dataclasses
from collections.abc import Iterator, Mapping, Sequence
from typing import (
    Any,
    NamedTuple,
    Protocol,
    TypeGuard,
    TypeVar,
    get_type_hints,
    overload,
    runtime_checkable,
)

from typing_extensions import Self, SupportsIndex, get_original_bases

from tsdm.types.variables import any_co as T_co
from tsdm.types.variables import key_contra, scalar_co, value_co

In [None]:
class Array(Protocol[scalar_co]):
    r"""Protocol for array-like objects (tensors with single data type).

    Matches with

    - `numpy.ndarray`
    - `torch.Tensor`
    - `tensorflow.Tensor`
    - `jax.numpy.ndarray`
    - `pandas.Series`

    Does not match with

    - `pandas.DataFrame`
    - `pyarrow.Table`
    - `pyarrow.Array`
    """

    @property
    def ndim(self) -> int:
        r"""Number of dimensions."""
        ...

    @property
    def dtype(self) -> scalar_co:
        r"""Yield the data type of the array."""
        ...

    @property
    def shape(self) -> Sequence[int]:
        """Yield the shape of the array."""
        ...

    def __len__(self) -> int:
        """Number of elements along first axis."""
        ...

    def __getitem__(self, key: Any) -> Self:
        """Return an element/slice of the table."""
        ...

    def __iter__(self) -> Iterator[Self]:
        """Iterate over the first dimension."""

In [None]:
import inspect

In [None]:
a = dict(inspect.getmembers(Protocol))
# print(a)
b = dict(inspect.getmembers(Array))
# print(b)

for k, v in b.items():
    if k not in a or v != a[k]:
        print("Not inherited:", k, v)

In [None]:
from typing_extensions import get_protocol_members

In [None]:
get_protocol_members(Array)

In [None]:
{key for key in Array.__dict__ if key not in Protocol.__dict__}

In [None]:
Array.__dict__

In [None]:
set(dir(Array)) - set(dir(Protocol)) - {"__orig_bases__", "__weakref__", "__dict__"}

In [None]:
dict(inspect.getmembers(Array))

In [None]:
Array.__parameters__

In [None]:
Protocol.__dict__

In [None]:
set(dir(Array)) - set(dir(Protocol))

In [None]:
dir(Array)

In [None]:
data = [1, 2, 3]
arrays = {
    "torch_tensor": torch.tensor(data),
    "numpy_ndarray": ndarray(data),
    "pandas_series": Series(data),
    "pyarrow_array": pa.array(data),
}
shared_attrs = set.intersection(*(set(dir(arr)) for arr in arrays.values()))
array_attrs = set(dir(Array))

In [None]:
missing_attrs = array_attrs - shared_attrs

In [None]:
missing_attrs

In [None]:
get_type_hints(Array)

In [None]:
Array.__slots__

In [None]:
get_type_hints(Array)

In [None]:
dir(Array)