In [1]:
import os
os.chdir("../..")

In [2]:
import torch
from utils.dataset_loader import *
from utils.model_utils import *
from utils.train_utils import *
from utils.metrics import *
from utils.visualization import *

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")

[INFO] Using device: cuda


In [4]:
data_dir = "wildfire_dataset_scaled"
batch_size = 8
learning_rate = 0.0001
num_classes = 2

In [5]:
print("[INFO] Loading datasets...")
train_loader, val_loader, test_loader = load_datasets(data_dir, batch_size, augmentation="baseline")
print("[INFO] Datasets loaded successfully!")



2024-12-24 17:38:53,786 - INFO - Loading datasets from wildfire_dataset_scaled with augmentation type 'baseline'.
2024-12-24 17:38:53,787 - INFO - Applying baseline augmentations: Resize and Normalize.
2024-12-24 17:38:53,806 - INFO - Initialized AlbumentationsDataset with root: wildfire_dataset_scaled/train
2024-12-24 17:38:53,811 - INFO - Initialized AlbumentationsDataset with root: wildfire_dataset_scaled/val
2024-12-24 17:38:53,817 - INFO - Initialized AlbumentationsDataset with root: wildfire_dataset_scaled/test
2024-12-24 17:38:53,818 - INFO - Datasets initialized. Preparing DataLoaders...
2024-12-24 17:38:53,819 - INFO - Using WeightedRandomSampler for class balancing.
2024-12-24 17:38:53,819 - INFO - Computing class weights from directory: wildfire_dataset_scaled/train
2024-12-24 17:38:53,824 - INFO - Class 'fire' has 730 samples.
2024-12-24 17:38:53,828 - INFO - Class 'nofire' has 1157 samples.
2024-12-24 17:38:53,830 - INFO - Computed class weights: {'fire': 1.0, 'nofire': 0.

[INFO] Loading datasets...


2024-12-24 17:38:56,196 - INFO - DataLoaders created successfully.


[INFO] Datasets loaded successfully!


In [6]:
print("\n[INFO] Starting Baseline Training for Model: ConvnextTiny\n")
print("[INFO] Initializing ConvnextTiny model...")
model = initialize_model(
    model_name="convnext_tiny",
    num_classes=num_classes,
    pretrained=True,
    freeze_all=True  # Baseline: Freeze all layers
)
print("[INFO] Model initialized successfully!")


[INFO] Starting Baseline Training for Model: ConvnextTiny

[INFO] Initializing ConvnextTiny model...
[INFO] Model initialized successfully!


In [7]:
print("[INFO] Setting up loss function and optimizer...")
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
print("[INFO] Loss function and optimizer set up!")

[INFO] Setting up loss function and optimizer...
[INFO] Loss function and optimizer set up!


In [8]:
import time

# Start the timer
start_time = time.time()
print("Training started...")

print("[INFO] Starting model training...")
history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=None,  # No learning rate scheduler for baseline
    device=device,
    save_path=f"outputs/models/baseline/convnext_tiny.pt",
    early_stop_patience=12,
    monitor_metric="val_f1",
)
print("[INFO] Training completed successfully!")

# Calculate and display elapsed time
end_time = time.time()
elapsed_time = end_time - start_time

# Format elapsed time as hours, minutes, and seconds
hours = int(elapsed_time // 3600)
minutes = int((elapsed_time % 3600) // 60)
seconds = int(elapsed_time % 60)

print(f"Training completed in {hours} hours, {minutes} minutes, and {seconds} seconds.")

[INFO] Starting model training...

Starting training...



                                                                       

KeyboardInterrupt: 

In [None]:
print("[INFO] Evaluating the model...")
metrics = evaluate_model(
    model=model,
    test_loader=test_loader,
    classes=["No Fire", "Fire"],
    device=device,
    model_name="convnext_tiny",
    save_base_path="outputs/plots/baseline",
)
print(f"\n[INFO] Metrics for ResNet18:\n{metrics}")

In [None]:
# Training and Validation Curves
plot_training(history, f"outputs/plots/baseline/convnext_baseline_training_curve.png")

# Confusion Matrix
plot_confusion_matrix(
    cm=metrics["confusion_matrix"],
    classes=["No Fire", "Fire"],
    output_path=f"outputs/plots/baseline/convnext_baseline_confusion_matrix.png"
)

# Precision-Recall Curve
if "y_probs" in metrics:
    y_true = metrics["y_true"]
    y_probs = [prob[1] for prob in metrics["y_probs"]]
    plot_precision_recall(
        y_true=y_true,
        y_scores=y_probs,
        output_path=f"outputs/plots/baseline/convnext_baseline_precision_recall_curve.png"
    )

    # ROC Curve
    plot_roc_curve(
        y_true=y_true,
        y_scores=y_probs,
        output_path=f"outputs/plots/baseline/convnext_baseline_roc_curve.png"
    )

print("[INFO] All results saved successfully!")