In [1]:
from arcgis.learn import UnetClassifier, DeepLab, PSPNetClassifier, MultiTaskRoadExtractor, prepare_data
import torch
import os

# ---- Training settings ----
epochsNum = 10 #Maximum times to run the training.
batchNum = 4 #Test speeds between batch sizes, typically 2, 4, 8, but can be odd numbers.  (8 causes memory overrun on my pc.)
chipSize = 1024 #dataset chip size
numWorkers = 4  # Adjust based on your CPU cores

# ---- Paths ----
training_data_path = r"C:\Users\ss2596\Documents\njoko training\LabeledObjects\512"
model_output_path = r"C:\Users\ss2596\Documents\Njoko_model\Njoko_model_512"

# ---- Prepare data ----
print("📦 Preparing data...", flush=True)
data = prepare_data(
    path=training_data_path,
    batch_size=batchNum,
    chip_size=chipSize,
    num_workers=numWorkers
)

# ---- Initialize model ----
print("🧠 Initializing UnetClassifier model...", flush=True)
#Uncomment the model you want to use
model = UnetClassifier(data)
#model = DeepLab(data)
#model = PSPNetClassifier(data)
#model = MultiTaskRoadExtractor(data)

# ---- Create output folder if needed ----
os.makedirs(model_output_path, exist_ok=True)

# ---- Detect and display GPU device ----
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(torch.cuda.current_device())
    print(f"⚙️ Using GPU: {gpu_name}", flush=True)
else:
    print("⚙️ Using CPU", flush=True)



📦 Preparing data...
🧠 Initializing UnetClassifier model...
⚙️ Using GPU: NVIDIA GeForce RTX 3060


In [2]:
import os
import time
import csv


# Define path for saving training metrics
metrics_file = os.path.join(model_output_path, "training_metrics.csv")

# Write header to metrics CSV if it doesn't already exist
if not os.path.exists(metrics_file):
    with open(metrics_file, mode='w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["epoch", "train_loss", "valid_loss", "accuracy", "dice", "duration_mins"])

# ---- Early stopping setup ----
best_loss = float('inf')     # Lowest validation loss seen so far
patience = 3                 # Stop training if no improvement after this many epochs
no_improve_epochs = 0        # Counter for consecutive non-improving epochs

# ---- Begin training loop ----
for epoch in range(epochsNum):
    print(f"\n🔁 Starting epoch {epoch + 1}/{epochsNum}...", flush=True)
    start_time = time.time()

    # Train for one epoch
    model.fit(1)

    # ---- Collect training and validation metrics ----
    learner = model.learn
    train_loss = learner.recorder.losses[-1].item()               # Last training loss
    valid_loss, *metrics = learner.validate()                     # Validation loss and metrics
    accuracy = float(metrics[0]) if len(metrics) > 0 else None    # Extract accuracy if available
    dice = float(metrics[1]) if len(metrics) > 1 else None        # Extract dice coefficient if available
    duration = round((time.time() - start_time) / 60, 2)          # Duration in minutes

    # ---- Append metrics to CSV ----
    with open(metrics_file, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            epoch + 1,
            round(train_loss, 4),
            round(valid_loss, 4),
            round(accuracy, 4) if accuracy is not None else None,
            round(dice, 4) if dice is not None else None,
            duration
        ])

    # ---- Save checkpoint for this epoch ----
    checkpoint_path = os.path.join(model_output_path, f"checkpoint_epoch_{epoch + 1}.dlpk")
    model.save(checkpoint_path, framework='PyTorch')
    print(f"💾 Checkpoint saved: {checkpoint_path}", flush=True)


    # ---- Early stopping check ----
    if valid_loss < best_loss:
        best_loss = valid_loss             # New best model
        no_improve_epochs = 0              # Reset counter
    else:
        no_improve_epochs += 1             # Increment counter
        print(f"📉 No improvement. {no_improve_epochs} consecutive epochs without improvement.")

    # Trigger early stop if patience limit is reached
    if no_improve_epochs >= patience:
        print(f"🛑 Early stopping triggered after {patience} epochs without improvement.")
        break

# ---- Save final model after training completes or early stopping ----
final_model_path = os.path.join(model_output_path, "final_model.dlpk")
model.save(final_model_path, framework='PyTorch')
print(f"\n🎯 Final model saved: {final_model_path}", flush=True)

print("\n✅ Training complete.")


epoch,train_loss,valid_loss,accuracy,dice,time
0,0.42393,0.368557,0.84694,0.537068,01:28


Computing model metrics...
💾 Checkpoint saved: C:\Users\ss2596\Documents\Njoko_model\Njoko_model_512\checkpoint_epoch_10.dlpk

🎯 Final model saved: C:\Users\ss2596\Documents\Njoko_model\Njoko_model_512\final_model.dlpk

✅ Training complete.


| Model                 | Key Features                             | Strengths                       | Typical Use Cases                 | Training Speed     | Model Size / Complexity | Notes                                |
|-----------------------|----------------------------------------|--------------------------------|---------------------------------|--------------------|-------------------------|-------------------------------------|
| **UNet**              | Encoder-decoder with skip connections  | Good balance of accuracy & speed | General-purpose segmentation    | Fast to Moderate   | Medium                  | Very popular, simple architecture   |
| **DeepLab (DeepLabV3)** | Atrous convolutions, ResNet backbone options | High accuracy, good boundary detection | Complex scenes with fine details | Moderate to Slow    | Large                   | High performance, slower to train   |
| **PSPNet**            | Pyramid pooling for global context     | Excellent for large-scale context | Scenes with varied background/classes | Moderate          | Large                   | Great for capturing global context  |
| **LinkNet**           | Lightweight encoder-decoder architecture | Fast inference, smaller size    | Real-time or resource-constrained use | Fast              | Small                   | Good for speed-critical apps        |
| **MultiTaskRoadExtractor** | Specialized for roads extraction       | Accurate for roads, edges       | Road/transportation mapping      | Moderate          | Medium                  | Focused use case, less general      |
