In [177]:
%load_ext autoreload
%autoreload 2

from pyro import sample, plate

from scipy.linalg import null_space, lstsq

from src.data.datasets import SiteDataset
import torch
import seaborn as sns
import pyro
from src.data.datasets import SiteDataset
from src.models.initial_model import InitialModel, TraceGuide
from src.models.batched_model import BatchedModel
from pyro import distributions as dist

from torch.nn.utils.rnn import pad_sequence

site_data = SiteDataset("5a0546857ecc773753327266")
floor = site_data.floors[0]
height, width = floor.info["map_info"]["height"], floor.info["map_info"]["width"]
floor_uniform = dist.Uniform(
    low=torch.tensor([0.0, 0.0]), high=torch.tensor([height, width])
).to_event(1)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Model Definition

In [178]:
batch_size = 12
traces = [trace for trace in floor.traces[:batch_size]]

mini_batch_index = torch.arange(batch_size)
mini_batch_length = torch.tensor([len(t.matrices["time"]) for t in traces])

mini_batch_time = mini_batch_time = pad_sequence(
    [torch.tensor(t.matrices["time"], dtype=torch.float32) for t in traces],
    batch_first=True,
)
mini_batch_position = pad_sequence(
    [torch.tensor(t.matrices["position"], dtype=torch.float32) for t in traces],
    batch_first=True,
)
mini_batch_position_mask = ~mini_batch_position.isnan().any(dim=-1)
for i, length in enumerate(mini_batch_length):
    mini_batch_position_mask[i, length:] = False
mini_batch_position[~mini_batch_position_mask] = 0

bssids = set()
for t in traces:
    bssids.update(set(t.data["TYPE_WIFI"]["bssid"].unique()))

mini_batch_wifi_unpadded = []
for t in traces:
    wifi = t._get_matrices(bssids=bssids)["wifi"]
    mini_batch_wifi_unpadded.append(torch.tensor(wifi, dtype=torch.float32))

mini_batch_wifi = pad_sequence(mini_batch_wifi_unpadded, batch_first=True)
mini_batch_wifi_mask = ~mini_batch_wifi.isnan()
for i, length in enumerate(mini_batch_length):
    mini_batch_wifi_mask[i, length:, :] = False
mini_batch_wifi[~mini_batch_wifi_mask] = 0

_, T, K = mini_batch_wifi.shape

model = BatchedModel(floor, K)

model.model(
    mini_batch_index=mini_batch_index,
    mini_batch_length=mini_batch_length,
    mini_batch_time=mini_batch_time,
    mini_batch_position=mini_batch_position,
    mini_batch_position_mask=mini_batch_position_mask,
    mini_batch_wifi=mini_batch_wifi,
    mini_batch_wifi_mask=mini_batch_wifi_mask,
)

model.guide(
    mini_batch_index=mini_batch_index,
    mini_batch_length=mini_batch_length,
    mini_batch_time=mini_batch_time,
    mini_batch_position=mini_batch_position,
    mini_batch_position_mask=mini_batch_position_mask,
    mini_batch_wifi=mini_batch_wifi,
    mini_batch_wifi_mask=mini_batch_wifi_mask,
)




In [179]:
from pyro.infer import MCMC, NUTS, HMC, SVI, Trace_ELBO
from pyro.optim import Adam, ClippedAdam

# Reset parameter values
pyro.clear_param_store()

# Define the number of optimization steps
n_steps = 1000

# Setup the optimizer
adam_params = {"lr": 0.01}
optimizer = Adam(adam_params)

# # Setup the inference algorithm
elbo = Trace_ELBO(num_particles=3)
svi = SVI(model.model, model.guide, optimizer, loss=elbo)

# Do gradient steps
for step in range(n_steps):
    elbo = svi.step(
        mini_batch_index=mini_batch_index,
        mini_batch_length=mini_batch_length,
        mini_batch_time=mini_batch_time,
        mini_batch_position=mini_batch_position,
        mini_batch_position_mask=mini_batch_position_mask,
        mini_batch_wifi=mini_batch_wifi,
        mini_batch_wifi_mask=mini_batch_wifi_mask,
    )

    print("[%d] ELBO: %.1f" % (step, elbo))

TypeError: guide() missing 7 required positional arguments: 'mini_batch_index', 'mini_batch_length', 'mini_batch_time', 'mini_batch_position', 'mini_batch_position_mask', 'mini_batch_wifi', and 'mini_batch_wifi_mask'

In [None]:
model.to(device=cuda)