In [None]:
print("Installing required packages...")
!pip install gdown pandas pyarrow torch torchvision torchaudio --quiet

Installing required packages...


In [None]:
import os
import json
import pandas as pd
import numpy as np
from pathlib import Path
from datetime import datetime
import torch
from torch.utils.data import Dataset, DataLoader
import gc
import gzip

# For Avro binary format
try:
    import fastavro
    print("‚úì fastavro installed successfully")
except ImportError:
    print("‚ö†Ô∏è  Installing fastavro...")
    !pip install fastavro
    import fastavro

# Set your base directory
BASE_DIR = "/content/drive/MyDrive/mythesis/vicky/darpa_tc"
print(f"‚úì Working directory: {BASE_DIR}")

‚ö†Ô∏è  Installing fastavro...
Collecting fastavro
  Downloading fastavro-1.12.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (5.8 kB)
Downloading fastavro-1.12.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (3.5 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m3.5/3.5 MB[0m [31m96.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fastavro
Successfully installed fastavro-1.12.1
‚úì Working directory: /content/drive/MyDrive/mythesis/vicky/darpa_tc


In [None]:
def setup_directories():
    """Create all necessary subdirectories"""
    dirs = {
        'raw': f'{BASE_DIR}/raw',
        'processed': f'{BASE_DIR}/processed',
        'features': f'{BASE_DIR}/features',
        'splits': f'{BASE_DIR}/splits',
        'metadata': f'{BASE_DIR}/metadata'
    }

    for name, path in dirs.items():
        Path(path).mkdir(parents=True, exist_ok=True)
        print(f"‚úì Created: {name} ‚Üí {path}")

    return dirs

dirs = setup_directories()

‚úì Created: raw ‚Üí /content/drive/MyDrive/mythesis/vicky/darpa_tc/raw
‚úì Created: processed ‚Üí /content/drive/MyDrive/mythesis/vicky/darpa_tc/processed
‚úì Created: features ‚Üí /content/drive/MyDrive/mythesis/vicky/darpa_tc/features
‚úì Created: splits ‚Üí /content/drive/MyDrive/mythesis/vicky/darpa_tc/splits
‚úì Created: metadata ‚Üí /content/drive/MyDrive/mythesis/vicky/darpa_tc/metadata


In [None]:
def download_darpa_sample():
    """
    Download a manageable subset of DARPA TC Engagement 5

    Options:
    1. Manual download from Google Drive link and upload to Colab
    2. Download specific files using gdown
    3. Use shared Drive folder directly
    """

    print("=" * 60)
    print("DARPA TC Dataset Download Options")
    print("=" * 60)

    print("\nüì• OPTION 1: Manual Upload (Recommended for testing)")
    print("   1. Go to: https://drive.google.com/drive/folders/1okt4AYElyBohW4XiOBqmsvjwXsnUjLVf")
    print("   2. Download ONE file (e.g., ta1-trace-e5-official-1.json.gz)")
    print("   3. Upload to Colab using the file browser")
    print(f"   4. Place in: {dirs['raw']}/")

    print("\nüì• OPTION 2: Access via Shared Drive")
    print("   1. Right-click the folder in Google Drive")
    print("   2. Select 'Add shortcut to Drive'")
    print("   3. Access directly via mounted Drive")

    print("\nüì• OPTION 3: Direct Download (if you have file ID)")
    print("   Use: !gdown <file_id> -O {}/sample.json.gz".format(dirs['raw']))

    print("\n" + "=" * 60)
    print("For initial testing, download just ONE host file (~5-10GB)")
    print("=" * 60)

download_darpa_sample()


DARPA TC Dataset Download Options

üì• OPTION 1: Manual Upload (Recommended for testing)
   1. Go to: https://drive.google.com/drive/folders/1okt4AYElyBohW4XiOBqmsvjwXsnUjLVf
   2. Download ONE file (e.g., ta1-trace-e5-official-1.json.gz)
   3. Upload to Colab using the file browser
   4. Place in: /content/drive/MyDrive/mythesis/vicky/darpa_tc/raw/

üì• OPTION 2: Access via Shared Drive
   1. Right-click the folder in Google Drive
   2. Select 'Add shortcut to Drive'
   3. Access directly via mounted Drive

üì• OPTION 3: Direct Download (if you have file ID)
   Use: !gdown <file_id> -O /content/drive/MyDrive/mythesis/vicky/darpa_tc/raw/sample.json.gz

For initial testing, download just ONE host file (~5-10GB)


In [None]:
def list_raw_files():
    """Check what files are in your raw directory"""
    raw_dir = Path(dirs['raw'])
    files = list(raw_dir.glob('*'))

    if not files:
        print("  No files found in raw directory")
        print(f"   Please download data to: {raw_dir}")
        return []

    print(f"\n‚úì Found {len(files)} file(s) in raw directory:")
    for i, f in enumerate(files, 1):
        size_mb = f.stat().st_size / (1024 * 1024)
        print(f"   {i}. {f.name} ({size_mb:.2f} MB)")

    return files

raw_files = list_raw_files()


‚úì Found 6 file(s) in raw directory:
   1. .ipynb_checkpoints (0.00 MB)
   2. ta1-trace-1-e5-official-1.bin.2.gz (270.65 MB)
   3. ta1-trace-1-e5-official-1.bin.1.gz (266.57 MB)
   4. ta1-trace-1-e5-official-1.bin.4.gz (272.90 MB)
   5. ta1-trace-1-e5-official-1.bin.3.gz (273.70 MB)
   6. ta1-trace-1-e5-official-1.bin.5.gz (259.57 MB)


In [None]:
def read_avro_binary(input_file, max_records=None):
    """
    Read DARPA TC Avro binary files (.bin or .bin.gz)

    Args:
        input_file: Path to .bin or .bin.gz file
        max_records: Limit number of records (for testing)

    Yields:
        Parsed Avro records as dictionaries
    """

    is_gzipped = str(input_file).endswith('.gz')

    print(f"Reading Avro file: {Path(input_file).name}")
    print(f"Compressed: {is_gzipped}")

    try:
        if is_gzipped:
            with gzip.open(input_file, 'rb') as gz_file:
                reader = fastavro.reader(gz_file)

                for i, record in enumerate(reader):
                    yield record

                    if max_records and i >= max_records - 1:
                        break
        else:
            with open(input_file, 'rb') as f:
                reader = fastavro.reader(f)

                for i, record in enumerate(reader):
                    yield record

                    if max_records and i >= max_records - 1:
                        break

    except Exception as e:
        print(f"‚ùå Error reading Avro file: {e}")
        raise


In [None]:
def process_darpa_avro_streaming(input_file, chunk_size=50000, max_records=None, start_chunk_id=0):
    """
    Process DARPA TC Avro binary logs with streaming

    Args:
        input_file: Path to DARPA .bin or .bin.gz file
        chunk_size: Number of records per chunk
        max_records: Limit for testing (None = process all)
        start_chunk_id: The starting ID for chunk numbering

    """

    print(f"\n{'='*60}")
    print(f"Processing: {Path(input_file).name}")
    print(f"{'='*60}\n")

    chunk = []
    chunk_counter = start_chunk_id # Initialize chunk counter with start_chunk_id
    total_records = 0
    error_count = 0

    start_time = datetime.now()

    try:
        for i, record in enumerate(read_avro_binary(input_file, max_records)):
            try:
                # Extract features from Avro record
                processed = extract_apt_features_from_avro(record)

                if processed:  # Only add valid records
                    chunk.append(processed)

                # Save chunk when size reached
                if len(chunk) >= chunk_size:
                    save_processed_chunk(chunk, chunk_counter)
                    chunk_counter += 1
                    total_records += len(chunk)
                    chunk = []

                    # Memory management
                    gc.collect()

                # Progress update
                if (i + 1) % 50000 == 0:
                    elapsed = (datetime.now() - start_time).total_seconds()
                    rate = (i + 1) / elapsed if elapsed > 0 else 0
                    print(f"   Processed: {i+1:,} records | Rate: {rate:.0f} records/sec | "
                          f"Chunks saved: {chunk_counter}")

            except Exception as e:
                error_count += 1
                if error_count < 5:
                    print(f"   Warning: Error processing record {i}: {str(e)[:100]}")
                continue

        # Save remaining records
        if chunk:
            save_processed_chunk(chunk, chunk_counter)
            total_records += len(chunk)
            chunk_counter += 1

        elapsed = (datetime.now() - start_time).total_seconds()
        print(f"\n{'='*60}")
        print(f"‚úì Processing Complete!")
        print(f"  Total records: {total_records:,}")
        # Return the number of chunks processed for this file only
        num_chunks_this_file = chunk_counter - start_chunk_id
        print(f"  Chunks created in this file: {num_chunks_this_file}")
        print(f"  Errors: {error_count:,}")
        print(f"  Time: {elapsed/60:.2f} minutes")
        if elapsed > 0:
            print(f"  Avg rate: {total_records/elapsed:.0f} records/sec")
        print(f"{'='*60}\n")

        return num_chunks_this_file, total_records

    except FileNotFoundError:
        print(f"‚ùå Error: File not found: {input_file}")
        return 0, 0
    except Exception as e:
        print(f"‚ùå Error during processing: {e}")
        import traceback
        traceback.print_exc()
        return 0, 0

In [None]:
def inspect_avro_records(input_file, num_samples=5):
    """
    Inspect the structure of Avro records to understand the schema
    """
    print(f"\n{'='*60}")
    print("INSPECTING AVRO RECORD STRUCTURE")
    print(f"{'='*60}\n")

    for i, record in enumerate(read_avro_binary(input_file, max_records=num_samples)):
        print(f"Record #{i+1}:")
        print(f"  Type: {type(record)}")
        print(f"  Keys: {list(record.keys())}")

        # Print first level structure
        for key, value in record.items():
            if isinstance(value, dict):
                print(f"  {key}: dict with keys {list(value.keys())[:5]}")
            elif isinstance(value, list):
                print(f"  {key}: list with {len(value)} items")
            else:
                print(f"  {key}: {type(value).__name__} = {str(value)[:100]}")

        print(f"\nFull record structure:")
        import json
        print(json.dumps(record, indent=2, default=str)[:1000])
        print("\n" + "-"*60 + "\n")

        if i >= num_samples - 1:
            break

    print("="*60 + "\n")


In [None]:
def extract_apt_features_from_avro(record):
    """
    Extract security-relevant features from DARPA TC Avro records
    Schema: CDM20 with top-level 'type' and nested 'datum'
    """

    try:
        # Get record type from top level
        record_type = record.get('type', 'unknown')
        datum = record.get('datum', {})

        if not datum:
            return None

        # Initialize with safe defaults
        features = {
            'record_type': record_type,
            'cdm_version': str(record.get('CDMVersion', '20')),
            'source': record.get('source', ''),
            'session_number': record.get('sessionNumber', 0),
            'timestamp_ns': 0,
            'event_id': '',
            'sequence': 0,
            'thread_id': 0,
            'subject_uuid': '',
            'object_uuid': '',
            'event_type': '',
            'src_addr': '',
            'src_port': 0,
            'dst_addr': '',
            'dst_port': 0,
            'ip_protocol': 0,
            'file_path': '',
            'file_size': 0,
            'process_pid': 0,
            'process_ppid': 0,
            'memory_address': 0,
            'protection': '',
            'is_suspicious': False
        }

        # Extract UUID (bytes format)
        if 'uuid' in datum:
            uuid_bytes = datum.get('uuid')
            if isinstance(uuid_bytes, bytes):
                features['event_id'] = uuid_bytes.hex()
            else:
                features['event_id'] = str(uuid_bytes)

        # Extract timestamp
        if 'timestampNanos' in datum:
            features['timestamp_ns'] = datum.get('timestampNanos', 0)

        # Extract sequence
        if 'sequence' in datum:
            features['sequence'] = datum.get('sequence', 0)

        # Extract thread ID
        if 'threadId' in datum:
            features['thread_id'] = datum.get('threadId', 0)

        # Process based on record type
        if record_type == 'RECORD_EVENT':
            # Event records
            features['event_type'] = datum.get('type', '')

            # Subject (process UUID)
            if 'subject' in datum:
                subject = datum.get('subject')
                if isinstance(subject, bytes):
                    features['subject_uuid'] = subject.hex()
                else:
                    features['subject_uuid'] = str(subject)

            # Predicate object (what the event operates on)
            if 'predicateObject' in datum:
                pred_obj = datum.get('predicateObject')
                if pred_obj:
                    if isinstance(pred_obj, bytes):
                        features['object_uuid'] = pred_obj.hex()
                    else:
                        features['object_uuid'] = str(pred_obj)

            # Size (for file operations)
            if 'size' in datum and datum.get('size'):
                features['file_size'] = datum.get('size', 0)

            # Properties (additional metadata)
            props = datum.get('properties', {})
            if props:
                # Protection for memory operations
                if 'protection' in props:
                    features['protection'] = str(props.get('protection', ''))

                # File path
                if 'path' in props:
                    features['file_path'] = str(props.get('path', ''))

                # Network info
                if 'remoteAddress' in props:
                    features['dst_addr'] = str(props.get('remoteAddress', ''))
                if 'remotePort' in props:
                    features['dst_port'] = int(props.get('remotePort', 0))
                if 'localAddress' in props:
                    features['src_addr'] = str(props.get('localAddress', ''))
                if 'localPort' in props:
                    features['src_port'] = int(props.get('localPort', 0))

        elif record_type == 'RECORD_SUBJECT':
            # Subject/Process records
            features['event_type'] = datum.get('type', '')

            base_obj = datum.get('baseObject', {})
            if base_obj:
                props = base_obj.get('properties', {})
                if props:
                    if 'pid' in props:
                        features['process_pid'] = int(props.get('pid', 0))
                    if 'ppid' in props:
                        features['process_ppid'] = int(props.get('ppid', 0))
                    if 'tgid' in props:
                        features['thread_id'] = int(props.get('tgid', 0))

        elif record_type == 'RECORD_FILE_OBJECT':
            # File object records
            base_obj = datum.get('baseObject', {})
            if base_obj:
                props = base_obj.get('properties', {})
                if props and 'path' in props:
                    features['file_path'] = str(props.get('path', ''))

            if 'size' in datum:
                features['file_size'] = datum.get('size', 0)

        elif record_type == 'RECORD_NETFLOW_OBJECT':
            # Network flow records
            base_obj = datum.get('baseObject', {})
            if base_obj:
                props = base_obj.get('properties', {})
                if props:
                    features['src_addr'] = str(props.get('localAddress', ''))
                    features['src_port'] = int(props.get('localPort', 0))
                    features['dst_addr'] = str(props.get('remoteAddress', ''))
                    features['dst_port'] = int(props.get('remotePort', 0))
                    features['ip_protocol'] = int(props.get('ipProtocol', 0))

        elif record_type == 'RECORD_MEMORY_OBJECT':
            # Memory object records
            if 'memoryAddress' in datum:
                features['memory_address'] = datum.get('memoryAddress', 0)
            if 'size' in datum:
                features['file_size'] = datum.get('size', 0)

            base_obj = datum.get('baseObject', {})
            if base_obj:
                props = base_obj.get('properties', {})
                if props and 'tgid' in props:
                    features['process_pid'] = int(props.get('tgid', 0))

        # Detect suspicious behavior
        features['is_suspicious'] = detect_suspicious_behavior(features)

        return features

    except Exception as e:
        # Log error but still return basic structure
        return None

def extract_event_features(event_data):
    """Extract features from Event records"""
    features = {}

    features['timestamp_ns'] = event_data.get('timestampNanos', 0)
    features['sequence'] = event_data.get('sequence', 0)
    features['event_id'] = str(event_data.get('uuid', ''))

    # Subject (process/principal)
    subject_uuid = event_data.get('subject')
    if subject_uuid:
        features['subject_uuid'] = str(subject_uuid)

    # Predicate Object (what the event operates on)
    pred_obj = event_data.get('predicateObject')
    if pred_obj:
        features['object_uuid'] = str(pred_obj)

    # Predicate (operation type)
    predicate = event_data.get('predicateObjectPath')
    if predicate:
        features['predicate_type'] = str(predicate)

    # Operation type
    event_type = event_data.get('type')
    if event_type:
        features['operation'] = str(event_type)

    return features

def extract_subject_features(subject_data):
    """Extract features from Subject (process) records"""
    features = {}

    features['timestamp_ns'] = subject_data.get('timestampNanos', 0)
    features['subject_uuid'] = str(subject_data.get('uuid', ''))

    # Process properties
    properties = subject_data.get('properties', {})
    if properties:
        features['process_pid'] = properties.get('map', {}).get('pid', 0)
        features['process_ppid'] = properties.get('map', {}).get('ppid', 0)
        features['process_name'] = properties.get('map', {}).get('name', '')

    return features

def extract_file_features(file_data):
    """Extract features from FileObject records"""
    features = {}

    features['timestamp_ns'] = file_data.get('timestampNanos', 0)
    features['object_uuid'] = str(file_data.get('uuid', ''))

    # File properties
    properties = file_data.get('properties', {})
    if properties:
        prop_map = properties.get('map', {})
        features['file_path'] = prop_map.get('path', '')
        features['file_size'] = int(prop_map.get('size', 0)) if prop_map.get('size') else 0

    return features

def extract_network_features(net_data):
    """Extract features from NetFlowObject records"""
    features = {}

    features['timestamp_ns'] = net_data.get('timestampNanos', 0)
    features['object_uuid'] = str(net_data.get('uuid', ''))

    # Network properties
    properties = net_data.get('properties', {})
    if properties:
        prop_map = properties.get('map', {})
        features['src_addr'] = prop_map.get('srcAddress', '')
        features['src_port'] = int(prop_map.get('srcPort', 0)) if prop_map.get('srcPort') else 0
        features['dst_addr'] = prop_map.get('destAddress', '')
        features['dst_port'] = int(prop_map.get('destPort', 0)) if prop_map.get('destPort') else 0
        features['ip_protocol'] = int(prop_map.get('ipProtocol', 0)) if prop_map.get('ipProtocol') else 0

    return features

def detect_suspicious_behavior(features):
    """
    Enhanced heuristic for flagging potentially suspicious activity
    Based on DARPA TC APT detection scenarios
    """
    suspicious = False

    # Suspicious network activity
    dst_port = features.get('dst_port', 0)
    if dst_port in [4444, 31337, 1337, 8080, 9999, 6666, 1234]:
        suspicious = True

    # Suspicious file operations
    file_path = features.get('file_path', '').lower()
    suspicious_paths = ['/tmp/', '/dev/shm/', 'powershell', 'wget', 'curl',
                        '.sh', 'base64', '/etc/passwd', '/etc/shadow']
    if any(sp in file_path for sp in suspicious_paths):
        suspicious = True

    # Suspicious event types
    event_type = features.get('event_type', '')
    suspicious_events = ['EVENT_EXECUTE', 'EVENT_MMAP', 'EVENT_CLONE',
                         'EVENT_LOADLIBRARY', 'EVENT_CREATE_THREAD']
    if event_type in suspicious_events:
        suspicious = True

    # Memory protection changes (common in exploits)
    protection = features.get('protection', '')
    if protection in ['7', '5']:  # RWX or R-X permissions
        suspicious = True

    # Unusual process relationships
    if features.get('process_ppid', 0) == 1 and features.get('process_pid', 0) > 1000:
        # Process reparented to init (orphaned)
        suspicious = True

    return suspicious

In [None]:
def save_processed_chunk(chunk, chunk_id):
    """Save processed chunk efficiently using Parquet format"""
    try:
        df = pd.DataFrame(chunk)

        # Convert timestamp to datetime
        if 'timestamp_ns' in df.columns:
            df['timestamp'] = pd.to_datetime(df['timestamp_ns'], unit='ns', errors='coerce')

        # Fill NaN values
        numeric_columns = df.select_dtypes(include=[np.number]).columns
        df[numeric_columns] = df[numeric_columns].fillna(0)

        string_columns = df.select_dtypes(include=['object']).columns
        df[string_columns] = df[string_columns].fillna('')

        # Save as Parquet
        output_file = f"{dirs['processed']}/chunk_{chunk_id:04d}.parquet"
        df.to_parquet(output_file, compression='snappy', index=False)

        return True
    except Exception as e:
        print(f"   Error saving chunk {chunk_id}: {e}")
        return False

In [None]:
def create_temporal_splits(processed_dir, train_ratio=0.7, val_ratio=0.15):
    """Create stratified splits for APT detection (ensuring consistent threat distribution)"""
    from sklearn.model_selection import train_test_split

    print(f"\n{'='*60}")
    print("Creating Train/Val/Test Splits (STRATIFIED)")
    print(f"{'='*60}\n")

    processed_path = Path(processed_dir)
    parquet_files = sorted(processed_path.glob('*.parquet'))

    if not parquet_files:
        print("‚ùå No processed files found!")
        print(f"   Expected location: {processed_dir}")
        return None, None, None

    print(f"‚úì Found {len(parquet_files)} chunk files")

    # Load and concatenate all chunks
    print("Loading chunks...")
    dfs = []
    for i, f in enumerate(parquet_files):
        df = pd.read_parquet(f)
        dfs.append(df)
        if (i + 1) % 10 == 0:
            print(f"   Loaded {i+1}/{len(parquet_files)} chunks...")

    print("Concatenating data...")
    full_df = pd.concat(dfs, ignore_index=True)

    print(f"‚úì Total records: {len(full_df):,}")

    # CRITICAL FIX: Use stratified splitting to ensure consistent threat distribution
    print("\nüîÑ Creating stratified splits...")
    print("   This ensures all splits have the same threat distribution!")

    # Extract labels for stratification
    labels = full_df['is_suspicious'].values

    # First split: 70% train, 30% temp (val + test)
    train_df, temp_df = train_test_split(
        full_df,
        test_size=(1 - train_ratio),
        stratify=labels,  # ‚Üê CRITICAL: Ensures same threat ratio
        random_state=42
    )

    # Second split: Split temp into val (15%) and test (15%)
    temp_labels = temp_df['is_suspicious'].values
    val_df, test_df = train_test_split(
        temp_df,
        test_size=0.5,  # 50% of temp = 15% of total
        stratify=temp_labels,  # ‚Üê CRITICAL: Ensures same threat ratio
        random_state=42
    )

    print(f"‚úì Stratified splitting complete!")

    # Save splits
    splits_dir = dirs['splits']
    print(f"\nüíæ Saving splits to {splits_dir}...")

    train_df.to_parquet(f'{splits_dir}/train.parquet', index=False)
    val_df.to_parquet(f'{splits_dir}/val.parquet', index=False)
    test_df.to_parquet(f'{splits_dir}/test.parquet', index=False)

    # Calculate metrics
    train_threats = train_df['is_suspicious'].sum()
    val_threats = val_df['is_suspicious'].sum()
    test_threats = test_df['is_suspicious'].sum()

    train_threat_pct = train_threats / len(train_df) * 100
    val_threat_pct = val_threats / len(val_df) * 100
    test_threat_pct = test_threats / len(test_df) * 100

    # Save metadata
    metadata = {
        'total_records': len(full_df),
        'train_records': len(train_df),
        'val_records': len(val_df),
        'test_records': len(test_df),
        'train_threats': int(train_threats),
        'val_threats': int(val_threats),
        'test_threats': int(test_threats),
        'train_threat_pct': float(train_threat_pct),
        'val_threat_pct': float(val_threat_pct),
        'test_threat_pct': float(test_threat_pct),
        'split_method': 'stratified',
        'random_state': 42
    }

    with open(f"{dirs['metadata']}/split_info.json", 'w') as f:
        json.dump(metadata, f, indent=2)

    # Print summary
    print(f"\n{'='*60}")
    print("‚úì Splits Created Successfully!")
    print(f"{'='*60}")
    print(f"\nüìä Dataset Statistics:")
    print(f"   Total Records:     {len(full_df):,}")
    print(f"   Train:             {len(train_df):,} ({train_ratio*100:.1f}%)")
    print(f"   Validation:        {len(val_df):,} ({val_ratio*100:.1f}%)")
    print(f"   Test:              {len(test_df):,} ({(1-train_ratio-val_ratio)*100:.1f}%)")
    
    print(f"\nüö® Suspicious Activity Distribution (STRATIFIED):")
    print(f"   Train:     {train_threats:,}/{len(train_df):,} = {train_threat_pct:.2f}%")
    print(f"   Val:       {val_threats:,}/{len(val_df):,} = {val_threat_pct:.2f}%")
    print(f"   Test:      {test_threats:,}/{len(test_df):,} = {test_threat_pct:.2f}%")
    
    # Verify stratification worked
    print(f"\n‚úÖ Verification:")
    if abs(train_threat_pct - val_threat_pct) < 0.5 and abs(train_threat_pct - test_threat_pct) < 0.5:
        print(f"   ‚úì All splits have consistent threat distribution!")
        print(f"   ‚úì Difference < 0.5% between splits")
    else:
        print(f"   ‚ö†Ô∏è  Warning: Splits have different distributions!")
        print(f"   ‚ö†Ô∏è  Train-Val diff: {abs(train_threat_pct - val_threat_pct):.2f}%")
        print(f"   ‚ö†Ô∏è  Train-Test diff: {abs(train_threat_pct - test_threat_pct):.2f}%")
    
    print(f"{'='*60}\n")

    return train_df, val_df, test_df

In [None]:
class DARPAAPTDataset(Dataset):
    """PyTorch Dataset for DARPA TC data optimized for MARL"""

    def __init__(self, parquet_file, feature_columns=None):
        print(f"Loading dataset from {parquet_file}...")
        self.df = pd.read_parquet(parquet_file)

        # Define features for MARL state representation
        if feature_columns is None:
            self.feature_columns = [
                'sequence', 'src_port', 'dst_port',
                'ip_protocol', 'file_size', 'process_pid', 'process_ppid'
            ]
        else:
            self.feature_columns = feature_columns

        # Ensure columns exist
        for col in self.feature_columns:
            if col not in self.df.columns:
                self.df[col] = 0

        # Fill missing values
        for col in self.feature_columns:
            self.df[col] = self.df[col].fillna(0)

        # Normalize features
        self.normalize_features()

        print(f"‚úì Loaded {len(self.df):,} records")
        print(f"‚úì Using {len(self.feature_columns)} features")

    def normalize_features(self):
        """Normalize numeric features to [0, 1] range"""
        for col in self.feature_columns:
            max_val = self.df[col].max()
            if max_val > 0:
                self.df[col] = self.df[col] / max_val

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        features = []
        for col in self.feature_columns:
            features.append(float(row[col]))

        state = torch.tensor(features, dtype=torch.float32)
        label = torch.tensor(int(row['is_suspicious']), dtype=torch.long)

        metadata = {
            'record_type': row.get('record_type', ''),
            'timestamp': str(row.get('timestamp', ''))
        }

        return state, label, metadata

def create_data_loader(split_file, batch_size=256, shuffle=True):
    """Create DataLoader for MARL training"""
    dataset = DARPAAPTDataset(split_file)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=2, pin_memory=True)

    print(f"‚úì DataLoader created: {len(loader)} batches")
    return loader, dataset


## Modify `run complete pipeline`

### Subtask:
Adjust the `run_complete_pipeline` function to only perform the processing step using `process_darpa_avro_streaming` and save the processed chunks, without proceeding to create splits or dataloaders.


In [None]:
def run_complete_pipeline(raw_file_path, test_mode=False, start_chunk_id=0):
    """
    Execute the data processing pipeline step for a single raw file.

    Args:
        raw_file_path: Path to DARPA TC .bin or .bin.gz file
        test_mode: If True, process only 100K records for testing
        start_chunk_id: The starting ID for chunk numbering (to avoid overwriting)


    Returns:
        tuple: (num_chunks, num_records) processed from the file,
               or (0, 0) if processing fails.
    """

    print("DARPA TC DATA PROCESSING STEP")

    max_records = 100000 if test_mode else None

    # Step 1: Process raw Avro logs
    print("STEP 1: Processing Avro binary logs...")
    num_chunks, num_records = process_darpa_avro_streaming(
        raw_file_path,
        chunk_size=50000,
        max_records=max_records,
        start_chunk_id=start_chunk_id # Pass the start chunk ID
    )

    if num_records == 0:
        print(" Processing failed for this file.")
        return 0, 0

    print("PROCESSING STEP COMPLETE FOR THIS FILE!")

    return num_chunks, num_records

## Iterate and process

### Subtask:
Loop through each raw file in the raw directory, calling the `run_complete_pipeline` function to process each file individually. Accumulate the results (number of chunks and records) from each file.


In [None]:
# Get the list of raw files
raw_files = list_raw_files()

# Initialize variables
total_chunks_processed = 0
total_records_processed = 0

# Initialize a global chunk counter
global_chunk_counter = 0

# Iterate through the list of raw files
for raw_file in raw_files:
    # Skip directory files like .ipynb_checkpoints
    if raw_file.is_dir():
        print(f"Skipping directory: {raw_file.name}")
        continue

    print(f"\n{'='*80}")
    print(f"PROCESSING FILE: {raw_file.name}")
    print(f"{'='*80}\n")

    # Run the complete pipeline for the current file in test mode
    # Pass the current global chunk counter to the processing function
    num_chunks, num_records = run_complete_pipeline(str(raw_file), test_mode=True, start_chunk_id=global_chunk_counter)

    # Update the global chunk counter
    global_chunk_counter += num_chunks

    # Add the returned num_chunks and num_records to the total
    total_chunks_processed += num_chunks
    total_records_processed += num_records

    # Free up memory
    gc.collect()

# After the loop, print the total number of chunks and records processed
print(f"\nFinished processing all raw files.")
print(f"Total chunks processed across all files: {total_chunks_processed}")
print(f"Total records processed across all files: {total_records_processed:,}")


‚úì Found 6 file(s) in raw directory:
   1. .ipynb_checkpoints (0.00 MB)
   2. ta1-trace-1-e5-official-1.bin.2.gz (270.65 MB)
   3. ta1-trace-1-e5-official-1.bin.1.gz (266.57 MB)
   4. ta1-trace-1-e5-official-1.bin.4.gz (272.90 MB)
   5. ta1-trace-1-e5-official-1.bin.3.gz (273.70 MB)
   6. ta1-trace-1-e5-official-1.bin.5.gz (259.57 MB)
Skipping directory: .ipynb_checkpoints

PROCESSING FILE: ta1-trace-1-e5-official-1.bin.2.gz

DARPA TC DATA PROCESSING STEP
STEP 1: Processing Avro binary logs...

Processing: ta1-trace-1-e5-official-1.bin.2.gz

Reading Avro file: ta1-trace-1-e5-official-1.bin.2.gz
Compressed: True
   Processed: 50,000 records | Rate: 34182 records/sec | Chunks saved: 1
   Processed: 100,000 records | Rate: 34640 records/sec | Chunks saved: 2

‚úì Processing Complete!
  Total records: 100,000
  Chunks created in this file: 2
  Errors: 0
  Time: 0.05 minutes
  Avg rate: 34632 records/sec

PROCESSING STEP COMPLETE FOR THIS FILE!

PROCESSING FILE: ta1-trace-1-e5-official-1.

## Create combined splits

### Subtask:
Create a single set of train, validation, and test splits from the combined processed data stored in the `processed` directory.


In [None]:
# Step 3: Create splits from the combined processed data
print("\nSTEP 3: Creating train/val/test splits from combined data...")
combined_train_df, combined_val_df, combined_test_df = create_temporal_splits(dirs['processed'])

if combined_train_df is None:
    print(" Pipeline failed: Could not create combined splits")



STEP 3: Creating train/val/test splits from combined data...

Creating Train/Val/Test Splits

Found 10 chunk files
Loading chunks...
   Loaded 10/10 chunks...
Concatenating data...
Total records: 500,000
Sorting by timestamp...
Splitting data...
Saving splits to /content/drive/MyDrive/mythesis/vicky/darpa_tc/splits...

‚úì Splits Created Successfully!

üìä Dataset Statistics:
   Total Records:     500,000
   Train:             350,000 (70.0%)
   Validation:        75,000 (15.0%)
   Test:              75,000 (15.0%)

üïê Time Ranges:
   Train:     1970-01-01 00:00:00 ‚Üí 2019-05-07 20:01:38.246000
   Val:       2019-05-07 20:01:38.246000 ‚Üí 2019-05-07 20:01:45.530000
   Test:      2019-05-07 20:01:45.530000 ‚Üí 2019-05-07 21:54:19.668000

üö® Suspicious Activity Rates:
   Train:     9.97%
   Val:       6.09%
   Test:      20.22%



## Create combined dataloaders

### Subtask:
Create the PyTorch DataLoaders from the combined split files.


In [None]:
# Step 4: Create dataloaders from combined splits
print("\nSTEP 4: Creating PyTorch DataLoaders from combined splits...")

combined_train_loader, _ = create_data_loader(f"{dirs['splits']}/train.parquet", batch_size=256, shuffle=True)
combined_val_loader, _ = create_data_loader(f"{dirs['splits']}/val.parquet", batch_size=256, shuffle=False)
combined_test_loader, _ = create_data_loader(f"{dirs['splits']}/test.parquet", batch_size=256, shuffle=False)

print("DATALOADER CREATION COMPLETE!")


STEP 4: Creating PyTorch DataLoaders from combined splits...
Loading dataset from /content/drive/MyDrive/mythesis/vicky/darpa_tc/splits/train.parquet...
‚úì Loaded 350,000 records
‚úì Using 7 features
‚úì DataLoader created: 1368 batches
Loading dataset from /content/drive/MyDrive/mythesis/vicky/darpa_tc/splits/val.parquet...
‚úì Loaded 75,000 records
‚úì Using 7 features
‚úì DataLoader created: 293 batches
Loading dataset from /content/drive/MyDrive/mythesis/vicky/darpa_tc/splits/test.parquet...
‚úì Loaded 75,000 records
‚úì Using 7 features
‚úì DataLoader created: 293 batches
DATALOADER CREATION COMPLETE!


## Summary:

*   Five raw data files were processed sequentially.
*   In test mode, each file processed 100,000 records, resulting in 2 chunks per file.
*   A total of 10 chunks and 500,000 records were processed across all files.
*   The combined processed data was successfully split into train (350,000 records), validation (75,000 records), and test (75,000 records) sets using a temporal split strategy.
*   PyTorch DataLoaders were successfully created for the combined train, validation, and test splits with a batch size of 256.


*   The created DataLoaders (`combined_train_loader`, `combined_val_loader`, `combined_test_loader`) are now ready for use in training and evaluating our proposed model.
