In [1]:
import os
os.chdir("../..")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from utils.dataset_loader import load_datasets
from utils.model_utils import initialize_model
from utils.train_utils import train_model
from utils.metrics import evaluate_model
from utils.visualization import plot_training, plot_confusion_matrix
import optuna

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")

# Load dataset
print("[INFO] Loading datasets...")
train_loader, val_loader, test_loader = load_datasets(
    data_dir="wildfire_dataset_scaled",
    batch_size=32,
    augmentation="augmented"  # Switch to "augmented" for better generalization
)
print("[INFO] Datasets loaded successfully!")

# Objective function for Optuna
def objective(trial):
    print(f"[INFO] Starting trial {trial.number}")

    # Initialize VGG16 with the last 2 layers unfrozen
    print("[DEBUG] Initializing model with last 2 layers unfrozen...")
    model = initialize_model("vgg16", num_classes=2, pretrained=True, freeze_all=False, unfreeze_last_n=2)
    model.to(device)
    print("[DEBUG] Model initialized successfully!")

    # Define hyperparameters to tune
    print("[DEBUG] Suggesting hyperparameters...")
    lr = trial.suggest_loguniform("lr", 1e-4, 1e-2)
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "SGD", "AdamW"])
    weight_decay = trial.suggest_loguniform("weight_decay", 1e-5, 1e-3)
    print(f"[DEBUG] Suggested hyperparameters: lr={lr}, optimizer={optimizer_name}, weight_decay={weight_decay}")

    # Set optimizer
    print("[DEBUG] Setting up optimizer...")
    if optimizer_name == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_name == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    elif optimizer_name == "AdamW":
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    print("[DEBUG] Optimizer set up successfully!")

    # Define loss function
    criterion = nn.CrossEntropyLoss()

    # Train the model
    print("[INFO] Starting model training...")
    history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=None,
        num_epochs=5,  # Use fewer epochs for hyperparameter search
        device=device
    )
    print("[INFO] Training completed!")

    # Evaluate the model on validation data
    val_acc = history["val_acc"][-1]
    print(f"[INFO] Trial {trial.number} - Validation Accuracy: {val_acc:.4f}")
    return val_acc

# Create and run Optuna study
print("[INFO] Creating Optuna study...")
study = optuna.create_study(direction="maximize")
print("[INFO] Starting hyperparameter optimization...")
study.optimize(objective, n_trials=20)
print("[INFO] Hyperparameter optimization completed!")

# Retrieve best hyperparameters
best_params = study.best_params
print(f"[INFO] Best Hyperparameters: {best_params}")

# Train the final model with best hyperparameters
print("[INFO] Initializing final model with best hyperparameters...")
final_model = initialize_model("vgg16", num_classes=2, pretrained=True, freeze_all=False, unfreeze_last_n=2)
final_model.to(device)

print("[DEBUG] Setting up optimizer for final training...")
if best_params["optimizer"] == "Adam":
    optimizer = optim.Adam(final_model.parameters(), lr=best_params["lr"], weight_decay=best_params["weight_decay"])
elif best_params["optimizer"] == "SGD":
    optimizer = optim.SGD(final_model.parameters(), lr=best_params["lr"], momentum=0.9, weight_decay=best_params["weight_decay"])
elif best_params["optimizer"] == "AdamW":
    optimizer = optim.AdamW(final_model.parameters(), lr=best_params["lr"], weight_decay=best_params["weight_decay"])
print("[DEBUG] Optimizer setup for final model completed!")

criterion = nn.CrossEntropyLoss()

# Train the final model
print("[INFO] Starting final model training...")
history = train_model(
    model=final_model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=None,
    num_epochs=15,  # Full training with best parameters
    device=device
)
print("[INFO] Final model training completed!")

# Evaluate and visualize results
print("[INFO] Evaluating final model...")
metrics = evaluate_model(final_model, test_loader, ["No Fire", "Fire"], device)
print("[INFO] Evaluation completed!")

print("[INFO] Saving training and evaluation results...")
plot_training(history, "outputs/tuned_training_curve.png")
plot_confusion_matrix(metrics["confusion_matrix"], ["No Fire", "Fire"], "outputs/tuned_confusion_matrix.png")
print("[INFO] Results saved successfully!")

# Visualize Optuna study results
try:
    import optuna.visualization as vis
    print("[INFO] Generating Optuna visualizations...")
    vis.plot_optimization_history(study).show()
    vis.plot_param_importances(study).show()
    print("[INFO] Optuna visualizations generated successfully!")
except ImportError:
    print("[WARNING] Optuna visualization library is not installed. Skipping visualizations.")


[I 2024-12-12 13:26:43,844] A new study created in memory with name: no-name-82ef76ec-bd0e-4e3a-9e0e-5471495518b0


[INFO] Using device: cuda
[INFO] Loading datasets...
[INFO] Datasets loaded successfully!
[INFO] Creating Optuna study...
[INFO] Starting hyperparameter optimization...
[INFO] Starting trial 0
[DEBUG] Initializing model with last 2 layers unfrozen...


  lr = trial.suggest_loguniform("lr", 1e-4, 1e-2)
  weight_decay = trial.suggest_loguniform("weight_decay", 1e-5, 1e-3)


[DEBUG] Model initialized successfully!
[DEBUG] Suggesting hyperparameters...
[DEBUG] Suggested hyperparameters: lr=0.0005029584234920532, optimizer=Adam, weight_decay=0.0006944720003900249
[DEBUG] Setting up optimizer...
[DEBUG] Optimizer set up successfully!
[INFO] Starting model training...

Starting training...



Epoch [1/5] - Training:  69%|██████▉   | 41/59 [05:40<02:30,  8.37s/it, Batch Loss=0.6857]