In [32]:
import numpy as np
import pandas as pd
import os
from dataset import OutGridDataset
from dotenv import load_dotenv
import json
import torch
from config import N_LABELS
from model import UNet
from tqdm import tqdm
from common import get_scene
from config import STATIC, index_to_label
from tracker import PointTrackList, tracking_main


from encoder import get_grid_encoder


%load_ext autoreload
%autoreload 2
load_dotenv()

True

In [3]:
results_folder = "results"
checkpoints_folder = "checkpoints"
detections_folder = os.getenv("DATA_LOCATION")
network_name = "deep3_unet_b4_DoppFilt_WCE01_LN"
epoch = 3
model_name = f"{network_name}_ep{epoch}"
sequence_id = "sequence_2"

# loading model and weights
results_path = os.path.join(results_folder, model_name)
with open(os.path.join(checkpoints_folder, f"{network_name}_config.json")) as f:
    config = json.load(f)
model = UNet(chs=config["unet_chs"], n_classes=N_LABELS)
checkpoint = torch.load(f"{checkpoints_folder}/{model_name}.pth")
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

# loading detections
scene_fn = os.path.join(
    os.getenv("DATA_LOCATION"), sequence_id, "scenes.json"
)
detections = get_scene(scene_fn)

In [33]:
# remove pandas warning
pd.options.mode.chained_assignment = None  # default='warn'

timestamps = np.unique(detections["timestamp"])

track_list = PointTrackList()
all_tracks_df = pd.DataFrame()

for i, ts in tqdm(enumerate(timestamps[:100])):
    cur_dets = detections[detections["timestamp"] == ts]

    # SEGMENTATION
    # creating input grid
    grid_fl = get_grid_encoder()
    grid_fl.fill_grid(cur_dets)

    input_data = torch.from_numpy(grid_fl.grid).permute(2, 0, 1).float()

    with torch.no_grad():
        output = model(input_data.unsqueeze(0))
        predicted_classes = torch.argmax(output, dim=1)
    
    predicted_classes = predicted_classes.squeeze(0).cpu().numpy() # 196 x 140
    
    # by default all detections are static
    cur_dets.loc[:, "predicted_class"] = STATIC
    for _, row in cur_dets.iterrows():
        x_pos, y_pos = grid_fl.get_cell_id(row["x_cc"], row["y_cc"])
        pred_class = predicted_classes[x_pos, y_pos]
        cur_dets.at[_, "predicted_class"] = index_to_label[pred_class]
    # ------------------------------------------

    # TRACKING
    moving_dets = cur_dets[cur_dets['predicted_class'] != STATIC]
    moving_dets.loc[: ,"associated_track_id"] = -1
    moving_dets.loc[:, "use_for_update"] = False
    dt_s = (ts - timestamps[i - 1]) * 1e-6
    tracks_df = tracking_main(
        track_list, moving_dets, ts, dt_s
    )
    # concatentate with previous tracks dataframes
    all_tracks_df = pd.concat([all_tracks_df, tracks_df], ignore_index=True)
    # ------------------------------------------
    

100it [00:28,  3.52it/s]


In [34]:
all_tracks_df

Unnamed: 0,timestamp,track_id,x,y,vx,vy,age,status
0,158195645746,0,17.183104,0.409395,0.000000,0.000000,2,tentative
1,158195645746,1,18.282721,1.053692,0.000000,0.000000,2,tentative
2,158195645746,2,18.858540,2.842702,0.000000,0.000000,2,tentative
3,158195645746,3,23.228834,1.108121,0.000000,0.000000,2,tentative
4,158195645746,4,30.174366,0.716875,0.000000,0.000000,2,tentative
...,...,...,...,...,...,...,...,...
713,158202890075,82,14.511658,-7.685435,0.000000,0.000000,3,tentative
714,158202890075,83,14.729210,-7.653048,0.000000,0.000000,3,tentative
715,158202961787,0,28.409494,0.405987,2.022834,-0.598804,101,confirmed
716,158202961787,48,17.649345,0.155567,3.188711,5.125838,56,confirmed


In [24]:
cur_dets[cur_dets['label_id'] != cur_dets['predicted_class']]

Unnamed: 0,timestamp,x_cc,y_cc,vr_compensated,rcs,label_id,predicted_class
2,158195645746,18.85854,2.842702,4.516439,8.326876,11,0
20,158195645746,30.174366,0.716875,4.975053,-2.50931,11,1
21,158195645746,33.225216,0.990729,4.969501,-5.583652,11,0
22,158195645746,41.784348,2.370693,9.340811,-3.566014,11,0
27,158195645746,33.636074,3.496989,9.27509,-8.978685,11,0
28,158195645746,71.951027,6.75107,-7.170452,2.054195,1,0
29,158195645746,79.385971,6.36893,-7.158795,2.286739,11,0
73,158195645746,21.954672,-3.377731,0.077225,-18.755455,0,11
77,158195645746,24.201195,-2.54364,1.191874,14.521302,0,11
78,158195645746,25.696009,-1.878942,1.301065,-5.82809,0,11
