In [3]:
INPUT_DIR = "dataset"
OUTPUT_DIR = "processed_dataset"
TARGET_SIZE = (640, 360)  # Width, Height
DRONE_BOX_SIZE = 20       # Bounding box size in original resolution
FRAME_NAME_PREFIX = "frame"

In [4]:
os.makedirs(f"{OUTPUT_DIR}/images", exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/labels", exist_ok=True)


In [5]:
def process_video_pair(video_path, txt_path, global_frame_counter):
    # Read coordinate data from txt file
    coord_data = {}
    with open(txt_path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            frame_num = int(float(parts[0]))  # Handle 1.000000 format
            x, y = map(float, parts[1:3])
            coord_data[frame_num] = (x, y)

    # Process video
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # Calculate scaling factors
    orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    scale_x = TARGET_SIZE[0] / orig_width
    scale_y = TARGET_SIZE[1] / orig_height

    with tqdm(total=total_frames, desc=os.path.basename(video_path)) as pbar:
        video_frame_counter = 0
        while True:
            ret, frame = cap.read()
            if not ret:
                break

            video_frame_counter += 1  # Video frames are 1-indexed

            # Resize frame
            resized_frame = cv2.resize(frame, TARGET_SIZE)

            # Get coordinates for this frame (0,0 means no drone)
            x, y = coord_data.get(video_frame_counter, (0, 0))

            # Save frame with global numbering
            frame_filename = f"{FRAME_NAME_PREFIX}_{global_frame_counter:06d}.jpg"
            cv2.imwrite(f"{OUTPUT_DIR}/images/{frame_filename}", resized_frame)

            # Create label file only if coordinates are not (0,0)
            if x != 0 or y != 0:
                # Scale coordinates to target size
                x_scaled = x * scale_x
                y_scaled = y * scale_y
                box_w = DRONE_BOX_SIZE * scale_x
                box_h = DRONE_BOX_SIZE * scale_y

                # Convert to YOLO format (normalized)
                x_center = x_scaled / TARGET_SIZE[0]
                y_center = y_scaled / TARGET_SIZE[1]
                w = box_w / TARGET_SIZE[0]
                h = box_h / TARGET_SIZE[1]

                # Write label file
                label_filename = f"{FRAME_NAME_PREFIX}_{global_frame_counter:06d}.txt"
                with open(f"{OUTPUT_DIR}/labels/{label_filename}", "w") as f:
                    f.write(f"0 {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}\n")

            global_frame_counter += 1
            pbar.update(1)

    cap.release()
    return global_frame_counter




In [6]:
# Process all video files
video_files = [f for f in os.listdir(INPUT_DIR) if f.endswith(".mp4")]
global_counter = 0

In [None]:
import re
for video_file in video_files:
    cam_match = re.search(r'cam(\d+)', video_file)
    if not cam_match:
        print(f"Skipping {video_file} (couldn't extract cam number)")
        continue

    cam_number = cam_match.group(1)
    txt_file = f"cam{cam_number}.txt"

    video_path = os.path.join(INPUT_DIR, video_file)
    txt_path = os.path.join(INPUT_DIR, txt_file)

    if not os.path.exists(txt_path):
        print(f"Skipping {video_file} (missing {txt_file})")
        continue

    print(f"\nProcessing {video_file} with {txt_file}")
    global_counter = process_video_pair(video_path, txt_path, global_counter)

print("\nPreprocessing completed successfully!")
print(f"Total frames processed: {global_counter}")


Processing cam1.mp4 with cam1.txt


cam1.mp4: 100%|██████████| 4941/4941 [00:18<00:00, 271.65it/s]



Processing cam3.mp4 with cam3.txt


cam3.mp4: 100%|██████████| 4080/4080 [00:14<00:00, 279.04it/s]



Processing cam2.mp4 with cam2.txt


cam2.mp4: 100%|██████████| 8016/8016 [00:30<00:00, 264.04it/s]



Processing cam6.mp4 with cam6.txt


cam6.mp4:  26%|██▌       | 4419/17166 [00:54<02:25, 87.79it/s]

In [None]:
import random
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2

# Configuration
DATA_DIR = "processed_dataset"
NUM_SAMPLES = 10
PADDING = 10  # Pixels to add around each bounding box
IMG_WIDTH = 20  # Inches (full notebook width)
IMG_HEIGHT_PER_SAMPLE = 5  # Inches per image

def plot_samples():
    # Get all image files
    image_files = [f for f in os.listdir(f"{DATA_DIR}/images") if f.endswith(".jpg")]
    selected_files = random.sample(image_files, min(NUM_SAMPLES, len(image_files)))
    
    # Create one figure per sample with full width
    fig, axes = plt.subplots(
        NUM_SAMPLES, 1,
        figsize=(IMG_WIDTH, IMG_HEIGHT_PER_SAMPLE * NUM_SAMPLES)
    )
    
    for i, img_file in enumerate(selected_files):
        ax = axes[i] if NUM_SAMPLES > 1 else axes
        
        # Read image
        img_path = os.path.join(DATA_DIR, "images", img_file)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Display image
        ax.imshow(img)
        ax.axis('off')
        
        # Set title
        frame_num = img_file.split('_')[1].split('.')[0]
        ax.set_title(f"Frame {frame_num}", pad=10)
        
        # Check for corresponding label
        label_path = os.path.join(DATA_DIR, "labels", img_file.replace(".jpg", ".txt"))
        has_drone = os.path.exists(label_path)
        
        # Plot bounding box if drone exists
        if has_drone:
            with open(label_path, 'r') as f:
                line = f.readline().strip()
                cls, x_center, y_center, w, h = map(float, line.split())
                
                # Convert from YOLO to pixel coordinates
                img_h, img_w = img.shape[:2]
                x = (x_center - w/2) * img_w
                y = (y_center - h/2) * img_h
                width = w * img_w
                height = h * img_h
                
                # Create padded rectangle
                rect = plt.Rectangle(
                    (x-PADDING, y-PADDING),
                    width + 2*PADDING,
                    height + 2*PADDING,
                    linewidth=2,
                    edgecolor='r',
                    facecolor='none',
                    linestyle='-'
                )
                ax.add_patch(rect)
                
                # Add center marker
                # ax.scatter(
                #     [x_center * img_w], [y_center * img_h],
                #     color='lime', s=40, marker='+'
                # )
        else:
            ax.text(
                img.shape[1]//2, img.shape[0]//2,
                "NO DRONE", color='white',
                ha='center', va='center',
                bbox=dict(facecolor='red', alpha=0.5))
    
    plt.tight_layout()
    plt.show()

plot_samples()

FileNotFoundError: [Errno 2] No such file or directory: 'processed_dataset/images'

In [None]:
import os

# Configuration
DATA_DIR = "processed_dataset/labels"  # Path to your label files

# Initialize counters
drone_frames = 0
no_drone_frames = 0

# Count all label files (each represents a drone frame)
label_files = [f for f in os.listdir(DATA_DIR) if f.endswith('.txt')]
drone_frames = len(label_files)

# Count total frames (image files)
image_files = [f for f in os.listdir(DATA_DIR.replace('labels', 'images')) if f.endswith('.jpg')]
total_frames = len(image_files)
no_drone_frames = total_frames - drone_frames

# Print results
print(f"Drone frames: {drone_frames} ({drone_frames/total_frames:.1%})")
print(f"No-drone frames: {no_drone_frames} ({no_drone_frames/total_frames:.1%})")

Drone frames: 61284 (69.6%)
No-drone frames: 26754 (30.4%)


In [None]:
import os
import cv2
import torch
import numpy as np
import albumentations as A
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.transforms import functional as F
from sklearn.model_selection import train_test_split
from collections import defaultdict

INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.5 (you have 1.4.11). Upgrade using: pip install --upgrade albumentations


In [None]:
# Device configuration for Apple Silicon
DEVICE = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

Using device: mps


In [None]:
DATA_DIR = "processed_dataset"
BATCH_SIZE = 8  # Reduced for MPS memory constraints
NUM_EPOCHS = 50
IMG_SIZE = (640, 360)
NUM_WORKERS = 0 if DEVICE.type == 'mps' else 4

In [None]:
no_drone_aug = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.RandomBrightnessContrast(p=0.3),
    A.Rotate(limit=30, p=0.4),
    A.GaussianBlur(blur_limit=(3, 7), p=0.2),
    A.CLAHE(p=0.3),
    A.ChannelShuffle(p=0.1)
])

In [None]:
class DroneDataset(Dataset):
    def __init__(self, image_paths, label_paths, augment=False):
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.augment = augment
        self.class_weights = [0.7, 1.3]
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label_path = self.label_paths[idx]
        
        # Load image
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Load targets
        boxes = []
        labels = []
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    cls, x_center, y_center, width, height = map(float, line.strip().split())
                    
                    # Convert to xyxy format
                    x_min = (x_center - width/2) * IMG_SIZE[0]
                    y_min = (y_center - height/2) * IMG_SIZE[1]
                    x_max = x_min + width * IMG_SIZE[0]
                    y_max = y_min + height * IMG_SIZE[1]
                    
                    boxes.append([x_min, y_min, x_max, y_max])
                    labels.append(int(cls))
        
        # Apply augmentations for no-drone images
        if self.augment and len(boxes) == 0:
            aug = no_drone_aug(image=image)
            image = aug['image']

        # Convert to tensors with proper dimensionality
        image = F.to_tensor(image)
        
        # Handle empty boxes case
        if len(boxes) == 0:
            boxes_tensor = torch.zeros((0, 4), dtype=torch.float32)
        else:
            boxes_tensor = torch.as_tensor(boxes, dtype=torch.float32)
        
        target = {
            'boxes': boxes_tensor,
            'labels': torch.as_tensor(labels, dtype=torch.int64),
            'image_id': torch.tensor([idx]),
            'area': (boxes_tensor[:, 3] - boxes_tensor[:, 1]) * 
                    (boxes_tensor[:, 2] - boxes_tensor[:, 0]),
            'iscrowd': torch.zeros(len(boxes), dtype=torch.int64)
        }
        
        return image, target

In [None]:
#%% [markdown]
### Step 3: Prepare Dataset with PROPER Label Handling
# Get all image/label pairs with explicit pairing
# Get all image/label pairs with explicit pairing
image_label_pairs = [
    (
        os.path.join(DATA_DIR, 'images', f), 
        os.path.join(DATA_DIR, 'labels', f.replace('.jpg', '.txt'))
    ) for f in os.listdir(os.path.join(DATA_DIR, 'images'))
]

# Split dataset with paired items
train_pairs, val_pairs = train_test_split(
    image_label_pairs, 
    test_size=0.2, 
    stratify=[int(os.path.exists(lbl)) for img, lbl in image_label_pairs]
)

# Separate drone/no-drone PAIRS
train_drone_pairs = [(img, lbl) for img, lbl in train_pairs if os.path.exists(lbl)]
train_no_drone_pairs = [(img, lbl) for img, lbl in train_pairs if not os.path.exists(lbl)]

# Oversample no-drone PAIRS (images + labels together)
num_to_generate = max(0, len(train_drone_pairs) - len(train_no_drone_pairs))

if num_to_generate > 0 and len(train_no_drone_pairs) > 0:
    new_indices = np.random.choice(range(len(train_no_drone_pairs)), num_to_generate, replace=True)
    train_no_drone_pairs += [train_no_drone_pairs[i] for i in new_indices]

# Create balanced dataset with CORRECT PAIRS
balanced_pairs = train_drone_pairs + train_no_drone_pairs

# Separate into images and labels lists
balanced_train_images = [img for img, lbl in balanced_pairs]
balanced_train_labels = [lbl for img, lbl in balanced_pairs]

# Create datasets
train_dataset = DroneDataset(balanced_train_images, balanced_train_labels, augment=True)
val_dataset = DroneDataset([img for img, lbl in val_pairs], [lbl for img, lbl in val_pairs])

In [None]:
# Check pairing consistency
for img, lbl in zip(balanced_train_images, balanced_train_labels):
    assert os.path.exists(lbl) == os.path.exists(lbl), "Mismatched pair!"

In [None]:
def visualize_samples(dataset, num_samples=10):
    # Select random indices
    indices = random.sample(range(len(dataset)), num_samples)
    
    # Create one row per sample with full width
    fig, axes = plt.subplots(num_samples, 1, figsize=(20, 5*num_samples))
    
    for i, idx in enumerate(indices):
        image, target = dataset[idx]
        ax = axes[i] if num_samples > 1 else axes
        
        # Convert tensor to numpy array
        img = image.permute(1, 2, 0).numpy()
        
        # Display image
        ax.imshow(img)
        ax.axis('off')
        
        # Get filename
        img_path = dataset.image_paths[idx]
        frame_num = os.path.basename(img_path).split('_')[1].split('.')[0]
        ax.set_title(f"Frame {frame_num}", pad=10)
        
        # Draw bounding boxes if present
        if len(target['boxes']) > 0:
            for box in target['boxes']:
                xmin, ymin, xmax, ymax = box.numpy()
                
                # Create rectangle with padding
                padding = 5
                rect = plt.Rectangle(
                    (xmin-padding, ymin-padding),
                    (xmax-xmin)+2*padding,
                    (ymax-ymin)+2*padding,
                    linewidth=2,
                    edgecolor='r',
                    facecolor='none'
                )
                ax.add_patch(rect)
                
                # Mark center
                # ax.scatter(
                #     [(xmin + xmax)/2], [(ymin + ymax)/2],
                #     color='lime', s=40, marker='+'
                # )
        else:
            ax.text(
                img.shape[1]//2, img.shape[0]//2,
                "NO DRONE", color='white',
                ha='center', va='center',
                bbox=dict(facecolor='red', alpha=0.5))
    
    plt.tight_layout()
    plt.show()

# Visualize samples (will now scroll vertically)
visualize_samples(train_dataset, num_samples=10)


In [None]:
import os
import cv2
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F  # Add this import at the top
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign
from sklearn.model_selection import train_test_split
from collections import defaultdict

In [None]:
DATA_DIR = "processed_dataset"
BATCH_SIZE = 8
NUM_EPOCHS = 50
IMG_SIZE = (640, 360)
DEVICE = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [None]:
### Step 1: Dynamic Bounding Box Handling

class DroneDataset(Dataset):
    def __init__(self, image_paths, label_paths, augment=False):
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.augment = augment
        
        # Calculate dynamic box statistics
        self.box_sizes = self._calculate_box_stats()
        self.mean_size = np.mean(self.box_sizes) if self.box_sizes else 30
        
    def _calculate_box_stats(self):
        sizes = []
        for lbl_path in self.label_paths:
            if os.path.exists(lbl_path):
                with open(lbl_path, 'r') as f:
                    line = f.readline().strip()
                    if line:
                        _, _, _, w, h = map(float, line.split())
                        sizes.append((w*IMG_SIZE[0], h*IMG_SIZE[1]))
        return sizes
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label_path = self.label_paths[idx]
        
        # Load image
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        boxes = []
        labels = []
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    cls, x_center, y_center, w, h = map(float, line.strip().split())
                    
                    # Dynamic box scaling based on dataset statistics
                    scale_factor = 1 + random.uniform(-0.3, 0.3)  # ±30% variation
                    w *= scale_factor
                    h *= scale_factor
                    
                    # Convert to xyxy format
                    x_min = (x_center - w/2) * IMG_SIZE[0]
                    y_min = (y_center - h/2) * IMG_SIZE[1]
                    x_max = x_min + w * IMG_SIZE[0]
                    y_max = y_min + h * IMG_SIZE[1]
                    
                    boxes.append([x_min, y_min, x_max, y_max])
                    labels.append(int(cls))
        
        # Convert to tensors
        img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
        
        target = {
            'boxes': torch.tensor(boxes, dtype=torch.float32) if boxes else torch.zeros((0, 4), dtype=torch.float32),
            'labels': torch.tensor(labels, dtype=torch.int64),
            'image_id': torch.tensor([idx]),
            'area': (torch.tensor(boxes)[:, 3] - torch.tensor(boxes)[:, 1]) * 
                    (torch.tensor(boxes)[:, 2] - torch.tensor(boxes)[:, 0]) if boxes else torch.zeros(0),
            'iscrowd': torch.zeros(len(boxes), dtype=torch.int64) if boxes else torch.zeros(0)
        }
        
        return img, target

In [None]:
# Create image-label pairs
image_label_pairs = []
for img_file in os.listdir(f"{DATA_DIR}/images"):
    if img_file.endswith(".jpg"):
        img_path = os.path.join(DATA_DIR, "images", img_file)
        label_path = os.path.join(DATA_DIR, "labels", img_file.replace(".jpg", ".txt"))
        image_label_pairs.append((img_path, label_path))

# Split dataset
train_pairs, val_pairs = train_test_split(
    image_label_pairs,
    test_size=0.2,
    stratify=[int(os.path.exists(pair[1])) for pair in image_label_pairs],
    random_state=42
)

In [None]:
train_drone = [pair for pair in train_pairs if os.path.exists(pair[1])]
train_no_drone = [pair for pair in train_pairs if not os.path.exists(pair[1])]

In [None]:
# Dynamic oversampling based on imbalance ratio
oversample_ratio = max(1, len(train_drone) // max(len(train_no_drone), 1))
train_no_drone = train_no_drone * oversample_ratio
random.shuffle(train_no_drone)

balanced_train = train_drone + train_no_drone[:len(train_drone)]

In [None]:
# Create datasets
train_dataset = DroneDataset(
    [pair[0] for pair in balanced_train],
    [pair[1] for pair in balanced_train],
    augment=True
)
val_dataset = DroneDataset(
    [pair[0] for pair in val_pairs],
    [pair[1] for pair in val_pairs]
)

In [None]:
def check_balance(dataset, name="Dataset"):
    drone_count = 0
    no_drone_count = 0
    
    for img_path, lbl_path in zip(dataset.image_paths, dataset.label_paths):
        if os.path.exists(lbl_path):
            drone_count += 1
        else:
            no_drone_count += 1
    
    total = drone_count + no_drone_count
    print(f"\n{name} Balance:")
    print(f"Total samples: {total}")
    print(f"Drone samples: {drone_count} ({drone_count/total:.1%})")
    print(f"No drone samples: {no_drone_count} ({no_drone_count/total:.1%})")
    print("Status: ", "✅ Balanced" if abs(drone_count - no_drone_count)/total < 0.1 else "❌ Imbalanced")

# Check training set balance
check_balance(train_dataset, "Training Set")

# Check validation set balance
check_balance(val_dataset, "Validation Set")


Training Set Balance:
Total samples: 91833
Drone samples: 49027 (53.4%)
No drone samples: 42806 (46.6%)
Status:  ✅ Balanced

Validation Set Balance:
Total samples: 17608
Drone samples: 12257 (69.6%)
No drone samples: 5351 (30.4%)
Status:  ❌ Imbalanced


In [None]:
def custom_fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
    # Convert labels list to tensor
    labels = torch.cat(labels, dim=0) if isinstance(labels, list) else labels
    
    # Classification loss
    classification_loss = F.cross_entropy(class_logits, labels)

    # Box regression loss with MPS workaround
    sampled_pos_inds_subset = torch.where(labels > 0)[0]
    labels_pos = labels[sampled_pos_inds_subset]
    
    if labels_pos.numel() == 0:
        box_loss = torch.tensor(0.0, device=class_logits.device)
    else:
        # Ensure regression targets are concatenated
        regression_targets = torch.cat(regression_targets, dim=0)
        box_regression = box_regression.reshape(-1, box_regression.size(-1) // 4, 4)
        
        box_loss = F.smooth_l1_loss(
            box_regression[sampled_pos_inds_subset, labels_pos],
            regression_targets[sampled_pos_inds_subset],
            beta=1/9,
            reduction="sum"
        ) / max(1, labels.numel())

    return classification_loss, box_loss

In [None]:
from torchvision.models.detection.roi_heads import RoIHeads

class CustomRoIHeads(RoIHeads):
    def forward(self, features, proposals, image_shapes, targets=None):
        # Original forward pass implementation
        if targets is not None:
            for t in targets:
                floating_point_types = (torch.float, torch.double, torch.half)
                assert t["boxes"].dtype in floating_point_types
                assert t["labels"].dtype == torch.int64
        
        if self.training:
            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
        else:
            labels = None
            regression_targets = None
            matched_idxs = None

        # Box head forward pass
        box_features = self.box_roi_pool(features, proposals, image_shapes)
        box_features = self.box_head(box_features)
        class_logits, box_regression = self.box_predictor(box_features)

        # Calculate losses
        if self.training:
            loss_classifier, loss_box_reg = custom_fastrcnn_loss(
                class_logits, 
                box_regression, 
                labels, 
                regression_targets
            )
            losses = {
                "loss_classifier": loss_classifier,
                "loss_box_reg": loss_box_reg
            }
            return proposals, losses
        else:
            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
            return boxes, scores, labels

In [None]:
// ... existing code ...
def create_model():
    # Check for cached weights
    weights_path = os.path.expanduser('~/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth')
    if os.path.exists(weights_path):
        print("Using cached ResNet50 weights")
        weights = torch.load(weights_path, map_location=DEVICE)
    else:
        print("Downloading ResNet50 weights...")
        weights = None

    # Anchor configuration
    anchor_sizes = ((8,), (16,), (32,), (64,), (128,))
    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)

    print("Creating model architecture...")
    # Create base model
    model = fasterrcnn_resnet50_fpn(
        weights=weights,
        min_size=IMG_SIZE[1],
        max_size=IMG_SIZE[0],
        rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
        box_roi_pool=MultiScaleRoIAlign(
            featmap_names=['0', '1', '2', '3', 'pool'],
            output_size=7,
            sampling_ratio=2
        )
    )

    print("Configuring custom ROI heads...")
    # Configure custom ROI heads with explicit parameters
    model.roi_heads = CustomRoIHeads(
        box_roi_pool=model.roi_heads.box_roi_pool,
        box_head=model.roi_heads.box_head,
        box_predictor=model.roi_heads.box_predictor,
        fg_iou_thresh=0.5,
        bg_iou_thresh=0.5,
        batch_size_per_image=512,
        positive_fraction=0.25,
        bbox_reg_weights=None,
        score_thresh=0.05,
        nms_thresh=0.5,
        detections_per_img=100
    )

    print("Initializing model weights...")
    # Initialize weights
    for module in model.modules():
        if isinstance(module, torch.nn.Conv2d):
            torch.nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
            if module.bias is not None:
                torch.nn.init.constant_(module.bias, 0)
    
    print("Moving model to device...")
    model = model.to(DEVICE)
    print("Model initialization complete!")
    
    return model


In [None]:
// ... existing code ...
def train_model():
    print("Starting model initialization...")
    start_time = time.time()
    
    model = create_model()
    print(f"Model initialization took {time.time() - start_time:.2f} seconds")
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.0005)
    
    print("Preparing data loaders...")
    def collate_fn(batch):
        return tuple(zip(*batch))

    # Create weighted sampler
    weights = [2.0 if os.path.exists(lbl) else 1.0 for (img, lbl) in balanced_train]
    sampler = WeightedRandomSampler(weights, len(weights), replacement=True)

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        sampler=sampler,
        collate_fn=collate_fn,
        num_workers=0  # Required for MPS
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        collate_fn=collate_fn,
        num_workers=0
    )
    
    print(f"Starting training for {NUM_EPOCHS} epochs...")
    for epoch in range(NUM_EPOCHS):
        epoch_start = time.time()
        model.train()
        total_loss = 0
        
        for batch_idx, (images, targets) in enumerate(train_loader):
            images = [img.to(DEVICE) for img in images]
            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
            
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            
            optimizer.zero_grad()
            losses.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()
            
            total_loss += losses.item()
            
            if batch_idx % 10 == 0:  # Print progress every 10 batches
                print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Batch {batch_idx}/{len(train_loader)} | Loss: {losses.item():.4f}")

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, targets in val_loader:
                images = [img.to(DEVICE) for img in images]
                targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
                try:
                    loss_dict = model(images, targets)
                    val_loss += sum(loss for loss in loss_dict.values()).item()
                    val_batches += 1
                except ValueError:
                    # Skip this batch if there's a value error due to ROI heads mismatch
                    print("Skipping validation batch due to ROI heads mismatch")
                    continue
        
        epoch_time = time.time() - epoch_start
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train Loss: {total_loss/len(train_loader):.4f} | Val Loss: {val_loss/len(val_loader):.4f} | Time: {epoch_time:.2f}s")
    
    print("Saving model...")
    torch.save(model.state_dict(), 'drone_model.pth')
    print("Training complete!")
    return model
// ... existing code ...

<!--visualize tracking results and logs: -->

In [None]:
#%% [markdown]
### Step 4: Visualization and Analysis of Tracking Results

import json
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from datetime import datetime
import seaborn as sns

def visualize_tracking_results(results_dir="results"):
    # Find all result files
    import glob
    metric_files = glob.glob(f"{results_dir}/*_metrics.json")
    
    if not metric_files:
        print("No results found in the results directory")
        return
    
    # Create visualization directory
    os.makedirs(f"{results_dir}/visualizations", exist_ok=True)
    
    # Process each result file
    for metric_file in metric_files:
        base_name = metric_file.replace("_metrics.json", "")
        video_name = base_name.split("/")[-1].split("_")[0]
        
        print(f"\nVisualizing results for {video_name}")
        
        # Load metrics
        with open(metric_file) as f:
            metrics = json.load(f)
        
        # Load visibility log
        visibility_file = f"{base_name}_visibility.log"
        visibility_data = []
        current_id = None
        with open(visibility_file) as f:
            for line in f:
                if "Drone ID" in line:
                    if "Visible from" in line:
                        parts = line.strip().split()
                        start = datetime.strptime(" ".join(parts[3:5]), "%Y-%m-%d %H:%M:%S.%f")
                        end = datetime.strptime(" ".join(parts[7:9]), "%Y-%m-%d %H:%M:%S.%f") if parts[7] != "present" else None
                        duration = (end - start).total_seconds() if end else None
                        visibility_data.append({
                            'drone_id': current_id,
                            'start': start,
                            'end': end,
                            'duration': duration
                        })
                    else:
                        current_id = int(line.split()[2].replace(":", ""))
        
        # Create DataFrame
        df = pd.DataFrame(visibility_data)
        
        # Convert timestamps to relative seconds
        min_time = df['start'].min()
        df['start_rel'] = df['start'].apply(lambda x: (x - min_time).total_seconds())
        df['end_rel'] = df.apply(lambda row: (row['end'] - min_time).total_seconds() 
                                if row['end'] else (datetime.now() - min_time).total_seconds(), axis=1)
        
        # Plot 1: Visibility Timeline
        plt.figure(figsize=(15, 6))
        for _, row in df.iterrows():
            plt.plot([row['start_rel'], row['end_rel']], [row['drone_id'], row['drone_id']], 
                    linewidth=5, marker="|", markersize=10)
        plt.title(f"Drone Visibility Timeline - {video_name}")
        plt.xlabel("Time (seconds)")
        plt.ylabel("Drone ID")
        plt.yticks(df['drone_id'].unique())
        plt.grid(True)
        plt.savefig(f"{results_dir}/visualizations/{video_name}_timeline.png")
        plt.show()
        
        # Plot 2: Visibility Duration
        plt.figure(figsize=(10, 6))
        duration_df = df.groupby('drone_id')['duration'].sum().reset_index()
        sns.barplot(data=duration_df, x='drone_id', y='duration')
        plt.title(f"Total Visibility Time by Drone - {video_name}")
        plt.xlabel("Drone ID")
        plt.ylabel("Total Visibility (seconds)")
        plt.savefig(f"{results_dir}/visualizations/{video_name}_duration.png")
        plt.show()
        
        # Plot 3: Reappearances
        reappearances = df.groupby('drone_id').size() - 1
        plt.figure(figsize=(10, 6))
        sns.barplot(x=reappearances.index, y=reappearances.values)
        plt.title(f"Number of Reappearances by Drone - {video_name}")
        plt.xlabel("Drone ID")
        plt.ylabel("Number of Reappearances")
        plt.savefig(f"{results_dir}/visualizations/{video_name}_reappearances.png")
        plt.show()
        
        # Display metrics
        print("\nPerformance Metrics:")
        print(f"Processing Time: {metrics['processing_time']:.2f} seconds")
        print(f"Average FPS: {metrics['avg_fps']:.2f}")
        print(f"Redetection Count: {metrics['redetection_count']}")
        print(f"Average Detection Time: {metrics['avg_detection_time']:.4f} seconds")
        print(f"Average Tracking Time: {metrics['avg_tracking_time']:.4f} seconds")
        
        # Save combined report
        report = {
            'video': metrics['video'],
            'processing_time': metrics['processing_time'],
            'avg_fps': metrics['avg_fps'],
            'redetection_count': metrics['redetection_count'],
            'drones_tracked': len(df['drone_id'].unique()),
            'total_visibility_time': duration_df['duration'].sum(),
            'avg_detection_time': metrics['avg_detection_time'],
            'avg_tracking_time': metrics['avg_tracking_time']
        }
        
        with open(f"{results_dir}/visualizations/{video_name}_report.json", 'w') as f:
            json.dump(report, f, indent=4)
        
        print(f"\nVisualizations saved to {results_dir}/visualizations/")

# Run visualization
visualize_tracking_results()

<!-- Enhanced Visualization Cell -->

In [None]:
#%% [markdown]
### Step 4: Enhanced Tracking Visualization with Flight Paths

def visualize_flight_paths(results_dir="results"):
    # Find all tracking videos
    import glob
    video_files = glob.glob(f"{results_dir}/*_tracked.mp4")
    
    if not video_files:
        print("No tracked videos found in the results directory")
        return
    
    for video_file in video_files:
        video_name = os.path.basename(video_file).replace("_tracked.mp4", "")
        print(f"\nVisualizing flight paths for {video_name}")
        
        # Create output directory
        os.makedirs(f"{results_dir}/flight_paths", exist_ok=True)
        
        # Load corresponding metrics file
        metrics_file = f"{results_dir}/{video_name}_metrics.json"
        with open(metrics_file) as f:
            metrics = json.load(f)
        
        # Extract trajectory data
        trajectories = {}
        for drone_id, data in metrics['visibility_stats'].items():
            trajectories[drone_id] = {
                'positions': [],
                'color': plt.cm.tab10(int(drone_id) % 10)
            }
        
        # Read video to get flight paths
        cap = cv2.VideoCapture(video_file)
        frame_count = 0
        
        # We'll sample every nth frame to reduce processing
        sample_rate = 10  
        
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            
            if frame_count % sample_rate == 0:
                # Convert to RGB for visualization
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                
                # Find all drone markers in the frame
                for drone_id in trajectories.keys():
                    # Look for drone ID text in the frame
                    text = f"ID {drone_id}"
                    result = cv2.matchTemplate(
                        frame_rgb, 
                        np.zeros((30, 100, 3), dtype=np.uint8),  # Dummy template
                        cv2.TM_CCOEFF_NORMED
                    )
                    # (In practice, you'd need a better way to extract positions)
                    # This is simplified - actual implementation would track positions
                    
            frame_count += 1
        
        cap.release()
        
        # Alternative: Use the trajectory data from the tracker
        # (This assumes you saved the trajectory data in metrics)
        if 'trajectory_data' in metrics:
            plt.figure(figsize=(12, 8))
            
            for drone_id, path in metrics['trajectory_data'].items():
                if len(path) > 1:
                    x = [p[0] for p in path]
                    y = [p[1] for p in path]
                    color = plt.cm.tab10(int(drone_id) % 10)
                    
                    # Plot the flight path
                    plt.plot(x, y, '.-', color=color, label=f'Drone {drone_id}', 
                            alpha=0.6, linewidth=2, markersize=8)
                    
                    # Add start and end markers
                    plt.scatter(x[0], y[0], color='green', s=100, marker='o', edgecolors='white')
                    plt.scatter(x[-1], y[-1], color='red', s=100, marker='X', edgecolors='white')
            
            plt.title(f"Drone Flight Paths - {video_name}")
            plt.xlabel("X Position (pixels)")
            plt.ylabel("Y Position (pixels)")
            plt.gca().invert_yaxis()  # Match image coordinate system
            plt.grid(True)
            plt.legend()
            
            # Add frame reference
            if 'frame_size' in metrics:
                w, h = metrics['frame_size']
                plt.gca().add_patch(Rectangle((0, 0), w, h, fill=False, edgecolor='gray', linestyle='--'))
            
            plt.savefig(f"{results_dir}/flight_paths/{video_name}_paths.png", dpi=300, bbox_inches='tight')
            plt.show()
            
            print(f"Flight path visualization saved to {results_dir}/flight_paths/{video_name}_paths.png")
        else:
            print("No trajectory data found in metrics")

# First modify your tracking code to save trajectory data
def add_trajectory_saving():
    # Add this to your DroneDetectionAndTracking class in track_drone.py
    def _save_results(self, video_path, start_time, processed_frames):
        # ... existing code ...
        
        # Add trajectory data to metrics
        metrics['trajectory_data'] = {
            str(k): v for k, v in self.tracker.trajectory.items()
        }
        metrics['frame_size'] = (int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
                                int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
        
        # ... rest of existing code ...

# Then run visualizations
print("Visualizing flight paths...")
visualize_flight_paths()

print("\nVisualizing tracking metrics...")
visualize_tracking_results()

In [None]:
if __name__ == "__main__":
    model = train_model()
    
    # Cleanup
    if DEVICE.type == 'mps':
        torch.mps.empty_cache()

KeyboardInterrupt: 