# DroneDetect V2 - RF-UAVNet Training

**Reference Paper:** Huynh-The et al. (2022), "RF-UAVNet: High-Performance Convolutional Network for RF-Based Drone Surveillance Systems", *IEEE Access*, Vol. 10, pp. 49143-49154.  
DOI: [10.1109/ACCESS.2022.3173181](https://doi.org/10.1109/ACCESS.2022.3173181)

## Overview

RF-UAVNet is a lightweight 1D CNN architecture designed for RF-based drone surveillance. This implementation adapts the IEEE paper's approach to the DroneDetect V2 dataset.

### Architecture Summary
- **Input**: Raw IQ signals (2 channels: real + imaginary, 10,000 samples)
- **R-Unit**: Initial feature extraction (Conv1d 2→64, k=5, s=5)
- **G-Units**: 4x grouped convolutions (64→64, k=3, s=2, groups=8) with skip connections
- **Multi-GAP**: Multi-scale global average pooling (kernels: 1000, 500, 250, 125)
- **Classifier**: Fully connected layer (320 → num_classes)
- **Total parameters**: 9,991 (~1800x smaller than VGG16)

### Dataset Comparison: DroneRF (paper) vs DroneDetect V2 (ours)

| Aspect | DroneRF (paper) | DroneDetect V2 (ours) | Impact |
|--------|-----------------|----------------------|--------|
| **Samples** | ~50,000 | 19,478 | **2.5x fewer samples** → higher overfitting risk |
| **Drones** | 10 classes | 7 classes | Simpler task |
| **Interference** | Unknown | 4 conditions (CLEAN/WIFI/BLUE/BOTH) | More realistic variability |
| **Split method** | Segment-level (with leakage) | File-level stratified | Scientifically valid but harder |

**Key trade-off:** Our model has **82% fewer parameters** (9,991 vs 53M for VGG16), making it ideal for edge deployment, but our **smaller dataset** (2.5x less data) limits generalization. The paper's reported 98.53% accuracy includes **data leakage** (see RFClassification analysis); our file-level split ensures valid generalization estimates.

### Hyperparameters: Aligned with IEEE Paper

Training configuration follows Huynh-The et al. (2022):
- **Optimizer**: SGD with momentum (0.95) and weight decay (1e-4)
- **Learning rate**: 0.01 (10x higher than typical Adam)
- **Batch size**: 512 (paper uses 512)
- **Epochs**: 50 (paper uses 120, reduced for computational cost)
- **Scheduler**: ReduceLROnPlateau (factor=0.5, patience=5)

These hyperparameters differ significantly from image CNNs (Adam, lr=0.001) due to RF signal characteristics.

### Future Improvements

**1. Transfer Learning** (highest priority)
- Pre-train on DroneRF dataset (50k samples) if accessible
- Fine-tune final layers on DroneDetect V2
- Expected gain: +10-20% accuracy (mitigates small dataset issue)

**2. Data Augmentation**
- Time shifting (circular roll): `np.roll(iq, shift, axis=-1)`
- Noise injection (AWGN): `iq + noise`
- Expected gain: +5-10% accuracy

**3. Ensemble Methods**
- Combine RF-UAVNet + SVM (PSD) + VGG16 (spectrograms)
- Majority voting or stacking
- Expected gain: +3-5% accuracy

**4. Segment Duration Optimization**
- Test 50ms segments (paper: 10ms→20ms→50ms: 76.9%→83.6%→89.4%)
- Modify `config.DEFAULT_SEGMENT_MS` from 20 to 50

## Mount Google Drive

In [None]:
from google.colab import drive

drive.mount('/content/drive')

In [None]:
!ls drive/MyDrive/DroneDetect_V2/output/features/iq_features.npz

## Imports

In [None]:
!pip install -U kaleido==0.2.1

In [None]:
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, f1_score, precision_recall_fscore_support
from tqdm import tqdm
import os
import gc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Setup figure saving
import re
from pathlib import Path

NOTEBOOK_NAME = "training_rfuavnet_COLAB"
FIGURES_DIR = Path("figures") / NOTEBOOK_NAME

def save_figure(fig) -> None:
    """Save plotly figure to PNG file using the figure'''s title as filename."""
    FIGURES_DIR.mkdir(parents=True, exist_ok=True)
    title = fig.layout.title.text if fig.layout.title.text else "untitled"
    filename = re.sub(r'''[^\w\s-]''', '''''', title).strip()
    filename = re.sub(r'''[\s-]+''', '''_''', filename)
    filepath = FIGURES_DIR / f"{filename}.png"
    try:
        fig.write_image(str(filepath), width=1200, height=800)
        print(f"Saved: {filepath}")
    except Exception as e:
        print(f"Warning: Could not save figure (kaleido required): {e}")

## Configuration

In [None]:
CONFIG = {
    # Paths
    'features_path': 'drive/MyDrive/DroneDetect_V2/output/features/iq_features.npz',
    'models_dir': 'drive/MyDrive/DroneDetect_V2/output/models/',
    'test_data_dir': 'drive/MyDrive/DroneDetect_V2/output/sample/test_data/',

    # Split parameters
    'test_size': 0.2,
    'random_state': 42,

    # Training parameters (aligned with Huynh-The et al. 2022, IEEE Access)
    'batch_size': 512,          # Paper: 512 (paper specification)
    'epochs': 120,               # Paper: 120 (paper specification)
    'learning_rate': 0.01,      # Paper: 0.01 (SGD with momentum)
    'momentum': 0.95,           # SGD momentum (paper specification)
    'weight_decay': 1e-4,       # L2 regularization (paper specification)
    'scheduler_factor': 0.5,    # LR reduction factor
    'scheduler_patience': 5,    # Epochs before LR reduction

    # Device
    'device': device
}

print(f"Configuration (aligned with RF-UAVNet IEEE paper): {CONFIG}")

## RF-UAV-Net Model Definition

In [None]:
class RFUAVNet(nn.Module):
    """RF-UAVNet: 1D CNN architecture for RF-based drone classification.

    This model is based on the architecture proposed by Huynh-The et al. (2022) in
    "RF-UAVNet: High-Performance Convolutional Network for RF-Based Drone Surveillance Systems".
    It processes raw IQ signals (Real/Imaginary) through a series of specialized units
    (R-Unit, G-Units) and multi-scale pooling to classify drone signals.

    Attributes:
        conv_r (nn.Conv1d): Initial convolution layer (R-Unit).
        bn_r (nn.BatchNorm1d): Batch normalization for R-Unit.
        elu_r (nn.ELU): ELU activation function.
        g_convs (nn.ModuleList): List of 4 grouped convolutional layers (G-Units).
        g_bns (nn.ModuleList): List of batch normalization layers for G-Units.
        g_elus (nn.ModuleList): List of ELU activations for G-Units.
        pool (nn.MaxPool1d): Max pooling layer used in skip connections.
        gap1000 (nn.AvgPool1d): Global Average Pooling with kernel size 1000.
        gap500 (nn.AvgPool1d): Global Average Pooling with kernel size 500.
        gap250 (nn.AvgPool1d): Global Average Pooling with kernel size 250.
        gap125 (nn.AvgPool1d): Global Average Pooling with kernel size 125.
        fc (nn.Linear): Fully connected output layer.

    Args:
        num_classes (int): The number of output classes.
    """

    def __init__(self, num_classes: int):
        super().__init__()

        # R-unit
        self.conv_r = nn.Conv1d(2, 64, kernel_size=5, stride=5)
        self.bn_r = nn.BatchNorm1d(64)
        self.elu_r = nn.ELU()

        # G-units (4x)
        self.g_convs = nn.ModuleList([
            nn.Conv1d(64, 64, kernel_size=3, stride=2, groups=8)
            for _ in range(4)
        ])
        self.g_bns = nn.ModuleList([nn.BatchNorm1d(64) for _ in range(4)])
        self.g_elus = nn.ModuleList([nn.ELU() for _ in range(4)])

        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)

        # Multi-scale GAP
        self.gap1000 = nn.AvgPool1d(1000)
        self.gap500 = nn.AvgPool1d(500)
        self.gap250 = nn.AvgPool1d(250)
        self.gap125 = nn.AvgPool1d(125)

        # Classifier
        self.fc = nn.Linear(320, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the network.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, 2, 10000).
                Channels are [Real, Imaginary].

        Returns:
            torch.Tensor: Raw logits of shape (batch_size, num_classes).
        """
        # R-unit
        x = self.elu_r(self.bn_r(self.conv_r(x)))

        # G-units with residual connections
        g_outputs = []
        for i in range(4):
            g_out = self.g_elus[i](self.g_bns[i](self.g_convs[i](F.pad(x, (1, 0)))))
            g_outputs.append(g_out)
            x = g_out + self.pool(x)

        # Multi-scale GAP
        gaps = [
            self.gap1000(g_outputs[0]),
            self.gap500(g_outputs[1]),
            self.gap250(g_outputs[2]),
            self.gap125(g_outputs[3]),
            self.gap125(x)
        ]

        x = torch.cat(gaps, dim=1).flatten(start_dim=1)
        return self.fc(x)

    def reset_weights(self):
        """Resets the model weights using the default initialization."""
        for m in self.modules():
            if hasattr(m, 'reset_parameters'):
                m.reset_parameters()

print("RF-UAV-Net model class defined")

## File-Level Stratified Split Function

In [None]:
def get_stratified_file_split(X, y, file_ids, test_size=0.2, random_state=42):
    """
    Split data at FILE level to prevent data leakage.

    Segments from the same .dat file (~100 segments) will never appear
    in both train and test sets.

    Parameters
    ----------
    X : array-like
        Features (n_samples, ...)
    y : array-like
        Labels for stratification (n_samples,)
    file_ids : array-like
        Source file ID for each sample (n_samples,)
    test_size : float
        Approximate test set proportion (actual may vary due to file grouping)
    random_state : int
        Random seed for reproducibility

    Returns
    -------
    train_idx, test_idx : arrays
        Indices for train/test split
    """
    n_splits = int(1 / test_size)  # e.g., test_size=0.2 -> 5 splits -> 1 fold = 20%

    sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=random_state)

    # Take first fold as train/test split
    train_idx, test_idx = next(sgkf.split(X, y, groups=file_ids))

    # Verify no file leakage
    train_files = set(file_ids[train_idx])
    test_files = set(file_ids[test_idx])
    assert len(train_files & test_files) == 0, "Data leakage detected: files in both splits"

    return train_idx, test_idx

print("Stratified file split function defined")

## 1. Load IQ Features

In [None]:
# Pattern 1: Memory mapping to avoid loading full array into RAM
data = np.load(CONFIG['features_path'], mmap_mode='r',
               allow_pickle=True)

X = data['X']  # Shape: (N, 2, 10000) - memory mapped, not loaded
y_drone = data['y_drone']
y_interference = data['y_interference']
y_state = data['y_state']
file_ids = data['file_ids']  # For stratified file-level splitting

drone_classes = data['drone_classes']
interference_classes = data['interference_classes']
state_classes = data['state_classes']

print(f"IQ data shape: {X.shape}")
print(f"Drone labels shape: {y_drone.shape}")
print(f"File IDs shape: {file_ids.shape} (unique files: {len(np.unique(file_ids))})")
print(f"Drone classes: {drone_classes}")
print(f"Interference classes: {interference_classes}")
print(f"State classes: {state_classes}")

## 2. Train/Test Split

We'll use 80/20 split with file-level stratification to prevent data leakage.

In [None]:
# Split for drone classification using file-level stratification
train_idx, test_idx = get_stratified_file_split(
    X, y_drone, file_ids,
    test_size=CONFIG['test_size'],
    random_state=CONFIG['random_state']
)

# Pattern 2: Zero-copy split (use views, not copies)
X_train, X_test = X[train_idx], X[test_idx]
y_train, y_test = y_drone[train_idx], y_drone[test_idx]
y_interference_test = y_interference[test_idx]
y_state_test = y_state[test_idx]

# Verify no file leakage
train_files = set(file_ids[train_idx])
test_files = set(file_ids[test_idx])
print(f"Training files: {len(train_files)}")
print(f"Test files: {len(test_files)}")
print(f"File overlap: {len(train_files & test_files)} (should be 0)")

print(f"Training set: {X_train.shape}")
print(f"Test set: {X_test.shape}")

# Save test data for reuse
os.makedirs(CONFIG['test_data_dir'], exist_ok=True)

# Save full test data with interference and state metadata
test_data_path = os.path.join(CONFIG['test_data_dir'], 'rfuavnet_test_data.npz')
np.savez(
    test_data_path,
    X_test=X_test,
    y_test=y_test,
    y_interference_test=y_interference_test,
    y_state_test=y_state_test,
    test_idx=test_idx,
    file_ids_test=file_ids[test_idx],
    drone_classes=drone_classes,
    interference_classes=interference_classes,
    state_classes=state_classes
)
print(f"\nFull test data saved to {test_data_path}")

# Save separated files per Drone and Interference (Hierarchical)
print("\nGenerating separated test files (structure: iq/INT/DRONE/)...")

for d_idx, drone_class in enumerate(drone_classes):
    for i_idx, int_class in enumerate(interference_classes):
        # Filter for specific drone and interference
        mask = (y_test == d_idx) & (y_interference_test == i_idx)

        if not np.any(mask):
            continue

        X_sub = X_test[mask]
        y_sub = y_test[mask]
        y_int_sub = y_interference_test[mask]

        # Define components for hierarchy and filename
        data_type = 'iq'
        int_name = str(int_class)
        drone_name = str(drone_class)
        duration = '20' # 20ms fixed duration

        # Create directory structure: output/sample/test_data/{INT}/
        save_dir = os.path.join(CONFIG['test_data_dir'], int_name)
        os.makedirs(save_dir, exist_ok=True)

        # Construct filename: iq_{INT}_{DRONE}_20.npz
        filename = f"{data_type}_{int_name}_{drone_name}_{duration}.npz"
        file_path = os.path.join(save_dir, filename)

        np.savez(
            file_path,
            X=X_sub,
            y=y_sub,
            y_interference=y_int_sub,
            drone_class=drone_class,
            interference_class=int_class
        )
        print(f"  Saved {filename} in {save_dir} ({len(X_sub)} samples)")

# Cleanup: delete references to full array and indices
del X, data, train_idx, test_idx
gc.collect()
print("\nMemory cleanup: X, data, indices deleted")

## 3. Prepare PyTorch Datasets

In [None]:
# Pattern 3: PyTorch conversion with immediate cleanup
# Convert train set
X_train_t = torch.from_numpy(X_train).float()
y_train_t = torch.from_numpy(y_train).long()
del X_train, y_train  # Delete numpy arrays immediately
gc.collect()

# Convert test set
X_test_t = torch.from_numpy(X_test).float()
y_test_t = torch.from_numpy(y_test).long()
del X_test, y_test  # Delete numpy arrays immediately
gc.collect()

print("PyTorch tensors created and NumPy arrays deleted")

# Create datasets
train_dataset = TensorDataset(X_train_t, y_train_t)
test_dataset = TensorDataset(X_test_t, y_test_t)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False
)

print(f"Train batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

## 4. Training Function

In [None]:
def train_model(model, train_loader, test_loader, config):
    """
    Train RF-UAVNet with IEEE paper hyperparameters.

    Optimizer: SGD with momentum (0.95) and weight decay (1e-4)
    Scheduler: ReduceLROnPlateau (factor=0.5, patience=5)
    """
    model = model.to(config['device'])
    criterion = nn.CrossEntropyLoss()

    # SGD optimizer (as per IEEE paper)
    optimizer = optim.SGD(
        model.parameters(),
        lr=config['learning_rate'],
        momentum=config['momentum'],
        weight_decay=config['weight_decay']
    )

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=config['scheduler_factor'],
        patience=config['scheduler_patience'],
    )

    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'lr': []}
    best_val_acc = 0.0

    for epoch in range(config['epochs']):
        # Training
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for batch_x, batch_y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']} - Train"):
            batch_x, batch_y = batch_x.to(config['device']), batch_y.to(config['device'])

            optimizer.zero_grad()
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += batch_y.size(0)
            train_correct += (predicted == batch_y).sum().item()

        train_loss /= len(train_loader)
        train_acc = train_correct / train_total

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for batch_x, batch_y in test_loader:
                batch_x, batch_y = batch_x.to(config['device']), batch_y.to(config['device'])
                outputs = model(batch_x)
                loss = criterion(outputs, batch_y)

                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += batch_y.size(0)
                val_correct += (predicted == batch_y).sum().item()

        val_loss /= len(test_loader)
        val_acc = val_correct / val_total

        # Learning rate scheduling
        scheduler.step(val_loss)
        current_lr = optimizer.param_groups[0]['lr']

        # Record history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['lr'].append(current_lr)

        # Track best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch + 1

        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, "
              f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}, LR={current_lr:.6f}")

        # Periodic cleanup
        if (epoch + 1) % 10 == 0:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            print(f"  Memory cleanup at epoch {epoch+1}")

    print(f"\nBest validation accuracy: {best_val_acc:.4f} at epoch {best_epoch}")
    return model, history

print("Training function defined (SGD + ReduceLROnPlateau)")

## 5. Train RF-UAV-Net

In [None]:
num_classes = len(drone_classes)
rfuavnet = RFUAVNet(num_classes=num_classes)
print(rfuavnet)

In [None]:
print(f"Training RF-UAVNet with {num_classes} classes...")
print(f"Model parameters: {sum(p.numel() for p in rfuavnet.parameters()):,}")
print(f"Paper comparison: VGG16 has ~138M parameters (1800x larger)\n")

rfuavnet, history = train_model(rfuavnet, train_loader, test_loader, CONFIG)

## 6. Plot Training History

In [None]:
# Training history visualization with plotly
fig = make_subplots(rows=1, cols=2, subplot_titles=('RF-UAVNet Loss', 'RF-UAVNet Accuracy'))

epochs = list(range(1, len(history['train_loss']) + 1))

# Loss subplot
fig.add_trace(go.Scatter(x=epochs, y=history['train_loss'], mode='lines+markers', name='Train Loss', line=dict(color='blue')), row=1, col=1)
fig.add_trace(go.Scatter(x=epochs, y=history['val_loss'], mode='lines+markers', name='Val Loss', line=dict(color='orange')), row=1, col=1)

# Accuracy subplot
fig.add_trace(go.Scatter(x=epochs, y=history['train_acc'], mode='lines+markers', name='Train Acc', line=dict(color='blue')), row=1, col=2)
fig.add_trace(go.Scatter(x=epochs, y=history['val_acc'], mode='lines+markers', name='Val Acc', line=dict(color='orange')), row=1, col=2)

fig.update_layout(
    title='RF-UAVNet Training History',
    height=500,
    width=1200
)
fig.update_xaxes(title_text='Epoch', row=1, col=1)
fig.update_yaxes(title_text='Loss', row=1, col=1)
fig.update_xaxes(title_text='Epoch', row=1, col=2)
fig.update_yaxes(title_text='Accuracy', row=1, col=2)

fig.show()
save_figure(fig)

## 7. Evaluate on Test Set

In [None]:
def evaluate_model(model, test_loader):
    """
    Evaluate model on test set with efficient memory usage.
    """
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch_x, batch_y in test_loader:
            batch_x = batch_x.to(CONFIG['device'])
            outputs = model(batch_x)
            _, predicted = torch.max(outputs, 1)

            # Pattern 6: Append to lists directly (avoid unnecessary copies)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(batch_y.numpy())

    return np.array(all_preds), np.array(all_labels)

preds, labels = evaluate_model(rfuavnet, test_loader)
accuracy = accuracy_score(labels, preds)
f1 = f1_score(labels, preds, average='weighted')

print(f"RF-UAV-Net Test Accuracy: {accuracy:.4f}")
print(f"RF-UAV-Net Test F1-Score (weighted): {f1:.4f}")
print("Classification Report:")
print(classification_report(labels, preds, target_names=drone_classes))

## 8. Confusion Matrix

In [None]:
cm = confusion_matrix(labels, preds)

# Create confusion matrix heatmap with plotly
fig = go.Figure(data=go.Heatmap(
    z=cm,
    x=list(drone_classes),
    y=list(drone_classes),
    colorscale='Purples',
    text=cm,
    texttemplate='%{text}',
    textfont={'size': 12},
    hoverongaps=False
))

fig.update_layout(
    title=f'RF-UAVNet Confusion Matrix - Accuracy: {accuracy:.4f}',
    xaxis_title='Predicted',
    yaxis_title='True',
    xaxis={'side': 'bottom'},
    yaxis={'autorange': 'reversed'},
    width=800,
    height=700
)
fig.show()
save_figure(fig)

## 9. Per-Class Performance

In [None]:
# Calculate per-class metrics
precision, recall, f1_per_class, support = precision_recall_fscore_support(
    labels, preds, labels=range(len(drone_classes)), zero_division=0
)

import pandas as pd

# Create DataFrame for display
metrics_df = pd.DataFrame({
    'Class': drone_classes,
    'Precision': precision,
    'Recall': recall,
    'F1-Score': f1_per_class,
    'Support': support
})

print("\nPer-Class Performance:")
print(metrics_df.to_string(index=False))

# Precision plot
fig_precision = px.bar(metrics_df, x='Class', y='Precision', title='RF-UAVNet Precision per Class',
                       color='Precision', range_y=[0, 1.05])
fig_precision.update_layout(xaxis_title="Class", yaxis_title="Precision", height=400)
fig_precision.show()
save_figure(fig_precision)

# Recall plot
fig_recall = px.bar(metrics_df, x='Class', y='Recall', title='RF-UAVNet Recall per Class',
                    color='Recall', color_continuous_scale=px.colors.sequential.Oranges, range_y=[0, 1.05])
fig_recall.update_layout(xaxis_title="Class", yaxis_title="Recall", height=400)
fig_recall.show()
save_figure(fig_recall)

# F1-Score plot
fig_f1 = px.bar(metrics_df, x='Class', y='F1-Score', title='RF-UAVNet F1-Score per Class',
                color='F1-Score', color_continuous_scale=px.colors.sequential.Greens, range_y=[0, 1.05])
fig_f1.update_layout(xaxis_title="Class", yaxis_title="F1-Score", height=400)
fig_f1.show()
save_figure(fig_f1)

## 10. Save Model

In [None]:
# Ensure the directory exists
os.makedirs(CONFIG['models_dir'], exist_ok=True)

model_path = os.path.join(CONFIG['models_dir'], 'rfuavnet_iq.pth')

torch.save({
    'model_state_dict': rfuavnet.state_dict(),
    'classes': drone_classes,
    'accuracy': accuracy,
    'f1': f1,
    'history': history,
    'config': CONFIG
}, model_path)

print(f"Model saved to {model_path}")

## 12. Summary

Key takeaways:
- RF-UAV-Net processes raw IQ data (2 channels, 10000 samples)
- Architecture: R-Unit (Conv1d 2->64) + 4 G-Units (grouped convolutions)
- Multi-scale GAP for feature aggregation
- File-level stratified split prevents data leakage
- Lower learning rate (0.001) compared to CNN (0.0001)

In [None]:
print("=== RF-UAV-Net Training Summary ===")
print(f"Dataset:")
print(f"  Total samples: {len(X_train_t) + len(X_test_t)}")
print(f"  Training samples: {len(X_train_t)}")
print(f"  Test samples: {len(X_test_t)}")
print(f"  Number of classes: {num_classes}")

print(f"Training Configuration:")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Epochs: {CONFIG['epochs']}")
print(f"  Learning rate: {CONFIG['learning_rate']}")
print(f"  Device: {CONFIG['device']}")

print(f"Performance:")
print(f"  Test Accuracy: {accuracy:.4f}")
print(f"  Test F1-Score: {f1:.4f}")
print(f"  Final Train Loss: {history['train_loss'][-1]:.4f}")
print(f"  Final Val Loss: {history['val_loss'][-1]:.4f}")

print(f"Model saved to: {model_path}")