In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2

In [None]:
import pickle
import xarray

import numpy as np
from zipfile import ZipFile

import linodenet
import matplotlib.pyplot as plt
import torch

from typing import Sized
from tqdm import tqdm, trange
from pathlib import Path
import pandas
from pandas import DataFrame

In [None]:
df = pandas.read_hdf(Path.cwd().parent.joinpath("data/some_data_from_db.h5"))
df = df.drop(columns="unit")

In [None]:
data = df.loc[8627].reset_index().set_index("measurement_time")

In [None]:
from typing import Union, Optional, Callable, Iterator
from torch.utils.data import Sampler
from numpy.random import Generator


# TODO: add exclusive_args decorator
class SequentialSliceSampler(Sampler):
    r"""

    Default modus operandi:

    - Use fixed window size
    - Sample starting index uniformly from [0:-window]

    Should you want to sample windows of varying size, you may supply a
    """

    def __init__(
        self,
        data,
        slice_sampler: Optional[Union[int, Callable[[], int]]],
        sampler: Optional[Callable[[], tuple[int, int]]] = None,
    ):
        super().__init__(data)
        self.data = data
        self.rng = np.random.default_rng()
        self.idx = np.arange(len(data))
        self.slice_sampler = (
            slice_sampler if callable(slice_sampler) else (lambda: slice_sampler)
        )

        def sampler() -> tuple[int, int]:
            window_size = self.slice_sampler()
            start_index = self.rng.choice(self.idx[:-window_size])
            return window_size, start_index

        self.sampler = sampler

    def __iter__(self) -> Iterator:
        while True:
            # sample len and index
            window_size, start_index = self.sampler()
            # return slice
            yield self.data.iloc[start_index : start_index + window_size]

In [None]:
from tsdm.utils.dtypes import TimeDeltaLike, TimeStampLike

In [None]:
# TODO: add exclusive_args decorator
class TimeSliceSampler(Sampler):
    r"""Sample by time.

    Default modus operandi:

    - Use fixed window size
    - Sample starting index uniformly from [0:-window]

    Should you want to sample windows of varying size, you may supply a
    """

    def __init__(
        self,
        data_source: Optional[Sized],
        slice_sampler: Optional[Union[TimeDeltaLike, Callable[[], TimeDeltaLike]]],
        sampler: Optional[Callable[[], tuple[TimeDeltaLike, TimeStampLike]]] = None,
    ):
        """Initialize Sampler.

        Parameters
        ----------
        data_source
        slice_sampler
        sampler
        """
        super().__init__(data_source)
        self.data = DataFrame(data_source)
        self.idx = np.arange(len(data_source))

        self.rng = np.random.default_rng()
        self.slice_sampler = (
            slice_sampler if callable(slice_sampler) else (lambda: slice_sampler)
        )

        def sampler() -> tuple[int, int]:
            window_size = self.slice_sampler()
            start_index = self.rng.choice(self.idx[:-window_size])
            return window_size, start_index

        self.sampler = sampler

    def __iter__(self) -> Iterator:
        r"""Yield random window from dataset

        Returns
        -------

        """
        while True:
            # sample len and index
            window_size, start_index = self.sampler()
            # return slice
            yield self.data.iloc[start_index : start_index + window_size]

In [None]:
full, slc = data, data.iloc[[0, 1, 2, 7, 99, 101], [1, 5, 4]]

In [None]:
sampler = torch.utils.data.SequentialSampler(full)
dloader = torch.utils.data.dataloader.DataLoader(full, sampler=sampler)

In [None]:
xarray.DataArray(
    full, dims=["obs time", "obs val"], name="run 1"
).to_dataframe().memory_usage()

In [None]:
xarray.Dataset(coords={"full": xarray.DataArray(full), "slc": xarray.DataArray(slc)})

In [None]:
fig, ax = plt.subplots()

for col in ("Acetate", "Acid", "Base", "DOT"):
    line = ax.plot(data["measurement_time"].iloc[6:-100], data[col].iloc[6:-100])

ax.legend(["Acetate", "Acid", "Base", "DOT"])

In [None]:
ds = torch.tensor(data.values)

In [None]:
sampler = BatchSampler(SequentialSampler(ds), batch_size=32, drop_last=True)
from torch import Tensor


def collate_list(batch: list[Tensor]) -> list[Tensor]:
    r"""Collates list of tensors as list of tensors."""
    return batch

In [None]:
dloader = DataLoader(TensorDataset(ds), batch_sampler=sampler)

idloader = iter(dloader)
next(idloader), next(idloader)

In [None]:
list(iter(SequentialSampler(ds)))

In [None]:
dir(tsdm)

In [None]:
import numpy

In [None]:
dir(numpy)