<a href="https://colab.research.google.com/github/rezahamzeh69/unsw-nb15-ids-bilstm-transformer/blob/main/robust_transformer_bilstm_ids.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import math
import os
import time
import io
import shutil
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import f1_score, classification_report, confusion_matrix
from contextlib import contextmanager
import kagglehub

# =========================
# Hyperparameters
# =========================
LSTM_HIDDEN_DIM = 128
LSTM_LAYERS = 2
TRANSFORMER_DIM = 128
NHEAD = 8
NUM_TRANSFORMER_LAYERS = 2
NUM_CLASSES = 2
DROPOUT = 0.3
SEQUENCE_LENGTH = 15
BATCH_SIZE = 256
LEARNING_RATE = 5e-4
NUM_EPOCHS = 30
VALIDATION_SPLIT_FROM_TRAIN = 0.15
RANDOM_SEED = 42
EARLY_STOPPING_PATIENCE = 5
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Adversarial Training
ADV_TRAINING = True          # True/False
ADV_METHOD = 'pgd'           # 'fgsm' or 'pgd'
ADV_EPS = 0.03
ADV_ALPHA = 0.007
ADV_STEPS = 7
ADV_LAMBDA = 0.5

# Interpretability
IG_STEPS = 32                # تعداد گام‌های مسیر IG
IG_BASELINE = 'zeros'        # 'zeros' or 'random'
SAVE_INTERPRETABILITY = True # ذخیره خروجی‌ها روی دیسک

# Checkpoint/Resume
AUTO_RESUME = True           # اگر True باشد، در شروع، به‌صورت خودکار از آخرین چک‌پوینت ادامه می‌دهد
PROJECT_NAME = "unsw_bilstm_trf_adv_ig"
def _default_ckpt_dir():
    # اگر Google Drive mount شده باشد، درایو را ترجیح بده
    if os.path.exists("/content/drive/MyDrive"):
        return f"/content/drive/MyDrive/{PROJECT_NAME}_checkpoints"
    # در Kaggle/Colab بدون درایو:
    if os.path.exists("/kaggle/working"):
        return f"/kaggle/working/{PROJECT_NAME}_checkpoints"
    return f"./{PROJECT_NAME}_checkpoints"

CKPT_DIR = _default_ckpt_dir()
CKPT_LATEST = os.path.join(CKPT_DIR, "ckpt_latest.pt")
CKPT_BEST = os.path.join(CKPT_DIR, "ckpt_best.pt")

TRAIN_FILE_NAME = 'UNSW_NB15_training-set.parquet'
TEST_FILE_NAME = 'UNSW_NB15_testing-set.parquet'

# Globals for interpretability/masking
FEATURE_NAMES = None              # لیست ستون‌های ویژگی بعد از وان‌هات
PERTURB_MASK = None              # ماسک ویژگی‌های قابل اغتشاش (1=عددی، 0=وان‌هات دسته‌ای)

# =========================
# Reproducibility
# =========================
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# =========================
# Utils: Checkpointing
# =========================
def ensure_dir(path):
    os.makedirs(path, exist_ok=True)

def save_checkpoint(state, is_best=False):
    ensure_dir(CKPT_DIR)
    torch.save(state, CKPT_LATEST)
    if is_best:
        shutil.copyfile(CKPT_LATEST, CKPT_BEST)

def try_load_checkpoint(model, optimizer):
    """
    اگر چک‌پوینت موجود باشد، مدل/اپتیمایزر/ایپاک/هیستوری/early stopping را برمی‌گرداند.
    """
    if AUTO_RESUME and os.path.exists(CKPT_LATEST):
        ckpt = torch.load(CKPT_LATEST, map_location=DEVICE)
        model.load_state_dict(ckpt['model_state'])
        optimizer.load_state_dict(ckpt['optimizer_state'])
        start_epoch = ckpt.get('epoch', 0) + 1
        best_val_loss = ckpt.get('best_val_loss', float('inf'))
        epochs_no_improve = ckpt.get('epochs_no_improve', 0)
        history = ckpt.get('history', {'train_loss': [], 'val_loss': [], 'val_accuracy': [], 'val_f1_score': []})
        print(f"[Resume] Loaded checkpoint from '{CKPT_LATEST}' at epoch {start_epoch}")
        return start_epoch, best_val_loss, epochs_no_improve, history
    return 0, float('inf'), 0, {'train_loss': [], 'val_loss': [], 'val_accuracy': [], 'val_f1_score': []}

# =========================
# Data Loading & Preprocess
# =========================
def load_and_preprocess_unsw(dataset_dir):
    global FEATURE_NAMES, PERTURB_MASK
    train_file_path = os.path.join(dataset_dir, TRAIN_FILE_NAME)
    test_file_path = os.path.join(dataset_dir, TEST_FILE_NAME)
    print(f"Attempting to load training data from: {train_file_path}")
    print(f"Attempting to load testing data from: {test_file_path}")

    if not os.path.exists(train_file_path) or not os.path.exists(test_file_path):
        print("Error: Training or Testing file not found in the downloaded dataset directory.")
        print(f"Contents of {dataset_dir}: {os.listdir(dataset_dir)}")
        return None, None, None, None, -1

    try:
        df_train = pd.read_parquet(train_file_path)
        df_test = pd.read_parquet(test_file_path)
        print("Datasets loaded successfully.")
    except Exception as e:
        print(f"Error loading Parquet files: {e}")
        return None, None, None, None, -1

    print("Preprocessing datasets...")
    df = pd.concat([df_train, df_test], ignore_index=True)
    df = df.drop(['id', 'attack_cat'], axis=1, errors='ignore')
    df.columns = df.columns.str.strip().str.lower().str.replace(' ', '_').str.replace(r'[^a-z0-9_]', '', regex=True)

    categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
    numerical_cols = df.select_dtypes(include=np.number).columns.tolist()

    for col in numerical_cols:
        if df[col].isnull().any():
            df[col] = df[col].fillna(df[col].median() if df[col].nunique() > 10 else 0)
    for col in categorical_cols:
        if df[col].isnull().any():
            if pd.api.types.is_categorical_dtype(df[col]):
                df[col] = df[col].astype('object')
            df[col] = df[col].fillna(df[col].mode()[0])

    if 'label' in numerical_cols:
        numerical_cols.remove('label')
    elif 'label' in categorical_cols:
        categorical_cols.remove('label')
        try:
            df['label'] = pd.to_numeric(df['label'])
        except ValueError:
            print("Error: Could not convert 'label' column to numeric.")
            return None, None, None, None, -1

    if 'label' not in df.columns or not pd.api.types.is_numeric_dtype(df['label']):
        print("Error: 'label' column is missing or not numeric after cleaning.")
        return None, None, None, None, -1

    # One-hot
    df_encoded = pd.get_dummies(df, columns=categorical_cols, dummy_na=False)

    train_len = len(df_train)
    df_train_processed = df_encoded.iloc[:train_len].copy()
    df_test_processed = df_encoded.iloc[train_len:].copy()

    # Scaling only numeric parts (after encoding)
    scaler = MinMaxScaler()
    numerical_cols_encoded = df_train_processed.drop('label', axis=1).select_dtypes(include=np.number).columns.tolist()
    if numerical_cols_encoded:
        for col in numerical_cols_encoded:
            df_train_processed[col] = df_train_processed[col].astype(np.float32)
            df_test_processed[col] = df_test_processed[col].astype(np.float32)
        scaler.fit(df_train_processed[numerical_cols_encoded])
        df_train_processed.loc[:, numerical_cols_encoded] = scaler.transform(df_train_processed[numerical_cols_encoded])
        df_test_processed.loc[:, numerical_cols_encoded] = scaler.transform(df_test_processed[numerical_cols_encoded])

    feature_cols = df_train_processed.columns.drop('label')
    FEATURE_NAMES = list(feature_cols)

    # ساخت ماسک اغتشاش: فقط ستون‌های عددی (غیر وان‌هات) اجازه اغتشاش دارند
    # منطق: اگر نام ستون با "<cat>_" شروع شود، آن ستون وان‌هاتِ دسته‌ای است.
    numeric_mask = []
    cat_prefixes = set(categorical_cols)
    for name in FEATURE_NAMES:
        is_categorical_onehot = any(name.startswith(f"{c}_") for c in cat_prefixes)
        numeric_mask.append(0.0 if is_categorical_onehot else 1.0)
    PERTURB_MASK = torch.tensor(numeric_mask, dtype=torch.float32)  # [D]

    try:
        X_train = df_train_processed[feature_cols].astype(np.float32).values
        X_test = df_test_processed[feature_cols].astype(np.float32).values
    except Exception as e:
        print(f"Error converting feature columns to float32 before .values: {e}")
        for col in feature_cols:
            try:
                df_train_processed[col].astype(np.float32)
            except Exception as col_e:
                print(f"  - Column '{col}' failed conversion: {col_e}, dtype: {df_train_processed[col].dtype}")
        return None, None, None, None, -1

    y_train = df_train_processed['label'].astype(np.int64).values
    y_test = df_test_processed['label'].astype(np.int64).values

    input_dim = X_train.shape[1]
    print(f"Preprocessing finished. Input dimension: {input_dim}")
    return X_train, y_train, X_test, y_test, input_dim

def create_sequences(features, labels, sequence_length):
    sequences = []
    sequence_labels = []
    if len(features) < sequence_length:
        return np.array([]), np.array([])
    for i in range(len(features) - sequence_length + 1):
        sequences.append(features[i:i+sequence_length])
        sequence_labels.append(labels[i+sequence_length-1])
    return np.array(sequences, dtype=np.float32), np.array(sequence_labels, dtype=np.int64)

# =========================
# Positional Encoding
# =========================
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model, dtype=torch.float32)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# =========================
# Transformer with Attention Output
# =========================
class TransformerEncoderLayerWithAttn(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = nn.GELU()

    def forward(self, src, need_weights=False):
        attn_out, attn_weights = self.self_attn(src, src, src, need_weights=need_weights, average_attn_weights=False)
        src = self.norm1(src + self.dropout1(attn_out))
        ff = self.linear2(self.dropout(self.activation(self.linear1(src))))
        out = self.norm2(src + self.dropout2(ff))
        if need_weights:
            return out, attn_weights  # [B, H, S, S]
        return out, None

class TransformerEncoderWithAttn(nn.Module):
    def __init__(self, layer, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([layer if i == 0 else TransformerEncoderLayerWithAttn(
            d_model=layer.self_attn.embed_dim,
            nhead=layer.self_attn.num_heads,
            dim_feedforward=layer.linear1.out_features,
            dropout=layer.dropout.p
        ) for i in range(num_layers)])
    def forward(self, src, need_weights=False):
        attentions = []
        output = src
        for layer in self.layers:
            output, attn = layer(output, need_weights=need_weights)
            if need_weights:
                attentions.append(attn)
        return output, attentions

# =========================
# Main Model: Transformer -> BiLSTM -> FC
# =========================
class TransformerBiLSTMModel(nn.Module):
    def __init__(self, input_dim, lstm_hidden_dim, lstm_layers, transformer_dim, nhead, num_transformer_layers, num_classes, dropout):
        super().__init__()
        self.transformer_dim = transformer_dim
        self.input_proj = nn.Linear(input_dim, transformer_dim)
        self.pos_encoder = SinusoidalPositionalEncoding(transformer_dim)

        enc_layer = TransformerEncoderLayerWithAttn(
            d_model=transformer_dim,
            nhead=nhead,
            dim_feedforward=transformer_dim * 4,
            dropout=dropout
        )
        self.transformer_encoder = TransformerEncoderWithAttn(enc_layer, num_layers=num_transformer_layers)

        self.bilstm = nn.LSTM(transformer_dim, lstm_hidden_dim,
                              num_layers=lstm_layers, batch_first=True,
                              dropout=dropout if lstm_layers > 1 else 0,
                              bidirectional=True)
        self.dropout_layer = nn.Dropout(dropout)
        self.classifier = nn.Linear(lstm_hidden_dim * 2, num_classes)

    def forward(self, x, return_attn=False):
        x = self.input_proj(x) * math.sqrt(self.transformer_dim)
        x = self.pos_encoder(x)
        x, attn_list = self.transformer_encoder(x, need_weights=return_attn)
        _, (h_n, _) = self.bilstm(x)
        h_forward_last = h_n[-2]
        h_backward_last = h_n[-1]
        h_cat = torch.cat([h_forward_last, h_backward_last], dim=1)
        h_cat = self.dropout_layer(h_cat)
        logits = self.classifier(h_cat)
        if return_attn:
            last_attn = attn_list[-1] if len(attn_list) > 0 else None
            return logits, last_attn
        return logits

# =========================
# Evaluation
# =========================
def evaluate_model(model, data_loader, criterion, device, is_test_set=False):
    model.eval()
    total_loss = 0.0
    correct_predictions = 0
    all_labels = []
    all_predictions = []
    if not data_loader or len(data_loader.dataset) == 0:
        if not is_test_set:
            return np.nan, np.nan, np.nan
        print("Warning: Cannot evaluate on empty or invalid dataloader.")
        return 0.0, 0.0, 0.0
    with torch.no_grad():
        for batch_data, batch_labels in data_loader:
            batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)
            outputs = model(batch_data)
            loss = criterion(outputs, batch_labels)
            total_loss += loss.item() * batch_data.size(0)
            _, predicted = torch.max(outputs.data, 1)
            correct_predictions += (predicted == batch_labels).sum().item()
            all_labels.extend(batch_labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())
    avg_loss = total_loss / len(data_loader.dataset)
    accuracy = correct_predictions / len(data_loader.dataset)
    f1 = 0.0
    if len(all_labels) > 0 and len(all_predictions) > 0:
        f1 = f1_score(all_labels, all_predictions, average='binary', zero_division=0)
    if is_test_set:
        print("\n--- Final Test Set Evaluation ---")
        print(f"Test Loss: {avg_loss:.4f}")
        print(f"Test Accuracy: {accuracy:.4f}")
        print(f"Test F1 Score (Binary): {f1:.4f}")
        if len(all_labels) > 0 and len(all_predictions) > 0:
            print("\nConfusion Matrix:")
            print(confusion_matrix(all_labels, all_predictions))
            print("\nClassification Report:")
            print(classification_report(all_labels, all_predictions, target_names=['Normal (0)', 'Attack (1)'], zero_division=0))
        else:
            print("\nNo labels/predictions available for Confusion Matrix/Classification Report.")
        print("--------------------------------")
    return avg_loss, accuracy, f1

# =========================
# Adversarial Attacks (FGSM/PGD) — keep TRAIN mode, disable Dropout; mask categorical one-hot dims
# =========================
@contextmanager
def train_no_dropout(model: nn.Module):
    """Keep model in train mode for cuDNN LSTM backward, but disable ALL Dropout modules and set LSTM.dropout=0."""
    was_training = model.training
    dropout_states = []
    lstm_dropouts = []
    for m in model.modules():
        if isinstance(m, nn.Dropout):
            dropout_states.append((m, m.training))
        if isinstance(m, nn.LSTM):
            lstm_dropouts.append((m, m.dropout))
    try:
        model.train()
        for m, _ in dropout_states:
            m.train(False)
        for lstm, d in lstm_dropouts:
            lstm.dropout = 0.0
        yield
    finally:
        if not was_training:
            model.eval()
        for m, was_tr in dropout_states:
            m.train(was_tr)
        for lstm, d in lstm_dropouts:
            lstm.dropout = d

def _mask_grad(grad):
    if PERTURB_MASK is None:  # fallback
        return grad
    mask = PERTURB_MASK.to(grad.device).view(1, 1, -1)  # [1,1,D]
    return grad * mask

def _mask_step(step):
    if PERTURB_MASK is None:
        return step
    mask = PERTURB_MASK.to(step.device).view(1, 1, -1)
    return step * mask

def fgsm_attack(model, x, y, loss_fn, eps, clamp=(0.0, 1.0)):
    with train_no_dropout(model):
        x_adv = x.detach().clone().requires_grad_(True)
        model.zero_grad(set_to_none=True)
        logits = model(x_adv)
        loss = loss_fn(logits, y)
        grad = torch.autograd.grad(loss, x_adv, retain_graph=False, create_graph=False)[0]
        grad = _mask_grad(grad)
        x_adv = x_adv + eps * grad.sign()
        x_adv = torch.clamp(x_adv, clamp[0], clamp[1]).detach()
    return x_adv

def pgd_attack(model, x, y, loss_fn, eps, alpha, steps, clamp=(0.0, 1.0)):
    with train_no_dropout(model):
        x_orig = x.detach()
        x_adv = x_orig + torch.empty_like(x_orig).uniform_(-eps, eps)
        x_adv = torch.clamp(x_adv, clamp[0], clamp[1]).detach()
        for _ in range(steps):
            x_adv = x_adv.clone().detach().requires_grad_(True)
            model.zero_grad(set_to_none=True)
            logits = model(x_adv)
            loss = loss_fn(logits, y)
            grad = torch.autograd.grad(loss, x_adv, retain_graph=False, create_graph=False)[0]
            grad = _mask_grad(grad)
            x_adv = x_adv + _mask_step(alpha * grad.sign())
            eta = torch.clamp(x_adv - x_orig, min=-eps, max=eps)
            x_adv = torch.clamp(x_orig + eta, clamp[0], clamp[1]).detach()
    return x_adv

# =========================
# Training (with optional adversarial) + Checkpointing/Resume
# =========================
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, patience):
    # سعی در ازسرگیری
    start_epoch, best_val_loss, epochs_no_improve, history = try_load_checkpoint(model, optimizer)

    print(f"\n--- Starting Training on {device} ---")
    total_start_time = time.time()

    for epoch in range(start_epoch, num_epochs):
        epoch_start_time = time.time()
        model.train()
        running_loss = 0.0

        for batch_data, batch_labels in train_loader:
            batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)
            optimizer.zero_grad(set_to_none=True)

            if ADV_TRAINING:
                if ADV_METHOD.lower() == 'fgsm':
                    x_adv = fgsm_attack(model, batch_data, batch_labels, criterion, eps=ADV_EPS, clamp=(0.0, 1.0))
                else:
                    x_adv = pgd_attack(model, batch_data, batch_labels, criterion, eps=ADV_EPS, alpha=ADV_ALPHA, steps=ADV_STEPS, clamp=(0.0, 1.0))

                logits_clean = model(batch_data)
                logits_adv = model(x_adv)
                loss = (1.0 - ADV_LAMBDA) * criterion(logits_clean, batch_labels) + ADV_LAMBDA * criterion(logits_adv, batch_labels)
            else:
                logits_clean = model(batch_data)
                loss = criterion(logits_clean, batch_labels)

            loss.backward()
            optimizer.step()
            running_loss += loss.item() * batch_data.size(0)

        train_loss = running_loss / len(train_loader.dataset)

        # Validation
        if val_loader:
            val_loss, val_accuracy, val_f1_score = evaluate_model(model, val_loader, criterion, device)
            history['train_loss'].append(train_loss)
            history['val_loss'].append(val_loss)
            history['val_accuracy'].append(val_accuracy)
            history['val_f1_score'].append(val_f1_score)
            epoch_duration = time.time() - epoch_start_time
            print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.4f} | Val F1: {val_f1_score:.4f} | Time: {epoch_duration:.2f}s")

            improved = val_loss < best_val_loss
            if improved:
                best_val_loss = val_loss
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            # ذخیره چک‌پوینت (latest + best)
            state = {
                'epoch': epoch,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'epochs_no_improve': epochs_no_improve,
                'history': history,
                'hparams': {
                    'LEARNING_RATE': LEARNING_RATE,
                    'BATCH_SIZE': BATCH_SIZE,
                    'ADV_TRAINING': ADV_TRAINING,
                    'ADV_METHOD': ADV_METHOD,
                    'ADV_EPS': ADV_EPS,
                    'ADV_ALPHA': ADV_ALPHA,
                    'ADV_STEPS': ADV_STEPS,
                    'ADV_LAMBDA': ADV_LAMBDA,
                }
            }
            save_checkpoint(state, is_best=improved)

            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after epoch {epoch+1}!")
                break

        else:
            # No validation
            history['train_loss'].append(train_loss)
            history['val_loss'].append(np.nan)
            history['val_accuracy'].append(np.nan)
            history['val_f1_score'].append(np.nan)
            epoch_duration = time.time() - epoch_start_time
            print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Time: {epoch_duration:.2f}s")

            # ذخیره چک‌پوینت حداقلی
            state = {
                'epoch': epoch,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'epochs_no_improve': epochs_no_improve,
                'history': history
            }
            save_checkpoint(state, is_best=False)

    total_training_time = time.time() - total_start_time
    print(f"--- Training Finished --- Total time: {total_training_time // 60:.0f}m {total_training_time % 60:.0f}s")

    # Load best (if exists)
    if val_loader and os.path.exists(CKPT_BEST):
        try:
            ckpt = torch.load(CKPT_BEST, map_location=device)
            model.load_state_dict(ckpt['model_state'])
            print(f"Loaded best model weights from '{CKPT_BEST}' for final evaluation.")
        except Exception as e:
            print(f"Warning: Could not load best model weights ({e}). Using last epoch model.")

    return history

# =========================
# Integrated Gradients (IG)
# =========================
def integrated_gradients(model, x, y, steps=50, baseline='zeros', clamp=(0.0, 1.0)):
    """
    x: [B, T, D], y: [B]
    returns attributions with same shape as x
    """
    x = x.detach()
    if baseline == 'zeros':
        x0 = torch.zeros_like(x)
    elif baseline == 'random':
        x0 = torch.rand_like(x)  # already in [0,1]
    else:
        x0 = torch.zeros_like(x)

    scaled_inputs = [(x0 + (float(i) / steps) * (x - x0)).detach().requires_grad_(True) for i in range(1, steps + 1)]
    total_grad = torch.zeros_like(x)

    with train_no_dropout(model):
        for xin in scaled_inputs:
            model.zero_grad(set_to_none=True)
            logits = model(xin)
            target_logits = logits.gather(1, y.view(-1, 1)).sum()
            grads = torch.autograd.grad(target_logits, xin, retain_graph=False, create_graph=False)[0]
            total_grad += grads.detach()

    avg_grad = total_grad / steps
    attributions = (x - x0) * avg_grad
    return attributions

def summarize_attributions(attr: torch.Tensor):
    """
    attr: [B, T, D]
    Returns:
      time_importance: [T]
      feature_importance: [D]
    """
    abs_attr = attr.abs()
    time_importance = abs_attr.sum(dim=2).mean(dim=0)      # over features then mean over batch
    feature_importance = abs_attr.sum(dim=1).mean(dim=0)   # over time then mean over batch
    return time_importance, feature_importance

# =========================
# Robust Evaluation (optional) — FIXED
# =========================
def robust_eval(model, loader, criterion, eps=0.03, steps=7, alpha=0.007, method='pgd'):
    """
    Clean & adversarial accuracy. Uses no_grad() for inference and enable_grad() for attack generation.
    """
    model.eval()
    tot, cor_clean, cor_adv = 0, 0, 0
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)

        # Clean accuracy (بدون گرادیان)
        with torch.no_grad():
            pred = model(xb).argmax(1)
            cor_clean += (pred == yb).sum().item()

        # ساخت نمونه خصمانه (با گرادیان)
        with torch.enable_grad():
            if method == 'fgsm':
                xa = fgsm_attack(model, xb, yb, criterion, eps, (0.0, 1.0))
            else:
                xa = pgd_attack(model, xb, yb, criterion, eps, alpha, steps, (0.0, 1.0))

        # دقت روی نمونه‌ی خصمانه (بدون گرادیان)
        with torch.no_grad():
            pred_a = model(xa).argmax(1)
            cor_adv += (pred_a == yb).sum().item()

        tot += xb.size(0)
    print(f"[RobustEval] Clean Acc: {cor_clean/tot:.4f} | Adv Acc ({method}, eps={eps}): {cor_adv/tot:.4f}")

# =========================
# Pipeline
# =========================
print("Downloading UNSW-NB15 dataset using kagglehub...")
try:
    dataset_path = kagglehub.dataset_download("dhoogla/unswnb15")
    print(f"Dataset downloaded to: {dataset_path}")
except Exception as e:
    print(f"Error downloading dataset: {e}")
    print("Please ensure Kaggle API credentials are set up correctly in your environment.")
    dataset_path = None

if dataset_path and os.path.isdir(dataset_path):
    X_train_raw, y_train_raw, X_test_raw, y_test_raw, input_dim = load_and_preprocess_unsw(dataset_path)

    if X_train_raw is not None and input_dim != -1:
        INPUT_DIM = input_dim
        print("Creating sequences...")
        X_train_seq, y_train_seq = create_sequences(X_train_raw, y_train_raw, SEQUENCE_LENGTH)
        X_test_seq, y_test_seq = create_sequences(X_test_raw, y_test_raw, SEQUENCE_LENGTH)

        if len(X_train_seq) == 0 or len(X_test_seq) == 0:
            print("Error: Not enough data to create sequences with the specified length.")
        else:
            print(f"Raw training sequences: {len(X_train_seq)}")
            print(f"Raw testing sequences: {len(X_test_seq)}")

            # Train/Val split
            if len(X_train_seq) > 1 and VALIDATION_SPLIT_FROM_TRAIN > 0:
                try:
                    X_train_final_seq, X_val_seq, y_train_final_seq, y_val_seq = train_test_split(
                        X_train_seq, y_train_seq,
                        test_size=VALIDATION_SPLIT_FROM_TRAIN,
                        random_state=RANDOM_SEED, stratify=y_train_seq
                    )
                    print(f"Splitting {VALIDATION_SPLIT_FROM_TRAIN*100:.1f}% of training sequences for validation.")
                except ValueError as e:
                    print(f"Warning: Could not stratify split ({e}). Performing non-stratified split.")
                    X_train_final_seq, X_val_seq, y_train_final_seq, y_val_seq = train_test_split(
                        X_train_seq, y_train_seq,
                        test_size=VALIDATION_SPLIT_FROM_TRAIN,
                        random_state=RANDOM_SEED
                    )
                print(f"Final Train sequences: {len(X_train_final_seq)}")
                print(f"Validation sequences: {len(X_val_seq)}")
                print(f"Final Test sequences: {len(X_test_seq)}")
            else:
                print("Using all training sequences for training, no validation split performed.")
                X_train_final_seq, y_train_final_seq = X_train_seq, y_train_seq
                X_val_seq, y_val_seq = np.array([]), np.array([])
                print(f"Final Train sequences: {len(X_train_final_seq)}")
                print("Validation sequences: 0")
                print(f"Final Test sequences: {len(X_test_seq)}")

            if len(X_train_final_seq) > 0:
                X_train_tensor = torch.from_numpy(X_train_final_seq)
                y_train_tensor = torch.from_numpy(y_train_final_seq)
                X_val_tensor = torch.from_numpy(X_val_seq)
                y_val_tensor = torch.from_numpy(y_val_seq)
                X_test_tensor = torch.from_numpy(X_test_seq)
                y_test_tensor = torch.from_numpy(y_test_seq)

                num_workers = 2 if DEVICE.type == 'cuda' else 0
                pin_memory_flag = True if DEVICE.type == 'cuda' else False

                train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
                val_dataset = TensorDataset(X_val_tensor, y_val_tensor) if len(X_val_seq) > 0 else None
                test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

                train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory_flag, drop_last=True)
                val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=pin_memory_flag) if val_dataset else None
                test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=pin_memory_flag)

                model = TransformerBiLSTMModel(INPUT_DIM, LSTM_HIDDEN_DIM, LSTM_LAYERS,
                                               TRANSFORMER_DIM, NHEAD, NUM_TRANSFORMER_LAYERS,
                                               NUM_CLASSES, DROPOUT).to(DEVICE)

                # Class weights (Normalized: N / (K * count_c))
                if len(y_train_final_seq) > 0:
                    counts = np.bincount(y_train_final_seq)
                    N, K = counts.sum(), len(counts)
                    if len(counts) == NUM_CLASSES and 0 not in counts:
                        weights = torch.tensor([N / (K * c) for c in counts], dtype=torch.float32).to(DEVICE)
                        criterion = nn.CrossEntropyLoss(weight=weights)
                        print(f"Using normalized class weights for CrossEntropyLoss: {weights.cpu().numpy()}")
                    else:
                        criterion = nn.CrossEntropyLoss()
                else:
                    criterion = nn.CrossEntropyLoss()

                optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)

                # Train (+ Auto-Resume)
                history = train_model(model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS, DEVICE, EARLY_STOPPING_PATIENCE)

                # Final clean test evaluation
                evaluate_model(model, test_loader, criterion, DEVICE, is_test_set=True)

                # Optional robust evaluation
                robust_eval(model, test_loader, criterion, eps=0.01, steps=7, alpha=0.003, method='pgd')
                robust_eval(model, test_loader, criterion, eps=0.03, steps=7, alpha=0.007, method='pgd')

                # =========================
                # Interpretability Demo (small batch)
                # =========================
                if SAVE_INTERPRETABILITY:
                    model.eval()  # برای گرفتن attention پایدار

                    # نمونه کوچک از تست
                    Xb, yb = next(iter(test_loader))
                    Xb = Xb.to(DEVICE)
                    yb = yb.to(DEVICE)

                    # 1) Attention Maps از آخرین لایه ترنسفورمر
                    with torch.no_grad():
                        logits, attn = model(Xb[:1], return_attn=True)  # فقط یک نمونه برای نمایش
                    if attn is not None:
                        np.save(os.path.join(CKPT_DIR, 'attn_weights.npy'), attn[0].cpu().numpy())
                        print(f"Saved attention weights to {os.path.join(CKPT_DIR, 'attn_weights.npy')} (shape: H x T x T).")

                    # 2) Integrated Gradients برای N نمونه
                    N_IG = min(32, Xb.size(0))
                    attr = integrated_gradients(model, Xb[:N_IG], yb[:N_IG], steps=IG_STEPS, baseline=IG_BASELINE, clamp=(0.0, 1.0))
                    np.save(os.path.join(CKPT_DIR, 'ig_attributions.npy'), attr.detach().cpu().numpy())
                    print(f"Saved Integrated Gradients attributions to {os.path.join(CKPT_DIR, 'ig_attributions.npy')} (shape: B x T x D).")

                    # Summaries
                    t_imp, f_imp = summarize_attributions(attr.detach())
                    t_imp_np = t_imp.cpu().numpy()
                    f_imp_np = f_imp.cpu().numpy()
                    np.save(os.path.join(CKPT_DIR, 'ig_time_importance.npy'), t_imp_np)
                    np.save(os.path.join(CKPT_DIR, 'ig_feature_importance.npy'), f_imp_np)
                    print("Saved IG time and feature importance.")

                    # CSV از اهمیت ویژگی‌ها با نام‌ها
                    if FEATURE_NAMES is not None and len(FEATURE_NAMES) == f_imp_np.shape[0]:
                        df_feat = pd.DataFrame({
                            'feature': FEATURE_NAMES,
                            'ig_importance': f_imp_np
                        }).sort_values('ig_importance', ascending=False)
                        df_feat.to_csv(os.path.join(CKPT_DIR, 'ig_feature_importance.csv'), index=False)
                        print(f"Saved IG feature importance CSV to {os.path.join(CKPT_DIR, 'ig_feature_importance.csv')}.")

                print("\nExecution finished successfully.")
            else:
                print("\nExecution aborted: No training data available after sequencing/splitting.")
    else:
        print("\nExecution aborted due to data loading/preprocessing errors.")
else:
    print("\nExecution aborted: Dataset download failed or directory not found.")


Downloading UNSW-NB15 dataset using kagglehub...
Using Colab cache for faster access to the 'unswnb15' dataset.
Dataset downloaded to: /kaggle/input/unswnb15
Attempting to load training data from: /kaggle/input/unswnb15/UNSW_NB15_training-set.parquet
Attempting to load testing data from: /kaggle/input/unswnb15/UNSW_NB15_testing-set.parquet
Datasets loaded successfully.
Preprocessing datasets...
Preprocessing finished. Input dimension: 188
Creating sequences...
Raw training sequences: 175327
Raw testing sequences: 82318
Splitting 15.0% of training sequences for validation.
Final Train sequences: 149027
Validation sequences: 26300
Final Test sequences: 82318
Using normalized class weights for CrossEntropyLoss: [1.5658044 0.7345646]
[Resume] Loaded checkpoint from './unsw_bilstm_trf_adv_ig_checkpoints/ckpt_latest.pt' at epoch 24

--- Starting Training on cuda ---
Epoch 25/30 | Train Loss: 0.0124 | Val Loss: 0.0609 | Val Acc: 0.9846 | Val F1: 0.9887 | Time: 87.28s
Early stopping triggered 