# Imports:

In [None]:
# Makes sure to reload modules when they change
%load_ext autoreload
%autoreload 2

# --- Standard Library Imports ---
import os
import joblib
from PIL import Image, ImageOps

# --- Third-Party Imports ---
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
import torch

# --- Custom Imports ---
from src.utils import (
    extractCoordinates, aspect_crop, haversine_distance,
    plot_images_from_dataloader, setup_TensorBoard_writers,
    log_error_map, create_interactive_heatmap
)
from src.dataset import GeolocalizationDataset
from src.models import ConvNet, ConvNet2, ConvNet3, MultiTaskDINOGeo

# Enable CUDA optimizations
torch.backends.cudnn.benchmark = True

# Image preprocessing:

In [None]:
# Setup paths
RAW_IMAGE_FOLDER = r"data_manual_gps_united"             # Use original images for GPS extraction
PROCESSED_IMAGE_FOLDER = r"data_processed_manual_gps" # Use processed images for training

if not os.path.exists(PROCESSED_IMAGE_FOLDER) or len(os.listdir(PROCESSED_IMAGE_FOLDER)) < 1475:
    os.makedirs(PROCESSED_IMAGE_FOLDER, exist_ok=True)

    print("Starting Pre-processing...")
    files = [f for f in os.listdir(RAW_IMAGE_FOLDER) if f.lower().endswith(('.jpg', '.jpeg'))]

    for filename in tqdm(files):
        src_path = os.path.join(RAW_IMAGE_FOLDER, filename)
        dst_path = os.path.join(PROCESSED_IMAGE_FOLDER, filename)
        
        try:
            with Image.open(src_path) as img:
                img = ImageOps.exif_transpose(img)
                img = img.convert('RGB')
                img = aspect_crop(img) 
                img = img.resize((192, 256), Image.Resampling.LANCZOS)
                img.save(dst_path, quality=95)
        except Exception as e:
            print(f"Failed {filename}: {e}")
else:
    print("Pre-processed images already exist. Skipping preprocessing step.")

# Data loading:

In [None]:
if __name__ == "__main__":

    SCALER_SAVE_PATH = 'coordinate_scaler.pkl'

    # --- 2. EXTRACTION PHASE ---
    processed_data = []

    for filename in os.listdir(RAW_IMAGE_FOLDER):
        if filename.lower().endswith(('.jpg', '.jpeg')):
            raw_image_path = os.path.join(RAW_IMAGE_FOLDER, filename)
            processed__image_path = os.path.join(PROCESSED_IMAGE_FOLDER, filename)
            
            # Check if the processed version actually exists
            if not os.path.exists(processed__image_path):
                continue
                
            # Extract coordinates from the ORIGINAL file
            coords = extractCoordinates(raw_image_path)
            
            if coords:
                processed_data.append({
                    'path': processed__image_path, 
                    'lat': coords[0], 
                    'lon': coords[1]
                })

    df = pd.DataFrame(processed_data)
    
    # Keep 20% of the data for validation
    train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
    
    # --- 3. SCALING PHASE ---
    # must fit on training data *only*
    scaler = MinMaxScaler()
    train_df[['lat', 'lon']] = scaler.fit_transform(train_df[['lat', 'lon']])
    val_df[['lat', 'lon']] = scaler.transform(val_df[['lat', 'lon']])

    joblib.dump(scaler, SCALER_SAVE_PATH)

## Clustering for multi-task training:

In [None]:
from sklearn.cluster import KMeans

print("Generating Zone Labels on Normalized Data...")

NUM_ZONES = 25 

kmeans = KMeans(n_clusters=NUM_ZONES, random_state=42, n_init=10)

# Fit on the normalized coordinates
train_df['zone_label'] = kmeans.fit_predict(train_df[['lat', 'lon']])
val_df['zone_label'] = kmeans.predict(val_df[['lat', 'lon']])

# plot the zones to see if they look reasonable for both training and validation
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Training data plot
axes[0].scatter(train_df['lon'], train_df['lat'], c=train_df['zone_label'], cmap='tab20', s=30)
axes[0].set_title(f"Training Data Divided into {NUM_ZONES} Zones")
axes[0].axis('equal')  # Keep aspect ratio so it looks like a map

# Validation data plot
axes[1].scatter(val_df['lon'], val_df['lat'], c=val_df['zone_label'], cmap='tab20', s=30)
axes[1].set_title(f"Validation Data Divided into {NUM_ZONES} Zones")
axes[1].axis('equal')  # Keep aspect ratio so it looks like a map

plt.tight_layout()
plt.show()

# Plot heatmap based on training data

In [None]:
train_coords = scaler.inverse_transform(train_df[['lat', 'lon']].values)
create_interactive_heatmap(train_coords, output_file='train_data_heatmap.html')

In [None]:
if __name__ == "__main__":

    # --- 4. DATASET INITIALIZATION ---
    print("Initializing train dataset...")
    train_dataset = GeolocalizationDataset(
        image_paths=train_df['path'].tolist(),
        coordinates=train_df[['lat', 'lon']].values,
        zone_labels=train_df['zone_label'].values,
        is_train=True,
        target_size=(252, 182)
    )
    print("Initializing validation dataset...")
    val_dataset = GeolocalizationDataset(
        image_paths=val_df['path'].tolist(),
        coordinates=val_df[['lat', 'lon']].values,
        zone_labels=val_df['zone_label'].values,
        is_train=False,
        target_size=(252, 182)
    )

    # --- 5. THE DATALOADER ---
    train_loader = DataLoader(
        train_dataset,
        batch_size=64,
        shuffle=True,
        num_workers=0, 
        pin_memory=True)
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=128,
        shuffle=False,
        num_workers=0, 
        pin_memory=False)
    
    plot_images_from_dataloader(train_loader)

# Model setup:

In [None]:
# --- 6. INITIALIZE MODEL ---

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
is_rtx = "RTX" in torch.cuda.get_device_name(0)
print("Using device:", device)
print("CUDA available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name(0))

model = MultiTaskDINOGeo(NUM_ZONES).to(device)

if is_rtx:
    model = model.to(
        memory_format=torch.channels_last
    )  # Optimize for modern GPUs that prefer channels_last
else:
    print("RTX card not detected: Disabling AMP/Channels_Last optimizations")
    model = model.to(device)

# --- 7. LOSS & OPTIMIZER ---
UNFREEZE_INTERVAL = 20  # epochs between unfreezing backbone blocks
BLOCKS_PER_STEP = 1     # How many blocks to open at once
victory_lap_started = False

patience_counter = 0
early_stopping_patience = UNFREEZE_INTERVAL - 1 # stop early if no improvement before next unfreeze
epochs = 500
use_TensorBoard = True  # Set to False to disable TensorBoard logging

criterion_reg = torch.nn.HuberLoss(delta=1.0)
criterion_cls = torch.nn.CrossEntropyLoss(label_smoothing=0.15)

# Learning rates for phase 1 - iterative unfreezing
p1_base_head_lr = 1e-3
p1_backbone_lr = 5e-5

# Learning rates for phase 2 - fine-tuning all layers
p2_backbone_lr = 1e-5  
p2_head_lr = 5e-4  

optimizer = torch.optim.AdamW(
    [
        {
            "params": filter(lambda p: p.requires_grad, model.backbone.parameters()),
            "lr": p1_backbone_lr,
        },
        {"params": model.shared.parameters(), "lr": p1_base_head_lr},
        {"params": model.reg_head.parameters(), "lr": p1_base_head_lr},
        {"params": model.cls_head.parameters(), "lr": p1_base_head_lr},
    ],
    weight_decay=0.01,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=UNFREEZE_INTERVAL, T_mult=1, eta_min=1e-7
)

print(
    f"Training on {len(train_dataset)} images, Validating on {len(val_dataset)} images."
)

# Model training:

In [None]:
# --- 8. TRAINING & VALIDATION LOOP ---
train_losses = []
val_losses = []
val_avg_dist_history = []
val_median_dist_history = []
val_zone_accuracy = []
learning_rates = []
best_dist = float("inf")

if use_TensorBoard:
    writer_train, writer_val = setup_TensorBoard_writers()

print(f"Starting training on {device}...")

gradScaler = GradScaler(
    "cuda"
)  # scaler for mixed precision training, prevents gradient underflow

for epoch in range(epochs):
    # Check for Unfreezing layer blocks
    if epoch > 0 and epoch % UNFREEZE_INTERVAL == 0 and not victory_lap_started:

        model.unfreeze_step(BLOCKS_PER_STEP)

        # Check if all blocks are unfrozen - phase 2 begins
        if next(model.backbone.blocks[0].parameters()).requires_grad: 
            print(f"\nüèÜ VICTORY LAP DETECTED at Epoch {epoch} üèÜ")
            victory_lap_started = True
            
            # Switch Optimizer to Low & Slow
            optimizer = torch.optim.AdamW([
                {'params': model.backbone.parameters(), 'lr': p2_backbone_lr},
                {'params': model.shared.parameters(),   'lr': p2_head_lr},
                {'params': model.reg_head.parameters(), 'lr': p2_head_lr},
                {'params': model.cls_head.parameters(), 'lr': p2_head_lr}
            ], weight_decay=0.02)
            
            # Switch Scheduler to Plateau (Patience=3, Factor=0.5)
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode='min', factor=0.5, patience=3, verbose=True
            )
            
            # Tighten Early Stopping for the end game
            patience_counter = 0

        else:
            patience_counter = 0  # reset patience on unfreeze
            # rebind optimizer to capture new blocks that were unfrozen
            num_unfrozen = epoch // UNFREEZE_INTERVAL

            current_backbone_lr = p1_backbone_lr * 0.9**num_unfrozen
            current_head_lr = p1_base_head_lr * 0.95**num_unfrozen

            optimizer = torch.optim.AdamW(
                [
                    {"params": model.backbone.parameters(), "lr": current_backbone_lr},
                    {"params": model.shared.parameters(), "lr": current_head_lr},
                    {"params": model.reg_head.parameters(), "lr": current_head_lr},
                    {"params": model.cls_head.parameters(), "lr": current_head_lr},
                ],
                weight_decay=0.01,
            )

            # Restart scheduler with the new, lower ceiling
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=UNFREEZE_INTERVAL, T_mult=1, eta_min=1e-7
            )
            print("--- Optimizer & Scheduler Reset for New Backbone Blocks ---")

    # --- PHASE 1: TRAINING ---
    model.train()  # Dropout ON
    train_running_loss = 0.0

    pbar = tqdm(
        train_loader,
        desc=f"Epoch {epoch+1}/{epochs}",
        leave=False,
        unit="batch",
        mininterval=0.5,
    )

    for batch_idx, (images, labels_coords, labels_zones) in enumerate(pbar):
        # non_blocking=True speeds up RAM-to-VRAM transfer
        images = images.to(device, non_blocking=True)
        labels_coords = labels_coords.to(device, non_blocking=True)
        labels_zones = labels_zones.to(device, non_blocking=True)

        if is_rtx:  # Optimize for RTX GPUs that prefer channels_last
            images = images.to(memory_format=torch.channels_last)

        optimizer.zero_grad(set_to_none=True)

        if is_rtx:  # Use Mixed Precision Training only on RTX cards
            with autocast("cuda", dtype=torch.float16):
                pred_coords, pred_zones = model(images)

                loss_reg = criterion_reg(pred_coords, labels_coords)
                loss_cls = criterion_cls(pred_zones, labels_zones)

                loss = loss_reg + (0.5 * loss_cls)

            gradScaler.scale(loss).backward()
            gradScaler.step(optimizer)
            gradScaler.update()

        else:  # Standard training for non-RTX cards
            pred_coords, pred_zones = model(images)

            loss_reg = criterion_reg(pred_coords, labels_coords)
            loss_cls = criterion_cls(pred_zones, labels_zones)

            loss = loss_reg + (0.5 * loss_cls)

            loss.backward()
            optimizer.step()

        if not victory_lap_started:
            scheduler.step(epoch + batch_idx / len(train_loader))
        train_running_loss += loss.item()

    # --- PHASE 2: VALIDATION ---
    model.eval()  # Set model to evaluation mode (disables Dropout)
    val_running_loss = 0.0
    correct_zones = 0
    raw_preds_coords = []
    raw_trues_coords = []

    with torch.no_grad():  # Disable gradient calculation for efficiency
        for images, labels_coords, labels_zones in val_loader:
            # todo: channels last?
            images = images.to(device, non_blocking=True)
            labels_coords = labels_coords.to(device, non_blocking=True)
            labels_zones = labels_zones.to(device, non_blocking=True)

            # Standard Prediction and MSE Loss
            pred_coords, pred_zones = model(images)
            loss_reg = criterion_reg(pred_coords, labels_coords)
            loss_cls = criterion_cls(pred_zones, labels_zones)
            val_running_loss += (loss_reg + 0.5 * loss_cls).item()

            raw_preds_coords.append(pred_coords.cpu().numpy())
            raw_trues_coords.append(labels_coords.cpu().numpy())

            predicted_zones = torch.argmax(pred_zones, dim=1)
            correct_zones += (predicted_zones == labels_zones).sum().item()

    full_preds_raw = np.vstack(raw_preds_coords)
    full_trues_raw = np.vstack(raw_trues_coords)
    real_preds = scaler.inverse_transform(full_preds_raw)
    real_trues = scaler.inverse_transform(full_trues_raw)

    # --- PHASE 3: METRICS CALCULATION & PRINTING ---
    distances = haversine_distance(real_preds, real_trues)
    avg_dist_error = np.mean(distances)

    if victory_lap_started:
        scheduler.step(avg_dist_error)  # ReduceLROnPlateau step
        
    median_dist_error = np.median(distances)
    zone_accuracy = correct_zones / len(val_dataset) * 100.0

    avg_train_loss = train_running_loss / len(train_loader)
    avg_val_loss = val_running_loss / len(val_loader)
    if victory_lap_started:
        current_lr = optimizer.param_groups[0]['lr']
    else:
        current_lr = scheduler.get_last_lr()[0]

    if epoch % UNFREEZE_INTERVAL == 0:  # log error map every 15 epochs
        log_error_map(
            real_preds,
            real_trues,
            epoch,
            TB_writer=writer_val if use_TensorBoard else None,
        )

    if use_TensorBoard:  # Write to TensorBoard
        writer_train.add_scalar("MSE Loss", avg_train_loss, epoch)
        writer_val.add_scalar("MSE Loss", avg_val_loss, epoch)
        writer_val.add_scalar(
            "Metrics/Avg_distance_Error_Meters", avg_dist_error, epoch
        )
        writer_val.add_scalar(
            "Metrics/Median_distance_Error_Meters", median_dist_error, epoch
        )
        writer_val.add_scalar("Metrics/Zone_Accuracy_Percent", zone_accuracy, epoch)
        writer_train.add_scalar("Hyperparameters/Learning_Rate", current_lr, epoch)

    else:  # No TensorBoard: store in lists
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        val_avg_dist_history.append(avg_dist_error)
        val_median_dist_history.append(median_dist_error)
        val_zone_accuracy.append(zone_accuracy)
        learning_rates.append(current_lr)

    print(
        f"Epoch {epoch+1}: Train Loss {avg_train_loss:.4f} | Val Loss {avg_val_loss:.4f} | "
        f"Avg Dist Error {avg_dist_error:.1f}m | Median Dist Error {median_dist_error:.1f}m | Zone Acc {zone_accuracy:.1f}%"
    )

    # 2. Save the BEST version of YOUR model
    if avg_dist_error < best_dist:
        best_dist = avg_dist_error
        torch.save(model.state_dict(), "custom_geo_model_best.pth")
        patience_counter = 0
        print(f"  *** NEW BEST: {best_dist:.1f}m ***")
    else:
        patience_counter += 1
        print(
            f"(No improvement for {patience_counter}/{early_stopping_patience} epochs)"
        )
        # 3. Early Stopping check
        if patience_counter >= early_stopping_patience:
            print("Model stopped improving. Ending training early.")
            break

log_error_map(
    real_preds, real_trues, epoch, TB_writer=writer_val if use_TensorBoard else None
)  # final log

if use_TensorBoard:
    writer_train.close()
    writer_val.close()

In [None]:
# --- 1. LOAD MODEL ---
print("Loading Best Checkpoint for Victory Lap...")
model = MultiTaskDINOGeo(NUM_ZONES).to(device)

# Load weights (ensure map_location is correct)
checkpoint = torch.load("dino_geo_model_best_mean_13.1.pth", map_location=device)
model.load_state_dict(checkpoint)

# --- 2. FORCE UNFREEZE EVERYTHING ---
print("üîì Unfreezing ENTIRE Model...")
for param in model.parameters():
    param.requires_grad = True

# --- 3. VICTORY LAP CONFIGURATION ---
# We use a Plateau scheduler because we want to keep the LR high 
# as long as we are improving, then drop it when we get stuck.

# HEADS: 5e-4 (Active fine-tuning)
# BACKBONE: 1e-5 (Very safe "cruising speed" to avoid breaking features)
victory_optimizer = torch.optim.AdamW([
    # Backbone: 2e-6 (Extremely slow polish)
    {'params': model.backbone.parameters(), 'lr': 2e-6},
    
    # Heads: 1e-4 (Standard fine-tuning speed)
    {'params': model.shared.parameters(),   'lr': 1e-4},
    {'params': model.reg_head.parameters(), 'lr': 1e-4},
    {'params': model.cls_head.parameters(), 'lr': 1e-4}
], weight_decay=0.02)

victory_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    victory_optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

# --- 4. TRAINING LOOP ---
print("üèÜ STARTING VICTORY LAP (Full Fine-Tuning) üèÜ")
patience_counter = 0
early_stopping = 15 # Shorter patience since we are just polishing
best_dist = 13.1    # Set this to your current best to avoid saving worse models!

for epoch in range(50): # Run for ~50 epochs, usually enough
    model.train()
    train_loss = 0.0
    
    # Training Step
    for images, labels_coords, labels_zones in tqdm(train_loader, desc=f"Victory Epoch {epoch+1}"):
        images = images.to(device)
        labels_coords = labels_coords.to(device)
        labels_zones = labels_zones.to(device)
        
        victory_optimizer.zero_grad()
        
        pred_coords, pred_zones = model(images)
        loss_reg = criterion_reg(pred_coords, labels_coords)
        loss_cls = criterion_cls(pred_zones, labels_zones)
        loss = loss_reg + (0.5 * loss_cls)
        
        loss.backward()
        victory_optimizer.step()
        train_loss += loss.item()

    # Validation Step
    model.eval()
    val_loss = 0.0
    raw_preds = []
    raw_trues = []
    
    with torch.no_grad():
        for images, labels_coords, labels_zones in val_loader:
            images = images.to(device)
            labels_coords = labels_coords.to(device)
            labels_zones = labels_zones.to(device)
            
            pred_coords, pred_zones = model(images)
            loss_reg = criterion_reg(pred_coords, labels_coords)
            loss_cls = criterion_cls(pred_zones, labels_zones)
            val_loss += (loss_reg + 0.5 * loss_cls).item()
            
            raw_preds.append(pred_coords.cpu().numpy())
            raw_trues.append(labels_coords.cpu().numpy())

    # Metrics
    full_preds = scaler.inverse_transform(np.vstack(raw_preds))
    full_trues = scaler.inverse_transform(np.vstack(raw_trues))
    dist_errors = haversine_distance(full_preds, full_trues)
    avg_dist = np.mean(dist_errors)
    median_dist = np.median(dist_errors)
    
    # Scheduler Step (Watch Distance, not Loss)
    victory_scheduler.step(avg_dist)
    
    print(f"Epoch {epoch+1}: Avg Dist: {avg_dist:.1f}m | Median: {median_dist:.1f}m")
    
    # Save if improved
    if avg_dist < best_dist:
        best_dist = avg_dist
        torch.save(model.state_dict(), "custom_geo_model_victory.pth")
        print(f"‚ú® NEW BEST: {best_dist:.2f}m ‚ú®")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= early_stopping:
            print("Victory Lap Finished.")
            break

In [None]:
from sklearn.neighbors import KNeighborsRegressor

def evaluate_with_knn(model, train_loader, val_loader, k=3):
    model.eval()
    print("Building Feature Bank from Training Data...")
    
    # 1. Extract features for all TRAINING images
    X_train = []
    y_train = []
    with torch.no_grad():
        for images, coords, _ in tqdm(train_loader):
            images = images.to(device)
            # Use the backbone + shared layer (the "Smart" features)
            # We bypass the heads entirely
            features = model.shared(model.backbone(images))
            X_train.append(features.cpu().numpy())
            y_train.append(coords.numpy()) # Keep normalized coords
            
    X_train = np.vstack(X_train)
    y_train = np.vstack(y_train)
    
    # 2. Extract features for VALIDATION images
    print("Querying Validation Data...")
    X_val = []
    y_val = []
    with torch.no_grad():
        for images, coords, _ in tqdm(val_loader):
            images = images.to(device)
            features = model.shared(model.backbone(images))
            X_val.append(features.cpu().numpy())
            y_val.append(coords.numpy())
            
    X_val = np.vstack(X_val)
    y_val = np.vstack(y_val)

    # 3. Run k-NN (Cosine distance is usually best for DINO)
    knn = KNeighborsRegressor(n_neighbors=k, metric='cosine', weights='distance')
    knn.fit(X_train, y_train)
    preds = knn.predict(X_val)
    
    # 4. Calculate Real Error
    real_preds = scaler.inverse_transform(preds)
    real_trues = scaler.inverse_transform(y_val)
    distances = haversine_distance(real_preds, real_trues)
    
    print(f"=== k-NN Results (k={k}) ===")
    print(f"Mean Error: {np.mean(distances):.1f} meters, Median Error: {np.median(distances):.1f} meters")

model.load_state_dict(torch.load("custom_geo_model_best.pth", map_location=device))
for i in range(1, 10):
    evaluate_with_knn(model, train_loader, val_loader, k=3)

# Plotting the results:

In [None]:
log_error_map(real_preds, real_trues, epoch,num_points= 50)

In [None]:
# --- 9. SAVE THE MODEL & GENERATE PLOTS ---

# Create a figure with two subplots (1 row, 2 columns)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Left Plot: MSE Loss (Training vs Validation)
ax1.plot(train_losses, label="Training Loss", color="blue", linewidth=2)
ax1.plot(val_losses, label="Validation Loss", color="red", linestyle="--", linewidth=2)
ax1.set_title("Mathematical Performance (MSE Loss)")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("MSE Loss")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Right Plot: Physical Distance Error (Meters)
ax2.plot(val_dist_history, label="Avg Distance Error", color="green", linewidth=2)
ax2.set_title("Real-World Performance (Distance Error)")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Error (Meters)")
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Save the model weights
torch.save(model.state_dict(), "geo_model.pth")
print(f"\nTraining finished! Best Validation Error: {best_dist:.1f} meters.")
print("Model saved as 'geo_model.pth'")

In [None]:
""" # Check if GPU is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SCALER_PATH = 'coordinate_scaler.pkl'
MODEL_WEIGHTS_PATH = 'geo_model.pth'

# Loading Model and Scaler
# Initialize the model architecture and move to the device (GPU/CPU)
model = ConvNet2().to(device)

# Load the trained weights from the .pth file
model.load_state_dict(torch.load(MODEL_WEIGHTS_PATH))

# Set the model to evaluation mode (disables Dropout and Batchnorm layers)
model.eval()

# Load the MinMaxScaler used during training to reverse the normalization
scaler = joblib.load(SCALER_PATH)

# Image Preprocessing Function
def predict_location(image_path):

    # Load the image and ensure it is in RGB format
    img = Image.open(image_path).convert('RGB')
    
    # Apply the same validation transforms (No augmentations here!)
    preprocess = T.Compose([
        T.Resize(256),
        T.CenterCrop(256),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Add a batch dimension (Batch size of 1) and move the tensor to device
    img_tensor = preprocess(img).unsqueeze(0).to(device) 
    
    # Perform inference without calculating gradients
    with torch.no_grad():
        output = model(img_tensor)
    
    # Convert the prediction back to a NumPy array on the CPU
    prediction_normalized = output.cpu().numpy()
    
    # Reverse the scaling to get real-world GPS coordinates
    real_coords = scaler.inverse_transform(prediction_normalized)
    
    # Return the first (and only) result in the batch [Latitude, Longitude]
    return real_coords[0]

# Run Inference on a New Image
# Provide the full path to your local image file
test_path = r"C:\path\to\your\new\image.jpg"
lat, lon = predict_location(test_path)

print(f"Predicted Location: Latitude {lat:.6f}, Longitude {lon:.6f}")
print(f"Google Maps Link: http://maps.google.com/maps?q={lat},{lon}") """