In [218]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import yaml
import os
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from torchvision.transforms import functional as F
from ultralytics import YOLO
from ultralytics.utils import ops
from ultralytics.utils.loss import v8OBBLoss
from pathlib import Path

In [219]:
class YOLODataset(Dataset):
    def __init__(self, data_yaml, transforms=None, mode='train'):
        super().__init__()
        self.data_yaml = data_yaml
        self.transforms = transforms
        self.mode = mode

        # Load YAML *and* check dataset
        self.data_dict, self.img_files, self.label_files = self._load_and_prepare_data(data_yaml)
        self.class_names = self.data_dict['names']
        self.num_classes = len(self.class_names)

    def _load_and_prepare_data(self, data_yaml_path):
        """Loads YAML, checks data, and constructs image and label file lists."""
        with open(data_yaml_path, 'r', errors='ignore') as f:
            data_dict = yaml.safe_load(f)

        # Ensure 'path' is an absolute path and a Path object
        data_dict['path'] = Path(os.path.abspath(data_dict['path']))

        # Check for required keys
        for key in ('train', 'validation', 'names'):
            if key not in data_dict:
                raise ValueError(f"'{key}' is missing from {data_yaml_path}")

        # --- Construct image and label file lists ---
        if self.mode == 'train':
            image_set = data_dict['train']
        elif self.mode == 'val':
            image_set = data_dict['validation']
        elif self.mode == 'test':
            image_set = data_dict['test']  # Handle test set
        else:
            raise ValueError("mode must be 'images', 'images', or 'images'")

        # Handle both directory-based and file-list-based datasets
        img_files = []
        label_files = []

        if isinstance(image_set, str) and image_set.endswith('.txt'):  # File list
            # Construct full paths for images
            with open(data_dict['path'] / image_set, 'r') as f:
                img_files = [str(data_dict['path'] / line.strip()) for line in f if line.strip()]
            label_files = [img_file.replace('images', 'labels').replace(Path(img_file).suffix, '.txt')
                           for img_file in img_files]

        elif isinstance(image_set, str): # Directory
            img_dir = data_dict['path'] / image_set
            label_dir = data_dict['path'] / image_set.replace('images', 'labels')

            if not os.path.exists(img_dir):
                raise FileNotFoundError(f"Image directory does not exist: {img_dir}")
            if not os.path.exists(label_dir):
                raise FileNotFoundError(f"Label directory does not exist: {label_dir}")

            img_files = sorted([str(f) for f in img_dir.glob('*') if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
            label_files = [str(label_dir / Path(f).name.replace(Path(f).suffix, '.txt')) for f in img_files]

        else:
            raise ValueError("'images', 'images' and 'images' must be a path to directory, or a path to .txt file")


        return data_dict, img_files, label_files

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_path = self.img_files[idx]
        label_path = self.label_files[idx]

        image = Image.open(img_path).convert("RGB")
        img_width, img_height = image.size

        labels = []
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    line = line.strip().split()
                    class_id = int(line[0])
                    # Always treat as OBB
                    coords = [float(c) for c in line[1:]]
                    labels.append([class_id] + coords)


        if len(labels) > 0:
            labels = torch.tensor(labels)
        else:
            labels = torch.zeros((0, 9))  # Always 9 for OBB (x1, y1, ..., x4, y4)

        if self.transforms:
            # Apply transforms as a dictionary
            sample = {'image': image, 'bboxes': labels[:, 1:], 'labels': labels[:, 0]}
            sample = self.transforms(sample)
            image = sample['image']
            if len(sample['bboxes']) > 0:
                labels = torch.cat((sample['labels'].unsqueeze(1), sample['bboxes']), dim=1)
            else:
                labels = torch.zeros((0, 9))  # Keep consistent OBB format
        else:
            image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0

        return image, labels


In [220]:
class ResizeAndPad:  # Custom transform
    def __init__(self, target_size):
        self.target_size = target_size  # (width, height)

    def __call__(self, sample):
        image, bboxes, labels = sample['image'], sample['bboxes'], sample['labels']
        # Resize and pad the image
        w, h = image.size
        scale = min(self.target_size[0] / w, self.target_size[1] / h)
        new_w = int(w * scale)
        new_h = int(h * scale)
        image = F.resize(image, (new_h, new_w))

        pad_w = self.target_size[0] - new_w
        pad_h = self.target_size[1] - new_h
        pad_left = pad_w // 2
        pad_top = pad_h // 2

        image = F.pad(image, (pad_left, pad_top, pad_w - pad_left, pad_h - pad_top))

        # Adjust bounding box coordinates (always OBB)
        if len(bboxes) > 0:
            img_width, img_height = w, h
            for i in range(0, bboxes.shape[1], 2):
                bboxes[:, i] = (bboxes[:, i] * img_width * scale + pad_left) / self.target_size[0]
                bboxes[:, i + 1] = (bboxes[:, i + 1] * img_height * scale + pad_top) / self.target_size[1]

        return {'image': image, 'bboxes': bboxes, 'labels': labels}

In [221]:
class ComposeTransforms:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, sample):
        for t in self.transforms:
            if isinstance(t, (transforms.RandomHorizontalFlip, transforms.RandomVerticalFlip, transforms.ToTensor, transforms.Normalize)):
                # Apply these transforms ONLY to the image part
                sample['image'] = t(sample['image'])
            else:
                # Apply other transforms (like ResizeAndPad) to the whole sample
                sample = t(sample)
        return sample

In [222]:
data_yaml_path = './datasets/yolo_obb/data.yaml'  #  Path to your data.yaml file
model_path = "./yolo11n-obb.pt"  #  Path to your YOLOv11 model

In [223]:
train_transforms = ComposeTransforms([
    ResizeAndPad((1920, 1088)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

val_transforms = ComposeTransforms([
    ResizeAndPad((1920, 1088)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [224]:
train_dataset = YOLODataset(data_yaml_path, transforms=train_transforms, mode='train')
val_dataset = YOLODataset(data_yaml_path, transforms=val_transforms, mode='val')

In [225]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0, collate_fn=lambda x: tuple(zip(*x)))

In [226]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
loaded_state = torch.load(model_path, map_location=device)
model = loaded_state['model'].float()  # Load and convert to float
model = model.to(device)

In [227]:
hyp = loaded_state.get('train_args', {}).copy()  # Use .get and .copy
if hyp:
    # Convert any 'True' or 'False' strings to boolean values
    for k, v in hyp.items():
        if v == 'True':
            hyp[k] = True
        elif v == 'False':
            hyp[k] = False
else:  # Handle situation that hyp is empty
    hyp = {  # Set default values.  Adjust these as needed!
        'lr0': 0.01,
        'lrf': 0.01,
        'momentum': 0.937,
        'weight_decay': 0.0005,
        'warmup_epochs': 3.0,
        'warmup_momentum': 0.8,
        'box': 7.5,
        'cls': 0.5,
        'dfl': 1.5,
        'hsv_h': 0.015,
        'hsv_s': 0.7,
        'hsv_v': 0.4,
        'degrees': 0.0,
        'translate': 0.1,
        'scale': 0.5,
        'shear': 0.0,
        'perspective': 0.0,
        'flipud': 0.0,
        'fliplr': 0.5,
        'mosaic': 1.0,
        'mixup': 0.0,
        'copy_paste': 0.0,
    }
num_classes = model.model[-1].nc
model.names = train_dataset.class_names  # Important: Set names
model.nc = num_classes  # Important: Set number of classes

In [228]:
def train_one_epoch(model, optimizer, train_loader, device, criterion):
    model.train()
    total_loss = 0

    for images, labels in train_loader:
        images = torch.stack(list(image.to(device) for image in images))
        targets = []
        for i, label in enumerate(labels):
            if len(label) == 0:
                targets.append(torch.zeros((0, 10), device=device))
                continue

            # Add batch index to labels, and move to device
            target = torch.cat((torch.ones(label.shape[0], 1, device=device) * i, label.to(device)), dim=1)
            targets.append(target)
        targets = torch.cat(targets, 0)

        # Forward pass
        preds = model(images)  # Get predictions

        # Prepare targets for Ultralytics' loss function
        batch_size = images.shape[0]
        targets_list = []
        for i in range(batch_size):
            indices = targets[:, 0] == i
            targets_for_image = targets[indices, 1:]  # Remove batch index for loss calculation
            targets_list.append(targets_for_image)

        # Convert targets to the format expected by v8OBBLoss
        targets_dict = {
            'batch_idx': targets[:, 0],  # Batch indices
            'cls': targets[:, 1],       # Class labels
            'bboxes': targets[:, 2:],    # Bounding box coordinates (OBB format)
            'imgsz': images.shape[2:]   # Pass image size here
        }


        # Calculate loss (using Ultralytics' DetectionLoss)
        loss = criterion(preds, targets_dict)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)


In [229]:
def validate(model, val_loader, device, criterion):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = torch.stack(list(image.to(device) for image in images))
            targets = []
            for i, label in enumerate(labels):
                if len(label) == 0:
                    targets.append(torch.zeros((0, 10), device=device))  # Always 10 for OBB
                    continue

                # Always treat as OBB, move labels to device
                target = torch.cat((torch.ones(label.shape[0], 1, device=device) * i, label.to(device)), dim=1)
                targets.append(target)

            targets = torch.cat(targets, 0)

            # Forward Pass
            preds = model(images)

            # Prepare targets for Ultralytics' loss function (same as in training)
            targets_dict = {
                'batch_idx': targets[:, 0],
                'cls': targets[:, 1],
                'bboxes': targets[:, 2:],
                'imgsz': images.shape[2:]
            }

            # Calculate the loss
            loss = criterion(preds, targets_dict)
            total_loss += loss.item()

    return total_loss / len(val_loader)

In [230]:
optimizer = optim.Adam(model.parameters(), lr=float(hyp.get('lr0', 0.01))) # Use hyp

In [231]:
from ultralytics.utils.loss import v8OBBLoss
criterion = v8OBBLoss(model)

In [232]:
num_epochs = 1

In [None]:
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, optimizer, train_loader, device, criterion)
    val_loss = validate(model, val_loader, device, criterion)

    print(f"Epoch {epoch+1}/{num_epochs}, images Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Save checkpoint (adapt as needed)
    checkpoint_path = f'./checkpoints/yolov11_obb_epoch_{epoch+1}.pt'
    os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
    torch.save({
        'epoch': epoch + 1,
        'model': model,
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
    }, checkpoint_path)

print("Training finished!")
