In [None]:
from dataset.data import train_data
import dataset.preprocess as preprocess
import dataset.data as data
import dataset.simulation as simulation
import torch
import torch.nn.functional as F
import pandas as pd

def plot(pred, gt, xlim=None, ylim=(20, 40)):
    print(F.mse_loss(pred, gt))
    pd.DataFrame({
        'pred': pred,
        'gt': gt
    }, index = train_data.index[1:]) \
        .plot(
            figsize=(15, 4),
            xlim=xlim,
            ylim=ylim,
            grid=True
        )


# Split

In [None]:
sec_back_t_model = simulation.MLPModel.load_from_checkpoint('lightning_logs/split.sec_back_t/checkpoints/epoch=816-step=90687.ckpt')

sec_back_t_model.eval()
with torch.no_grad():    
    sec_back_t_pred = sec_back_t_model(preprocess.X_sec_back_t.to(sec_back_t_model.device)).cpu().squeeze()

plot(
    sec_back_t_pred,
    preprocess.y_sec_back_t,
    # xlim=('2022-1-1', '2022-1-2'),
    ylim=(32, 40)
)


In [None]:
indoor_model = simulation.MLPModel.load_from_checkpoint('lightning_logs/split.indoor/checkpoints/epoch=814-step=90465.ckpt')

indoor_model.eval()
with torch.no_grad():    
    indoor_pred = indoor_model(preprocess.X_indoor.to(indoor_model.device)).cpu().squeeze()

plot(
    indoor_pred,
    preprocess.y_indoor,
    # xlim=('2022-1-1', '2022-1-2'),
    ylim=(20, 28)
)


# Branch

In [None]:

branch_model = simulation.BranchModel.load_from_checkpoint("lightning_logs/version_5/checkpoints/epoch=664-step=73815.ckpt")

branch_model.eval()
with torch.no_grad():    
    sec_back_t_pred, indoor_pred = branch_model(preprocess.X_branch.to(branch_model.device))
    sec_back_t_pred, indoor_pred = indoor_pred.cpu().squeeze(), indoor_pred.cpu().squeeze()

plot(
    sec_back_t_pred,
    preprocess.y_sec_back_t,
    # xlim=('2022-1-1', '2022-1-2'),
    ylim=(32, 40)
)

plot(
    indoor_pred,
    preprocess.y_indoor,
    # xlim=('2022-1-1', '2022-1-2'),
    ylim=(20, 28)
)

# Joint

In [None]:

joint_model = simulation.MLPModel.load_from_checkpoint("lightning_logs/joint/checkpoints/epoch=698-step=77589.ckpt")

joint_model.eval()
with torch.no_grad():    
    pred = joint_model(preprocess.X_branch.to(joint_model.device))
    sec_back_t_pred, indoor_pred = pred.cpu().split(1, dim=1)

sec_back_t_pred, indoor_pred = sec_back_t_pred.squeeze(), indoor_pred.squeeze()

plot(
    sec_back_t_pred,
    preprocess.y_sec_back_t,
    # xlim=('2022-1-1', '2022-1-2'),
    ylim=(32, 40)
)

plot(
    indoor_pred,
    preprocess.y_indoor,
    # xlim=('2022-1-1', '2022-1-2'),
    ylim=(20, 28)
)