In [33]:
import torch
from torch import nn
import torch.optim as optim
from train.train import train
from train.dataloader import GeoDataLoader, img_train_transform
from torch.utils.data import DataLoader
from geoclip import GeoCLIP
import os
import wandb
import random
from datetime import datetime

### Eval Utils ###

In [34]:
import torch
import importlib
import json
import math
import numpy as np
from transformers.models.layoutlmv3.configuration_layoutlmv3 import LayoutLMv3OnnxConfig

def count_total_imgs(json_filepath, dataset_name):

    total = 0

    with open(json_filepath, 'r') as file:
        metadata = json.load(file)

    for _, data in metadata.items():
        if dataset_name == "googlestreetview":
            total += 1
        elif dataset_name == "eyesonrussia":
            images = data["img_names"]
            for _ in images:
                total += 1

    return total

def tensor_to_python_type(tensor):
    # Converts a tensor to a Python float, int, or list depending on its structure
    if tensor.ndim == 0:  # Scalar tensor
        return tensor.item()
    elif tensor.ndim == 1:  # 1D tensor
        return tensor.tolist()
    else:
        raise ValueError("Unsupported tensor dimension for conversion.")

def haversine_distance(latlon1, latlon2):
    """
    Calculate the great circle distance between two points on a sphere given their longitudes and latitudes.
    
    param latlon1: latitude / longitude of the first point
    param latlon2: latitude / longitude of the second point
    return: The distance between the two points in kilometers
    """
    # Radius of the Earth in kilometers
    R = 6371
    
    # Convert degrees to radians
    lat1, lon1 = latlon1
    lat2, lon2 = latlon2
    lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2])
    
    # Haversine formula
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1-a))
    
    # Distance
    distance = R * c
    
    return distance

def run_eval(imgs_filepath, json_filepath, output_filename, timestamp, epoch_num, dataset_name="eyesonrussia", streetview_only=True):
    """
    imgs_filepath: path to images folder
    json_filepath: path to json metadata
    """
    
    model = GeoCLIP(timestamp=timestamp, epoch_num=epoch_num)
    if torch.cuda.is_available(): 
       model.to("cuda")
    else:
        print("CUDA NOT AVAILABLE")

    pred_accuracies = {}
    errors = []

    with open(json_filepath, 'r') as file:
        metadata = json.load(file)
    
    img_count = 0
    for _, img_metadata in metadata.items():
        
        if dataset_name == "googlestreetview":
            images = [img_metadata["image_name"]]
        elif dataset_name == "eyesonrussia":
            images = img_metadata["img_names"]

        for idx, img_name in enumerate(images):
            streetview_entries = img_metadata["streetview"]

            if "streetview" in img_metadata and idx < len(img_metadata["streetview"]) and img_metadata["streetview"][idx] == "YES" and img_metadata["train_or_test"][idx] == "TEST":
                run_eval = True
            elif not streetview_only and img_metadata["train_or_test"][idx] == "TEST":
                run_eval = True
            else:
                run_eval = False

            if run_eval:
                try:
                    top_pred_gps, top_pred_prob = model.predict(imgs_filepath + img_name, top_k=1)
                    ground_truth = img_metadata["coordinates"]
                    pred_latlon = tensor_to_python_type(top_pred_gps[0])
                    probability = tensor_to_python_type(top_pred_prob[0])

                    dist = haversine_distance(pred_latlon, ground_truth)
                    pred_accuracies[img_name] = {
                        "prediction": pred_latlon,
                        "error_km": dist,
                        "probability": probability,
                        "ground_truth": ground_truth,
                        "country": img_metadata["country"],
                        "city": img_metadata["city"]
                    }
                    if dataset_name == "eyesonrussia":
                        pred_accuracies[img_name]["url"] = img_metadata["urls"]
                        
                    errors.append(dist)
                    img_count += 1
                except FileNotFoundError:
                    print(f"Error: The file '{img_name}' was not found in the specified directory '{imgs_filepath}'.")

    print(f"Average Error (km): {sum(errors) / len(errors) if errors else 0}")

    with open(output_filename, 'w') as file:
        json.dump(pred_accuracies, file, indent=4)
    print(f"Prediction accuracies saved to {output_filename}")

    error_list = [item["error_km"] for item in pred_accuracies.values()]

    summary_stats = {
        "count": len(error_list),
        "mean": np.mean(error_list),
        "std": np.std(error_list, ddof=1),  # Sample standard deviation (ddof=1)
        "min": np.min(error_list),
        "max": np.max(error_list),
        "25%": np.percentile(error_list, 25),
        "50%": np.percentile(error_list, 50),  # Median
        "75%": np.percentile(error_list, 75),
    }

    for key, value in summary_stats.items():
        print(f"{key}: {value:.2f}")

    dist_acc = [0, 0, 0, 0, 0]
    for error in error_list:
        if error < 1:
            dist_acc[0] += 1
        if error < 25:
            dist_acc[1] += 1
        if error < 200:
            dist_acc[2] += 1
        if error < 750:
            dist_acc[3] += 1
        if error < 2500:
            dist_acc[4] += 1

    acc_1km = dist_acc[0] / len(error_list)
    acc_25km = dist_acc[1] / len(error_list)
    acc_200km = dist_acc[2] / len(error_list)
    acc_750km = dist_acc[3] / len(error_list)
    acc_2500km = dist_acc[4] / len(error_list)

    print(f"\n1km Accuracy: {acc_1km:.2f}")
    print(f"25km Accuracy: {acc_25km:.2f}")
    print(f"200km Accuracy: {acc_200km:.2f}")
    print(f"750km Accuracy: {acc_750km:.2f}")
    print(f"2500km Accuracy: {acc_2500km:.2f}")

    # Log probabilities
    probs_list = [item["probability"] for item in pred_accuracies.values()]
    summary_stats = {
        "count": len(probs_list),
        "mean": np.mean(probs_list),
        "std": np.std(probs_list, ddof=1),  # Sample standard deviation (ddof=1)
        "min": np.min(probs_list),
        "max": np.max(probs_list),
        "25%": np.percentile(probs_list, 25),
        "50%": np.percentile(probs_list, 50),  # Median
        "75%": np.percentile(probs_list, 75),
    }
    for key, value in summary_stats.items():
        print(f"{key}: {value:.2f}")
    
    return acc_1km, acc_25km, acc_200km, acc_750km, acc_2500km, summary_stats["50%"], summary_stats["min"], summary_stats["max"]

### Fine Tuning ###

In [35]:
dataset_file = os.path.expanduser("~/mnt/cluster_storage/ai_geolocation/combined_train_geolocations.csv")
dataset_folder = "~/mnt/cluster_storage/ai_geolocation"
batch_size = 32

train_dataset = GeoDataLoader(dataset_file, dataset_folder, transform=img_train_transform())
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

Loading image paths and coordinates: 0it [00:00, ?it/s]

Loading image paths and coordinates: 15236it [00:00, 22330.51it/s]

Total images found: 15236





In [36]:
# Initialize model
current_time = datetime.now().strftime("%m-%d-%H:%M")
os.makedirs("snapshots", exist_ok=True)

step_size = 30
lr = 0.0001
num_epochs = 10
gamma = 0.1
model = GeoCLIP(timestamp=None, from_pretrained=True)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# wandb setup
wandb.init(
    project="ai-geolocation",
    config={
        "learning_rate": {lr},
        "step_size": {step_size},
        "epochs":{num_epochs},
        "optimizer_gamma":{gamma}
    }
)

# Train
for epoch in range(num_epochs):
    train(train_dataloader, model, optimizer, epoch, batch_size, device, scheduler=scheduler)

    snapshot_dir = f"snapshots/epoch_{epoch}"
    os.makedirs(snapshot_dir, exist_ok=True)

    # Run test set on this epoch's weights
    
    # EoR
    eor_imgs_path = "/home/ray/mnt/cluster_storage/ai_geolocation/eyesonrussia/eyesonrussia_imgs/"
    eor_json_path = "/home/ray/mnt/cluster_storage/ai_geolocation/eyesonrussia/eyesonrussia.json"
    eor_output_path = f"/home/ray/mnt/cluster_storage/ai_geolocation/eyesonrussia/eyesonrussia_finetuned_geoclip_predictions_{current_time}"
    acc_1km, acc_25km, acc_200km, acc_750km, acc_2500km, prob_50, prob_min, prob_max = run_eval(epoch_num=epoch, timestamp=current_time, imgs_filepath=eor_imgs_path, json_filepath=eor_json_path, output_filename=eor_output_path, dataset_name="eyesonrussia")

    wandb.log({
        "epoch": epoch,
        "eor_accuracy_1km": acc_1km,
        "eor_accuracy_25km": acc_25km,
        "eor_accuracy_200km": acc_200km,
        "eor_accuracy_750km": acc_750km,
        "eor_accuracy_2500km": acc_2500km,
        "eor_probability_50%": prob_50,
        "eor_probability_min": prob_min,
        "eor_probability_max": prob_max
    })

    # GSV
    gsv_imgs_path = "/home/ray/mnt/cluster_storage/ai_geolocation/googlestreetview/googlestreetview_all_imgs/"
    gsv_json_path = "/home/ray/mnt/cluster_storage/ai_geolocation/googlestreetview/googlestreetview.json"
    gsv_output_path = f"/home/ray/mnt/cluster_storage/ai_geolocation/googlestreetview/googlestreetview_finetuned_predictions_{current_time}"
    acc_1km, acc_25km, acc_200km, acc_750km, acc_2500km, prob_50, prob_min, prob_max = run_eval(epoch_num=epoch, timestamp=current_time, imgs_filepath=gsv_imgs_path, json_filepath=gsv_json_path, output_filename=gsv_output_path, dataset_name="googlestreetview")
    wandb.log({
        "epoch": epoch,
        "gsv_accuracy_1km": acc_1km,
        "gsv_accuracy_25km": acc_25km,
        "gsv_accuracy_200km": acc_200km,
        "gsv_accuracy_750km": acc_750km,
        "gsv_accuracy_2500km": acc_2500km,
        "gsv_probability_50%": prob_50,
        "gsv_probability_min": prob_min,
        "gsv_probability_max": prob_max
    })

    print(f"Saved snapshot for epoch {epoch}")

# Save fine-tuned weights
torch.save(model.image_encoder.mlp.state_dict(), f"model/weights/fine_tuned_image_encoder_mlp_weights_{current_time}.pth")
torch.save(model.location_encoder.state_dict(), f"model/weights/fine_tuned_location_encoder_weights_{current_time}.pth")
torch.save(model.logit_scale, f"model/weights/fine_tuned_logit_scale_weights_{current_time}.pth")
wandb.finish()

  self.load_state_dict(torch.load(f"{file_dir}/weights/{location_encoder_path}"))


gps_gallery: /home/ray/mnt/cluster_storage/ai_geolocation/geo-clip/geoclip/model/gps_gallery/coordinates_ukraine_russia.csv


  self.image_encoder.mlp.load_state_dict(torch.load(f"/home/ray/mnt/cluster_storage/ai_geolocation/geo-clip/geoclip/snapshots/{image_encoder_path}"))


FileNotFoundError: [Errno 2] No such file or directory: '/home/ray/mnt/cluster_storage/ai_geolocation/geo-clip/geoclip/snapshots/fine_tuned_image_encoder_mlp_weights.pth'