In [1]:
import os
import time
import datetime
import re
import shutil
from collections import deque
import argparse

import numpy as np
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

from simulation.simulations.data_generator import DataGenerator
from TransformerMOT.util.misc import save_checkpoint, update_logs
from TransformerMOT.util.load_config_files import load_yaml_into_dotdict
from TransformerMOT.util.plotting import output_truth_plot, compute_avg_certainty, get_constrastive_ax, get_false_ax, \
    get_total_loss_ax, get_state_uncertainties_ax
from TransformerMOT.util.logger import Logger
from TransformerMOT.models.BOMTv1 import BOMT
from simulation.simulations.data_generator import DataGenerator, get_single_training_example


In [2]:
task_params = r"C:\Users\chiny\OneDrive - Nanyang Technological University\Y3S2 (Internship)\MultiTracking\configs\tasks\task1.yaml"
model_params = r"C:\Users\chiny\OneDrive - Nanyang Technological University\Y3S2 (Internship)\MultiTracking\configs\models\BOMTv1.yaml"

params = load_yaml_into_dotdict(task_params)
params.update(load_yaml_into_dotdict(model_params))

if params.general.pytorch_and_numpy_seed is None:
    random_data = os.urandom(4)
    params.general.pytorch_and_numpy_seed = int.from_bytes(random_data, byteorder="big")
print(f'Using seed: {params.general.pytorch_and_numpy_seed}')

if params.training.device == 'auto':
    params.training.device = 'cuda' if torch.cuda.is_available() else 'cpu'

Using seed: 30386090


In [3]:
data_generator = DataGenerator(params=params)
training_nested_tensor, labels, unique_measurement_ids = data_generator.get_batch()
training_nested_tensor.tensors.shape

torch.Size([2, 230, 4])

In [4]:
model1 = BOMT(params).to("cuda")
res = model1(training_nested_tensor.to("cuda"))

1: torch.Size([2, 2, 64, 230])
2: torch.Size([2, 230, 2])
3: torch.Size([2, 230, 256])
preprocessed_measurements : torch.Size([230, 2, 256])
time_encoding : torch.Size([230, 2, 256])
4: torch.Size([230, 2, 256])


(<TransformerMOT.util.misc.Prediction at 0x1abbd3cd880>,
 [<TransformerMOT.util.misc.Prediction at 0x1abbd3cd8e0>,
  <TransformerMOT.util.misc.Prediction at 0x1abbd3abd60>,
  <TransformerMOT.util.misc.Prediction at 0x1abbd30cac0>,
  <TransformerMOT.util.misc.Prediction at 0x1abbd30c940>,
  <TransformerMOT.util.misc.Prediction at 0x1ab91d7f1c0>],
 <TransformerMOT.util.misc.Prediction at 0x1abbd3cd9a0>,
 {'contrastive_classifications': tensor([[[-1.0000e+08, -5.4650e+00, -5.4487e+00,  ..., -5.4251e+00,
            -5.4529e+00, -5.4173e+00],
           [-5.4668e+00, -1.0000e+08, -5.4231e+00,  ..., -5.3968e+00,
            -5.4449e+00, -5.4190e+00],
           [-5.4506e+00, -5.4232e+00, -1.0000e+08,  ..., -5.4424e+00,
            -5.4272e+00, -5.4386e+00],
           ...,
           [-5.4421e+00, -5.4121e+00, -5.4576e+00,  ..., -1.0000e+08,
            -5.4692e+00, -5.4108e+00],
           [-5.4493e+00, -5.4397e+00, -5.4218e+00,  ..., -5.4486e+00,
            -1.0000e+08, -5.4338e+00],
   