In [1]:
import pandas as pd
import numpy as np
import logging

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Configuration ---
INPUT_CSV = "my_master_dataset_RELABELED.csv"
OUTPUT_CSV = "my_master_dataset_RELABELED_V2.csv"
# This is the column that identifies *when* an attack packet was sent
ATTACK_MARKER_COLUMN = "attack_type"
NORMAL_FLIGHT_MARKER = "none" # The string used for normal flights

# --- NEW LOGIC: Define the "attack" signature ---
# We will flag any row where the autopilot *desires* a roll or pitch
# greater than this value (in radians) as part of an attack.
# 0.05 radians is ~2.8 degrees, a clear deviation from normal cruising.
ATTACK_DEVIATION_THRESHOLD = 0.05 

def relabel_dataset(df: pd.DataFrame) -> pd.DataFrame:
    """
    Finds attack flights and relabels only the segment where
    the autopilot is *actively fighting* the spoof.
    """
    logger.info("Starting relabeling process...")
    
    # Create a new label column to work with, default to 0 (normal)
    df['new_label'] = 0
    
    total_flights = len(df['flight_id'].unique())
    processed_count = 0
    total_rows_reduced = 0
    
    # --- Make sure timestamp is numeric for comparison ---
    df['timestamp'] = pd.to_numeric(df['timestamp'], errors='coerce')
    df = df.dropna(subset=['timestamp']) # Drop rows where timestamp was invalid

    for flight_id, group in df.groupby('flight_id'):
        processed_count += 1
        
        # --- *** CRITICAL FIX *** ---
        # Check if 'attack_type' is something *other* than 'none'
        is_attack_flight = (group[ATTACK_MARKER_COLUMN] != NORMAL_FLIGHT_MARKER).any()
        
        if is_attack_flight:
            try:
                # Find rows where the autopilot is fighting
                fighting_rows = group[
                    (group['nav_roll'].abs() > ATTACK_DEVIATION_THRESHOLD) |
                    (group['nav_pitch'].abs() > ATTACK_DEVIATION_THRESHOLD)
                ]
                
                if fighting_rows.empty:
                    logger.warning(f"Flight {flight_id} is an attack flight but no NAV deviation was found. Skipping.")
                    continue
                
                # Find the first and last timestamp where the drone was fighting
                attack_timestamps = fighting_rows['timestamp']
                min_ts = attack_timestamps.min()
                max_ts = attack_timestamps.max()
                
                logger.info(f"Processing Attack Flight {flight_id} ({processed_count}/{total_flights})...")
                logger.info(f"  Identified attack window (based on NAV deviation): {min_ts} -> {max_ts}")
                
                # Find the indices in the *original* dataframe (group) that fall within this attack window
                attack_indices = group[
                    (group['timestamp'] >= min_ts) & 
                    (group['timestamp'] <= max_ts)
                ].index
                
                # Set new_label to 1 *only* for those indices
                df.loc[attack_indices, 'new_label'] = 1
                
                # Get original '1's from the *old* label column
                original_ones = group['label'].sum() 
                new_ones = len(attack_indices)
                rows_reduced = original_ones - new_ones
                total_rows_reduced += rows_reduced
                
                logger.info(f"  Relabeled: {original_ones} rows -> {new_ones} rows. (Reduced by {rows_reduced})")

            except Exception as e:
                logger.error(f"Error processing flight {flight_id}: {e}")
        else:
            # This is a normal flight, all labels remain 0
            if (processed_count % 5 == 0) or (processed_count == 1):
                logger.info(f"Processing Normal Flight {flight_id} ({processed_count}/{total_flights})...")

    # Drop the old, incorrect label column and rename the new one
    if 'label' in df.columns:
        df = df.drop(columns=['label'])
    df = df.rename(columns={'new_label': 'label'})
    
    logger.info("Relabeling process complete.")
    logger.info(f"Total attack rows changed: {total_rows_reduced}")
    return df

def main():
    logger.info(f"Loading dataset: {INPUT_CSV}")
    try:
        df = pd.read_csv(INPUT_CSV)
    except FileNotFoundError:
        logger.error(f"FATAL: Input file not found at {INPUT_CSV}")
        return
    except Exception as e:
        logger.error(f"FATAL: Error loading CSV: {e}")
        return
    
    # --- Check for required columns ---
    required_cols = ['flight_id', 'timestamp', 'label', ATTACK_MARKER_COLUMN, 'nav_roll', 'nav_pitch']
    if not all(col in df.columns for col in required_cols):
        missing = [col for col in required_cols if col not in df.columns]
        logger.error(f"FATAL: Input CSV must contain all of these columns. MISSING: {missing}")
        return
        
    original_ones = len(df[df['label'] == 1])
    df_relabeled = relabel_dataset(df)
    new_ones = len(df_relabeled[df_relabeled['label'] == 1])
    
    logger.info(f"Saving new relabeled dataset to: {OUTPUT_CSV}")
    df_relabeled.to_csv(OUTPUT_CSV, index=False)
    
    logger.info("\n--- Stats ---")
    logger.info(f"Old '1' labels: {original_ones}")
    logger.info(f"New '1' labels: {new_ones}")
    logger.info(f"Net reduction in '1' labels: {original_ones - new_ones}")
    logger.info("Done.")

if __name__ == "__main__":
    main()

2025-11-16 20:48:06,482 - INFO - Loading dataset: my_master_dataset_RELABELED.csv
2025-11-16 20:48:06,607 - INFO - Starting relabeling process...
2025-11-16 20:48:06,623 - INFO - Processing Attack Flight attack_corridor_base (1/24)...
2025-11-16 20:48:06,624 - INFO -   Identified attack window (based on NAV deviation): 77107 -> 662357
2025-11-16 20:48:06,626 - INFO -   Relabeled: 2296 rows -> 2296 rows. (Reduced by 0)
2025-11-16 20:48:06,628 - INFO - Processing Attack Flight attack_corridor_base_14 (2/24)...
2025-11-16 20:48:06,628 - INFO -   Identified attack window (based on NAV deviation): 72355 -> 309104
2025-11-16 20:48:06,630 - INFO -   Relabeled: 948 rows -> 948 rows. (Reduced by 0)
2025-11-16 20:48:06,631 - INFO - Processing Attack Flight attack_corridor_base_15 (3/24)...
2025-11-16 20:48:06,632 - INFO -   Identified attack window (based on NAV deviation): 73604 -> 317104
2025-11-16 20:48:06,633 - INFO -   Relabeled: 975 rows -> 975 rows. (Reduced by 0)
2025-11-16 20:48:06,634 