Code for developing and test the TimeSeriesDataset and TimeSeriesBatch

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
import torch.utils

from janelia_core.ml.datasets import TimeSeriesDataset
from janelia_core.ml.datasets import TimeSeriesBatch
from janelia_core.ml.datasets import cat_time_series_batches

Create some TimeSeriesBatch objects and test the concatenation function

In [None]:
b1_data_g1 = torch.arange(10).view(10, 1)
b1_data_g1 = torch.cat([b1_data_g1, b1_data_g1], dim=1)
b1_data_g2 = torch.arange(10).view(10, 1).float() + .5

b1_i_x = torch.arange(9)
b1_i_y = b1_i_x + 1

b1_i_orig = torch.arange(10)

b1 = TimeSeriesBatch(data=[b1_data_g1, b1_data_g2], i_x=b1_i_x, 
                    i_y=b1_i_y, i_orig=b1_i_orig)


In [None]:
b2_data_g1 = torch.arange(5, 13).view(8, 1)
b2_data_g1 = torch.cat([b2_data_g1, b2_data_g1], dim=1)
b2_data_g2 = torch.arange(5, 13).view(8, 1).float() + .5

b2_i_x = torch.arange(7)
b2_i_y = b2_i_x + 1

b2_i_orig = torch.arange(5, 13)

b2 = TimeSeriesBatch(data=[b2_data_g1, b2_data_g2], i_x=b2_i_x, 
                    i_y=b2_i_y, i_orig=b2_i_orig)

In [None]:
c = cat_time_series_batches([b1, b2])

In [None]:
c.data[1][c.i_y]

In [None]:
b1.data[1][b1.i_y]

In [None]:
b2.data[1][b2.i_y]

Create a dataset and sample from it

In [None]:
ds = TimeSeriesDataset(ts_data=[b1_data_g1, b1_data_g2])

In [None]:
loader = torch.utils.data.DataLoader(dataset=ds, batch_size=3, 
                                     collate_fn=cat_time_series_batches,
                                     shuffle=False, pin_memory=True, num_workers=2)

In [None]:
for b_i, b in enumerate(loader):
    print('******** Batch ' + str(b_i) + ' ********')
    print(b.data[0][b.i_x])
    print(b.i_orig.is_pinned())

## Create a TimeSeriesBatch object and sample from it efficiently

In [4]:
b_smp_data_g1 = torch.arange(10).view(10, 1)
b_smp_data_g2 = torch.arange(10).view(10, 1).float() + .5

b_smp_i_x = torch.arange(9)
b_smp_i_y = b_smp_i_x + 1

b_smp_i_orig = torch.arange(10)

b_smp = TimeSeriesBatch(data=[b_smp_data_g1, b_smp_data_g2], i_x=b_smp_i_x, 
                    i_y=b_smp_i_y, i_orig=b_smp_i_orig)

In [66]:
smp = b_smp.efficient_get_item(np.asarray([0, 1, 3, 0]))
#smp = b_smp[np.asarray([0, 1, 3, 0])]

In [67]:
smp.data[0][smp.i_x]
#smp.data[0][smp.i_y]

tensor([[0],
        [1],
        [3],
        [0]])

In [68]:
smp2 = cat_time_series_batches(smp)

In [69]:
smp2.data[0][smp.i_x]
smp2.data[0][smp.i_y]

tensor([[1],
        [2],
        [4],
        [1]])

In [70]:
smp2.data

[tensor([[0],
         [1],
         [2],
         [3],
         [4]]), tensor([[0.5000],
         [1.5000],
         [2.5000],
         [3.5000],
         [4.5000]])]