In [1]:
import os
os.chdir("../..")

In [2]:
import torch
from utils.dataset_loader import *
from utils.model_utils import *
from utils.train_utils import *
from utils.metrics import *
from utils.visualization import *

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")

[INFO] Using device: cuda


In [4]:
data_dir = "wildfire_dataset_scaled"
batch_size = 16
learning_rate = 0.0004
num_classes = 2
model_name = "resnet18"
output_dir = "outputs/models/augmented"
os.makedirs(output_dir, exist_ok=True)


In [5]:
print("[INFO] Loading augmented datasets...")
train_loader, val_loader, test_loader = load_datasets(
    data_dir=data_dir, batch_size=batch_size, augmentation="augmented"
)
print("[INFO] Augmented datasets loaded successfully!")

2024-12-21 22:42:22,927 - INFO - Loading datasets from wildfire_dataset_scaled with augmentation type 'augmented'.
2024-12-21 22:42:22,927 - INFO - Applying moderate augmentations: Flip, Rotate, and ColorJitter.
2024-12-21 22:42:22,936 - INFO - Initialized AlbumentationsDataset with root: wildfire_dataset_scaled/train
2024-12-21 22:42:22,939 - INFO - Initialized AlbumentationsDataset with root: wildfire_dataset_scaled/val
2024-12-21 22:42:22,941 - INFO - Initialized AlbumentationsDataset with root: wildfire_dataset_scaled/test
2024-12-21 22:42:22,941 - INFO - Datasets initialized. Preparing DataLoaders...
2024-12-21 22:42:22,942 - INFO - Using WeightedRandomSampler for class balancing.
2024-12-21 22:42:22,942 - INFO - Computing class weights from directory: wildfire_dataset_scaled/train
2024-12-21 22:42:22,943 - INFO - Class 'fire' has 730 samples.
2024-12-21 22:42:22,945 - INFO - Class 'nofire' has 1157 samples.
2024-12-21 22:42:22,945 - INFO - Computed class weights: {'fire': 2.58493

[INFO] Loading augmented datasets...


2024-12-21 22:42:24,758 - INFO - DataLoaders created successfully.


[INFO] Augmented datasets loaded successfully!


In [6]:
# Initialize the model
print(f"\n[INFO] Starting Training for Model: {model_name} with Augmented Data\n")
print(f"[INFO] Initializing {model_name} model...")
model = initialize_model(
    model_name=model_name,
    num_classes=num_classes,
    pretrained=True,
    freeze_all=True  # Baseline augmentation: Freeze all layers
)
print("[INFO] Model initialized successfully!")


[INFO] Starting Training for Model: resnet18 with Augmented Data

[INFO] Initializing resnet18 model...
[INFO] Model initialized successfully!


In [7]:
# Define loss function and optimizer
print("[INFO] Setting up loss function and optimizer...")
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
print("[INFO] Loss function and optimizer set up!")

[INFO] Setting up loss function and optimizer...
[INFO] Loss function and optimizer set up!


In [8]:
# Train the model

from torch.optim.lr_scheduler import StepLR

scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
print("[INFO] Starting model training with augmented data...")
history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,  # No LR scheduler for augmented baseline
    device=device,
    save_path=os.path.join(output_dir, f"{model_name}_augmented.pt"),
    early_stop_patience=12,  # Early stopping patience
    monitor_metric="val_f1"
)
print("[INFO] Training completed successfully!")

[INFO] Starting model training with augmented data...

Starting training...



                                                                       

[INFO] Learning rate adjusted to: 0.000400
[INFO] Best model saved with val_f1: 0.7918
Epoch [1]: Train Loss: 0.6949, Train Acc: 0.5861 | Val Loss: 0.5406, Val Acc: 0.7463, Val Recall: 0.7886, Val F1: 0.7918


                                                                       

[INFO] Learning rate adjusted to: 0.000400
[INFO] Best model saved with val_f1: 0.8159
Epoch [2]: Train Loss: 0.5823, Train Acc: 0.6937 | Val Loss: 0.4864, Val Acc: 0.7811, Val Recall: 0.7927, Val F1: 0.8159


                                                                      

KeyboardInterrupt: 

In [None]:
# Evaluate the model
print("[INFO] Evaluating the model...")
metrics = evaluate_model(
    model=model,
    test_loader=test_loader,
    classes=["No Fire", "Fire"],
    device=device
)
print(f"\n[INFO] Metrics for {model_name} with Augmented Data:\n{metrics}")

In [None]:
print("[INFO] Saving training and evaluation results...")

# Training curve
plot_training(
    history, output_path=os.path.join(output_dir, f"{model_name}_augmented_training_curve.png")
)

# Confusion matrix
plot_confusion_matrix(
    cm=metrics["confusion_matrix"],
    classes=["No Fire", "Fire"],
    output_path=os.path.join(output_dir, f"{model_name}_augmented_confusion_matrix.png")
)

# ROC Curve
plot_roc_curve(
    y_true=metrics["y_true"],
    y_scores=metrics["y_scores"],
    output_path=os.path.join(output_dir, f"{model_name}_augmented_roc_curve.png")
)

# Precision-Recall Curve
plot_precision_recall(
    y_true=metrics["y_true"],
    y_scores=metrics["y_scores"],
    output_path=os.path.join(output_dir, f"{model_name}_augmented_precision_recall.png")
)

print("[INFO] All results saved successfully!")