![Astute Notebooks](../images/notebooks-banner.png)

# How to Train drone detector

---

# Install and Import Required Libraries

This notebook sets up a deep learning workflow for training a YOLO12 model on a drone dataset. It begins by installing and importing essential libraries like PyTorch, Torchvision, NumPy, Matplotlib, and OpenCV, and checks for GPU availability to optimize computations. The dataset is loaded and preprocessed using transformations such as resizing, tensor conversion, and normalization, followed by splitting into training and validation sets with data loaders. 

The YOLO12 model architecture is defined with convolutional layers, activation functions, pooling layers, and fully connected layers for classification. The model is trained using an Adam optimizer and CrossEntropyLoss over multiple epochs, with loss tracking for each epoch. After training, the model's performance is evaluated on the validation set, and the trained model is saved to a specified path for future use.

In [1]:
# Install and Import Required Libraries
!pip install torch torchvision numpy matplotlib opencv-python

# Import libraries
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import cv2

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Collecting torch
  Downloading torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting torchvision
  Downloading torchvision-0.22.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting numpy
  Downloading numpy-2.3.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (62 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.1/62.1 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting matplotlib
  Downloading matplotlib-3.10.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting opencv-python
  Downloading opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (19 kB)
Collecting filelock (from torch)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Downloading networkx-3.5-py3-none-any.whl.metadata (6.

Matplotlib is building the font cache; this may take a moment.


Using device: cpu


# Load and preprocess drone dataset

The Load and preprocess drone dataset section handles the preparation of the dataset for training the YOLO12 model. It begins by defining the dataset path and applying transformations such as resizing images to 416x416 (YOLO input size), converting them to tensors, and normalizing pixel values to standardize the data. The dataset is loaded using torchvision.datasets.ImageFolder, which organizes images into classes based on folder names. 

The dataset is then split into training and validation sets using torch.utils.data.random_split, with 80% allocated for training and 20% for validation. Finally, data loaders are created using torch.utils.data.DataLoader to efficiently batch and shuffle the data during training and evaluation.

In [None]:
# Load and preprocess drone dataset
import os
from torchvision import transforms, datasets

# Define dataset path
dataset_path = '/path/to/drone/dataset'  # Replace with the actual path to your dataset

# Define transformations
transform = transforms.Compose([
    transforms.Resize((416, 416)),  # Resize images to YOLO input size
    transforms.ToTensor(),         # Convert images to tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
print(f"Loaded {len(dataset)} images from {dataset_path}")

# Split dataset into training and validation sets
from torch.utils.data import random_split
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
print(f"Training set: {len(train_dataset)} images, Validation set: {len(val_dataset)} images")

# Create data loaders
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
print("Data loaders created.")

# Define YOLO12 model architecture

The Define YOLO12 model architecture section outlines the structure of the YOLO12 neural network for drone classification. It starts by defining convolutional layers with increasing filter sizes (32, 64, 128) to extract features from input images, followed by ReLU activation functions and MaxPooling layers to reduce spatial dimensions and introduce non-linearity. 

The model includes fully connected layers for classification, where the first layer reduces the flattened feature map to 256 neurons, and the final layer outputs predictions for two classes (fixed-wing and quadcopter drones). The forward method specifies the flow of data through the layers, ensuring proper transformations at each stage. The model is initialized and moved to the appropriate device (GPU or CPU) for training.



In [None]:
# Define YOLO12 model architecture
from torch import nn

class YOLO12(nn.Module):
    def __init__(self):
        super(YOLO12, self).__init__()
        # Define convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        
        # Define activation and pooling layers
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Fully connected layers for classification
        self.fc1 = nn.Linear(128 * 52 * 52, 256)  # Adjust dimensions based on input size
        self.fc2 = nn.Linear(256, 2)  # Output layer for 2 classes (fixed-wing and quadcopter)

    def forward(self, x):
        # Forward pass through convolutional layers
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        
        x = self.conv3(x)
        x = self.relu(x)
        x = self.pool(x)
        
        # Flatten the tensor for fully connected layers
        x = x.view(x.size(0), -1)
        
        # Forward pass through fully connected layers
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        
        return x

# Initialize model
model = YOLO12().to(device)
print(model)

# Train YOLO12 Model

In [None]:
# Train YOLO12 Model
epochs = 10
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_function = nn.CrossEntropyLoss()

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss:.4f}")

# Evaluate Model Performance

In [None]:
# Evaluate Model Performance
# Train YOLO12 model
epochs = 10
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_function = nn.CrossEntropyLoss()

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss:.4f}")

# Save Trained Model

In [None]:
# Save Trained Model
model_path = '/path/to/save/yolo12_model.pth'  # Replace with the desired save path
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")