In [3]:
import torch
from torch_geometric.transforms import ToDevice
from torch_geometric.utils import to_networkx

from model import Net
from motclass import MotDataset
from utilities import get_best_device
from utilities import load_model_pkl

device = get_best_device()

mot_train_dl = MotDataset(dataset_path='/media/dmmp/vid+backup/Data/MOT17',
                          split='train',
                          subtrack_len=15,
                          slide=15,
                          linkage_window=5,
                          detections_file_folder='gt',
                          detections_file_name='gt.txt',
                          dl_mode=True,
                          knn_pruning_args={'k': 20, 'cosine': False},
                          device=device,
                          dtype=torch.float32,
                          classification=True)

# model = load_model_pkl("base_500_resnet50-backbone.pkl", device=device)  # regression
# model.mps_fallback = True

model = Net(backbone="ResNet50",
            layer_tipe="base",
            layer_size=500,
            dtype=torch.float32,
            edge_features_dim=6,
            heads=6,
            concat=False,
            dropout=0.3,
            add_self_loops=False,
            steps=6,
            device=device)

model.eval()

model = model.to(device)


[INFO] Using CUDA.


In [4]:
data = mot_train_dl[0]
data = ToDevice(device.type)(data)
preds = model(data)

In [5]:
data.y.sum()

tensor(550., device='cuda:0')

---

In [6]:
nodes_dict = {}  # frame, bbox


def build_trajectory_rec(node_idx, pyg_graph, nx_graph, node_dists, det_id, nodes_todo, nodes_dict, out, depth=0):
    all_out_edges = nx_graph.out_edges(node_idx)

    # Remove edges going in the past
    all_out_edges = [item for item in all_out_edges if item[0] < item[1]]

    if len(all_out_edges) == 0:

        if depth == 0:  # it's an orphan node

            ### CHECK IF NEEDED ###

            orp_coords = pyg_graph.detections_coords[node_idx]
            orp_frame = pyg_graph.times[node_idx].tolist()[0]

            if (orp_frame, *orp_coords.tolist()) not in nodes_dict.keys():
                nodes_dict[(orp_frame, *orp_coords.tolist())] = det_id

            orp_id = nodes_dict[(orp_frame, *orp_coords.tolist())]

            out.append(
                {'frame': orp_frame,
                 'id': orp_id,
                 'bb_left': orp_coords[0].item(),
                 'bb_top': orp_coords[1].item(),
                 'bb_width': orp_coords[2].item(),
                 'bb_height': orp_coords[3].item(),
                 'conf': -1,
                 'x': -1,
                 'y': -1,
                 'z': -1}
            )

            ### CHECK IF NEEDED ###

        return

    # Find best edge to keep
    # TODO: find best peso
    best_edge_idx = torch.tensor([node_dists[n1, n2] for n1, n2 in all_out_edges]).argmin()
    best_edge = list(all_out_edges)[best_edge_idx]

    # Remove all other edges
    nx_graph.remove_edges_from(
        [item for item in list(all_out_edges) if item != list(all_out_edges)[best_edge_idx]]
    )

    n1, n2 = best_edge

    # Remove nodes from to-visit set if there
    if n1 in nodes_todo: nodes_todo.remove(n1)
    if n2 in nodes_todo: nodes_todo.remove(n2)
    # TODO: make more light

    n1_coords = pyg_graph.detections_coords[n1]
    n2_coords = pyg_graph.detections_coords[n2]

    n1_frame = pyg_graph.times[n1].tolist()[0]
    n2_frame = pyg_graph.times[n2].tolist()[0]

    # n1_coords = box_convert(n1_coords, in_fmt='xyxy', out_fmt='xywh')
    # n2_coords = box_convert(n2_coords, in_fmt='xyxy', out_fmt='xywh')

    if (n1_frame, *n1_coords.tolist()) not in nodes_dict.keys():
        nodes_dict[(n1_frame, *n1_coords.tolist())] = det_id

    if (n2_frame, *n2_coords.tolist()) not in nodes_dict.keys():
        nodes_dict[(n2_frame, *n2_coords.tolist())] = det_id

    n1_id = nodes_dict[(n1_frame, *n1_coords.tolist())]

    # <frame>, <id>, <bb_left>, <bb_top>, <bb_width>, <bb_height>, <conf>, <x>, <y>, <z>

    out.append(
        {'frame': n1_frame,
         'id': n1_id,
         'bb_left': n1_coords[0].item(),
         'bb_top': n1_coords[1].item(),
         'bb_width': n1_coords[2].item(),
         'bb_height': n1_coords[3].item(),
         'conf': -1,
         'x': -1,
         'y': -1,
         'z': -1}
    )

    # print(
    #     n1_frame, n1_id, n1_coords[0].item(), n1_coords[1].item(), n1_coords[2].item(), n1_coords[3].item(),
    #     -1, -1, -1, -1
    # )

    depth += 1

    build_trajectory_rec(n2, pyg_graph, nx_graph, node_dists, det_id, nodes_todo, nodes_dict, out, depth=depth)


def build_trajectories(graph, preds, ths=.33):
    global nodes_dict

    pyg_graph = graph.clone().detach()

    mask = torch.where(preds > float(ths), True, False)

    masked_preds = preds[mask]  # Only for regression

    pred_edges = pyg_graph.edge_index.t()[mask]

    out = []

    det_id = 0

    node_dists = torch.cdist(pyg_graph.pos, pyg_graph.pos, p=2)

    nodes_todo = list(range(pyg_graph.num_nodes))  # Used as a stack

    # Create a NetwrokX graph
    pyg_graph.edge_index = pred_edges.t()
    nx_graph = to_networkx(pyg_graph)

    while len(nodes_todo) > 0:
        build_trajectory_rec(nodes_todo.pop(), pyg_graph, nx_graph, node_dists, det_id, nodes_todo, nodes_dict, out)
        det_id += 1
        print(f"\rRemaining nodes to visit: {len(nodes_todo)}     ", end="")

    return out


In [7]:
# out = build_trajectories(data, preds)
out = build_trajectories(data, data.y)


AttributeError: 'GlobalStorage' object has no attribute 'detections_coords'

In [128]:
import pandas as pd

df = pd.DataFrame.from_dict(out)

df

Unnamed: 0,frame,id,bb_left,bb_top,bb_width,bb_height,conf,x,y,z
0,14,0,679.0,451.0,713.0,536.0,-1,-1,-1,-1
1,14,1,1033.0,450.0,1057.0,518.0,-1,-1,-1,-1
2,14,2,595.0,427.0,613.0,469.0,-1,-1,-1,-1
3,14,3,578.0,429.0,598.0,472.0,-1,-1,-1,-1
4,14,4,1003.0,453.0,1021.0,514.0,-1,-1,-1,-1
...,...,...,...,...,...,...,...,...,...,...
4140,9,238,912.0,484.0,1009.0,593.0,-1,-1,-1,-1
4141,10,199,912.0,484.0,1009.0,593.0,-1,-1,-1,-1
4142,11,159,912.0,484.0,1009.0,593.0,-1,-1,-1,-1
4143,12,119,912.0,484.0,1009.0,593.0,-1,-1,-1,-1
