In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torch import Tensor
from torch.nn.utils.rnn import pack_sequence, pad_sequence, pack_padded_sequence, pad_packed_sequence, PackedSequence

In [2]:
class VaryingLengthDataset(Dataset):
    def __init__(self, n, dmin, dmax):
        super().__init__()
        self.size = n
        self.data = [torch.randint(0, 10, (np.random.randint(dmin, dmax), 2)) for _ in range(n)]
        
    def __len__(self):
        return self.size
    
    def __getitem__(self, index):
        return self.data[index]
    
    def __repr__(self):
        return "".join(repr(x.shape) for x in self.data)

### Batching this Dataset won't work:

In [3]:
MyDataset = VaryingLengthDataset(100, 3, 7)

try:
    next(iter(DataLoader(MyDataset, batch_size=10)))
    print(x)
except Exception as E:
    print(E)

### Let's try with a custom collate_fn

In [4]:
def collate_list(batch):
    return batch

try:
    x = next(iter(DataLoader(MyDataset, batch_size=5, collate_fn=collate_list)))
    print(x)
except Exception as E:
    print(E)

In [5]:
def collate_packed(batch):
    batch.sort(key=torch.Tensor.__len__, reverse=True)
    return pack_sequence(batch)

try:
    x = next(iter(DataLoader(MyDataset, batch_size=5, collate_fn=collate_packed)))
    print(x)
except Exception as E:
    print(E)

In [6]:
batch = next(iter(DataLoader(MyDataset, batch_size=5, collate_fn=collate_list)))
[x.T for x in batch]

In [7]:
batch_sorted = batch.copy()
batch_sorted.sort(key=torch.Tensor.__len__, reverse=True)
[x.T for x in batch_sorted]

In [8]:
batch_packed = pack_sequence(batch_sorted)
batch_packed.data.T, batch_packed.batch_sizes

In [9]:
batch_pad_packed, lengths = pad_packed_sequence(batch_packed, batch_first=True)
torch.swapaxes(batch_pad_packed, 1, 2), lengths

In [10]:
unpacked_batch = [x[:l].T for x, l in zip(batch_pad_packed, lengths)]
unpacked_batch

In [11]:
def reconstruct_batch(packed: PackedSequence) -> list[Tensor]:
    d = packed.data.shape[-1]
    dtype = packed.data.dtype
    device = packed.data.device
    
    lengths = []
    b0 = 0
    for b in batch_sizes:
        
        lengths += [b]

In [12]:
reconstructed_batch = batch_pad_packed

In [13]:
batch_pad_packed.shape

In [15]:
def unpack(packed):
    pass

In [16]:
a = MyDataset[:5]
a.sort(key=lambda x: len(x))
a