In [1]:
import torch
from torch_geometric.transforms import ToDevice
from tqdm import tqdm

from model import Net
from motclass import MotDataset, build_graph
from utilities import get_best_device, save_graph

device = get_best_device()

mot20_path = "data/MOT20"


[INFO] Using MPS.


In [15]:
mot20 = MotDataset(dataset_path=mot20_path,
				   split='train',
				   linkage_window=15,
				   detections_file_folder='gt',
				   detections_file_name='gt.txt',
				   device=device,
				   dl_mode=False,
				   dtype=torch.float32)

track = mot20[3]

graph = build_graph(**track.get_data(), device=device, dtype=torch.float32)

save_graph(graph, track)


[INFO] Track has 429 frames


Reading frame data/MOT20/train/MOT20-01/img1/000429.jpg: 100%|██████████| 429/429 [00:39<00:00, 10.87it/s]
Linking all nodes: 100%|██████████| 429/429 [00:33<00:00, 12.92it/s]


[INFO] 429 frames
[INFO] 26647 total nodes
[INFO] 52432408 total edges
[INFO] 90 gt trajectories found
[INFO] 53114 total gt edges (90 trajectories)
[INFO] Graph saved as saves/MOT20/track_MOT20-01/window_16.pickle


In [5]:
mot20 = MotDataset(dataset_path=mot20_path,
				   split='train',
				   subtrack_len=15,
				   detections_file_folder='gt',
				   detections_file_name='gt.txt',
				   device=device,
				   dl_mode=False,
				   dtype=torch.float32)

track = mot20[3]

graph = build_graph(**track.get_data(), device=device, dtype=torch.float32)

save_graph(graph, track)


[INFO] Batch #3 | track #1 (frames 3/2405 - 18/2405)
[INFO] Track has 15 frames
[TQDM] Reading frame data/MOT20/train/MOT20-03/img1/000018.jpg: 100%|██████████| 15/15 [00:01<00:00,  9.09it/s]
[TQDM] Linking all nodes: 100%|██████████| 15/15 [00:00<00:00, 26.73it/s]
[INFO] 15 frames
[INFO] 1402 total nodes
[INFO] 1834544 total edges
[INFO] 95 gt trajectories found
[INFO] 2614 total gt edges (95 trajectories)


[INFO] Graph saved as saves/MOT20/track_MOT20-03/subtrack_3_len_15.pickle


---

In [6]:
# Try out the network

model = Net(backbone='resnet50', channels=128).to(device)

model = model.to(device)
graph = ToDevice(device.type)(graph)

out = model(graph)


In [None]:
def train(model, train_loader, loss_function, optimizer, epochs, device):
	model = model.to(device)
	model.train()

	pbar = tqdm(range(epochs))

	for epoch in pbar:

		epoch_loss = 0
		for i, data in enumerate(train_loader):

			data = ToDevice(device.type)(data)
			# data = data.to(device)

			# Forward pass
			pred_edges = model(data)  # Get the predicted edge labels
			gt_edges = data.y  # Get the true edge labels

			loss = loss_function(pred_edges, gt_edges)

			print("Loss computed")

			# Backward and optimize
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()

			epoch_loss += loss.item()

		print(f'Epoch {epoch + 1}, Loss: {epoch_loss / i}')


# Hyperparameters
backbone = 'resnet50'
l_size = 128
epochs = 10
learning_rate = 0.001

model = Net(backbone, l_size)

loss_function = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

mo20_train_dl = MotDataset(dataset_path=mot20_path,
						   split='train',
						   subtrack_len=3,
						   linkage_window=0,
						   detections_file_folder='gt',
						   detections_file_name='gt.txt',
						   device=device,
						   dl_mode=True,
						   dtype=torch.float32)

x,y = train(model, mo20_train_dl, loss_function, optimizer, epochs, device)


---

In [37]:
mot17 = MotDataset(dataset_path='data/MOT17',
				   split='train',
				   subtrack_len=15,
				   slide=15,
				   detections_file_folder='gt',
				   detections_file_name='gt.txt',
				   device=device,
				   dl_mode=False,
				   dtype=torch.float32)


In [None]:
ret = 0
for nf in mot17.frames_per_track:
	ret += nf//15

print(ret)