In [None]:
import os
os.chdir("../..")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from utils.dataset_loader import load_datasets
from utils.model_utils import initialize_model
from utils.train_utils import train_model
from utils.metrics import evaluate_model
from utils.visualization import plot_training, plot_confusion_matrix

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
train_loader, val_loader, test_loader = load_datasets(
    data_dir="wildfire_dataset_scaled",
    batch_size=32,
    augmentation="baseline"
)

In [None]:
model = initialize_model(
    model_name="vgg16",  # Change to desired model, e.g., "resnet18"
    num_classes=2,
    pretrained=True,
    freeze_all=True  # Freezes all feature extractor layers
)

In [None]:
print("\nTrainable Parameters:")
for name, param in model.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

In [None]:
if not any(param.requires_grad for param in model.parameters()):
    raise ValueError("No trainable parameters found! Ensure the classifier layers are not frozen.")


In [None]:
model.train()

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)

In [None]:
print("\nOptimizer Parameters:")
for param_group in optimizer.param_groups:
    print(f"Learning Rate: {param_group['lr']}")


In [None]:
print("\nStarting Training...")
history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=None,  # No learning rate scheduler for baseline
    num_epochs=10,
    device=device
)

In [None]:
print("\nEvaluating the model...")
metrics = evaluate_model(
    model=model, test_loader=test_loader, classes=["No Fire", "Fire"], device=device
)
print(f"\nMetrics:\n{metrics}")

In [None]:
plot_training(history, "outputs/baseline_training_curve.png")
plot_confusion_matrix(
    metrics["confusion_matrix"],
    classes=["No Fire", "Fire"],
    output_path="outputs/baseline_confusion_matrix.png",
)