# Imports:

In [None]:
%load_ext autoreload
%autoreload 2

import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image, ImageOps

# --- CUSTOM IMPORTS ---
from src.utils import extractCoordinates, aspect_crop, haversine_distance, plot_images_from_dataloader, setup_TensorBoard_writers, log_error_map
from src.dataset import GeolocalizationDataset
from src.models import ConvNet, ConvNet2, ConvNet3, HierarchicalLocalizer

torch.backends.cudnn.benchmark = True

# Image preprocessing:

In [None]:
# Setup paths
RAW_IMAGE_FOLDER = r"data_manual_gps_united"             # Use original images for GPS extraction
PROCESSED_IMAGE_FOLDER = r"data_processed_manual_gps" # Use processed images for training

if not os.path.exists(PROCESSED_IMAGE_FOLDER) or len(os.listdir(PROCESSED_IMAGE_FOLDER)) < 1475:
    os.makedirs(PROCESSED_IMAGE_FOLDER, exist_ok=True)

    print("Starting Pre-processing...")
    files = [f for f in os.listdir(RAW_IMAGE_FOLDER) if f.lower().endswith(('.jpg', '.jpeg'))]

    for filename in tqdm(files):
        src_path = os.path.join(RAW_IMAGE_FOLDER, filename)
        dst_path = os.path.join(PROCESSED_IMAGE_FOLDER, filename)
        
        try:
            with Image.open(src_path) as img:
                img = ImageOps.exif_transpose(img)
                img = img.convert('RGB')
                img = aspect_crop(img) 
                img = img.resize((192, 256), Image.Resampling.LANCZOS)
                img.save(dst_path, quality=95)
        except Exception as e:
            print(f"Failed {filename}: {e}")
else:
    print("Pre-processed images already exist. Skipping preprocessing step.")

# Data loading:

In [None]:
if __name__ == "__main__":

    SCALER_SAVE_PATH = 'coordinate_scaler.pkl'

    # --- 2. EXTRACTION PHASE ---
    processed_data = []

    for filename in os.listdir(RAW_IMAGE_FOLDER):
        if filename.lower().endswith(('.jpg', '.jpeg')):
            raw_image_path = os.path.join(RAW_IMAGE_FOLDER, filename)
            processed__image_path = os.path.join(PROCESSED_IMAGE_FOLDER, filename)
            
            # Check if the processed version actually exists
            if not os.path.exists(processed__image_path):
                continue
                
            # Extract coordinates from the ORIGINAL file
            coords = extractCoordinates(raw_image_path)
            
            if coords:
                processed_data.append({
                    'path': processed__image_path, 
                    'lat': coords[0], 
                    'lon': coords[1]
                })

    images_df = pd.DataFrame(processed_data)
    
    # Keep 20% of the data for validation
    train_df, val_df = train_test_split(images_df, test_size=0.2, random_state=42)

In [None]:
if __name__ == "__main__":

    # --- 4. DATASET INITIALIZATION ---
    print("Initializing train dataset...")
    train_dataset = GeolocalizationDataset(
        image_paths=train_df['path'].tolist(),
        coordinates=train_df[['lat', 'lon']].values,
        target_size=(182, 252),
        is_train=False # Set to True if you want to apply data augmentations
    )
    print("Initializing validation dataset...")
    val_dataset = GeolocalizationDataset(
        image_paths=val_df['path'].tolist(),
        coordinates=val_df[['lat', 'lon']].values,
        target_size=(182, 252),
        is_train=False
    )

    # --- 5. THE DATALOADER ---
    train_loader = DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=0, 
        pin_memory=True)
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=0, 
        pin_memory=False)
    
    plot_images_from_dataloader(train_loader)

# Model setup:

In [None]:
localizer = HierarchicalLocalizer(train_dataset)

In [None]:
print("\n--- Running Evaluation ---")

def haversine_distance1(coord1, coord2):
    R = 6371000.0
    lat1, lon1 = np.radians(coord1)
    lat2, lon2 = np.radians(coord2)
    dlat, dlon = lat2 - lat1, lon2 - lon1
    a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
    return 2 * R * np.arctan2(np.sqrt(a), np.sqrt(1 - a))


total_error = 0.0
errors = []

# Create the tqdm object explicitly
pbar = tqdm(range(len(val_dataset)), desc="Validating")

for i in pbar:
    img, actual_gps = val_dataset[i]
    
    # Predict
    pred_gps = localizer.predict(img, top_k_dino= 20, top_k_lightglue=3, MIN_INLIER_THRESHOLD=100, debug=True, true_coords=actual_gps.numpy())
    
    # Measure
    err = haversine_distance1(pred_gps, actual_gps.numpy())
    errors.append(err)

    total_error += err
    avg_error = total_error / (i + 1)

    pbar.set_postfix({'curr': f"{err:.1f}m", 'avg': f"{avg_error:.1f}m"})

print(f"Final Average Error: {np.mean(errors):.2f} meters")

In [None]:
def error_analysis(errors):
    # Convert errors to a numpy array for easier manipulation
    errors_only = np.array([e[0] for e in errors]) 

    # Basic statistics
    print(f"Mean Error: {np.mean(errors_only):.2f} meters")
    print(f"Median Error: {np.median(errors_only):.2f} meters")
    print(f"Standard Deviation: {np.std(errors_only):.2f} meters")
    print(f"Minimum Error: {np.min(errors_only):.2f} meters")
    print(f"Maximum Error: {np.max(errors_only):.2f} meters")

    # Extract predicted and ground truth coordinates
    pred_coords = np.array([[e[1], e[2]] for e in errors])  # Pred_Lat, Pred_Lon
    true_coords = np.array([[e[3], e[4]] for e in errors])  # Gt_Lat, Gt_Lon

    # Log the error map
    log_error_map(pred_coords, true_coords, epoch=1)


In [None]:
errors = localizer.validate(val_loader, top_k_dino=10, top_k_lightglue=3, inlier_threshold=100)
error_analysis(errors)

In [None]:
errors = localizer.validate(val_loader, top_k_dino=15, top_k_lightglue=3, inlier_threshold=100)
error_analysis(errors)

In [None]:
errors = localizer.validate(val_loader, top_k_dino=20, top_k_lightglue=3, inlier_threshold=100)
error_analysis(errors)

In [None]:
errors = localizer.validate(val_loader, top_k_dino=25, top_k_lightglue=3, inlier_threshold=100)
error_analysis(errors)

In [None]:
errors = localizer.validate(val_loader, top_k_dino=30, top_k_lightglue=3, inlier_threshold=100)
error_analysis(errors)

In [None]:
errors = localizer.validate(val_loader, top_k_dino=35, top_k_lightglue=3, inlier_threshold=100)
error_analysis(errors)