In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import torch
import xarray
from torch import Tensor
from torch.nn.utils.rnn import (
    PackedSequence,
    pack_sequence,
    pad_packed_sequence,
    pad_sequence,
)
from torch.utils.data import DataLoader, Dataset

In [None]:
import datetime
from typing import Union

import numpy as np
import pandas

TimeStampLike = Union[str, datetime.datetime, np.datetime64, pandas.Timestamp]
TimeDeltaLike = Union[str, datetime.timedelta, np.timedelta64, pandas.Timedelta]

In [None]:
import tsdm
from tsdm.datasets import InSilicoData

In [None]:
tsdm.random.sample_timestamps(size=10, freq="15s")

In [None]:
class VaryingLengthDataset(Dataset):
    def __init__(self, n, dmin, dmax):
        super().__init__()
        self.size = n
        self.data = [
            torch.randint(0, 10, (np.random.randint(dmin, dmax), 2)) for _ in range(n)
        ]

    # a commemt
    def __len__(self):
        return self.size

    def __getitem__(self, index):
        return self.data[index]

    def __repr__(self):
        return "".join(repr(x.shape) for x in self.data)


def myfun():
    pass

In [None]:
dsets = []

for key, df in InSilicoData.dataset.items():
    ds = df.to_xarray()
    ds = ds.assign_coords(run_id=key)
    dsets.append(ds)

In [None]:
xarray.concat(dsets, dim="run_id").to_dataframe()

In [None]:
dsets

In [None]:
arrays = [xarray.DataArray(df, name=key) for key, df in InSilicoData.dataset.items()]

In [None]:
arrays[0]

In [None]:
xarray.merge(arrays)

### Batching this Dataset won't work:

In [None]:
MyDataset = VaryingLengthDataset(100, 3, 7)

try:
    next(iter(DataLoader(MyDataset, batch_size=10)))
    print(x)
except Exception as E:
    print(E)

### Let's try with a custom collate_fn

In [None]:
np.random.randint(0, len(MyDataset), 20)

In [None]:
batch = [MyDataset[idx] for idx in np.random.randint(0, len(MyDataset), 5)]
batch

In [None]:
pad_sequence(batch, batch_first=True).shape

In [None]:
def collate_list(batch):
    return batch


try:
    x = next(iter(DataLoader(MyDataset, batch_size=5, collate_fn=collate_list)))
    print(x)
except Exception as E:
    print(E)

In [None]:
def collate_packed(batch):
    batch.sort(key=torch.Tensor.__len__, reverse=True)
    return pack_sequence(batch)


try:
    x = next(iter(DataLoader(MyDataset, batch_size=5, collate_fn=collate_packed)))
    print(x)
except Exception as E:
    print(E)

In [None]:
batch = next(iter(DataLoader(MyDataset, batch_size=5, collate_fn=collate_list)))
[x.T for x in batch]

In [None]:
batch_sorted = batch.copy()
batch_sorted.sort(key=torch.Tensor.__len__, reverse=True)
[x.T for x in batch_sorted]

In [None]:
batch_packed = pack_sequence(batch_sorted)
batch_packed.data.T, batch_packed.batch_sizes

In [None]:
batch_pad_packed, lengths = pad_packed_sequence(batch_packed, batch_first=True)
torch.swapaxes(batch_pad_packed, 1, 2), lengths

In [None]:
unpacked_batch = [x[:l].T for x, l in zip(batch_pad_packed, lengths)]
unpacked_batch

In [None]:
def reconstruct_batch(packed: PackedSequence) -> list[Tensor]:
    packed.data.shape[-1]

    lengths = []
    for b in batch_sizes:
        lengths += [b]

In [None]:
reconstructed_batch = batch_pad_packed
reconstructed_batch

In [None]:
batch_pad_packed.shape

In [None]:
def unpack(packed):
    pass

In [None]:
a = MyDataset[:5]
a.sort(key=lambda x: len(x))
a