# Malware Classification Training Notebook

This notebook provides an interactive way to train, validate, and test malware classification models.

**For Google Colab:**
1. Clone your repository
2. Install dependencies
3. Run the cells sequentially

## 1. Setup and Imports

In [None]:
# For Google Colab: Clone repository (uncomment if needed)
!git clone https://github.com/YOUR_USERNAME/YOUR_REPO.git
%cd YOUR_REPO

In [None]:
# Install dependencies
!pip install -r requirements.txt

In [None]:
import torch
import torch.nn as nn
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from omegaconf import OmegaConf

# Import project modules
from src.data import MalwareDataModule
from src.models import MLP
from src.training import Trainer
from src.utils import setup_logger, set_seed, get_optimizer, get_criterion, get_device, MetricsTracker

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2. Configuration

In [None]:
# Configuration dictionary
config = {
    'experiment_name': 'malware_classification',
    'seed': 42,
    'device': 'cuda',  # Change to 'cpu' if no GPU

    # Data configuration
    'data_path': 'data/cybersequrity.csv',
    'target_column': 'Class',
    'batch_size': 64,
    'val_split': 0.20,
    'test_split': 0.20,
    'num_workers': 2,

    # Model configuration
    'hidden_dims': [128, 64, 32],
    'dropout': 0.0,
    'activation': 'relu',
    'batch_norm': False,
    'binary_classification': True,

    # Training configuration
    'max_epochs': 50,
    'early_stopping_patience': 10,
    'optimizer': 'adam',
    'lr': 0.001,
    'weight_decay': 0.0001,
    'criterion': 'bce_with_logits',

    # Output paths
    'output_dir': 'outputs/${experiment_name}',
    'checkpoint_dir': 'outputs/${experiment_name}/checkpoints',
    'tensorboard_dir': 'outputs/${experiment_name}/tensorboard',
}

# Create OmegaConf for easy access
cfg = OmegaConf.create(config)
print("Configuration:")
print(OmegaConf.to_yaml(cfg))

## 3. Set Random Seed

In [None]:
set_seed(cfg.seed)
print(f"Random seed set to: {cfg.seed}")

## 4. Load and Prepare Data

In [None]:
# Setup device
device = get_device(cfg.device)
print(f"Using device: {device}")

# Create data module
print("\nLoading data...")
data_module = MalwareDataModule(
    data_path=cfg.data_path,
    target_column=cfg.target_column,
    batch_size=cfg.batch_size,
    val_split=cfg.val_split,
    test_split=cfg.test_split,
    num_workers=cfg.num_workers,
    random_seed=cfg.seed,
)
data_module.setup()

# Get data dimensions
input_dim = data_module.get_feature_dim()
num_classes = data_module.get_num_classes()
print(f"\nInput dimension: {input_dim}")
print(f"Number of classes: {num_classes}")

## 5. Create Model

In [None]:
# Create model
print("Creating model...")
model = MLP(
    input_dim=input_dim,
    hidden_dims=cfg.hidden_dims,
    num_classes=num_classes,
    dropout=cfg.dropout,
    activation=cfg.activation,
    batch_norm=cfg.batch_norm,
    binary_classification=cfg.binary_classification,
)

print(f"\nModel architecture:")
print(model)
print(f"\nTotal parameters: {model.get_num_params():,}")

## 6. Setup Training

In [None]:
# Create optimizer
optimizer_cfg = OmegaConf.create({
    'name': cfg.optimizer,
    'lr': cfg.lr,
    'weight_decay': cfg.weight_decay
})
optimizer = get_optimizer(model, optimizer_cfg)

# Create loss function
criterion = get_criterion(cfg.criterion)

print(f"Optimizer: {optimizer.__class__.__name__}")
print(f"Learning rate: {cfg.lr}")
print(f"Weight decay: {cfg.weight_decay}")
print(f"Loss function: {criterion.__class__.__name__}")

## 7. Create Trainer

In [None]:
# Setup logger
logger = setup_logger()

# Create trainer
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    max_epochs=cfg.max_epochs,
    checkpoint_dir=Path(cfg.checkpoint_dir),
    tensorboard_dir=Path(cfg.tensorboard_dir),
    early_stopping_patience=cfg.early_stopping_patience,
    logger=logger,
)

print("Trainer created successfully!")

## 8. Train Model

In [None]:
# Get data loaders
train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

# Train the model
print(f"\nStarting training for {cfg.max_epochs} epochs...\n")
trainer.fit(train_loader, val_loader)

print("\nTraining completed!")

## 9. Plot Training History

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('Training and Validation Metrics', fontsize=16, fontweight='bold')

metrics_to_plot = ['loss', 'accuracy', 'precision', 'recall', 'f1', 'roc_auc']

for idx, metric in enumerate(metrics_to_plot):
    row = idx // 3
    col = idx % 3
    ax = axes[row, col]

    # Get metric values
    train_values = [epoch_metrics.get(metric, 0) for epoch_metrics in trainer.train_history]
    val_values = [epoch_metrics.get(metric, 0) for epoch_metrics in trainer.val_history]
    epochs = range(1, len(train_values) + 1)

    # Plot
    ax.plot(epochs, train_values, label='Train', linewidth=2, marker='o', markersize=4)
    ax.plot(epochs, val_values, label='Val', linewidth=2, marker='s', markersize=4)
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel(metric.replace('_', ' ').title(), fontsize=12)
    ax.set_title(metric.replace('_', ' ').title(), fontsize=14, fontweight='bold')
    ax.legend(loc='best', fontsize=10)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print best metrics
print(f"\nBest Validation Loss: {trainer.best_val_loss:.4f}")
print(f"Best Validation F1: {trainer.best_val_f1:.4f}")

## 10. Load Best Model and Test

In [None]:
# Load best checkpoint
best_checkpoint = Path(cfg.checkpoint_dir) / 'best_model.pt'
if best_checkpoint.exists():
    print(f"Loading best model from: {best_checkpoint}")
    trainer.load_checkpoint(best_checkpoint)
else:
    print("Best checkpoint not found, using current model")

# Test the model
test_loader = data_module.test_dataloader()
print("\nTesting model...")
test_metrics = trainer.test(test_loader)

# Print test results
print("\n" + "=" * 80)
print("Test Results:")
print("=" * 80)
for key, value in test_metrics.items():
    if key != 'confusion_matrix':
        print(f"{key:15s}: {value:.4f}")
print("=" * 80)

## 11. Confusion Matrix

In [None]:
# Plot confusion matrix
if 'confusion_matrix' in test_metrics:
    cm = test_metrics['confusion_matrix']

    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm,
        annot=True,
        fmt='d',
        cmap='Blues',
        xticklabels=['Benign', 'Malicious'],
        yticklabels=['Benign', 'Malicious'],
        cbar_kws={'label': 'Count'}
    )
    plt.title('Confusion Matrix', fontsize=16, fontweight='bold')
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()
    plt.show()

    print("\nConfusion Matrix:")
    print(cm)

## 12. Make Predictions on Random Samples

In [None]:
# Make predictions on random test samples
import random

model.eval()
num_samples = 10

print(f"\nPredictions on {num_samples} random test samples:")
print("=" * 80)

# Get random samples from test dataset
test_dataset = data_module.test_dataset
indices = random.sample(range(len(test_dataset)), num_samples)

correct = 0
for idx in indices:
    x, y_true = test_dataset[idx]
    x = x.unsqueeze(0).to(device)

    with torch.no_grad():
        logit = model(x)
        prob = torch.sigmoid(logit).item()
        y_pred = 1 if prob > 0.5 else 0

    is_correct = y_pred == y_true.item()
    correct += int(is_correct)

    print(f"Sample {idx}:")
    print(f"  True:      {y_true.item()} ({'Malicious' if y_true.item() == 1 else 'Benign'})")
    print(f"  Predicted: {y_pred} ({'Malicious' if y_pred == 1 else 'Benign'})")
    print(f"  Confidence: {max(prob, 1-prob):.4f}")
    print(f"  Correct: {'✓' if is_correct else '✗'}")
    print()

accuracy = correct / num_samples
print("=" * 80)
print(f"Sample Accuracy: {accuracy:.2%} ({correct}/{num_samples})")
print("=" * 80)

## 13. Save Configuration (Optional)

In [None]:
# Save configuration for reproducibility
config_save_path = Path(cfg.output_dir) / 'config.yaml'
config_save_path.parent.mkdir(parents=True, exist_ok=True)

with open(config_save_path, 'w') as f:
    OmegaConf.save(cfg, f)

print(f"Configuration saved to: {config_save_path}")

## 14. TensorBoard (Optional)

To view TensorBoard logs, run in a separate cell or terminal:
```python
%load_ext tensorboard
%tensorboard --logdir outputs/notebook_experiment/tensorboard
```