In [None]:
import torch
from torch.utils.data import DataLoader
from detector import SSD_CBAM_MNV3
from ssd_head import SSDLoss
from transforms_lowlight import get_train_transforms, get_val_transforms
from train import collate, HazardDataset

img_size = 320
num_classes = 1 + 4  # background + hazard classes
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint_path = "ssd_cbam_mnv3_lowlight.pt"  # upload your last model

# Load model
model = SSD_CBAM_MNV3(num_classes=num_classes, img_size=img_size).to(device)
ckpt = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(ckpt['model'])
print("Loaded weights from", checkpoint_path)

# Load merged dataset
train_ds = HazardDataset(
    img_dir="merged/train",
    ann_file="merged/train/_annotations.coco.json",
    transforms=get_train_transforms(img_size)
)
val_ds = HazardDataset(
    img_dir="merged/valid",
    ann_file="merged/valid/_annotations.coco.json",
    transforms=get_val_transforms(img_size)
)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2, collate_fn=collate)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2, collate_fn=collate)

loss_fn = SSDLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

best_val = float('inf')
for epoch in range(10):
    model.train()
    total_train_loss = 0
    for images, targets in train_loader:
        images = images.to(device)
        targets = [{'boxes': t['boxes'].to(device), 'labels': t['labels'].to(device)} for t in targets]

        cls_logits, box_deltas, anchors = model(images)
        loss = loss_fn(cls_logits, box_deltas, anchors, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)

    # Validation
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for images, targets in val_loader:
            images = images.to(device)
            targets = [{'boxes': t['boxes'].to(device), 'labels': t['labels'].to(device)} for t in targets]
            cls_logits, box_deltas, anchors = model(images)
            loss = loss_fn(cls_logits, box_deltas, anchors, targets)
            total_val_loss += loss.item()
    avg_val_loss = total_val_loss / len(val_loader)

    print(f"Epoch {epoch+1}: Train {avg_train_loss:.4f} | Val {avg_val_loss:.4f}")

    if avg_val_loss < best_val:
        best_val = avg_val_loss
        torch.save({'model': model.state_dict()}, "ssd_cbam_mnv3_lowlight_finetuned_merged.pt")
        print("Saved improved model")
