In [56]:
import os
import shutil
import json
from sklearn.model_selection import train_test_split
import os
import shutil
from typing import List, Tuple
import json
import random
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import matplotlib.patches as patches


In [57]:
class LargeRocksDataset:
    def __init__(self, image_folder: str, json_dataset: str, output_path: str, dsm_folder: str, hillshade_folder: str):
        """
        Initialize the dataloader with additional modalities.
        
        Args:
            image_folder (str): Path to folder containing RGB images.
            json_dataset (str): Path to JSON dataset file.
            output_path (str): Path to save YOLOv8 formatted dataset.
            dsm_folder (str): Path to folder containing DSM images.
            hillshade_folder (str): Path to folder containing Hillshade images.
        """
        self.image_folder = image_folder
        self.dsm_folder = dsm_folder
        self.hillshade_folder = hillshade_folder
        self.label_file = json_dataset
        self.output_path = output_path

        # Define directories for splits
        self.splits = ["train", "val", "test"]
        self.image_dir = os.path.join(output_path, "images")
        self.dsm_dir = os.path.join(output_path, "dsm")
        self.hillshade_dir = os.path.join(output_path, "hillshade")
        self.label_dir = os.path.join(output_path, "labels")

        # Create directories for each split
        for split in self.splits:
            os.makedirs(os.path.join(self.image_dir, split), exist_ok=True)
            os.makedirs(os.path.join(self.dsm_dir, split), exist_ok=True)
            os.makedirs(os.path.join(self.hillshade_dir, split), exist_ok=True)
            os.makedirs(os.path.join(self.label_dir, split), exist_ok=True)

    def _convert_bbox(self, rel_loc: Tuple[float, float], bbox_size: Tuple[int, int], img_size: Tuple[int, int]) -> List[float]:
        """
        Convert bounding box info to YOLO format: [class_id, x_center, y_center, width, height]. (Only one class for rocks)
        """
        x_center, y_center = rel_loc
        width = bbox_size[0] / img_size[0]
        height = bbox_size[1] / img_size[1]
        return [0, x_center, y_center, width, height]  # class_id = 0 for rocks

    def process_dataset(self):
        """
        Process the dataset + convert it to YOLOv8 format with train/val/test splits.
        """
        # Load the annotations JSON
        with open(self.label_file, 'r') as f:
            data = json.load(f)

        # Iterate over each image in the dataset
        for tile in data['dataset']:
            file_name = tile['file_name']
            img_path = os.path.join(self.image_folder, file_name)
            dsm_path = os.path.join(self.dsm_folder, file_name)
            hillshade_path = os.path.join(self.hillshade_folder, file_name)

            # Check if the RGB, DSM, and Hillshade images exist
            if not os.path.exists(img_path):
                print(f"Image {img_path} not found. Skipping.")
                continue
            if not os.path.exists(dsm_path):
                print(f"DSM {dsm_path} not found. Skipping.")
                continue
            if not os.path.exists(hillshade_path):
                print(f"Hillshade {hillshade_path} not found. Skipping.")
                continue

            img_width, img_height = tile['width'], tile['height']
            annotations = tile.get('rocks_annotations', [])
            split = tile.get('split', "train")  # Default to 'train' if no split is specified

            # Copy the images to the appropriate YOLO folders
            shutil.copy(img_path, os.path.join(self.image_dir, split, file_name))
            shutil.copy(dsm_path, os.path.join(self.dsm_dir, split, file_name))
            shutil.copy(hillshade_path, os.path.join(self.hillshade_dir, split, file_name))

            # Prepare labels for this image
            label_lines = []
            for annotation in annotations:
                rel_loc = annotation['relative_within_patch_location']
                bbox_size = annotation.get('bbox_size', [30, 30])  # Default to 30x30 if bbox size is missing
                yolo_bbox = self._convert_bbox(rel_loc, bbox_size, (img_width, img_height))
                label_lines.append(" ".join(map(str, yolo_bbox)))

            # Save labels to the appropriate folder
            label_file = os.path.join(self.label_dir, split, f"{os.path.splitext(file_name)[0]}.txt")
            with open(label_file, 'w') as lf:
                lf.write("\n".join(label_lines))

        print(f"Dataset processed and saved to {self.output_path}")


In [58]:
# Paths to your data
rgb_folder = "swissImage_50cm_patches"
dsm_folder = "swissSURFACE3D_patches"
hillshade_folder = "swissSURFACE3D_hillshade_patches"
label_file = "large_rock_dataset.json"
output_path = "YOLO"

# Create and process the dataset
rocks_dataset = LargeRocksDataset(rgb_folder, label_file, output_path, dsm_folder, hillshade_folder)
rocks_dataset.process_dataset()


Dataset processed and saved to YOLO


In [59]:
class YOLODataset(Dataset):
    def __init__(self, rgb_dir, dsm_dir, hillshade_dir, label_dir, geom_transform=None, rgb_transform=None, all_transform=None):
        """
        Args:
            rgb_dir (str): Directory containing RGB images.
            dsm_dir (str): Directory containing DSM images.
            hillshade_dir (str): Directory containing Hillshade images.
            label_dir (str): Directory containing YOLO-style labels.
            geom_transform (callable): Geometric transforms applied to all modalities (e.g., flipping, rotation).
            rgb_transform (callable): Pixel-value transforms applied only to RGB images (e.g., brightness, contrast).
            all_transform (callable): Transform applied to all modalities (e.g., resizing).
        """
        self.rgb_dir = rgb_dir
        self.dsm_dir = dsm_dir
        self.hillshade_dir = hillshade_dir
        self.label_dir = label_dir
        self.geom_transform = geom_transform
        self.rgb_transform = rgb_transform
        self.all_transform = all_transform
        self.image_files = sorted(os.listdir(rgb_dir))
        self.label_files = sorted(os.listdir(label_dir))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        # Load RGB, DSM, and Hillshade images
        rgb_path = os.path.join(self.rgb_dir, self.image_files[index])
        dsm_path = os.path.join(self.dsm_dir, self.image_files[index])
        hillshade_path = os.path.join(self.hillshade_dir, self.image_files[index])
        label_path = os.path.join(self.label_dir, self.label_files[index])
        
        rgb = Image.open(rgb_path).convert("RGB")
        dsm = Image.open(dsm_path).convert("L")
        hillshade = Image.open(hillshade_path).convert("L")

        # Load labels
        with open(label_path, 'r') as f:
            labels = f.readlines()
        labels = [list(map(float, line.strip().split())) for line in labels]
        labels = torch.tensor(labels)

        # Apply geometric transforms
        if self.geom_transform:
            combined_image = {"rgb": rgb, "dsm": dsm, "hillshade": hillshade}
            combined_image, labels = self.geom_transform(combined_image, labels)
            rgb, dsm, hillshade = combined_image["rgb"], combined_image["dsm"], combined_image["hillshade"]

        # Apply RGB-specific pixel-value transforms
        if self.rgb_transform:
            rgb = self.rgb_transform(rgb)

        # Apply general transforms to all modalities (e.g., resizing, normalization)
        if self.all_transform:
            rgb = self.all_transform(rgb)
            dsm = self.all_transform(dsm)
            hillshade = self.all_transform(hillshade)

        # Convert images to numpy arrays and stack them
        rgb = np.array(rgb)
        dsm = np.expand_dims(np.array(dsm), axis=-1)
        hillshade = np.expand_dims(np.array(hillshade), axis=-1)
        combined = np.concatenate((rgb, dsm, hillshade), axis=-1)  # Shape: (H, W, 5)
        combined = combined.astype(np.float32) / 255.0  # Normalize to [0, 1]

        # Convert to PyTorch tensor
        combined = torch.tensor(combined).permute(2, 0, 1)  # Shape: (5, H, W)

        return combined, labels


In [60]:
class RandomHorizontalFlipWithBBox:
    def __init__(self, flip_prob=0.5):
        """
        Initialize the transform with a probability of flipping the image and bounding boxes.
        Args:
            flip_prob (float): Probability of applying the horizontal flip.
        """
        self.flip_prob = flip_prob

    def __call__(self, image, labels):
        """
        Apply the transformation.
        Args:
            image (PIL.Image): The input image.
            labels (torch.Tensor): The bounding box labels in YOLO format (class, x_center, y_center, width, height).
        Returns:
            image (PIL.Image): Transformed image.
            labels (torch.Tensor): Adjusted bounding box labels.
        """
        if random.random() < self.flip_prob:
            # Flip the image horizontally
            image = image.transpose(Image.FLIP_LEFT_RIGHT)
            
            # Adjust the labels
            if len(labels) > 0:
                labels[:, 1] = 1 - labels[:, 1]  # Invert the x_center for horizontal flip
        
        return image, labels
    

class RandomVerticalFlipWithBBox:
    def __init__(self, flip_prob=0.5):
        """
        Initialize the transformation with a probability of flipping vertically.
        Args:
            flip_prob (float): Probability of applying the vertical flip.
        """
        self.flip_prob = flip_prob

    def __call__(self, image, labels):
        """
        Apply the transformation.
        Args:
            image (PIL.Image): The input image.
            labels (torch.Tensor): Bounding box labels in YOLO format 
                                   (class, x_center, y_center, width, height).
        Returns:
            image (PIL.Image): Transformed image.
            labels (torch.Tensor): Adjusted labels after flipping.
        """
        if random.random() < self.flip_prob:
            # Flip the image vertically
            image = image.transpose(Image.FLIP_TOP_BOTTOM)
            
            # Adjust the bounding box labels for the flip
            if len(labels) > 0:
                labels[:, 2] = 1 - labels[:, 2]  # Invert y_center for vertical flip
        
        return image, labels

class ComposeCustomTransforms:
    def __init__(self, transforms):
        """
        Initialize with a list of custom transformations.
        Args:
            transforms (list): A list of callable custom transforms.
        """
        self.transforms = transforms

    def __call__(self, image, labels):
        """
        Apply each transform in the sequence.
        Args:
            image (PIL.Image): Input image.
            labels (torch.Tensor): YOLO-style labels.
        Returns:
            image, labels: Transformed image and labels.
        """
        for transform in self.transforms:
            image, labels = transform(image, labels)
        return image, labels

def custom_collate_fn(batch):
    """
    Custom collate function to handle batches of variable-size bounding box labels.
    Args:
        batch (list): A list of (image, labels) tuples.
    Returns:
        images (torch.Tensor): Stacked images of shape [batch_size, channels, height, width].
        labels (list): A list of label tensors, each of shape [num_boxes, 5].
    """
    images = torch.stack([item[0] for item in batch])  # Stack all images
    labels = [item[1] for item in batch]  # Keep labels as a list
    return images, labels

In [61]:
# Define transforms
custom_geom_transform = ComposeCustomTransforms([
    RandomHorizontalFlipWithBBox(flip_prob=0.5),
    RandomVerticalFlipWithBBox(flip_prob=0.5)
])

rgb_only_transform = transforms.Compose([
    transforms.ColorJitter(brightness=0.5, contrast=0.5),
])

all_transform = transforms.Compose([
    transforms.Resize((416, 416)),
    transforms.ToTensor(),  # Convert to tensor after applying augmentations
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # For RGB channels only
])


# Create dataset
train_dataset = YOLODataset(
    rgb_dir="YOLO/images/train",
    dsm_dir="YOLO/dsm/train",
    hillshade_dir="YOLO/hillshade/train",
    label_dir="YOLO/labels/train",
    geom_transform=custom_geom_transform,
    rgb_transform=rgb_only_transform,
    all_transform=all_transform
)


# Create data loader
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=0,  
    collate_fn=custom_collate_fn
)