In [1]:
import os
import time
import datetime
import re
import shutil
from collections import deque
import argparse

import numpy as np
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

from simulation.simulations.data_generator import DataGenerator
from TransformerMOT.util.misc import save_checkpoint, update_logs
from TransformerMOT.util.load_config_files import load_yaml_into_dotdict
from TransformerMOT.util.plotting import output_truth_plot, compute_avg_certainty, get_constrastive_ax, get_false_ax, \
    get_total_loss_ax, get_state_uncertainties_ax
from TransformerMOT.util.logger import Logger
from TransformerMOT.models.BOMTv1 import BOMT
from simulation.simulations.data_generator import DataGenerator, get_single_training_example


In [2]:
task_params = r"C:\Users\chiny\OneDrive - Nanyang Technological University\Y3S2 (Internship)\MultiTracking\configs\tasks\task1.yaml"
model_params = r"C:\Users\chiny\OneDrive - Nanyang Technological University\Y3S2 (Internship)\MultiTracking\configs\models\BOMTv1.yaml"

params = load_yaml_into_dotdict(task_params)
params.update(load_yaml_into_dotdict(model_params))

if params.general.pytorch_and_numpy_seed is None:
    random_data = os.urandom(4)
    params.general.pytorch_and_numpy_seed = int.from_bytes(random_data, byteorder="big")
print(f'Using seed: {params.general.pytorch_and_numpy_seed}')

if params.training.device == 'auto':
    params.training.device = 'cuda' if torch.cuda.is_available() else 'cpu'

Using seed: 2792582266


In [3]:
data_generator = DataGenerator(params=params)
training_nested_tensor, labels, unique_measurement_ids = data_generator.get_batch()
training_nested_tensor.tensors.shape

torch.Size([2, 159, 4])

In [4]:
model1 = BOMT(params).to("cuda")
res = model1(training_nested_tensor.to("cuda"))

In [5]:
res

(<TransformerMOT.util.misc.Prediction at 0x1662d2990d0>,
 [<TransformerMOT.util.misc.Prediction at 0x1662d26b490>,
  <TransformerMOT.util.misc.Prediction at 0x1662d26b9d0>,
  <TransformerMOT.util.misc.Prediction at 0x1662d0d79d0>,
  <TransformerMOT.util.misc.Prediction at 0x1662d0d7580>,
  <TransformerMOT.util.misc.Prediction at 0x1662d0b6e20>],
 <TransformerMOT.util.misc.Prediction at 0x1662d26bfd0>,
 {'contrastive_classifications': tensor([[[-1.0000e+08, -4.8390e+00, -4.8603e+00,  ..., -1.0000e+08,
            -1.0000e+08, -1.0000e+08],
           [-4.8406e+00, -1.0000e+08, -4.8707e+00,  ..., -1.0000e+08,
            -1.0000e+08, -1.0000e+08],
           [-4.8720e+00, -4.8808e+00, -1.0000e+08,  ..., -1.0000e+08,
            -1.0000e+08, -1.0000e+08],
           ...,
           [-5.0689e+00, -5.0689e+00, -5.0689e+00,  ..., -5.0689e+00,
            -5.0689e+00, -5.0689e+00],
           [-5.0689e+00, -5.0689e+00, -5.0689e+00,  ..., -5.0689e+00,
            -5.0689e+00, -5.0689e+00],
   