# Packed sequences


In [195]:
%load_ext autoreload
%autoreload 2


from src.data.datasets import SiteDataset, FloorDataset
from src.models.initial_model import InitialModel
from src.models.model_trainer import ModelTrainer

import torch
from torch.nn.utils.rnn import pack_sequence, PackedSequence

import pyro

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [196]:
torch.manual_seed(123456789)

# Load data
site_data = SiteDataset(
    "5a0546857ecc773753327266", wifi_threshold=200, sampling_interval=100
)
floor = site_data.floors[0]

# Setup model
model = InitialModel(floor)

# Setup the optimizer
adam_params = {"lr": 1e-2}  # ., "betas":(0.95, 0.999)}
optimizer = torch.optim.Adam(model.parameters(), **adam_params)

In [197]:
%%time
mt = ModelTrainer(
    model=model,
    optimizer=optimizer,
    n_epochs=1,
)
mt.train(floor)

epoch[0] ELBO: 34722207.9
CPU times: user 25.1 s, sys: 10.1 s, total: 35.1 s
Wall time: 21.1 s


In [211]:
class PackedFloorDataset(FloorDataset):
    
    def __getitem__(self, indices):

        time_unpadded, position_unpadded, wifi_unpadded = self._generate_tensors()
        
        sorted_order = torch.tensor([-len(time_unpadded[x]) for x in indices]).argsort()
        mini_batch_index = [indices[i] for i in sorted_order]
        
        mini_batch_time = pack_sequence([time_unpadded[i] for i in mini_batch_index])
        
        mini_batch_position = pack_sequence([position_unpadded[i] for i in mini_batch_index])
        mini_batch_position_mask = PackedSequence(
            ~mini_batch_position.data.isnan().any(dim=-1),
            mini_batch_position.batch_sizes
        )
        

        mini_batch_wifi = pack_sequence([wifi_unpadded[i] for i in mini_batch_index])
        mini_batch_wifi_mask = PackedSequence(
            ~mini_batch_wifi.data.isnan(),
            mini_batch_wifi.batch_sizes
        )

        return (
            mini_batch_index,
            mini_batch_time,
            mini_batch_position,
            mini_batch_position_mask,
            mini_batch_wifi,
            mini_batch_wifi_mask
        )

In [212]:
from src.models.packed_model import PackedInitialModel

In [213]:
packed_floor_data = PackedFloorDataset(floor.site_id, floor.floor_id, wifi_threshold=200, sampling_interval=100)

In [214]:
indices = [8, 9, 10, 11, 12, 13, 14, 15]
mini_batch = packed_floor_data[indices]

In [215]:
packed_model = PackedInitialModel(floor)

In [228]:
with pyro.plate("A", 3):
    with pyro.plate("B", 13):
        x_0, x = packed_model.model(*mini_batch)

torch.Size([8, 13, 3, 2])


In [229]:
x_0.shape

torch.Size([8, 13, 3, 2])

In [231]:
x

tensor([[[[ 38.2373, 272.6837],
          [ 69.7852,  54.6544],
          [ 57.5826, 233.3924]],

         [[149.2876,  38.6550],
          [ 65.6305, 233.1351],
          [106.5674,  28.1526]],

         [[ 74.9254, 200.9640],
          [111.6512,  53.8550],
          [ 98.6935, 114.8946]],

         ...,

         [[ 89.1065, 179.5762],
          [ 94.9299, 116.0197],
          [ 23.2459, 243.5230]],

         [[153.5782,  19.4106],
          [ 83.8926, -42.5179],
          [ 14.3921, 143.1124]],

         [[112.8395, 120.1057],
          [114.9165, 183.5617],
          [ 54.6879, 132.4061]]],


        [[[140.9142,  96.3889],
          [-13.8576, 176.7368],
          [125.1393, 236.9658]],

         [[ 34.6568, 122.2154],
          [104.3158, 173.1718],
          [144.8387,  -1.6508]],

         [[ 79.7721, 187.2497],
          [ 69.2973, 186.9853],
          [ 36.3596,  13.7312]],

         ...,

         [[ 90.8872, 239.7402],
          [ 31.1628, 215.2138],
          [124.7947,  