In [None]:
%load_ext autoreload
%autoreload 2

from src.models.trace_guide import TraceGuide
from src.models.initial_model import InitialModel
from src.data.datasets import FloorDataset
from src.utils import object_to_markdown

from IPython.display import display

import seaborn as sns
import matplotlib.pyplot as plt
import torch

site_id = "5a0546857ecc773753327266"
floor_id = "B1"

floor_data = FloorDataset(site_id, floor_id, wifi_threshold=200, sampling_interval=100)
sns.set(style="whitegrid")

# Initial Model



In the initial model, we are only looking at including the wifi signals of the model. We have implemented the model in `src/models/initial_model.py` as a method of the `InitialModel` class which also holds various model attributes and the variational parameters. The source code of the generative model can be seen below:

In [None]:
display(object_to_markdown(InitialModel.model))

Initially, we wanted the prior on the wifi locations and the initial trace positions to be uniform over the floor area. However we ran into domain problems, when attempting to define the variational posterior. We therefore decided to relax the prior a bit to a normal distribution with the same mean and variance.

We use mini batches of padded sequences during training, so we also need to provide a mask of padded observations. This mask also serves as a mask of missing observations, as these should essentially be handled identically. We are using the `torch` data loader protocol, and the specific implementation can be seen in `src/data/datasets.py`. Below, we extract a mini_batch of 4 traces:

In [None]:
mini_batch = floor_data[torch.arange(16, 20)]
mini_batch_index = mini_batch[0]
mini_batch_length = mini_batch[1]
mini_batch_time = mini_batch[2]
mini_batch_position = mini_batch[3]
mini_batch_position_mask = mini_batch[4]
mini_batch_wifi = mini_batch[5]
mini_batch_wifi_mask = mini_batch[6]

In [None]:
torch.manual_seed(123)
initial_model = InitialModel(floor_data)
x, wifi_location = initial_model.model(*mini_batch)

Below, we can see the corresponding samples. Since they are essentially random walks (only conditioned on the previous observation), they dont really resemble the trace data in structure yet.

In [None]:
plt.figure(figsize=(10,10))
plt.scatter(*wifi_location.T, marker=".", color="grey", label="wifi locations")
plt.plot(*x.T, label=[f"trace {i}" for i in range(x.shape[0])])
plt.xlim(x[...,0].min()-5, x[...,0].max()+5)
plt.ylim(x[...,1].min()-5, x[...,1].max()+5)
plt.legend()
plt.show()

## Variational distribution
In the model, we sometimes observe $\hat {\boldsymbol x}_i $ as an observation of a latent variable $\boldsymbol x_i$. 
In order to use variational inference, we need an approximate posterior for the latent variable $\boldsymbol x_i$. 
Since the latent position is sampled quite frequently, if we just use a mean field approximation, we could end up training a lot of variational parameters. 
Instead, with inspiration from [pyro's Deep Markov Model example](https://pyro.ai/examples/dmm.html), we use a parameterized function  $\tilde { \boldsymbol  x}_\theta : \mathbb{R}_{\geq 0} \mapsto \mathbb{R}^2$ for each trace to hopefully follow the latent path. We can then construct the variational posterior with variational parameters $\theta$ as 
$$
q(\boldsymbol x_i | \theta) = \mathcal N (\boldsymbol x_i \mid \tilde { \boldsymbol x}_\theta(t_i), s \mathbf I )
$$
where the diagonal variance $s$ is also a variational parameter.

We want a class of functions $\tilde{\boldsymbol x}_\theta$ that is quite flexible, yet doesn't have too many parameters to fit. We decided to look at functions of the form: 
$$
\tilde{\boldsymbol x}_\theta(t) = \boldsymbol\beta_0 + \sum_{i=1}^{10} \boldsymbol \alpha_i h( t - \boldsymbol \gamma_i).
$$
Where $h$ is the softplus function.
This has been implemented as the `TraceGuide` module.

In [None]:
display(object_to_markdown(TraceGuide))

In [None]:
torch.manual_seed(123)
tg = TraceGuide(time_min_max=(0, 10), n=5)
tt = torch.linspace(0, 15, 100)
with torch.autograd.no_grad():
    basis_functions = torch.nn.functional.softplus(tt.view(-1, 1, 1) - tg.time_offset)
    basis_w_coeffs = basis_functions * tg.coeffs

We can train the parameters $\theta=(\beta_0, \alpha_1, \dots, \alpha_{10}, \gamma_1, \dots, \gamma_{10}, )$, to obtain quite flexible paths in $\mathbb{R}^2$. Below the basis for one of the dimensions is illustrated.

In [None]:
fig, axes = plt.subplots(nrows=2, sharex=True, figsize=(10,10))

axes[0].plot(tt, basis_functions[..., 0])
axes[0].set_title("$h(t-\\gamma_i)$")
axes[1].plot(tt, basis_w_coeffs[..., 0])
axes[1].set_title("$\\alpha_i \\cdot h(t-\\gamma_i)$")
plt.show()

Below, a few traces for randomly generated basis coefficients can be seen. Also shown is the corresponding standard deviation, in this case initialized as $\exp[0] = 1$

In [None]:
torch.manual_seed(126)
plt.figure(figsize=(10,10))
for i in range(5):
    tg = TraceGuide(time_min_max=(0, 10), n=5)
    with torch.autograd.no_grad():
        location, scale = tg(tt.view(-1,1))
    plt.plot(*location.T, "--", color=f"C{i}")
    ax = plt.gca()
    for j in range(19, len(tt), 10):
        ax.plot(*location[j], "x", color=f"C{i}")
        ax.add_patch(plt.Circle(location[j], scale, fill=False, color=f"C{i}"))
plt.axis("equal")   
plt.show()         

We used a mean field approximation for the other variational distributions eg. wifi locations and signal strengths. The corresponding guide function can be seen below:

In [None]:
display(object_to_markdown(InitialModel.guide))

## Model Training
Due to the size of the model / the number of parameters, the model is trained using the `src/models/initial_models.py` script. This allowed us to easily use the DTU HPC cluster for training. Below, we load a checkpoint of the trained model

```python
def print_a():
    print("a")
```

In [None]:
checkpoint = torch.load("../checkpoints/initial_model.pt")
initial_model.load_state_dict(checkpoint["model_state_dict"])

We can investigate the training loss over the course of training the model

In [None]:
plt.figure(figsize=(10,6))
plt.plot(checkpoint["loss_history"])
plt.xlabel("Epoch")
plt.ylabel("Loss (Negative ELBO)")
plt.show()