In [14]:
import torch
from torchvision.datasets import CocoDetection
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image

In [15]:
class CocoDataset(torch.utils.data.Dataset):
    def __init__(self, root, annFile, transform=None):
        self.coco = CocoDetection(root, annFile)
        self.transform = transform
    
    def __getitem__(self, index):
        img, target = self.coco[index]

        boxes = []
        labels = []

        for obj in target:
            x,y,w,h = obj['bbox']
            boxes.append([x,y,x+w,y+h])
            labels.append(obj['category_id'])
        
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        if boxes.numel() == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)

        target = {
            "boxes": boxes,
            "labels": labels
        }

        if self.transform:
            img = self.transform(img)

        return img, target
    
    def __len__(self):
        return len(self.coco)

In [16]:
def get_transform():
    return transforms.Compose([
        transforms.ToTensor()
    ])

In [17]:
train_root = 'dataset/train'
train_ann = 'dataset/train/_annotations.coco.json'
val_root = 'dataset/valid'
val_ann = 'dataset/valid/_annotations.coco.json'
test_root = 'dataset/test'
test_ann = 'dataset/test/_annotations.coco.json'

train_dataset = CocoDataset(train_root, train_ann, transform=get_transform())
val_dataset = CocoDataset(val_root, val_ann, transform=get_transform())
test_dataset = CocoDataset(test_root, test_ann, transform=get_transform())

print(f"Train Samples: {len(train_dataset)}")
print(f"Val Samples: {len(val_dataset)}")
print(f"Test Samples: {len(test_dataset)}")

loading annotations into memory...
Done (t=0.57s)
creating index...
index created!
loading annotations into memory...
Done (t=0.07s)
creating index...
index created!
loading annotations into memory...
Done (t=0.03s)
creating index...
index created!
Train Samples: 2634
Val Samples: 966
Test Samples: 458


In [18]:
def collate_batch(batch):
    images, targets = zip(*batch)
    return list(images), list(targets)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_batch, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_batch, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_batch, pin_memory=True)

In [19]:
import torchvision

num_classes = 13+1

from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
    weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT
)


In [20]:
in_features = model.roi_heads.box_predictor.cls_score.in_features

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

In [None]:
for params in model.parameters():
    params.requires_grad = False

for params in model.roi_heads.parameters():
    params.requires_grad = True

for param in model.rpn.parameters():
    param.requires_grad = True


In [None]:
import torch.optim as optim 

params = [p for p in model.parameters() if p.requires_grad]
print(len(params))

14


In [23]:
import time
import torch
from torch.amp import autocast, GradScaler
from tqdm import tqdm

# Proper device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)   # make sure optimizer exists

# Enable GradScaler only if CUDA is available
scaler = GradScaler(enabled=torch.cuda.is_available())

num_epochs = 10

for epoch in range(num_epochs):
    print("\n===========================================")
    print(f"Starting Epoch {epoch+1}/{num_epochs}")
    print("===========================================\n")

    model.train()

    total_loss = 0
    epoch_start = time.time()

    data_load_time = 0
    forward_time = 0
    backward_time = 0

    loop = tqdm(train_loader)

    for images, targets in loop:
        batch_start = time.time()

        images = [image.to(device) for image in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()

        data_load_time += time.time() - batch_start
        fw_start = time.time()

        # Correct autocast usage
        with autocast(device_type="cuda", enabled=torch.cuda.is_available()):
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

        forward_time += time.time() - fw_start
        bw_start = time.time()

        scaler.scale(losses).backward()
        scaler.step(optimizer)
        scaler.update()

        backward_time += time.time() - bw_start

        total_loss += losses.item()

        loop.set_description(f"Epoch {epoch+1}")
        loop.set_postfix(loss=losses.item())

    # ----- END OF EPOCH SUMMARY -----
    epoch_time = time.time() - epoch_start
    avg_loss = total_loss / len(train_loader)

    print("\n-------------------------------------------")
    print(f"Epoch {epoch+1} finished!")
    print(f"Average Loss: {avg_loss:.4f}")
    print(f"Epoch Time: {epoch_time:.2f}s")
    print(f"  Data Loading: {data_load_time:.2f}s ({data_load_time/epoch_time*100:.1f}%)")
    print(f"  Forward Pass: {forward_time:.2f}s ({forward_time/epoch_time*100:.1f}%)")
    print(f"  Backward Pass: {backward_time:.2f}s ({backward_time/epoch_time*100:.1f}%)")
    print("-------------------------------------------\n")


Using device: cuda

Starting Epoch 1/10



Epoch 1: 100%|██████████| 659/659 [07:32<00:00,  1.46it/s, loss=0.707]



-------------------------------------------
Epoch 1 finished!
Average Loss: 0.9720
Epoch Time: 452.17s
  Data Loading: 1.85s (0.4%)
  Forward Pass: 353.06s (78.1%)
  Backward Pass: 40.09s (8.9%)
-------------------------------------------


Starting Epoch 2/10



Epoch 2: 100%|██████████| 659/659 [20:16<00:00,  1.85s/it, loss=0.736]  



-------------------------------------------
Epoch 2 finished!
Average Loss: 0.7946
Epoch Time: 1216.56s
  Data Loading: 2.73s (0.2%)
  Forward Pass: 1057.47s (86.9%)
  Backward Pass: 104.29s (8.6%)
-------------------------------------------


Starting Epoch 3/10



Epoch 3: 100%|██████████| 659/659 [07:18<00:00,  1.50it/s, loss=0.567]



-------------------------------------------
Epoch 3 finished!
Average Loss: 0.7429
Epoch Time: 438.80s
  Data Loading: 1.70s (0.4%)
  Forward Pass: 356.84s (81.3%)
  Backward Pass: 40.51s (9.2%)
-------------------------------------------


Starting Epoch 4/10



Epoch 4: 100%|██████████| 659/659 [41:56<00:00,  3.82s/it, loss=0.569]



-------------------------------------------
Epoch 4 finished!
Average Loss: 0.7048
Epoch Time: 2516.36s
  Data Loading: 4.52s (0.2%)
  Forward Pass: 2214.28s (88.0%)
  Backward Pass: 211.56s (8.4%)
-------------------------------------------


Starting Epoch 5/10



Epoch 5: 100%|██████████| 659/659 [07:05<00:00,  1.55it/s, loss=0.694]



-------------------------------------------
Epoch 5 finished!
Average Loss: 0.6831
Epoch Time: 425.31s
  Data Loading: 1.71s (0.4%)
  Forward Pass: 343.87s (80.9%)
  Backward Pass: 38.82s (9.1%)
-------------------------------------------


Starting Epoch 6/10



Epoch 6: 100%|██████████| 659/659 [10:10<00:00,  1.08it/s, loss=0.263]



-------------------------------------------
Epoch 6 finished!
Average Loss: 0.6587
Epoch Time: 610.90s
  Data Loading: 2.05s (0.3%)
  Forward Pass: 506.28s (82.9%)
  Backward Pass: 58.87s (9.6%)
-------------------------------------------


Starting Epoch 7/10



Epoch 7: 100%|██████████| 659/659 [28:50<00:00,  2.63s/it, loss=1.22]   



-------------------------------------------
Epoch 7 finished!
Average Loss: 0.6468
Epoch Time: 1730.45s
  Data Loading: 3.39s (0.2%)
  Forward Pass: 1509.20s (87.2%)
  Backward Pass: 150.51s (8.7%)
-------------------------------------------


Starting Epoch 8/10



Epoch 8: 100%|██████████| 659/659 [44:29<00:00,  4.05s/it, loss=0.551]  



-------------------------------------------
Epoch 8 finished!
Average Loss: 0.6311
Epoch Time: 2669.55s
  Data Loading: 4.05s (0.2%)
  Forward Pass: 2362.72s (88.5%)
  Backward Pass: 224.23s (8.4%)
-------------------------------------------


Starting Epoch 9/10



Epoch 9: 100%|██████████| 659/659 [06:58<00:00,  1.57it/s, loss=0.583]



-------------------------------------------
Epoch 9 finished!
Average Loss: 0.6220
Epoch Time: 418.80s
  Data Loading: 1.55s (0.4%)
  Forward Pass: 343.87s (82.1%)
  Backward Pass: 39.63s (9.5%)
-------------------------------------------


Starting Epoch 10/10



Epoch 10: 100%|██████████| 659/659 [10:29<00:00,  1.05it/s, loss=0.489]  


-------------------------------------------
Epoch 10 finished!
Average Loss: 0.6089
Epoch Time: 629.28s
  Data Loading: 2.07s (0.3%)
  Forward Pass: 519.69s (82.6%)
  Backward Pass: 62.87s (10.0%)
-------------------------------------------






In [24]:
torch.save(model.state_dict(), "fasterrcnn_resnet50_fpn.pth")