In [1]:
import os
os.chdir("..")

In [2]:
import torch
from utils.dataset_loader import load_datasets
from utils.model_utils import initialize_model, save_model, load_model
from utils.train_utils import train_model
from utils.metrics import evaluate_model
from utils.visualization import (
    plot_training,
    plot_confusion_matrix,
    plot_roc_curve,
    plot_precision_recall_curve,
    plot_per_class_performance
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
data_dir = "wildfire_dataset_scaled"

train_loader, val_loader, test_loader, classes = load_datasets(data_dir, batch_size=32)
print("Classes:", classes)

Classes: ['fire', 'nofire']


In [4]:
def train_and_evaluate_model(model_name, num_epochs=10):
    """
    Train, evaluate, and visualize results for a given model.

    Args:
        model_name (str): Name of the model to train (e.g., 'resnet18', 'efficientnet_b0').
        num_epochs (int): Number of epochs for training.

    Returns:
        None
    """
    # Step 1: Initialize the model
    model = initialize_model(model_name, num_classes=len(classes), pretrained=True)
    model = model.to(device)

    # Step 2: Define training components
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    # Step 3: Train the model
    print(f"\nTraining {model_name}...")
    model, history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        num_epochs=num_epochs
    )

    # Step 4: Save the trained model
    model_path = f"outputs/trained_models/{model_name}.pth"
    save_model(model, model_path)
    print(f"Model saved at {model_path}")

    # Step 5: Evaluate the model
    print(f"\nEvaluating {model_name}...")
    metrics = evaluate_model(model, test_loader, classes, device)

    # Step 6: Visualize results
    plot_training(history)
    cm = confusion_matrix(metrics["all_labels"], metrics["all_preds"])
    plot_confusion_matrix(cm, classes)
    plot_roc_curve(model, test_loader, classes, device)
    plot_precision_recall_curve(model, test_loader, classes, device)
    plot_per_class_performance(metrics["classification_report"], classes)


In [None]:
train_and_evaluate_model(model_name="resnet18", num_epochs = 10)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\Yash/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth

00%|███████████| 44.7M/44.7M [00:03<00:00, 14.1MB/s]


Training resnet18...
Epoch 1/10
----------



00%|████████████████| 59/59 [00:25<00:00,  2.27it/s]

Train Loss: 0.6175 Acc: 0.7271
Val Loss: 0.6894 Acc: 0.6468
Epoch 2/10
----------



00%|████████████████| 59/59 [00:23<00:00,  2.47it/s]

Train Loss: 0.4652 Acc: 0.7801
Val Loss: 0.4892 Acc: 0.7935
Epoch 3/10
----------



00%|████████████████| 59/59 [00:24<00:00,  2.42it/s]