In [141]:
%load_ext autoreload
%autoreload 2

from new_model import GNLightning
from utils import load_config
import torch

torch.cuda.empty_cache()

model = GNLightning(
    d_model=128,
    lr=1e-3)

device = torch.device('cuda:0')
model.to(device)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


GNLightning(
  (gnet): GraphNetwork(
    (conv1): GCNConv(8, 128)
    (conv2): GCNConv(128, 128)
    (lin1): Linear(in_features=256, out_features=128, bias=True)
    (lin_final): Linear(in_features=128, out_features=3, bias=True)
  )
  (criterion): NLLLoss()
)

In [142]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer

# Set up checkpointing and trainer
checkpoint_callback = ModelCheckpoint(
    monitor="validation/loss",
    filename="detr-{epoch:02d}-{validation/loss:.2f}",
    save_top_k=3,
    mode="min",
    save_last=True,
    dirpath=r"C:\Users\tangy\Downloads\DETR-GFTE\checkpoints"
)

trainer = Trainer(
    max_epochs=50,
    log_every_n_steps=20,
    logger=None,
    callbacks=[checkpoint_callback],
    val_check_interval=0.3,  # Run validation 2 times per epoch
    devices=1,
    accelerator="gpu",
    accumulate_grad_batches=1)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [143]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type         | Params | Mode 
---------------------------------------------------
0 | gnet      | GraphNetwork | 50.9 K | train
1 | criterion | NLLLoss      | 0      | train
---------------------------------------------------
50.9 K    Trainable params
0         Non-trainable params
50.9 K    Total params
0.204     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Batch 0 - loss: 1.0905755758285522
Batch 1 - loss: 1.07924222946167
Batch 2 - loss: 1.0686794519424438
Batch 3 - loss: 1.0580402612686157
Batch 4 - loss: 1.0471725463867188
Batch 5 - loss: 1.0361846685409546
Batch 6 - loss: 1.0251452922821045
Batch 7 - loss: 1.014190435409546
Batch 8 - loss: 1.003840684890747
Batch 9 - loss: 0.9949822425842285
Batch 10 - loss: 0.9885658621788025
Batch 11 - loss: 0.9852936863899231
Batch 12 - loss: 0.985222339630127
Batch 13 - loss: 0.9871637225151062
Batch 14 - loss: 0.9890945553779602
Batch 15 - loss: 0.9895674586296082
Batch 16 - loss: 0.9884809255599976
Batch 17 - loss: 0.9866090416908264
Batch 18 - loss: 0.9848241806030273
Batch 19 - loss: 0.983613908290863
Batch 20 - loss: 0.9829282760620117
Batch 21 - loss: 0.9824470281600952
Batch 22 - loss: 0.9818964600563049
Batch 23 - loss: 0.9811367988586426
Batch 24 - loss: 0.9801602363586426
Batch 25 - loss: 0.9790391325950623
Batch 26 - loss: 0.9778947234153748
Batch 27 - loss: 0.9768728613853455
Batch 28

In [None]:
from new_data import TrainDataset

dataset = TrainDataset(r'C:\Users\tangy\Downloads\DETR-GFTE\datasets\gnet_train.jsonl')

In [144]:
import cv2
from utils import draw_bboxes_and_edges
import matplotlib.pyplot as plt

with torch.no_grad():

    probs, _, bbox_pairs = model(dataset[0])

print(probs, bbox_pairs)

tensor([[-4.8754e-01, -9.5227e-01, -1.2747e+01],
        [-2.3956e+01, -1.4563e+01, -4.7684e-07],
        [-6.6687e-01, -7.2014e-01, -1.3942e+01],
        ...,
        [-2.1018e-02, -8.6458e+00, -3.8813e+00],
        [-7.5621e-01, -6.3383e-01, -1.3215e+01],
        [-4.1899e-01, -1.0721e+00, -1.6756e+01]], device='cuda:0') tensor([[  1.,   4.,  26.,  ...,   4.,  41.,   9.],
        [  1.,   4.,  26.,  ...,  17.,  45.,  10.],
        [219.,   4.,  41.,  ...,   4.,  26.,   9.],
        ...,
        [336., 325.,  40.,  ..., 353.,  19.,  10.],
        [  8.,  31.,  15.,  ...,   4.,  41.,   9.],
        [219., 227.,  19.,  ..., 185.,  40.,  10.]], device='cuda:0')


In [145]:
print(bbox_pairs.tolist())
print(probs.tolist())

[[1.0, 4.0, 26.0, 9.0, 219.0, 4.0, 41.0, 9.0], [1.0, 4.0, 26.0, 9.0, 1.0, 17.0, 45.0, 10.0], [219.0, 4.0, 41.0, 9.0, 1.0, 4.0, 26.0, 9.0], [219.0, 4.0, 41.0, 9.0, 336.0, 4.0, 27.0, 9.0], [219.0, 4.0, 41.0, 9.0, 1.0, 17.0, 45.0, 10.0], [219.0, 4.0, 41.0, 9.0, 8.0, 31.0, 15.0, 10.0], [219.0, 4.0, 41.0, 9.0, 219.0, 31.0, 19.0, 10.0], [336.0, 4.0, 27.0, 9.0, 219.0, 4.0, 41.0, 9.0], [336.0, 4.0, 27.0, 9.0, 456.0, 4.0, 28.0, 9.0], [336.0, 4.0, 27.0, 9.0, 456.0, 17.0, 19.0, 10.0], [336.0, 4.0, 27.0, 9.0, 219.0, 31.0, 19.0, 10.0], [336.0, 4.0, 27.0, 9.0, 219.0, 45.0, 19.0, 10.0], [336.0, 4.0, 27.0, 9.0, 336.0, 45.0, 40.0, 10.0], [456.0, 4.0, 28.0, 9.0, 336.0, 4.0, 27.0, 9.0], [456.0, 4.0, 28.0, 9.0, 456.0, 17.0, 19.0, 10.0], [1.0, 17.0, 45.0, 10.0, 1.0, 4.0, 26.0, 9.0], [1.0, 17.0, 45.0, 10.0, 219.0, 4.0, 41.0, 9.0], [1.0, 17.0, 45.0, 10.0, 8.0, 31.0, 15.0, 10.0], [1.0, 17.0, 45.0, 10.0, 219.0, 31.0, 19.0, 10.0], [456.0, 17.0, 19.0, 10.0, 336.0, 4.0, 27.0, 9.0], [456.0, 17.0, 19.0, 10.0, 456.0

In [68]:
print(bbox_pairs.tolist())
print(probs.tolist())

[[1.0, 4.0, 26.0, 9.0, 219.0, 4.0, 41.0, 9.0], [1.0, 4.0, 26.0, 9.0, 1.0, 17.0, 45.0, 10.0], [219.0, 4.0, 41.0, 9.0, 1.0, 4.0, 26.0, 9.0], [219.0, 4.0, 41.0, 9.0, 336.0, 4.0, 27.0, 9.0], [219.0, 4.0, 41.0, 9.0, 1.0, 17.0, 45.0, 10.0], [219.0, 4.0, 41.0, 9.0, 8.0, 31.0, 15.0, 10.0], [219.0, 4.0, 41.0, 9.0, 219.0, 31.0, 19.0, 10.0], [336.0, 4.0, 27.0, 9.0, 219.0, 4.0, 41.0, 9.0], [336.0, 4.0, 27.0, 9.0, 456.0, 4.0, 28.0, 9.0], [336.0, 4.0, 27.0, 9.0, 456.0, 17.0, 19.0, 10.0], [336.0, 4.0, 27.0, 9.0, 219.0, 31.0, 19.0, 10.0], [336.0, 4.0, 27.0, 9.0, 219.0, 45.0, 19.0, 10.0], [336.0, 4.0, 27.0, 9.0, 336.0, 45.0, 40.0, 10.0], [456.0, 4.0, 28.0, 9.0, 336.0, 4.0, 27.0, 9.0], [456.0, 4.0, 28.0, 9.0, 456.0, 17.0, 19.0, 10.0], [1.0, 17.0, 45.0, 10.0, 1.0, 4.0, 26.0, 9.0], [1.0, 17.0, 45.0, 10.0, 219.0, 4.0, 41.0, 9.0], [1.0, 17.0, 45.0, 10.0, 336.0, 4.0, 27.0, 9.0], [1.0, 17.0, 45.0, 10.0, 456.0, 4.0, 28.0, 9.0], [1.0, 17.0, 45.0, 10.0, 456.0, 17.0, 19.0, 10.0], [1.0, 17.0, 45.0, 10.0, 8.0, 31.0