# Phishing URL Detection - Complete Training Pipeline

## 🎯 Overview
This notebook trains a Random Forest classifier to detect phishing URLs using 19 URL-based features.

## ⚡ Key Features:
- **No webpage fetching** - works with just the URL
- **19 engineered features** - domain, TLD, path, protocol analysis
- **Random Forest model** - balanced, high accuracy
- **Fixed TrustedBrandOnHTTP logic** - no false positives on legitimate sites

---

## 1. Import Libraries

In [None]:
import pandas as pd
import numpy as np
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import (
    classification_report, 
    roc_auc_score, 
    confusion_matrix,
    roc_curve,
    precision_recall_curve,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score
)
from pathlib import Path
import re
from difflib import SequenceMatcher
from urllib.parse import urlparse
import warnings
warnings.filterwarnings('ignore')

# Try to import tldextract
try:
    import tldextract
    HAS_TLDEXTRACT = True
    print("✅ tldextract library found")
except ImportError:
    HAS_TLDEXTRACT = False
    print("⚠️  tldextract not found - using fallback TLD extraction")

# Set visualization style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("✅ All libraries imported successfully")

## 2. Define Feature Extraction Functions

In [None]:
# Feature extraction constants
COMMON_LEGIT_TLDS = {
    "com": 0.9, "org": 0.85, "net": 0.8, "edu": 0.95, "gov": 0.97,
    "co": 0.7, "uk": 0.8, "de": 0.75, "fr": 0.75, "ca": 0.8
}

SUSPICIOUS_TLDS = {"xyz", "tk", "ml", "ga", "cf", "ru", "cn", "top", "gq", "pw"}

BRANDS = [
    "paypal", "google", "facebook", "amazon", "instagram", "bank", 
    "sbi", "hdfc", "icici", "apple", "microsoft", "netflix", "ebay",
    "myntra", "flipkart", "wikipedia", "github", "linkedin", "twitter"
]

PHISHING_KEYWORDS = ["login", "secure", "account", "verify", "bank", "update", "confirm"]


def simple_tld_extract(url):
    """Fallback TLD extraction if tldextract is not available"""
    parsed = urlparse(url)
    domain = parsed.netloc or parsed.path.split('/')[0]
    domain = domain.split(':')[0]
    parts = domain.split('.')
    
    if len(parts) >= 2:
        return {
            'domain': parts[-2],
            'suffix': parts[-1],
            'subdomain': '.'.join(parts[:-2]) if len(parts) > 2 else ''
        }
    return {'domain': domain, 'suffix': '', 'subdomain': ''}


def _normalize_url(url: str) -> str:
    """Normalize URL so trailing slash doesn't change features."""
    url = url.strip()
    parsed = urlparse(url)
    if parsed.path == "/" and not parsed.query and not parsed.fragment:
        url = url.rstrip("/")
    return url


def extract_features(url: str) -> dict:
    """
    Extract URL-based features for phishing detection.
    Returns 19 features that can be extracted without fetching the webpage.
    """
    if not url or not isinstance(url, str):
        url = str(url) if url else ""
    
    url = _normalize_url(url)
    
    # Parse URL
    if HAS_TLDEXTRACT:
        ext = tldextract.extract(url)
        domain = ext.domain
        tld = ext.suffix.lower()
        subdomain = ext.subdomain
    else:
        ext_dict = simple_tld_extract(url)
        domain = ext_dict['domain']
        tld = ext_dict['suffix'].lower()
        subdomain = ext_dict['subdomain']
    
    parsed = urlparse(url)
    
    # Character repetition
    repeated_chars = sum(1 for i in range(1, len(url)) if url[i] == url[i - 1])
    
    # Similarity to phishing keywords
    max_similarity = max(
        SequenceMatcher(None, url.lower(), kw).ratio()
        for kw in PHISHING_KEYWORDS
    )
    url_similarity_index = max_similarity * 100
    
    # Protocol check
    has_https = 1 if parsed.scheme == "https" else 0
    
    # Subdomain analysis
    subdomain_count = len(subdomain.split('.')) if subdomain else 0
    
    # Character analysis
    digit_count = sum(c.isdigit() for c in url)
    special_char_count = sum(c in "@?=-_&" for c in url)
    
    # IP address check
    is_ip = bool(re.match(r"^\d{1,3}(\.\d{1,3}){3}$", domain))
    
    # Path analysis
    path_length = len(parsed.path) if parsed.path else 0
    path_depth = parsed.path.count('/') if parsed.path else 0
    
    # Suspicious patterns
    has_at_symbol = 1 if '@' in url else 0
    double_slash_redirecting = 1 if url.count('//') > 1 else 0

    # Known brand check
    has_brand = int(any(b in domain.lower() for b in BRANDS))
    
    # FIXED: TrustedBrandOnHTTP - only flag suspicious subdomain usage
    has_suspicious_subdomain = subdomain and any(b in subdomain.lower() for b in BRANDS)
    trusted_brand_on_http = 1 if (has_suspicious_subdomain and has_https == 0) else 0

    return {
        "URLLength": len(url),
        "DomainLength": len(domain),
        "IsDomainIP": int(is_ip),
        "URLSimilarityIndex": url_similarity_index,
        "CharContinuationRate": repeated_chars / max(len(url), 1),
        "TLDLegitimateProb": COMMON_LEGIT_TLDS.get(tld, 0.05),
        "HasBrandName": has_brand,
        "HyphenCount": domain.count("-"),
        "SuspiciousTLD": int(tld in SUSPICIOUS_TLDS),
        "HasHTTPS": has_https,
        "TrustedBrandOnHTTP": trusted_brand_on_http,
        "SubdomainLevel": subdomain_count,
        "PathLength": path_length,
        "PathDepth": path_depth,
        "DigitCount": digit_count,
        "SpecialCharCount": special_char_count,
        "HasAtSymbol": has_at_symbol,
        "DoubleSlashRedirecting": double_slash_redirecting,
        "PrefixSuffix": 1 if '-' in domain else 0,
    }

print("✅ Feature extraction functions defined")
print(f"   Total features: 19")

## 3. Load Dataset

Make sure you have the dataset at: `../data/raw/PhiUSIIL_Phishing_URL_Dataset.csv`

In [None]:
# Load dataset
DATA_PATH = Path("../data/raw/PhiUSIIL_Phishing_URL_Dataset.csv")

if not DATA_PATH.exists():
    print(f"❌ Dataset not found at: {DATA_PATH.absolute()}")
    print("   Please place the dataset file there and try again.")
else:
    df = pd.read_csv(DATA_PATH)
    print(f"✅ Dataset loaded from: {DATA_PATH}")
    print(f"\nDataset shape: {df.shape}")
    print(f"Columns: {list(df.columns[:10])}...")
    print(f"\nClass distribution:")
    print(df['label'].value_counts())

## 4. Extract Features from URLs

In [None]:
# Extract features from all URLs
print("Extracting features from URLs...")
feature_list = []
for url in df['URL']:
    features = extract_features(url)
    feature_list.append(features)

# Create feature DataFrame
X = pd.DataFrame(feature_list)
y = df['label']

print(f"\n✅ Features extracted: {X.shape}")
print(f"   Features: {list(X.columns)}")
print(f"\nFirst few rows:")
print(X.head())

## 5. Train-Test Split

In [None]:
# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"Training set: {X_train.shape}")
print(f"Test set: {X_test.shape}")

## 6. Train Random Forest Model

In [None]:
# Train Random Forest
print("Training Random Forest model...")

rf_model = RandomForestClassifier(
    n_estimators=100,
    max_depth=10,
    random_state=42,
    class_weight='balanced',
    n_jobs=-1
)

rf_model.fit(X_train, y_train)

print("✅ Model trained successfully!")

## 7. Evaluate Model

In [None]:
# Make predictions
y_pred = rf_model.predict(X_test)
y_pred_proba = rf_model.predict_proba(X_test)[:, 1]

# Calculate metrics
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
roc_auc = roc_auc_score(y_test, y_pred_proba)

print("=" * 70)
print("MODEL PERFORMANCE")
print("=" * 70)
print(f"Accuracy:  {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1-Score:  {f1:.4f}")
print(f"ROC-AUC:   {roc_auc:.4f}")
print("=" * 70)

# Classification report
print("\nDetailed Classification Report:")
print(classification_report(y_test, y_pred, target_names=['Legitimate', 'Phishing']))

## 8. Feature Importance

In [None]:
# Feature importance
feature_importance = pd.DataFrame({
    'feature': X.columns,
    'importance': rf_model.feature_importances_
}).sort_values('importance', ascending=False)

print("\nTop 10 Most Important Features:")
print(feature_importance.head(10))

# Plot
plt.figure(figsize=(10, 6))
plt.barh(feature_importance['feature'][:10], feature_importance['importance'][:10])
plt.xlabel('Importance')
plt.title('Top 10 Feature Importance')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

## 9. Save Model

In [None]:
# Create models directory in project root
MODEL_PATH = Path("../models/url_rf_model.pkl")
MODEL_PATH.parent.mkdir(exist_ok=True, parents=True)

# Create model bundle with metadata
model_bundle = {
    "model": rf_model,
    "features": list(X.columns),
    "accuracy": accuracy,
    "f1_score": f1,
    "roc_auc": roc_auc,
}

# Save model
joblib.dump(model_bundle, MODEL_PATH)

print("=" * 70)
print("MODEL SAVED")
print("=" * 70)
print(f"Location: {MODEL_PATH.absolute()}")
print(f"File size: {MODEL_PATH.stat().st_size / 1024 / 1024:.2f} MB")
print("=" * 70)

## 10. Test on Sample URLs

In [None]:
# Test URLs
test_urls = [
    ("https://www.google.com", "LEGITIMATE"),
    ("https://www.github.com", "LEGITIMATE"),
    ("http://www.amazon.com", "LEGITIMATE"),
    ("http://secure-paypal-login.xyz", "PHISHING"),
    ("https://192.168.1.1/login", "PHISHING"),
    ("http://paypal-secure.verify-account.ml", "PHISHING"),
]

print("=" * 70)
print("TESTING ON SAMPLE URLs")
print("=" * 70)

results = []
for url, expected in test_urls:
    features = extract_features(url)
    X_url = pd.DataFrame([features])[list(X.columns)]
    
    pred_proba = rf_model.predict_proba(X_url)[0][1]
    pred = "PHISHING" if pred_proba >= 0.5 else "LEGITIMATE"
    
    status = "✅" if pred == expected else "❌"
    results.append((status, url, expected, pred, pred_proba))
    
    print(f"\n{status} URL: {url}")
    print(f"   Expected: {expected}, Predicted: {pred} (Risk: {pred_proba:.4f})")

correct = sum(1 for r in results if r[0] == "✅")
print(f"\n{'=' * 70}")
print(f"Test Results: {correct}/{len(results)} correct ({correct/len(results)*100:.1f}%)")
print(f"{'=' * 70}")

## 11. Use the Model for Predictions

In [None]:
# Load model and predict
import joblib
from pathlib import Path
import pandas as pd
from features import extract_features

# Load the model
MODEL_PATH = Path("../models/url_rf_model.pkl")
model_bundle = joblib.load(MODEL_PATH)

# Test it
url = "https://www.google.com"
features = extract_features(url)
X = pd.DataFrame([features])[model_bundle["features"]]
risk_score = model_bundle["model"].predict_proba(X)[0][1]
prediction = "PHISHING" if risk_score >= 0.5 else "LEGITIMATE"

print(f"URL: {url}")
print(f"Prediction: {prediction}")
print(f"Risk Score: {risk_score:.2%}")

print("\n" + "=" * 70)
print("For interactive prediction, run from terminal:")
print("  cd src && python predict_url.py -i")
print("=" * 70)