In [None]:
import sys
sys.path.insert(0, './src')

import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from config import load_config
from dataset_coco import COCODetectionDataset
from models import DETR
from lightning_module import DetectionModule


In [None]:
config = load_config('./configs/config_coco_detr_person.json')

DATA_ROOT = config['data']['root_dir']
TRAIN_SPLIT = config['data']['split']
VAL_SPLIT = config['data']['val_split']
TEST_SPLIT = config['data']['test_split']
TARGET_SIZE = tuple(config['data']['target_size'])
BATCH_SIZE = config['data']['batch_size']
NUM_WORKERS = config['data']['num_workers']
MIN_AREA = config['data']['min_area']
FILTER_CATEGORIES = config['data']['filter_categories']
NUM_CLASSES = config['model']['num_classes']

MAX_EPOCHS = config['training']['max_epochs']
LEARNING_RATE = config['training']['learning_rate']
WEIGHT_DECAY = config['training']['weight_decay']
OPTIMIZER = config['training']['optimizer']

NUM_QUERIES = config['model']['num_queries']
EMB_DIM = config['model']['emb_dim']
NHEAD = config['model']['nhead']
ENC_LAYERS = config['model']['enc_layers']
DEC_LAYERS = config['model']['dec_layers']
PRETRAINED = config['model']['pretrained']

print(f"Model: DETR")
print(f"Classes: {NUM_CLASSES} ({', '.join(FILTER_CATEGORIES)})")
print(f"Num queries: {NUM_QUERIES}")
print(f"Encoder layers: {ENC_LAYERS}, Decoder layers: {DEC_LAYERS}")


In [None]:
def collate_fn(batch):
    images = torch.stack([item[0] for item in batch])
    boxes = [item[1] for item in batch]
    labels = [item[2] for item in batch]
    return images, boxes, labels

train_dataset = COCODetectionDataset(
    root_dir=DATA_ROOT,
    split=TRAIN_SPLIT,
    target_size=TARGET_SIZE,
    min_area=MIN_AREA,
    filter_categories=FILTER_CATEGORIES
)

val_dataset = COCODetectionDataset(
    root_dir=DATA_ROOT,
    split=VAL_SPLIT,
    target_size=TARGET_SIZE,
    min_area=MIN_AREA,
    filter_categories=FILTER_CATEGORIES
)

test_dataset = COCODetectionDataset(
    root_dir=DATA_ROOT,
    split=TEST_SPLIT,
    target_size=TARGET_SIZE,
    min_area=MIN_AREA,
    filter_categories=FILTER_CATEGORIES
)

print(f'Train samples: {len(train_dataset)}')
print(f'Val samples: {len(val_dataset)}')
print(f'Test samples: {len(test_dataset)}')


In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn,
    pin_memory=True
)

print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')
print(f'Test batches: {len(test_loader)}')


In [None]:
model = DETR(
    num_classes=NUM_CLASSES,
    emb_dim=EMB_DIM,
    num_queries=NUM_QUERIES,
    nhead=NHEAD,
    enc_layers=ENC_LAYERS,
    dec_layers=DEC_LAYERS,
    pretrained=PRETRAINED
)

lightning_module = DetectionModule(
    model=model,
    num_classes=NUM_CLASSES,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    optimizer=OPTIMIZER
)

print(f"Model initialized")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")


In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath=config['training']['checkpoint_dirpath'],
    filename=config['training']['checkpoint_filename'],
    monitor=config['training']['checkpoint_monitor'],
    mode=config['training']['checkpoint_mode'],
    save_top_k=config['training']['checkpoint_save_top_k'],
    verbose=True
)

early_stopping_callback = EarlyStopping(
    monitor=config['training']['checkpoint_monitor'],
    patience=config['training']['early_stopping_patience'],
    mode=config['training']['checkpoint_mode'],
    verbose=True
)

logger = TensorBoardLogger(
    save_dir=config['training']['log_dir'],
    name=config['training']['experiment_name']
)

trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    accelerator=config['hardware']['accelerator'],
    devices=config['hardware']['devices'],
    callbacks=[checkpoint_callback, early_stopping_callback],
    logger=logger,
    log_every_n_steps=config['training']['log_every_n_steps'],
    deterministic=False
)

print("Trainer configured")
print(f"Max epochs: {MAX_EPOCHS}")
print(f"Device: {config['hardware']['accelerator']}")


In [None]:
trainer.fit(lightning_module, train_loader, val_loader)


In [None]:
best_model_path = checkpoint_callback.best_model_path
print(f"Best model path: {best_model_path}")

trainer.test(lightning_module, test_loader, ckpt_path=best_model_path if best_model_path else None)


## Prediction Visualization


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

images_vis, boxes_vis, labels_vis = next(iter(test_loader))
images_vis = images_vis[:4].to(lightning_module.device)

with torch.no_grad():
    outputs = lightning_module(images_vis)
    preds = lightning_module.model.postprocess(outputs, conf_threshold=0.8)

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

fig, axes = plt.subplots(2, 2, figsize=(16, 16))
axes = axes.flatten()

for i in range(4):
    img = images_vis[i].cpu().numpy().transpose(1, 2, 0)
    img = img * std + mean
    img = np.clip(img, 0, 1)
    
    axes[i].imshow(img)
    
    for j in range(len(boxes_vis[i])):
        cx, cy, w, h = boxes_vis[i][j].numpy()
        x1 = (cx - w / 2) * TARGET_SIZE[1]
        y1 = (cy - h / 2) * TARGET_SIZE[0]
        w_px = w * TARGET_SIZE[1]
        h_px = h * TARGET_SIZE[0]
        
        rect = patches.Rectangle((x1, y1), w_px, h_px, linewidth=2, edgecolor='lime', facecolor='none', linestyle='--')
        axes[i].add_patch(rect)
        axes[i].text(x1, y1-5, 'GT', color='lime', fontsize=10, weight='bold',
                     bbox=dict(facecolor='black', alpha=0.7))
    
    pred_boxes = preds[i]['boxes']
    pred_scores = preds[i]['scores']
    for j in range(len(pred_boxes)):
        cx, cy, w, h = pred_boxes[j].cpu().numpy()
        score = pred_scores[j].cpu().item()
        x1 = (cx - w / 2) * TARGET_SIZE[1]
        y1 = (cy - h / 2) * TARGET_SIZE[0]
        w_px = w * TARGET_SIZE[1]
        h_px = h * TARGET_SIZE[0]
        
        color = 'red' if score < 0.5 else 'yellow' if score < 0.7 else 'cyan'
        rect = patches.Rectangle((x1, y1), w_px, h_px, linewidth=2, edgecolor=color, facecolor='none')
        axes[i].add_patch(rect)
        axes[i].text(x1+w_px, y1, f'{score:.2f}', color=color, fontsize=8, weight='bold',
                     bbox=dict(facecolor='black', alpha=0.7))
    
    axes[i].set_title(f'Image {i+1}: GT={len(boxes_vis[i])} (green dashed), Pred={len(pred_boxes)} (colored)')
    axes[i].axis('off')

plt.suptitle('Detection Results: GT (green dashed) vs Predictions (colored by confidence)', fontsize=14)
plt.tight_layout()
plt.show()

print(f"\nDetection summary:")
for i in range(4):
    print(f"Image {i+1}: GT={len(boxes_vis[i])}, Predictions={len(preds[i]['boxes'])}")


## Internal Components


### 1. Position Embedding: Spatial Patterns


In [None]:
import torch.nn.functional as F

images_int, _, _ = next(iter(test_loader))
images_int = images_int[:1].to(lightning_module.device)

with torch.no_grad():
    features = lightning_module.model.backbone(images_int)
    pos_embed = lightning_module.model.pos_emb(features)

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
for i in range(8):
    ax = axes[i // 4, i % 4]
    ax.imshow(pos_embed[0, i * 32].cpu(), cmap='RdBu')
    ax.set_title(f'Channel {i*32}')
    ax.axis('off')
plt.suptitle('Position Embedding Channels (sin/cos patterns)', fontsize=14)
plt.tight_layout()
plt.show()
print(f'Position embedding shape: {pos_embed.shape}')


### 2. Query Attention: What Each Query Sees


In [None]:
with torch.no_grad():
    outputs = lightning_module(images_int)
    
    mem = features.flatten(2).permute(0, 2, 1)
    mem_pos = pos_embed.flatten(2).permute(0, 2, 1)
    
    mem_encoded = lightning_module.model.transformer_encoder(mem + mem_pos)
    
    queries = lightning_module.model.queries.weight.unsqueeze(0).expand(1, -1, -1)
    q_pos = lightning_module.model.query_pos.weight.unsqueeze(0).expand(1, -1, -1)
    
    Q = queries + q_pos
    K = mem_encoded + mem_pos
    attention = torch.matmul(Q, K.transpose(-2, -1)) / (Q.shape[-1] ** 0.5)
    attention = torch.softmax(attention, dim=-1)

H, W = features.shape[2:]
attention_maps = attention[0].view(100, H, W)

class_probs = torch.softmax(outputs['class_logits'][0], dim=-1)
person_probs = class_probs[:, 0]
top_queries = torch.argsort(person_probs, descending=True)[:6]

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for idx, q in enumerate(top_queries):
    ax = axes[idx // 3, idx % 3]
    attn_map = attention_maps[q].cpu()
    attn_map = F.interpolate(attn_map.unsqueeze(0).unsqueeze(0), 
                             size=(TARGET_SIZE[0], TARGET_SIZE[1]), mode='bilinear')[0, 0]
    ax.imshow(attn_map, cmap='hot')
    ax.set_title(f'Query {q.item()}, confidence={person_probs[q]:.3f}')
    ax.axis('off')
plt.suptitle('Query-to-Pixel Attention Maps (top 6 confident queries)', fontsize=14)
plt.tight_layout()
plt.show()


### 3. BBox Predictions: Visualization


In [None]:
import matplotlib.patches as patches

bbox_preds = outputs['bbox_pred'][0]

img = images_int[0].cpu().numpy().transpose(1, 2, 0)
img = img * std + mean
img = np.clip(img, 0, 1)

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for idx, q in enumerate(top_queries):
    ax = axes[idx // 3, idx % 3]
    ax.imshow(img)
    
    cx, cy, w, h = bbox_preds[q].cpu().numpy()
    x1 = (cx - w / 2) * TARGET_SIZE[1]
    y1 = (cy - h / 2) * TARGET_SIZE[0]
    w_px = w * TARGET_SIZE[1]
    h_px = h * TARGET_SIZE[0]
    
    rect = patches.Rectangle((x1, y1), w_px, h_px, linewidth=3, edgecolor='red', facecolor='none')
    ax.add_patch(rect)
    ax.set_title(f'Query {q.item()}, conf={person_probs[q]:.3f}')
    ax.axis('off')
plt.suptitle('BBox Predictions (top 6 confident queries)', fontsize=14)
plt.tight_layout()
plt.show()
print(f'BBox predictions shape: {bbox_preds.shape}')
print(f'BBox coords range: [{bbox_preds.min():.3f}, {bbox_preds.max():.3f}]')


### 4. Hungarian Matching: Cost Matrix


In [None]:
from scipy.optimize import linear_sum_assignment

test_img, test_boxes, test_labels = next(iter(test_loader))
test_img = test_img[:1].to(lightning_module.device)

targets = [{
    'boxes': test_boxes[0].to(lightning_module.device),
    'labels': test_labels[0].to(lightning_module.device)
}]

with torch.no_grad():
    test_outputs = lightning_module(test_img)
    
    class_probs = torch.softmax(test_outputs['class_logits'][0], dim=-1)
    bbox_pred = test_outputs['bbox_pred'][0]
    
    gt_labels = targets[0]['labels']
    gt_boxes = targets[0]['boxes']
    
    if len(gt_labels) > 0:
        cost_class = -class_probs[:, gt_labels]
        cost_bbox = torch.cdist(bbox_pred, gt_boxes, p=1)
        
        def box_cxcywh_to_xyxy(boxes):
            cx, cy, w, h = boxes.unbind(-1)
            return torch.stack([cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h], dim=-1)
        
        def box_iou(boxes1, boxes2):
            area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
            area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
            lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])
            rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])
            wh = (rb - lt).clamp(min=0)
            inter = wh[:, :, 0] * wh[:, :, 1]
            union = area1[:, None] + area2 - inter
            return inter / (union + 1e-6)
        
        def generalized_box_iou(boxes1, boxes2):
            iou = box_iou(boxes1, boxes2)
            lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
            rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
            wh = (rb - lt).clamp(min=0)
            area = wh[:, :, 0] * wh[:, :, 1]
            union_area = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
            union_area = union_area[:, None] + (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
            union_area = union_area - iou * union_area
            return iou - (area - union_area) / (area + 1e-6)
        
        pred_boxes_xyxy = box_cxcywh_to_xyxy(bbox_pred)
        gt_boxes_xyxy = box_cxcywh_to_xyxy(gt_boxes)
        cost_giou = -generalized_box_iou(pred_boxes_xyxy, gt_boxes_xyxy)
        
        cost_matrix = 2.0 * cost_class + 5.0 * cost_bbox + 2.0 * cost_giou
        
        pred_idx, gt_idx = linear_sum_assignment(cost_matrix.cpu().numpy())
        
        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
        
        im1 = axes[0].imshow(cost_class.cpu(), aspect='auto', cmap='RdYlGn_r')
        axes[0].set_title('Cost: Class CE')
        axes[0].set_xlabel('GT instances')
        axes[0].set_ylabel('Queries')
        plt.colorbar(im1, ax=axes[0])
        
        im2 = axes[1].imshow(cost_bbox.cpu(), aspect='auto', cmap='RdYlGn_r')
        axes[1].set_title('Cost: L1 BBox')
        axes[1].set_xlabel('GT instances')
        axes[1].set_ylabel('Queries')
        plt.colorbar(im2, ax=axes[1])
        
        im3 = axes[2].imshow(cost_giou.cpu(), aspect='auto', cmap='RdYlGn_r')
        axes[2].set_title('Cost: GIoU')
        axes[2].set_xlabel('GT instances')
        axes[2].set_ylabel('Queries')
        plt.colorbar(im3, ax=axes[2])
        
        im4 = axes[3].imshow(cost_matrix.cpu(), aspect='auto', cmap='RdYlGn_r')
        axes[3].set_title('Total Cost + Matches')
        axes[3].set_xlabel('GT instances')
        axes[3].set_ylabel('Queries')
        for p, g in zip(pred_idx, gt_idx):
            axes[3].plot(g, p, 'ro', markersize=10)
        plt.colorbar(im4, ax=axes[3])
        
        plt.suptitle('Hungarian Matching: Cost Matrix Visualization', fontsize=14)
        plt.tight_layout()
        plt.show()
        
        print(f'Cost matrix shape: {cost_matrix.shape}')
        print(f'Num matches: {len(pred_idx)}')
        print(f'Matched queries: {pred_idx.tolist()}')
        print(f'Matched GT indices: {gt_idx.tolist()}')
    else:
        print('No GT instances in this image')
