In [None]:
def generate_final_results_table(self):
        """Generate final comprehensive results table"""
        print("\n📊 FINAL COMPREHENSIVE RESULTS TABLE")
        print("=" * 120)

        # Collect all results in a structured format
        all_results = []

        for config_name, models in self.results.items():
            features_used = "Yes" if "With" in config_name else "No"

            for model_name, metrics in models.items():
                if model_name in MODELS_CONFIG:
                    # Calculate combined error rate
                    combined_error = metrics['false_positive_rate'] + metrics['false_negative_rate']

                    all_results.append({
                        'Features': features_used,
                        'Model': MODELS_CONFIG[model_name]['name'],
                        'Accuracy': f"{metrics['accuracy']:.3f}",
                        'Legit F1': f"{metrics['legitimate_f1']:.3f}",
                        'Phish F1': f"{metrics['phishing_f1']:.3f}",
                        'Avg F1': f"{metrics['f1_score']:.3f}",
                        'FPR': f"{metrics['false_positive_rate']:.3f}",
                        'FNR': f"{metrics['false_negative_rate']:.3f}",
                        'FPR+FNR': f"{combined_error:.3f}",
                        'Time(s)': f"{metrics['train_time']:.1f}"
                    })

        # Create DataFrame and display
        df_final = pd.DataFrame(all_results)

        # Sort by Features (No first) and then by combined error rate
        df_final['sort_features'] = df_final['Features'].map({'No': 0, 'Yes': 1})
        df_final['sort_error'] = df_final['FPR+FNR'].astype(float)
        df_final = df_final.sort_values(['sort_features', 'sort_error']).drop(['sort_features', 'sort_error'], axis=1)

        print(df_final.to_string(index=False))

        # Best models summary
        print("\n🏆 BEST MODELS BY METRIC")
        print("-" * 80)

        # Without features
        no_features_df = df_final[df_final['Features'] == 'No']
        if not no_features_df.empty:
            print("\nWithout Additional Features:")
            best_f1 = no_features_df.loc[no_features_df['Avg F1'].astype(float).idxmax()]
            print(f"  Best F1-Score: {best_f1['Model']} ({best_f1['Avg F1']})")

            best_balanced = no_features_df.loc[no_features_df['FPR+FNR'].astype(float).idxmin()]
            print(f"  Best Balanced: {best_balanced['Model']} (FPR+FNR: {best_balanced['FPR+FNR']})")

        # With features
        yes_features_df = df_final[df_final['Features'] == 'Yes']
        if not yes_features_df.empty:
            print("\nWith Additional Features:")
            best_f1 = yes_features_df.loc[yes_features_df['Avg F1'].astype(float).idxmax()]
            print(f"  Best F1-Score: {best_f1['Model']} ({best_f1['Avg F1']})")

            best_balanced = yes_features_df.loc[yes_features_df['FPR+FNR'].astype(float).idxmin()]
            print(f"  Best Balanced: {best_balanced['Model']} (FPR+FNR: {best_balanced['FPR+FNR']})")

        # Key insights
        print("\n📈 KEY PERFORMANCE INSIGHTS")
        print("-" * 80)

        # Compare feature impact
        print("\n1. Impact of Additional Features:")
        for model in ['Logistic Regression Baseline', 'Random Forest Baseline', 'XGBoost Baseline', 'BERT Base Uncased']:
            no_feat = df_final[(df_final['Features'] == 'No') & (df_final['Model'] == model)]
            yes_feat = df_final[(df_final['Features'] == 'Yes') & (df_final['Model'] == model)]

            if not no_feat.empty and not yes_feat.empty:
                no_f1 = float(no_feat.iloc[0]['Avg F1'])
                yes_f1 = float(yes_feat.iloc[0]['Avg F1'])
                diff = yes_f1 - no_f1

                if diff != 0:  # Only show if there's a difference
                    print(f"   {model}: {'+' if diff > 0 else ''}{diff:.3f} F1 improvement")

        # Speed vs Performance trade-off
        print("\n2. Speed vs Performance Trade-off:")
        print("   Fast (<1s) + High Performance (F1>0.97):")
        fast_good = df_final[(df_final['Time(s)'].astype(float) < 1) & (df_final['Avg F1'].astype(float) > 0.97)]
        for _, row in fast_good.iterrows():
            print(f"   - {row['Model']} with {row['Features']} features: F1={row['Avg F1']}, Time={row['Time(s)']}s")

        # Class balance analysis
        print("\n3. Class Balance Analysis (FPR vs FNR):")
        for _, row in df_final.iterrows():
            fpr = float(row['FPR'])
            fnr = float(row['FNR'])
            if abs(fpr - fnr) < 0.01:  # Well balanced
                print(f"   Well balanced: {row['Model']} with {row['Features']} features (FPR={row['FPR']}, FNR={row['FNR']})")

        # Save enhanced results
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        final_csv = f"final_results_table_{timestamp}.csv"
        df_final.to_csv(final_csv, index=False)
        print(f"\n💾 Final results table saved to: {final_csv}")

        return df_final#!/usr/bin/env python3
"""
Multi-Model Phishing Detection Evaluation Framework
Matches the exact sampling strategy used in fine-tuning:
- 3,000 examples per class (6,000 total)
- 4,730 training / 1,000 validation split

GPU-enabled for faster processing
"""

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    BertTokenizer, BertForSequenceClassification,
    AutoTokenizer, AutoModelForSequenceClassification,
    TrainingArguments, Trainer
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
import xgboost as xgb
import time
from datetime import datetime
import json
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
warnings.filterwarnings('ignore')

# ==========================================
# CONFIGURATION - MATCHING FINE-TUNING
# ==========================================

# Dataset path (same as your original)
DATASET_PATH = "/content/drive/MyDrive/phishing_detection_final/output/final_datasets/o3_mini_optimized_dataset_20250709_185906.csv"

# Sampling configuration (matching your fine-tuning)
SAMPLES_PER_CLASS = 3000  # Same as fine-tuning
VALIDATION_SIZE = 1000    # Same as fine-tuning
MAX_TOKEN_LENGTH = 1591   # 95th percentile from your statistics

# Model configurations
MODELS_CONFIG = {
    'baseline_lr': {
        'name': 'Logistic Regression Baseline',
        'type': 'sklearn'
    },
    'baseline_rf': {
        'name': 'Random Forest Baseline',
        'type': 'sklearn'
    },
    'baseline_xgb': {
        'name': 'XGBoost Baseline',
        'type': 'xgboost'
    },
    'bert': {
        'name': 'BERT Base Uncased',
        'model_name': 'bert-base-uncased',
        'type': 'transformer'
    },
    'secbert': {
        'name': 'SecBERT (jackaduma)',
        'model_name': 'jackaduma/SecBERT',
        'type': 'transformer'
    },
    'securebert': {
        'name': 'SecureBERT (ehsanaghaei)',
        'model_name': 'ehsanaghaei/SecureBERT',
        'type': 'transformer'
    },
    'cysecbert': {
        'name': 'CySecBERT (markusbayer)',
        'model_name': 'markusbayer/CySecBERT',
        'type': 'transformer'
    },
    'lstm': {
        'name': 'BiLSTM',
        'type': 'lstm'
    }
}

# Choose which cybersecurity models to use
CYBER_MODELS_TO_USE = ['secbert']  # Add more as needed

# Training configuration
BATCH_SIZE = 16
MAX_LENGTH = 512  # For transformers
EPOCHS = 3
LEARNING_RATE = 2e-5
RANDOM_STATE = 42

# GPU Configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🖥️ Using device: {DEVICE}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# ==========================================
# DATASET HANDLING WITH MATCHED SAMPLING
# ==========================================

class PhishingDataset(Dataset):
    """PyTorch Dataset for phishing detection"""

    def __init__(self, texts, labels, tokenizer, max_length=512, additional_features=None):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.additional_features = additional_features

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        item = {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

        if self.additional_features is not None:
            item['additional_features'] = torch.tensor(
                self.additional_features[idx],
                dtype=torch.float
            )

        return item

# ==========================================
# MODEL DEFINITIONS (Same as before)
# ==========================================

class BiLSTMClassifier(nn.Module):
    """Bidirectional LSTM for text classification"""

    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256,
                 num_layers=2, dropout=0.3, num_classes=2,
                 additional_features_dim=0):
        super(BiLSTMClassifier, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )
        self.dropout = nn.Dropout(dropout)

        lstm_output_dim = hidden_dim * 2
        final_dim = lstm_output_dim + additional_features_dim

        self.fc = nn.Linear(final_dim, num_classes)
        self.additional_features_dim = additional_features_dim

    def forward(self, input_ids, attention_mask=None, additional_features=None):
        embedded = self.embedding(input_ids)
        lstm_out, _ = self.lstm(embedded)

        last_hidden = lstm_out[:, -1, :]

        if additional_features is not None and self.additional_features_dim > 0:
            last_hidden = torch.cat([last_hidden, additional_features], dim=1)

        dropped = self.dropout(last_hidden)
        output = self.fc(dropped)

        return output

class EnhancedBERTClassifier(nn.Module):
    """BERT with optional additional features"""

    def __init__(self, model_name, num_classes=2, additional_features_dim=0):
        super(EnhancedBERTClassifier, self).__init__()

        self.bert = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            num_labels=num_classes
        )

        self.additional_features_dim = additional_features_dim

        if additional_features_dim > 0:
            bert_hidden_size = self.bert.config.hidden_size
            self.bert.classifier = nn.Linear(
                bert_hidden_size + additional_features_dim,
                num_classes
            )

    def forward(self, input_ids, attention_mask=None, additional_features=None):
        outputs = self.bert.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        pooled_output = outputs.pooler_output

        if additional_features is not None and self.additional_features_dim > 0:
            pooled_output = torch.cat([pooled_output, additional_features], dim=1)

        logits = self.bert.classifier(pooled_output)

        return logits

# ==========================================
# EVALUATION FRAMEWORK WITH MATCHED SAMPLING
# ==========================================

class MatchedSamplingEvaluator:
    """Evaluation framework matching fine-tuning sampling strategy"""

    def __init__(self, dataset_path):
        self.dataset_path = dataset_path
        self.df = None
        self.results = {}
        self.additional_feature_columns = []

    def load_and_prepare_data(self):
        """Load dataset and apply same sampling as fine-tuning"""
        print("📂 LOADING AND SAMPLING DATASET (Matching Fine-Tuning)")
        print("-" * 60)

        # Load full dataset
        self.df = pd.read_csv(self.dataset_path)
        print(f"✅ Dataset loaded: {len(self.df):,} total rows")

        # Identify columns
        base_columns = ['subject', 'body', 'original_label']
        all_columns = list(self.df.columns)

        # Additional feature columns
        self.additional_feature_columns = [
            col for col in all_columns
            if col not in base_columns and col != 'Unnamed: 0'
        ]

        print(f"\n📊 Column Analysis:")
        print(f"   Base columns: {base_columns}")
        print(f"   Additional columns: {len(self.additional_feature_columns)}")

        # Create combined text
        self.df['combined_text'] = self.df['subject'].fillna('') + ' ' + self.df['body'].fillna('')

        # Clean data
        self.df = self.df[self.df['combined_text'].str.strip() != '']
        self.df = self.df[self.df['original_label'].notna()]

        print(f"\n📊 Clean examples: {len(self.df):,}")

        # Split by class
        legitimate_df = self.df[self.df['original_label'] == 0]
        phishing_df = self.df[self.df['original_label'] == 1]

        print(f"   Legitimate: {len(legitimate_df):,}")
        print(f"   Phishing: {len(phishing_df):,}")

        # Sample same as fine-tuning: 3,000 per class
        print(f"\n📝 Sampling {SAMPLES_PER_CLASS:,} examples per class")

        # Sample with same random state for reproducibility
        legitimate_sample = legitimate_df.sample(
            n=min(SAMPLES_PER_CLASS, len(legitimate_df)),
            random_state=RANDOM_STATE
        )
        phishing_sample = phishing_df.sample(
            n=min(SAMPLES_PER_CLASS, len(phishing_df)),
            random_state=RANDOM_STATE
        )

        # Combine samples
        self.df = pd.concat([legitimate_sample, phishing_sample])
        self.df = self.df.sample(frac=1, random_state=RANDOM_STATE).reset_index(drop=True)

        print(f"\n✅ Total examples: {len(self.df):,}")

        # Show final distribution
        class_dist = self.df['original_label'].value_counts().sort_index()
        print(f"\n📊 Final Class Distribution:")
        print(f"   Legitimate (0): {class_dist.get(0, 0):,}")
        print(f"   Phishing (1):   {class_dist.get(1, 0):,}")

        # Token length filtering (optional - to match fine-tuning)
        print(f"\n🔍 Checking token lengths (for information only)")

        return True

    def prepare_features(self, use_additional_features=True):
        """Prepare features for training"""
        X_text = self.df['combined_text'].values
        y = self.df['original_label'].values.astype(int)

        X_additional = None
        if use_additional_features and self.additional_feature_columns:
            # Select numeric columns only
            numeric_features = self.df[self.additional_feature_columns].select_dtypes(
                include=[np.number]
            ).columns.tolist()

            if numeric_features:
                X_additional = self.df[numeric_features].values

                # Handle missing values
                X_additional = np.nan_to_num(X_additional, 0)

                # Standardize features
                scaler = StandardScaler()
                X_additional = scaler.fit_transform(X_additional)

                print(f"   Using {X_additional.shape[1]} additional features")

        return X_text, X_additional, y

    def create_train_val_split(self, X_text, X_additional, y):
        """Create train/validation split matching fine-tuning"""
        # Calculate sizes to match fine-tuning
        total_samples = len(y)
        val_size = VALIDATION_SIZE
        train_size = total_samples - val_size

        print(f"\n📊 CREATING TRAIN/VALIDATION SPLIT:")
        print(f"   Total samples: {total_samples:,}")
        print(f"   Training samples: {train_size:,}")
        print(f"   Validation samples: {val_size:,}")

        # Stratified split
        indices = np.arange(len(y))
        train_idx, val_idx = train_test_split(
            indices,
            test_size=val_size/total_samples,
            random_state=RANDOM_STATE,
            stratify=y
        )

        # Split text
        X_train_text = X_text[train_idx]
        X_val_text = X_text[val_idx]

        # Split labels
        y_train = y[train_idx]
        y_val = y[val_idx]

        # Split additional features if present
        if X_additional is not None:
            X_train_add = X_additional[train_idx]
            X_val_add = X_additional[val_idx]
        else:
            X_train_add = None
            X_val_add = None

        # Verify class distribution
        print(f"\n   Training set class distribution:")
        unique, counts = np.unique(y_train, return_counts=True)
        for label, count in zip(unique, counts):
            print(f"      Class {label}: {count:,} ({count/len(y_train)*100:.1f}%)")

        print(f"\n   Validation set class distribution:")
        unique, counts = np.unique(y_val, return_counts=True)
        for label, count in zip(unique, counts):
            print(f"      Class {label}: {count:,} ({count/len(y_val)*100:.1f}%)")

        return (X_train_text, X_val_text, y_train, y_val, X_train_add, X_val_add)

    def train_baseline_models(self, X_train_text, X_val_text, y_train, y_val,
                            X_train_add, X_val_add, use_additional_features):
        """Train baseline ML models"""
        print("\n🤖 TRAINING BASELINE MODELS")
        print("-" * 50)

        # Create TF-IDF features
        from sklearn.feature_extraction.text import TfidfVectorizer

        vectorizer = TfidfVectorizer(max_features=5000, ngram_range=(1, 2))
        X_train_tfidf = vectorizer.fit_transform(X_train_text)
        X_val_tfidf = vectorizer.transform(X_val_text)

        # Combine features if needed
        if use_additional_features and X_train_add is not None:
            from scipy.sparse import hstack
            X_train_combined = hstack([X_train_tfidf, X_train_add])
            X_val_combined = hstack([X_val_tfidf, X_val_add])
        else:
            X_train_combined = X_train_tfidf
            X_val_combined = X_val_tfidf

        results = {}

        # Train models
        models = {
            'baseline_lr': LogisticRegression(max_iter=1000, random_state=RANDOM_STATE, class_weight='balanced'),
            'baseline_rf': RandomForestClassifier(n_estimators=100, random_state=RANDOM_STATE, class_weight='balanced'),
            'baseline_xgb': xgb.XGBClassifier(
                n_estimators=100,
                learning_rate=0.1,
                max_depth=6,
                random_state=RANDOM_STATE,
                tree_method='gpu_hist' if torch.cuda.is_available() else 'auto',
                scale_pos_weight=1,
                objective='binary:logistic',
                use_label_encoder=False,
                eval_metric='logloss'
            )
        }

        for model_key, model in models.items():
            if model_key not in MODELS_CONFIG:
                continue

            print(f"\n📊 Training {MODELS_CONFIG[model_key]['name']}...")

            start_time = time.time()
            model.fit(X_train_combined, y_train)
            train_time = time.time() - start_time

            # Predict on validation set
            y_pred = model.predict(X_val_combined)

            # Calculate metrics
            accuracy = accuracy_score(y_val, y_pred)
            precision, recall, f1, support = precision_recall_fscore_support(
                y_val, y_pred, average=None, labels=[0, 1]
            )
            precision_avg, recall_avg, f1_avg, _ = precision_recall_fscore_support(
                y_val, y_pred, average='binary'
            )
            cm = confusion_matrix(y_val, y_pred)

            # Calculate additional metrics
            tn, fp, fn, tp = cm.ravel()
            fpr = fp / (fp + tn) if (fp + tn) > 0 else 0  # False Positive Rate
            fnr = fn / (fn + tp) if (fn + tp) > 0 else 0  # False Negative Rate

            results[model_key] = {
                'accuracy': accuracy,
                'precision': precision_avg,
                'recall': recall_avg,
                'f1_score': f1_avg,
                'confusion_matrix': cm.tolist(),
                # Class-wise metrics
                'legitimate_precision': precision[0],
                'legitimate_recall': recall[0],
                'legitimate_f1': f1[0],
                'legitimate_support': int(support[0]),
                'phishing_precision': precision[1],
                'phishing_recall': recall[1],
                'phishing_f1': f1[1],
                'phishing_support': int(support[1]),
                # Error rates
                'false_positive_rate': fpr,
                'false_negative_rate': fnr,
                'true_negatives': int(tn),
                'false_positives': int(fp),
                'false_negatives': int(fn),
                'true_positives': int(tp),
                'train_time': train_time,
                'val_size': len(y_val),
                'train_size': len(y_train)
            }

            print(f"   Accuracy: {accuracy:.3f}")
            print(f"   F1-Score: {f1_avg:.3f}")
            print(f"   Training time: {train_time:.1f}s")

            # Print class-wise summary
            print(f"   Legitimate - Precision: {precision[0]:.3f}, Recall: {recall[0]:.3f}")
            print(f"   Phishing   - Precision: {precision[1]:.3f}, Recall: {recall[1]:.3f}")
            print(f"   FPR: {fpr:.3f}, FNR: {fnr:.3f}")

            # Check for XGBoost specific issue
            if model_key == 'baseline_xgb' and accuracy < 0.6:
                print(f"   ⚠️ WARNING: XGBoost accuracy is suspiciously low ({accuracy:.3f})")
                print(f"      This might indicate a configuration issue.")
                print(f"      Check: learning rate, n_estimators, or data preprocessing")

        return results, vectorizer

    def train_lstm_model(self, X_train_text, X_val_text, y_train, y_val,
                        X_train_add, X_val_add, use_additional_features):
        """Train BiLSTM model"""
        print("\n🔤 TRAINING BiLSTM MODEL")
        print("-" * 50)

        # Create vocabulary
        from collections import Counter

        # Simple tokenization
        def simple_tokenize(text):
            return text.lower().split()

        # Build vocabulary
        word_counts = Counter()
        for text in X_train_text:
            word_counts.update(simple_tokenize(str(text)))

        # Create vocab mapping
        vocab = ['<PAD>', '<UNK>'] + [word for word, _ in word_counts.most_common(10000)]
        word_to_idx = {word: idx for idx, word in enumerate(vocab)}

        # Convert texts to indices
        def text_to_indices(text, max_length=512):
            tokens = simple_tokenize(str(text))[:max_length]
            indices = [word_to_idx.get(token, 1) for token in tokens]
            return indices

        # Process training data
        X_train_indices = [text_to_indices(text) for text in X_train_text]
        X_val_indices = [text_to_indices(text) for text in X_val_text]

        # Pad sequences
        def pad_sequences(sequences, max_length=MAX_LENGTH):
            padded = []
            for seq in sequences:
                if len(seq) < max_length:
                    seq.extend([0] * (max_length - len(seq)))
                else:
                    seq = seq[:max_length]
                padded.append(seq)
            return np.array(padded)

        X_train_padded = pad_sequences(X_train_indices)
        X_val_padded = pad_sequences(X_val_indices)

        # Convert to tensors
        X_train_tensor = torch.tensor(X_train_padded, dtype=torch.long)
        X_val_tensor = torch.tensor(X_val_padded, dtype=torch.long)
        y_train_tensor = torch.tensor(y_train, dtype=torch.long)
        y_val_tensor = torch.tensor(y_val, dtype=torch.long)

        additional_features_dim = 0
        if use_additional_features and X_train_add is not None:
            additional_features_dim = X_train_add.shape[1]
            X_train_add_tensor = torch.tensor(X_train_add, dtype=torch.float)
            X_val_add_tensor = torch.tensor(X_val_add, dtype=torch.float)
        else:
            X_train_add_tensor = None
            X_val_add_tensor = None

        # Create model
        model = BiLSTMClassifier(
            vocab_size=len(vocab),
            additional_features_dim=additional_features_dim
        ).to(DEVICE)

        # Training setup
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

        # Create DataLoader
        if X_train_add_tensor is not None:
            train_dataset = torch.utils.data.TensorDataset(
                X_train_tensor, y_train_tensor, X_train_add_tensor
            )
        else:
            train_dataset = torch.utils.data.TensorDataset(
                X_train_tensor, y_train_tensor,
                torch.zeros(len(X_train_tensor), 1)  # Dummy tensor
            )

        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

        # Training
        print(f"   Training on {DEVICE}...")
        start_time = time.time()

        model.train()
        for epoch in range(EPOCHS):
            total_loss = 0
            for batch_x, batch_y, batch_add in train_loader:
                batch_x = batch_x.to(DEVICE)
                batch_y = batch_y.to(DEVICE)

                if use_additional_features and X_train_add_tensor is not None:
                    batch_add = batch_add.to(DEVICE)
                else:
                    batch_add = None

                optimizer.zero_grad()

                outputs = model(
                    batch_x,
                    additional_features=batch_add if use_additional_features else None
                )
                loss = criterion(outputs, batch_y)

                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            print(f"   Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(train_loader):.4f}")

        train_time = time.time() - start_time

        # Evaluation on validation set
        model.eval()
        with torch.no_grad():
            X_val_tensor = X_val_tensor.to(DEVICE)
            y_val_tensor = y_val_tensor.to(DEVICE)

            if use_additional_features and X_val_add_tensor is not None:
                X_val_add_tensor = X_val_add_tensor.to(DEVICE)

            outputs = model(
                X_val_tensor,
                additional_features=X_val_add_tensor if use_additional_features else None
            )
            _, predicted = torch.max(outputs, 1)

            y_pred = predicted.cpu().numpy()

        # Calculate metrics
        accuracy = accuracy_score(y_val, y_pred)
        precision, recall, f1, support = precision_recall_fscore_support(
            y_val, y_pred, average=None, labels=[0, 1]
        )
        precision_avg, recall_avg, f1_avg, _ = precision_recall_fscore_support(
            y_val, y_pred, average='binary'
        )
        cm = confusion_matrix(y_val, y_pred)

        # Calculate additional metrics
        tn, fp, fn, tp = cm.ravel()
        fpr = fp / (fp + tn) if (fp + tn) > 0 else 0  # False Positive Rate
        fnr = fn / (fn + tp) if (fn + tp) > 0 else 0  # False Negative Rate

        results = {
            'lstm': {
                'accuracy': accuracy,
                'precision': precision_avg,
                'recall': recall_avg,
                'f1_score': f1_avg,
                'confusion_matrix': cm.tolist(),
                # Class-wise metrics
                'legitimate_precision': precision[0],
                'legitimate_recall': recall[0],
                'legitimate_f1': f1[0],
                'legitimate_support': int(support[0]),
                'phishing_precision': precision[1],
                'phishing_recall': recall[1],
                'phishing_f1': f1[1],
                'phishing_support': int(support[1]),
                # Error rates
                'false_positive_rate': fpr,
                'false_negative_rate': fnr,
                'true_negatives': int(tn),
                'false_positives': int(fp),
                'false_negatives': int(fn),
                'true_positives': int(tp),
                'train_time': train_time,
                'val_size': len(y_val),
                'train_size': len(y_train)
            }
        }

        print(f"   Accuracy: {accuracy:.3f}")
        print(f"   F1-Score: {f1_avg:.3f}")
        print(f"   Training time: {train_time:.1f}s")

        # Print class-wise summary
        print(f"   Legitimate - Precision: {precision[0]:.3f}, Recall: {recall[0]:.3f}")
        print(f"   Phishing   - Precision: {precision[1]:.3f}, Recall: {recall[1]:.3f}")
        print(f"   FPR: {fpr:.3f}, FNR: {fnr:.3f}")

        return results

    def train_transformer_models(self, X_train_text, X_val_text, y_train, y_val,
                               X_train_add, X_val_add, use_additional_features):
        """Train BERT-based models"""
        print("\n🤗 TRAINING TRANSFORMER MODELS")
        print("-" * 50)

        results = {}

        # Determine which models to train
        transformer_models = ['bert'] + CYBER_MODELS_TO_USE

        for model_key in transformer_models:
            if model_key not in MODELS_CONFIG:
                continue

            config = MODELS_CONFIG[model_key]
            print(f"\n📊 Training {config['name']}...")

            try:
                # Load tokenizer and model
                print(f"   Loading model: {config['model_name']}")

                # Handle different tokenizer types
                if model_key in ['securebert']:
                    from transformers import RobertaTokenizer
                    tokenizer = RobertaTokenizer.from_pretrained(config['model_name'])
                else:
                    tokenizer = AutoTokenizer.from_pretrained(config['model_name'])

                additional_features_dim = 0
                if use_additional_features and X_train_add is not None:
                    additional_features_dim = X_train_add.shape[1]

                # Create model
                if model_key in ['bert', 'cysecbert', 'secbert']:
                    model = EnhancedBERTClassifier(
                        config['model_name'],
                        additional_features_dim=additional_features_dim
                    ).to(DEVICE)
                else:
                    # For RoBERTa-based models
                    from transformers import RobertaForSequenceClassification
                    model = RobertaForSequenceClassification.from_pretrained(
                        config['model_name'],
                        num_labels=2
                    ).to(DEVICE)

                # Create datasets
                train_dataset = PhishingDataset(
                    X_train_text, y_train, tokenizer, MAX_LENGTH,
                    additional_features=X_train_add
                )
                val_dataset = PhishingDataset(
                    X_val_text, y_val, tokenizer, MAX_LENGTH,
                    additional_features=X_val_add
                )

                # Training setup
                optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
                train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

                print(f"   Training on {DEVICE}...")
                start_time = time.time()

                model.train()
                for epoch in range(EPOCHS):
                    total_loss = 0
                    for batch in train_loader:
                        batch = {k: v.to(DEVICE) for k, v in batch.items()}

                        optimizer.zero_grad()

                        if hasattr(model, 'forward') and model_key in ['bert', 'cysecbert', 'secbert']:
                            outputs = model(
                                input_ids=batch['input_ids'],
                                attention_mask=batch['attention_mask'],
                                additional_features=batch.get('additional_features')
                            )
                        else:
                            outputs = model(
                                input_ids=batch['input_ids'],
                                attention_mask=batch['attention_mask']
                            ).logits

                        loss = nn.CrossEntropyLoss()(outputs, batch['labels'])
                        loss.backward()
                        optimizer.step()

                        total_loss += loss.item()

                    print(f"   Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(train_loader):.4f}")

                train_time = time.time() - start_time

                # Evaluation on validation set
                model.eval()
                val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

                all_predictions = []
                all_labels = []

                with torch.no_grad():
                    for batch in val_loader:
                        batch = {k: v.to(DEVICE) for k, v in batch.items()}

                        if hasattr(model, 'forward') and model_key in ['bert', 'cysecbert', 'secbert']:
                            outputs = model(
                                input_ids=batch['input_ids'],
                                attention_mask=batch['attention_mask'],
                                additional_features=batch.get('additional_features')
                            )
                        else:
                            outputs = model(
                                input_ids=batch['input_ids'],
                                attention_mask=batch['attention_mask']
                            ).logits

                        predictions = torch.argmax(outputs, dim=-1)
                        all_predictions.extend(predictions.cpu().numpy())
                        all_labels.extend(batch['labels'].cpu().numpy())

                # Calculate metrics
                accuracy = accuracy_score(all_labels, all_predictions)
                precision, recall, f1, support = precision_recall_fscore_support(
                    all_labels, all_predictions, average=None, labels=[0, 1]
                )
                precision_avg, recall_avg, f1_avg, _ = precision_recall_fscore_support(
                    all_labels, all_predictions, average='binary'
                )
                cm = confusion_matrix(all_labels, all_predictions)

                # Calculate additional metrics
                tn, fp, fn, tp = cm.ravel()
                fpr = fp / (fp + tn) if (fp + tn) > 0 else 0  # False Positive Rate
                fnr = fn / (fn + tp) if (fn + tp) > 0 else 0  # False Negative Rate

                results[model_key] = {
                    'accuracy': accuracy,
                    'precision': precision_avg,
                    'recall': recall_avg,
                    'f1_score': f1_avg,
                    'confusion_matrix': cm.tolist(),
                    # Class-wise metrics
                    'legitimate_precision': precision[0],
                    'legitimate_recall': recall[0],
                    'legitimate_f1': f1[0],
                    'legitimate_support': int(support[0]),
                    'phishing_precision': precision[1],
                    'phishing_recall': recall[1],
                    'phishing_f1': f1[1],
                    'phishing_support': int(support[1]),
                    # Error rates
                    'false_positive_rate': fpr,
                    'false_negative_rate': fnr,
                    'true_negatives': int(tn),
                    'false_positives': int(fp),
                    'false_negatives': int(fn),
                    'true_positives': int(tp),
                    'train_time': train_time,
                    'val_size': len(all_labels),
                    'train_size': len(y_train)
                }

                print(f"   Accuracy: {accuracy:.3f}")
                print(f"   F1-Score: {f1_avg:.3f}")
                print(f"   Training time: {train_time:.1f}s")

                # Print class-wise summary
                print(f"   Legitimate - Precision: {precision[0]:.3f}, Recall: {recall[0]:.3f}")
                print(f"   Phishing   - Precision: {precision[1]:.3f}, Recall: {recall[1]:.3f}")
                print(f"   FPR: {fpr:.3f}, FNR: {fnr:.3f}")

            except Exception as e:
                print(f"   ⚠️ Error training {config['name']}: {str(e)}")
                print(f"   Skipping this model...")
                continue

            finally:
                # Clear GPU memory
                if 'model' in locals():
                    del model
                torch.cuda.empty_cache()

        return results

    def evaluate_all_models(self):
        """Main evaluation function with matched sampling"""
        print("\n🚀 MULTI-MODEL EVALUATION WITH MATCHED SAMPLING")
        print("=" * 70)

        # Test configurations
        configurations = [
            {'name': 'Without Additional Features', 'use_features': False},
            {'name': 'With Additional Features', 'use_features': True}
        ]

        all_results = {}

        for config in configurations:
            print(f"\n{'='*70}")
            print(f"📊 CONFIGURATION: {config['name']}")
            print(f"{'='*70}")

            # Prepare features
            X_text, X_additional, y = self.prepare_features(config['use_features'])

            # Create train/validation split matching fine-tuning
            X_train_text, X_val_text, y_train, y_val, X_train_add, X_val_add = \
                self.create_train_val_split(X_text, X_additional, y)

            # Train baseline models
            baseline_results, _ = self.train_baseline_models(
                X_train_text, X_val_text, y_train, y_val,
                X_train_add, X_val_add, config['use_features']
            )

            # Train LSTM
            lstm_results = self.train_lstm_model(
                X_train_text, X_val_text, y_train, y_val,
                X_train_add, X_val_add, config['use_features']
            )

            # Train transformers
            transformer_results = self.train_transformer_models(
                X_train_text, X_val_text, y_train, y_val,
                X_train_add, X_val_add, config['use_features']
            )

            # Combine results
            config_results = {
                **baseline_results,
                **lstm_results,
                **transformer_results
            }

            all_results[config['name']] = config_results

        self.results = all_results
        return all_results

    def create_class_wise_f1_comparison(self):
        """Create a focused comparison of class-wise F1 scores"""
        print("\n📊 CLASS-WISE F1 SCORE COMPARISON")
        print("=" * 80)
        print("Critical for understanding model performance on each class:")
        print("-" * 80)

        # Organize data by configuration
        for config_name in self.results.keys():
            print(f"\n{config_name}:")
            print(f"{'Model':<30} {'Legit F1':>10} {'Phish F1':>10} {'Avg F1':>10} {'Difference':>12}")
            print("-" * 75)

            models_data = []
            for model_name, metrics in self.results[config_name].items():
                if model_name in MODELS_CONFIG:
                    legit_f1 = metrics['legitimate_f1']
                    phish_f1 = metrics['phishing_f1']
                    avg_f1 = metrics['f1_score']
                    diff = abs(legit_f1 - phish_f1)

                    models_data.append({
                        'name': MODELS_CONFIG[model_name]['name'],
                        'legit_f1': legit_f1,
                        'phish_f1': phish_f1,
                        'avg_f1': avg_f1,
                        'diff': diff
                    })

            # Sort by average F1 score
            models_data.sort(key=lambda x: x['avg_f1'], reverse=True)

            for model in models_data:
                balance_indicator = "✓" if model['diff'] < 0.05 else "⚠" if model['diff'] < 0.1 else "✗"
                print(f"{model['name']:<30} {model['legit_f1']:>10.3f} {model['phish_f1']:>10.3f} "
                      f"{model['avg_f1']:>10.3f} {model['diff']:>10.3f} {balance_indicator}")

        print("\n Legend: ✓ Well balanced (<0.05 diff), ⚠ Some imbalance (0.05-0.1), ✗ High imbalance (>0.1)")

    def generate_comparison_report(self):
        """Generate comprehensive comparison report with class-wise metrics"""
        print("\n📊 COMPREHENSIVE COMPARISON REPORT WITH CLASS-WISE METRICS")
        print("=" * 70)

        # Create comparison tables for each configuration
        for config_name, models in self.results.items():
            print(f"\n{'='*70}")
            print(f"📊 {config_name}")
            print(f"{'='*70}")

            # Detailed metrics table
            detailed_data = []

            for model_name, metrics in models.items():
                if model_name in MODELS_CONFIG:
                    detailed_data.append({
                        'Model': MODELS_CONFIG[model_name]['name'],
                        'Overall Acc': f"{metrics['accuracy']:.3f}",
                        'Legit Prec': f"{metrics['legitimate_precision']:.3f}",
                        'Legit Rec': f"{metrics['legitimate_recall']:.3f}",
                        'Legit F1': f"{metrics['legitimate_f1']:.3f}",
                        'Phish Prec': f"{metrics['phishing_precision']:.3f}",
                        'Phish Rec': f"{metrics['phishing_recall']:.3f}",
                        'Phish F1': f"{metrics['phishing_f1']:.3f}",
                        'FPR': f"{metrics['false_positive_rate']:.3f}",
                        'FNR': f"{metrics['false_negative_rate']:.3f}"
                    })

            df_detailed = pd.DataFrame(detailed_data)
            print("\n" + df_detailed.to_string(index=False))

            # Confusion Matrix Details (shortened version)
            print(f"\n📊 CONFUSION MATRICES - {config_name}")
            print("-" * 50)

            for model_name, metrics in models.items():
                if model_name in MODELS_CONFIG:
                    print(f"\n{MODELS_CONFIG[model_name]['name']}:")
                    print(f"   TN: {metrics['true_negatives']:4d}  FP: {metrics['false_positives']:4d}  |  FPR: {metrics['false_positive_rate']:.3f}")
                    print(f"   FN: {metrics['false_negatives']:4d}  TP: {metrics['true_positives']:4d}  |  FNR: {metrics['false_negative_rate']:.3f}")

        # Add class-wise F1 comparison
        self.create_class_wise_f1_comparison()

        # Critical Metrics Summary (shortened)
        print("\n🚨 CRITICAL METRICS SUMMARY")
        print("=" * 70)

        for config_name in self.results.keys():
            print(f"\n{config_name}:")
            models = self.results[config_name]

            # Best by combined error rate
            best_balanced = min(models.items(),
                              key=lambda x: x[1]['false_positive_rate'] + x[1]['false_negative_rate'])
            combined_error = best_balanced[1]['false_positive_rate'] + best_balanced[1]['false_negative_rate']
            print(f"   Best Overall: {MODELS_CONFIG[best_balanced[0]]['name']} (FPR+FNR: {combined_error:.3f})")

        # Save results
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        # Save full results JSON
        results_file = f"detailed_evaluation_results_{timestamp}.json"
        with open(results_file, 'w') as f:
            json.dump(self.results, f, indent=2)
        print(f"\n💾 Detailed results saved to: {results_file}")

        # Save summary CSV
        summary_data = []
        for config_name, models in self.results.items():
            for model_name, metrics in models.items():
                if model_name in MODELS_CONFIG:
                    summary_data.append({
                        'Configuration': config_name,
                        'Model': MODELS_CONFIG[model_name]['name'],
                        'Accuracy': f"{metrics['accuracy']:.3f}",
                        'Legit_F1': f"{metrics['legitimate_f1']:.3f}",
                        'Phish_F1': f"{metrics['phishing_f1']:.3f}",
                        'Avg_F1': f"{metrics['f1_score']:.3f}",
                        'FPR': f"{metrics['false_positive_rate']:.3f}",
                        'FNR': f"{metrics['false_negative_rate']:.3f}",
                        'Train_Time_sec': f"{metrics['train_time']:.1f}"
                    })

        df_summary = pd.DataFrame(summary_data)
        csv_file = f"class_wise_evaluation_{timestamp}.csv"
        df_summary.to_csv(csv_file, index=False)
        print(f"📊 Class-wise metrics saved to: {csv_file}")

        # Generate final results table
        self.generate_final_results_table()

        return df_summary

# ==========================================
# MAIN EXECUTION
# ==========================================

def analyze_suspicious_results(results):
    """Analyze results for suspicious patterns"""
    print("\n⚠️  RESULTS VALIDATION")
    print("-" * 50)

    warnings = []

    for config_name, models in results.items():
        for model_name, metrics in models.items():
            model_display_name = MODELS_CONFIG.get(model_name, {}).get('name', model_name)

            # Check for models predicting only one class
            if metrics['phishing_recall'] == 0:
                warnings.append(f"{config_name} - {model_display_name}: Not detecting ANY phishing emails (100% FNR)!")
            elif metrics['legitimate_recall'] == 0:
                warnings.append(f"{config_name} - {model_display_name}: Not detecting ANY legitimate emails (100% FPR)!")

            # Check for extreme imbalances
            if metrics['false_positive_rate'] == 0 and metrics['false_negative_rate'] > 0.5:
                warnings.append(f"{config_name} - {model_display_name}: Zero FPR but very high FNR ({metrics['false_negative_rate']:.1%})")
            elif metrics['false_negative_rate'] == 0 and metrics['false_positive_rate'] > 0.5:
                warnings.append(f"{config_name} - {model_display_name}: Zero FNR but very high FPR ({metrics['false_positive_rate']:.1%})")

            # Check for broken models (accuracy near 50% for binary classification)
            if 0.45 < metrics['accuracy'] < 0.55:
                warnings.append(f"{config_name} - {model_display_name}: Near random performance (accuracy: {metrics['accuracy']:.1%})")

    if warnings:
        print("🚨 Issues detected:")
        for warning in warnings:
            print(f"   - {warning}")

        # Specific advice for common issues
        if any("Not detecting ANY phishing emails" in w for w in warnings):
            print("\n💡 XGBoost Fix: Try adjusting scale_pos_weight or using sample weights")
        if any("Near random performance" in w for w in warnings):
            print("\n💡 BiLSTM Fix: May need more epochs, different architecture, or better hyperparameters")
    else:
        print("✅ No critical issues detected")
        print("   Note: High accuracy (>99%) can be legitimate with good features")

    return warnings

def run_matched_evaluation():
    """Run evaluation with sampling matching fine-tuning"""

    evaluator = MatchedSamplingEvaluator(DATASET_PATH)

    # Load data with matched sampling
    if not evaluator.load_and_prepare_data():
        print("❌ Failed to load dataset")
        return

    # Run evaluation
    evaluator.evaluate_all_models()

    # Generate report
    evaluator.generate_comparison_report()

    # Analyze for suspicious results
    analyze_suspicious_results(evaluator.results)

    print("\n✅ EVALUATION COMPLETED!")
    print(f"   Matched fine-tuning sampling: {SAMPLES_PER_CLASS} per class")
    print(f"   Train/Val split preserved: ~4,730/1,000")

if __name__ == "__main__":
    # Clear GPU cache before starting
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    print("🚀 Starting Multi-Model Evaluation with Matched Sampling")
    print(f"   Matching fine-tuning setup: {SAMPLES_PER_CLASS} samples per class")
    print(f"   Device: {DEVICE}")

    run_matched_evaluation()