In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
from torchvision.ops import nms

import gc
import os
from PIL import Image
from torchinfo import summary

from utils import DetectionLoss
from esanet_model import DetectionModel

from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from tqdm import tqdm
import warnings

In [None]:
warnings.filterwarnings('ignore')

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class CustomDataset(Dataset):
    def __init__(self, annotation_file, img_dir, transform=None, scales=[(16, 16), (32, 32), (64, 64), (128, 128)]):
        self.coco = COCO(annotation_file)
        self.img_dir = img_dir
        self.transform = transform
        self.image_ids = list(self.coco.imgs.keys())
        self.scales = scales  # feature map scales for p1, p2, p3, p4

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        image = Image.open(os.path.join(self.img_dir, img_info['file_name'])).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        annotations = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))
        labels, boxes = self.generate_targets(annotations, img_info)
        return img_id, image, labels, boxes

    def generate_targets(self, annotations, img_info):
        w_img, h_img = img_info['width'], img_info['height']
        labels_list = []
        bbox_list = []
        
        for scale in self.scales:
            cls_targets = torch.zeros((scale[1], scale[0]), dtype=torch.int64)  # height, width
            bbox_targets = torch.zeros((scale[1], scale[0], 4), dtype=torch.float32)  # height, width, 4
            
            for ann in annotations:
                category = ann['category_id']
                bbox = ann['bbox']
                center_x = bbox[0] + bbox[2] / 2
                center_y = bbox[1] + bbox[3] / 2
                scale_x, scale_y = scale[0] / w_img, scale[1] / h_img
                grid_x, grid_y = int(center_x * scale_x), int(center_y * scale_y)
                
                if 0 <= grid_x < scale[0] and 0 <= grid_y < scale[1]:
                    cls_targets[grid_y, grid_x] = category
                    bbox_targets[grid_y, grid_x] = torch.tensor([
                        (center_x * scale_x - grid_x),  # normalized center x
                        (center_y * scale_y - grid_y),  # normalized center y
                        bbox[2] * scale_x,              # normalized width
                        bbox[3] * scale_y               # normalized height
                    ])
            
            labels_list.append(cls_targets)
            bbox_list.append(bbox_targets)
        
        return labels_list, bbox_list


def collate_fn(batch):
    img_ids, images, labels_batch, boxes_batch = zip(*batch)
    images = torch.stack(images)
    labels_batch = [torch.stack(labels) for labels in zip(*labels_batch)]
    boxes_batch = [torch.stack(boxes) for boxes in zip(*boxes_batch)]
    return img_ids, images, labels_batch, boxes_batch

In [None]:
# Define transforms and initialize DataLoader for training and validation sets
transform_train = T.Compose([
    T.Resize((512, 512)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.5),
    T.RandomRotation(degrees=15),
    T.ToTensor(),
])

transform_val = T.Compose([
    T.Resize((512, 512)),
    T.ToTensor(),
])

# Set dataset paths for training and validation
train_annotation_file = '../IP102_Rice/annotations_train_rice.json'
train_img_dir = '../IP102_Rice/images/train'
val_annotation_file = '../IP102_Rice/annotations_val_rice.json'
val_img_dir = '../IP102_Rice/images/val'

# Initialize datasets and DataLoaders for train and validation sets
train_dataset = CustomDataset(train_annotation_file, train_img_dir, transform=transform_train)
train_dataloader = DataLoader(train_dataset, batch_size=5, shuffle=True, collate_fn=collate_fn)
val_dataset = CustomDataset(val_annotation_file, val_img_dir, transform=transform_val)
val_dataloader = DataLoader(val_dataset, batch_size=5, shuffle=False, collate_fn=collate_fn)

In [None]:
# Initialize DetectionModel and DetectionLoss
num_classes = len(train_dataset.coco.getCatIds()) + 1
model = DetectionModel(num_classes=num_classes).to(device)
print(summary(model, input_size=(1, 3, 512, 512)))

In [None]:
# Define loss and optimizer
criterion = DetectionLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0

    for img_ids, images, labels_list, boxes_list in tqdm(dataloader):
        images = images.to(device)
        optimizer.zero_grad()

        cls_logits_list, bbox_preds_list = model(images)
        loss = 0

        for scale_idx in range(len(cls_logits_list)):
            cls_logits = cls_logits_list[scale_idx]
            bbox_preds = bbox_preds_list[scale_idx]
            labels = labels_list[scale_idx].to(device)
            boxes = boxes_list[scale_idx].to(device)

            scale_loss = criterion(cls_logits, bbox_preds, labels, boxes)
            loss += scale_loss

        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_loss = running_loss / len(dataloader)
    gc.collect()
    torch.cuda.empty_cache()

    return avg_loss

In [None]:
def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    results_list = []

    with torch.no_grad():
        for img_ids, images, labels_per_scale, boxes_per_scale in tqdm(dataloader):
            images = images.to(device)
            cls_logits_list, bbox_preds_list = model(images)
            loss = 0

            # Calculate loss for each scale
            for scale_idx, (cls_logits, bbox_preds) in enumerate(zip(cls_logits_list, bbox_preds_list)):
                labels = labels_per_scale[scale_idx].to(device)
                boxes = boxes_per_scale[scale_idx].to(device)

                scale_loss = criterion(cls_logits, bbox_preds, labels, boxes)
                loss += scale_loss

            running_loss += loss.item()

            # Process detections for NMS
            for scale_idx, (cls_logits, bbox_preds) in enumerate(zip(cls_logits_list, bbox_preds_list)):
                B, H, W, C = cls_logits.shape
                cls_probs = torch.sigmoid(cls_logits).view(B, H * W, C)
                bbox_preds = bbox_preds.view(B, H * W, 4)

                for img_idx in range(B):
                    scores, classes = cls_probs[img_idx].max(dim=1)
                    boxes = bbox_preds[img_idx]

                    # Apply NMS
                    keep = nms(boxes, scores, iou_threshold=0.5)
                    filtered_boxes = boxes[keep]
                    filtered_scores = scores[keep]
                    filtered_classes = classes[keep]

                    # Scale and convert coordinates
                    filtered_boxes = filtered_boxes * torch.tensor([W, H, W, H], device=device).float()
                    filtered_boxes = filtered_boxes[:, [1, 0, 3, 2]]  # convert to [y1, x1, y2, x2] format

                    # Store results
                    results_list.extend([
                        {
                            "image_id": int(img_ids[img_idx]),
                            "category_id": int(filtered_classes[j].item()),
                            "bbox": [float(coord) for coord in filtered_boxes[j].tolist()],
                            "score": float(filtered_scores[j].item())
                        }
                        for j in range(len(filtered_scores))
                    ])

        gc.collect()
        torch.cuda.empty_cache()
        avg_val_loss = running_loss / len(dataloader)
        return avg_val_loss, results_list

In [None]:
def coco_eval(results_list, val_annotation_file):
    if results_list:
        coco_gt = COCO(val_annotation_file)  # Load ground truth annotations
        coco_dt = coco_gt.loadRes(results_list)  # Load the results from predictions

        # Perform COCO evaluation using bounding boxes (iouType='bbox')
        coco_eval = COCOeval(coco_gt, coco_dt, iouType='bbox')
        coco_eval.evaluate()
        coco_eval.accumulate()
        coco_eval.summarize()
    else:
        print("No valid results to evaluate.")


In [None]:
# Training loop encapsulated in a function.
num_epochs = 20
for epoch in range(num_epochs):
    avg_loss_train = train(model, train_dataloader, criterion, optimizer, device)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Training Loss: {avg_loss_train:.4f}")

In [None]:
# Validate after training is complete.
avg_val_loss, results_list = validate(model, val_dataloader, criterion, device)
print(f"Validation Loss after Training: {avg_val_loss:.4f}")

In [None]:
# Evaluate the model after training.
coco_eval(results_list, val_annotation_file)