In [2]:
import torch
import torch.nn as nn
import tonic
from tonic import datasets, transforms
import os
import numpy as np
from tqdm.auto import tqdm

## Load dataset
We use the trainset of Spiking Speech Commands (SSC) and cut it to 250 time steps and 100 channels.

In [3]:
dt = 4000
encoding_dim = 100

class ToRaster():
    def __init__(self, encoding_dim):
        self.encoding_dim = encoding_dim

    def __call__(self, events):
        # tensor has dimensions (time_steps, encoding_dim)
        tensor = np.zeros((events["t"].max()+1, self.encoding_dim), dtype=int)
        np.add.at(tensor, (events["t"], events["x"]), 1)
        return tensor[:250,:]

transform = transforms.Compose([
    transforms.Downsample(time_factor=1/dt, spatial_factor=encoding_dim/700),
    ToRaster(encoding_dim),
])

In [4]:
dataset = tonic.datasets.SSC('./data', transform=transform)
print(f"This dataset has {len(dataset)} samples.")

This dataset has 75466 samples.


Create a disk-cached dataset, because the dataset is too large to fit in memory and we don't want to apply the ToRaster transform (see above) at every epoch.

In [5]:
batch_size = 128

cached_dataset = tonic.DiskCachedDataset(
            dataset=dataset,
            cache_path=os.path.join(f"cache/ssc/train/{encoding_dim}/{dt}"),
        )
cached_dataloader = torch.utils.data.DataLoader(
    cached_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=4, 
    collate_fn=tonic.collation.PadTensors(batch_first=True), 
    drop_last=True
)

In [6]:
memory_dataset = tonic.MemoryCachedDataset(
            dataset=dataset,
        )
memory_dataloader = torch.utils.data.DataLoader(
    memory_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=4, 
    collate_fn=tonic.collation.PadTensors(batch_first=True), 
    drop_last=True
)

We define a dummy model here.

In [7]:
network = nn.Sequential(
    torch.nn.Linear(100, 100),
    torch.nn.Linear(100, 100),
    torch.nn.Linear(100, 100),
    torch.nn.Linear(100, 100),
    torch.nn.Linear(100, 10),
).cuda()

Iterating over a memory-cached dataset takes about 20GB of RAM so beware!

In [None]:
# for data, target in tqdm(memory_dataloader):
#     output = network(data.cuda())
#     output.sum().backward()
#     target = target.cuda()

If we iterate over the cached dataset that sits on disk, this takes us a while to go through all the data. Run at least twice to generate all the samples. Takes about 1.4GB of disk space (the tensors are compressed)

In [9]:
for data, target in tqdm(cached_dataloader):
    output = network(data.cuda())
    output.sum().backward()
    target = target.cuda()

  0%|          | 0/589 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fde182445e0>
Traceback (most recent call last):
  File "/home/gregorlenz/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/home/gregorlenz/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    Exception ignored in: if w.is_alive():<function _MultiProcessingDataLoaderIter.__del__ at 0x7fde182445e0>

  File "/home/gregorlenz/anaconda3/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
Traceback (most recent call last):
  File "/home/gregorlenz/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/home/gregorlenz/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/home/gregorlenz/anaconda3/lib/python3.8/multiprocessing/proce

## Generate sparse tensor dataset that sits on the GPU

This will take about 7-8 GB of GPU memory! We load the actual batches here from disk and convert them to sparse tensors on the GPU. We store them in a big list which we'll use as a new dataset.

In [10]:
data_list = []
target_list = []
for data, targets in tqdm(cached_dataloader):
    data_list.append(data.to_sparse().coalesce().cuda())
    target_list.append(targets.byte().cuda())

sparse_tensor_dataset = list(zip(data_list, target_list))

  0%|          | 0/589 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fde182445e0>
Traceback (most recent call last):
  File "/home/gregorlenz/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/home/gregorlenz/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/home/gregorlenz/anaconda3/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fde182445e0>
Traceback (most recent call last):
  File "/home/gregorlenz/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/home/gregorlenz/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", l

Now let's see how fast we can do forward and backward passes without any host memory or disk involved... The sparse tensors are inflated (via to_dense()) directly on the GPU

In [11]:
for data, target in tqdm(sparse_tensor_dataset):
    output = network(data.to_dense())
    output.sum().backward()

  0%|          | 0/589 [00:00<?, ?it/s]