In [1]:
import os
import base64
import json
from pathlib import Path
from typing import List, Set, Dict, Tuple, Optional
import numpy as np
import torch
from PIL import Image
from io import BytesIO
import yaml
from google import genai
from google.genai import types
from dotenv import load_dotenv

# Import required utilities
from utils.G3 import G3
from utils.utils import search_with_image_and_text, extract_image_search_candidates
from utils.prompt import combine_prompts

load_dotenv()

class BatchKeyframePredictor:
    """
    Batch prediction class for directory of keyframes/images.
    
    This class:
    1. Takes a directory of keyframes/images as input
    2. For each image, performs image search and collects the first N result links
    3. Adds all links to a global dictionary (no duplicates)
    4. Processes the links in batches
    5. For each batch, runs LLM prediction using all images and calculates similarity scores
    6. Selects the GPS with the highest average similarity across all images
    """
    
    def __init__(self, checkpoint_path: str, device: str = "cuda", index_path: Optional[str] = None):
        """
        Initialize the BatchKeyframePredictor.
        
        Args:
            checkpoint_path (str): Path to G3 model checkpoint
            device (str): Device to run model on ("cuda" or "cpu")
            index_path (str): Path to FAISS index for RAG (required)
        """
        if index_path is None:
            raise ValueError("index_path is required for batch prediction. FAISS index is mandatory for RAG.")
        
        if not os.path.exists(index_path):
            raise FileNotFoundError(f"FAISS index file not found: {index_path}")
        self.device = torch.device(device)
        self.checkpoint_path = checkpoint_path
        
        # Initialize G3 model
        base_path = Path(r"C:\Users\tungd\OneDrive - MSFT\Second Year\ML\ACMMM25 - Grand Challenge on Multimedia Verification\G3-Original\g3\example.ipynb").parent
        hparams = yaml.safe_load(open(base_path / "hparams.yaml", "r"))
        pe = "projection_mercator"
        nn = "rffmlp"
        
        self.model = G3(
            device=device,
            positional_encoding_type=pe,
            neural_network_type=nn,
            hparams=hparams[f"{pe}_{nn}"],
        )
        self.model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        self.model.to(device)
        self.model.requires_grad_(False)
        self.model.eval()
        
        # Load FAISS index for RAG (required)
        import faiss
        try:
            self.index = faiss.read_index(index_path)
            print(f"‚úÖ Successfully loaded FAISS index from: {index_path}")
        except Exception as e:
            raise RuntimeError(f"Failed to load FAISS index from {index_path}: {e}")
        
        # Get API key
        self.api_key = os.getenv("API_KEY")
        if self.api_key is None:
            raise ValueError("API_KEY environment variable is not set.")
        
        # Supported image extensions
        self.image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp'}
        
        # Search candidates dictionary: {link: full_candidate_string}
        self.candidates_dict: Dict[str, str] = {}
    
    @staticmethod
    def extract_and_parse_json(raw_text: str) -> dict:
        """
        Extract JSON content between first { and last } and parse it.
        
        Args:
            raw_text (str): Raw response text from LLM
            
        Returns:
            dict: Parsed JSON data
            
        Raises:
            ValueError: If no valid JSON found or parsing fails
        """
        # Find first { and last }
        first_brace = raw_text.find('{')
        last_brace = raw_text.rfind('}')
        
        if first_brace == -1 or last_brace == -1 or first_brace >= last_brace:
            raise ValueError(f"No valid JSON braces found in response: {raw_text}")
        
        # Extract JSON substring
        json_str = raw_text[first_brace:last_brace + 1]
        
        try:
            # Parse JSON
            parsed_data = json.loads(json_str)
            return parsed_data
        except json.JSONDecodeError as e:
            raise ValueError(f"Failed to parse JSON: {json_str}, Error: {e}")

    @staticmethod
    def is_valid_enhanced_gps_dict(gps_data):
        """
        Check if GPS data dict has valid enhanced format with latitude, longitude, location, and reason.
        """
        if not isinstance(gps_data, dict):
            return False
            
        required_fields = ["latitude", "longitude", "location", "reason"]
        if not all(field in gps_data for field in required_fields):
            return False
            
        try:
            lat = float(gps_data["latitude"])
            lon = float(gps_data["longitude"])
            
            # Basic GPS coordinate validation
            if -90 <= lat <= 90 and -180 <= lon <= 180:
                return True
        except (ValueError, TypeError):
            pass
            
        return False

    @staticmethod
    def image_to_base64(image_path: str) -> str:
        """Convert image file to base64 string."""
        with open(image_path, "rb") as image_file:
            encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
        return encoded_string

    def get_llm_prediction_with_batch(
        self,
        base64_images: List[str],
        batch_links: List[str],
        candidates_gps: Optional[list] = None,
        reverse_gps: Optional[list] = None,
        image_paths: Optional[List[str]] = None,
        transcript_file_path: str = "",
        metadata_file_path: str = "",
        n_coords: int = 15,
        model_name: str = "gemini-2.5-flash"
    ) -> dict:
        """
        Get LLM prediction using multiple images and comprehensive context.
        
        Args:
            base64_images: List of base64 encoded images
            batch_links: List of links for this batch
            candidates_gps: List of candidate GPS coordinates from RAG
            reverse_gps: List of reverse GPS coordinates from RAG
            image_paths: List of image file paths for metadata
            transcript_file_path: Path to transcript file
            metadata_file_path: Path to metadata JSON file
            n_coords: Number of coordinates to include
            model_name: Model to use for prediction
            
        Returns:
            dict: Parsed prediction with latitude, longitude, location, reason
        """
        # Get full candidates from the dictionary using the links
        search_candidates = [self.candidates_dict[link] for link in batch_links if link in self.candidates_dict]
        
        # Create comprehensive prompt using all available context
        combined_prompt = combine_prompts(
            image_path=image_paths[0] if image_paths else "",  # Use first image path for reference
            transcript_file_path=transcript_file_path,
            metadata_file_path=metadata_file_path,
            candidates_gps=candidates_gps[:n_coords] if candidates_gps else [],
            reverse_gps=reverse_gps[:n_coords] if reverse_gps else [],
            search_candidates=search_candidates,
            n_search=len(search_candidates),
            n_coords=n_coords,
        )
        
        client = genai.Client(
            api_key=self.api_key,
            http_options=types.HttpOptions(timeout=1)
        )
        
        # Convert base64 images to Parts
        image_parts = []
        for b64_img in base64_images:
            image = types.Part.from_bytes(
                data=base64.b64decode(b64_img), 
                mime_type="image/jpeg"
            )
            image_parts.append(image)
        
        # Combine images with prompt
        contents = image_parts + [combined_prompt]
        
        tools = [
            types.Tool(google_search=types.GoogleSearch()),
            types.Tool(url_context=types.UrlContext())
        ]

        config = types.GenerateContentConfig(
            tools=tools,
            response_modalities=["TEXT"]
        )

        response = client.models.generate_content(
            model=model_name,
            contents=contents,
            config=config
        )
        
        raw_text = response.text.strip() if response.text is not None else ""
        
        # Extract and parse JSON from response
        parsed_json = self.extract_and_parse_json(raw_text)
        return parsed_json

    def calculate_similarity_scores(
        self,
        rgb_images: List[Image.Image], 
        predicted_coords: List[Tuple[float, float]]
    ) -> np.ndarray:
        """
        Calculate similarity scores between images and predicted coordinates.
        
        Args:
            rgb_images: List of PIL Images
            predicted_coords: List of (lat, lon) tuples
            
        Returns:
            np.ndarray: Average similarity scores across all images for each coordinate
        """
        all_similarities = []
        
        for rgb_image in rgb_images:
            # Process image
            image = self.model.vision_processor(images=rgb_image, return_tensors="pt")[
                "pixel_values"
            ].reshape(-1, 224, 224)
            image = image.unsqueeze(0).to(self.device)

            with torch.no_grad():
                vision_output = self.model.vision_model(image)[1]

                image_embeds = self.model.vision_projection_else_2(
                    self.model.vision_projection(vision_output)
                )
                image_embeds = image_embeds / image_embeds.norm(
                    p=2, dim=-1, keepdim=True
                )  # b, 768

                # Process coordinates
                gps_batch = torch.tensor(predicted_coords, dtype=torch.float32).to(self.device)
                gps_input = gps_batch.clone().detach().unsqueeze(0)  # Add batch dimension
                b, c, _ = gps_input.shape
                gps_input = gps_input.reshape(b * c, 2)
                location_embeds = self.model.location_encoder(gps_input)
                location_embeds = self.model.location_projection_else(
                    location_embeds.reshape(b * c, -1)
                )
                location_embeds = location_embeds / location_embeds.norm(
                    p=2, dim=-1, keepdim=True
                )
                location_embeds = location_embeds.reshape(b, c, -1)  # b, c, 768

                similarity = torch.matmul(
                    image_embeds.unsqueeze(1), location_embeds.permute(0, 2, 1)
                )  # b, 1, c
                similarity = similarity.squeeze(1).cpu().detach().numpy()
                all_similarities.append(similarity[0])  # Remove batch dimension
        
        # Calculate average similarity across all images
        avg_similarities = np.mean(all_similarities, axis=0)
        return avg_similarities

    def search_index(self, rgb_image, top_k=20):
        """
        Search FAISS index for similar and dissimilar coordinates using image embeddings.
        
        Args:
            rgb_image: PIL RGB Image
            top_k (int): Number of top results to return
            
        Returns:
            tuple: (D, I, D_reverse, I_reverse) - distances and indices for positive and negative embeddings
        """
        print("Searching FAISS index...")
        image = self.model.vision_processor(images=rgb_image, return_tensors="pt")[
            "pixel_values"
        ].reshape(-1, 224, 224)
        image = image.unsqueeze(0).to(self.device)  # Add batch dimension
        
        with torch.no_grad():
            vision_output = self.model.vision_model(image)[1]
            image_embeds = self.model.vision_projection(vision_output)
            image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

            image_text_embeds = self.model.vision_projection_else_1(
                self.model.vision_projection(vision_output)
            )
            image_text_embeds = image_text_embeds / image_text_embeds.norm(
                p=2, dim=-1, keepdim=True
            )

            image_location_embeds = self.model.vision_projection_else_2(
                self.model.vision_projection(vision_output)
            )
            image_location_embeds = image_location_embeds / image_location_embeds.norm(
                p=2, dim=-1, keepdim=True
            )

            positive_image_embeds = torch.cat(
                [image_embeds, image_text_embeds, image_location_embeds], dim=1
            )
            positive_image_embeds = (
                positive_image_embeds.cpu().detach().numpy().astype(np.float32)
            )

            negative_image_embeds = positive_image_embeds * (-1.0)

        # Search FAISS index (index is guaranteed to exist)
        D, I = self.index.search(positive_image_embeds, top_k)
        D_reverse, I_reverse = self.index.search(negative_image_embeds, top_k)
        return D, I, D_reverse, I_reverse

    def _get_gps_coordinates(self, I, I_reverse, database_csv_path):
        """
        Helper method to get GPS coordinates from database using FAISS indices.
        
        Args:
            I: FAISS indices for positive embeddings
            I_reverse: FAISS indices for negative embeddings  
            database_csv_path (str): Path to GPS coordinates database CSV
            
        Returns:
            tuple: (candidates_gps, reverse_gps) - lists of (lat, lon) tuples
        """
        if I is None or I_reverse is None:
            return [], []
            
        candidate_indices = I[0]
        reverse_indices = I_reverse[0]
        
        candidates_gps = []
        reverse_gps = []
        
        try:
            import pandas as pd
            for chunk in pd.read_csv(database_csv_path, chunksize=10000, usecols=["LAT", "LON"]):
                for idx in candidate_indices:
                    if idx in chunk.index:
                        lat = float(chunk.loc[idx, "LAT"])
                        lon = float(chunk.loc[idx, "LON"])
                        candidates_gps.append((lat, lon))

                for ridx in reverse_indices:
                    if ridx in chunk.index:
                        lat = float(chunk.loc[ridx, "LAT"])
                        lon = float(chunk.loc[ridx, "LON"])
                        reverse_gps.append((lat, lon))
        except Exception as e:
            print(f"‚ö†Ô∏è Error loading GPS coordinates from database: {e}")
            
        return candidates_gps, reverse_gps

    def predict(
        self,
        keyframes_dir: str,
        database_csv_path: str = "",
        serpapi_key: str = "",
        imgbb_key: str = "",
        transcript_file_path: str = "",
        metadata_file_path: str = "",
        batch_size: int = 10,
        links_per_image: int = 3,
        top_k: int = 20,
        model_name: str = "gemini-2.5-flash"
    ) -> dict:
        """
        Comprehensive batch prediction for directory of keyframes/images.
        
        Args:
            keyframes_dir (str): Directory containing keyframe/image files
            database_csv_path (str): Path to GPS coordinates database CSV for RAG (required)
            serpapi_key (str): SerpAPI key for image search
            imgbb_key (str): imgbb API key for image upload
            transcript_file_path (str): Path to transcript file
            metadata_file_path (str): Path to metadata JSON file
            batch_size (int): Number of links to process in each batch (default: 10)
            links_per_image (int): Number of search result links to extract per image (default: 3)
            top_k (int): Number of top FAISS results for RAG (default: 20)
            model_name (str): LLM model name to use
            
        Returns:
            dict: Best prediction with latitude, longitude, location, reason, and metadata
        """
        
        # Validate required parameters for RAG
        if not database_csv_path:
            raise ValueError("database_csv_path is required for RAG coordinates. This parameter is mandatory.")
        
        if not os.path.exists(database_csv_path):
            raise FileNotFoundError(f"Database CSV file not found: {database_csv_path}")
        
        # Reset candidates dictionary for new prediction
        self.candidates_dict = {}
        
        # Step 1: Collect all image files
        keyframes_path = Path(keyframes_dir)
        if not keyframes_path.exists():
            raise ValueError(f"Directory does not exist: {keyframes_dir}")
        
        image_files = []
        for file_path in keyframes_path.iterdir():
            if file_path.is_file() and file_path.suffix.lower() in self.image_extensions:
                image_files.append(file_path)
        
        if not image_files:
            raise ValueError(f"No image files found in directory: {keyframes_dir}")
        
        print(f"üìÅ Found {len(image_files)} image files in {keyframes_dir}")
        
        # Step 2: For each image, perform search and collect links + RAG coordinates
        base64_images = []
        rgb_images = []
        image_paths = []
        all_candidates_gps = []
        all_reverse_gps = []
        
        for i, image_file in enumerate(image_files):
            print(f"üñºÔ∏è Processing image {i+1}/{len(image_files)}: {image_file.name}")
            
            try:
                # Convert to base64 for later use
                base64_img = self.image_to_base64(str(image_file))
                base64_images.append(base64_img)
                
                # Load RGB image for similarity calculation and FAISS search
                rgb_img = Image.open(image_file).convert("RGB")
                rgb_images.append(rgb_img)
                image_paths.append(str(image_file))
                
                # Perform FAISS search for RAG coordinates
                if self.index is not None and database_csv_path:
                    print(f"üîç Searching FAISS index for RAG coordinates...")
                    D, I, D_reverse, I_reverse = self.search_index(rgb_img, top_k)
                    candidates_gps, reverse_gps = self._get_gps_coordinates(I, I_reverse, database_csv_path)
                    all_candidates_gps.extend(candidates_gps)
                    all_reverse_gps.extend(reverse_gps)
                    print(f"ÔøΩ Found {len(candidates_gps)} candidate GPS and {len(reverse_gps)} reverse GPS coordinates")
                
                # Perform image search for web candidates
                print(f"ÔøΩüîç Searching web for image: {image_file.name}")
                search_results = search_with_image_and_text(
                    image_path=str(image_file),
                    search_text="",  # Image-only search
                    serpapi_key=serpapi_key,
                    imgbb_key=imgbb_key
                )
                
                # Extract candidate links
                search_candidates = extract_image_search_candidates(
                    search_results, 
                    no_results=links_per_image
                )
                
                # Store candidates in dictionary: {link: full_candidate_string}
                for candidate in search_candidates:
                    print(f"üîó Found candidate: {candidate}")
                    if candidate.startswith("Link: ") and " | Title: " in candidate:
                        link = candidate.split(" | Title: ")[0].replace("Link: ", "")
                        if link != "No link" and link.startswith("http"):
                            self.candidates_dict[link] = candidate
                
                print(f"‚úÖ Found {len(search_candidates)} web candidates for {image_file.name}")
                
            except Exception as e:
                print(f"‚ùå Error processing {image_file.name}: {e}")
                continue
        
        # Remove duplicates from RAG coordinates
        all_candidates_gps = list(set(all_candidates_gps))
        all_reverse_gps = list(set(all_reverse_gps))
        
        print(f"üîó Total unique web links collected: {len(self.candidates_dict)}")
        print(f"üìç Total unique RAG candidates GPS: {len(all_candidates_gps)}")
        print(f"üìç Total unique RAG reverse GPS: {len(all_reverse_gps)}")
        
        if not self.candidates_dict and not all_candidates_gps:
            raise ValueError("No valid search result links or RAG coordinates found from any images")
        
        # Step 3: Process links in batches and run LLM predictions with comprehensive context
        links_list = list(self.candidates_dict.keys())
        all_predictions = {}  # {(lat, lon): prediction_dict}
        
        # If we have web links, process them in batches
        if links_list:
            for batch_start in range(0, len(links_list), batch_size):
                batch_end = min(batch_start + batch_size, len(links_list))
                batch_links = links_list[batch_start:batch_end]
                
                print(f"üî• Processing batch {batch_start//batch_size + 1}: {len(batch_links)} links")
                
                # Try to get prediction for this batch
                max_retries = 3
                for retry in range(max_retries):
                    try:
                        prediction = self.get_llm_prediction_with_batch(
                            base64_images=base64_images,
                            batch_links=batch_links,
                            candidates_gps=all_candidates_gps,
                            reverse_gps=all_reverse_gps,
                            image_paths=image_paths,
                            transcript_file_path=transcript_file_path,
                            metadata_file_path=metadata_file_path,
                            n_coords=15,
                            model_name=model_name
                        )
                        print(prediction)
                        
                        if self.is_valid_enhanced_gps_dict(prediction):
                            coords = (prediction["latitude"], prediction["longitude"])
                            all_predictions[coords] = prediction
                            print(f"‚úÖ Batch prediction successful: {coords}")
                            break
                        else:
                            print(f"‚ö†Ô∏è Invalid prediction format in batch, retrying... ({retry+1}/{max_retries})")
                            
                    except Exception as e:
                        print(f"‚ùå Batch prediction failed ({retry+1}/{max_retries}): {e}")
                        if retry == max_retries - 1:
                            print(f"‚ö†Ô∏è Skipping batch after {max_retries} failures")
        
        # If we only have RAG coordinates (no web links), make prediction with RAG only
        elif all_candidates_gps:
            print("üî• No web links found, using RAG coordinates only...")
            max_retries = 3
            for retry in range(max_retries):
                try:
                    prediction = self.get_llm_prediction_with_batch(
                        base64_images=base64_images,
                        batch_links=[],  # No web links
                        candidates_gps=all_candidates_gps,
                        reverse_gps=all_reverse_gps,
                        image_paths=image_paths,
                        transcript_file_path=transcript_file_path,
                        metadata_file_path=metadata_file_path,
                        n_coords=15,
                        model_name=model_name
                    )
                    
                    if self.is_valid_enhanced_gps_dict(prediction):
                        coords = (prediction["latitude"], prediction["longitude"])
                        all_predictions[coords] = prediction
                        print(f"‚úÖ RAG-only prediction successful: {coords}")
                        break
                    else:
                        print(f"‚ö†Ô∏è Invalid RAG prediction format, retrying... ({retry+1}/{max_retries})")
                        
                except Exception as e:
                    print(f"‚ùå RAG prediction failed ({retry+1}/{max_retries}): {e}")
                    if retry == max_retries - 1:
                        print(f"‚ö†Ô∏è RAG prediction failed after {max_retries} attempts")
        
        if not all_predictions:
            raise ValueError("No valid predictions obtained from any batch")
        
        # Step 4: Calculate similarity scores and select best prediction
        predicted_coords = list(all_predictions.keys())
        print(f"üéØ Calculating similarity scores for {len(predicted_coords)} predictions...")
        
        avg_similarities = self.calculate_similarity_scores(
            rgb_images=rgb_images,
            predicted_coords=predicted_coords
        )
        
        # Find best prediction
        best_idx = np.argmax(avg_similarities)
        best_coords = predicted_coords[best_idx]
        best_prediction = all_predictions[best_coords]
        
        # Add metadata to result
        result = best_prediction.copy()
        result["metadata"] = {
            "num_images_processed": len(image_files),
            "num_unique_links": len(self.candidates_dict),
            "num_rag_candidates": len(all_candidates_gps),
            "num_rag_reverse": len(all_reverse_gps),
            "num_predictions": len(all_predictions),
            "similarity_scores": avg_similarities.tolist(),
            "all_predictions": {str(coords): pred for coords, pred in all_predictions.items()},
            "best_similarity_score": float(avg_similarities[best_idx]),
            "batch_size": batch_size,
            "links_per_image": links_per_image,
            "top_k_faiss": top_k,
            "database_csv_path": database_csv_path,
            "transcript_file_path": transcript_file_path,
            "metadata_file_path": metadata_file_path,
            "candidates_dict": self.candidates_dict,
            "rag_coordinates": {
                "candidates_gps": all_candidates_gps,
                "reverse_gps": all_reverse_gps
            },
            "has_faiss_index": self.index is not None,
            "processing_mode": "comprehensive" if self.candidates_dict and all_candidates_gps else 
                             "web_only" if self.candidates_dict else "rag_only"
        }
        
        print(f"üèÜ Best prediction selected: {best_coords}")
        print(f"   Best similarity score: {avg_similarities[best_idx]:.4f}")
        print(f"   All similarity scores: {avg_similarities}")
        
        return result


# Convenience function for backward compatibility
def batch_predict_keyframes(
    keyframes_dir: str,
    checkpoint_path: str,
    device: str = "cuda",
    index_path: Optional[str] = None,
    database_csv_path: str = "",
    serpapi_key: str = "",
    imgbb_key: str = "",
    transcript_file_path: str = "",
    metadata_file_path: str = "",
    batch_size: int = 10,
    links_per_image: int = 3,
    top_k: int = 20,
    model_name: str = "gemini-2.5-flash"
) -> dict:
    """
    Convenience function for comprehensive batch prediction using the BatchKeyframePredictor class.
    
    Args:
        keyframes_dir (str): Directory containing keyframe/image files
        checkpoint_path (str): Path to G3 model checkpoint
        device (str): Device to run model on ("cuda" or "cpu")
        index_path (str): Path to FAISS index for RAG (optional)
        database_csv_path (str): Path to GPS coordinates database CSV for RAG
        serpapi_key (str): SerpAPI key for image search
        imgbb_key (str): imgbb API key for image upload
        transcript_file_path (str): Path to transcript file
        metadata_file_path (str): Path to metadata JSON file
        batch_size (int): Number of links to process in each batch (default: 10)
        links_per_image (int): Number of search result links to extract per image (default: 3)
        top_k (int): Number of top FAISS results for RAG (default: 20)
        model_name (str): LLM model name to use
        
    Returns:
        dict: Best prediction with latitude, longitude, location, reason, and metadata
    """
    predictor = BatchKeyframePredictor(
        checkpoint_path=checkpoint_path, 
        device=device,
        index_path=index_path
    )
    return predictor.predict(
        keyframes_dir=keyframes_dir,
        database_csv_path=database_csv_path,
        serpapi_key=serpapi_key,
        imgbb_key=imgbb_key,
        transcript_file_path=transcript_file_path,
        metadata_file_path=metadata_file_path,
        batch_size=batch_size,
        links_per_image=links_per_image,
        top_k=top_k,
        model_name=model_name
    )

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
keyframes_dir = r"C:\Users\tungd\OneDrive - MSFT\Second Year\ML\ACMMM25 - Grand Challenge on Multimedia Verification\G3-Original\g3\data\batch_processing\katna_keyframes\ID115"
checkpoint_path = "checkpoints/mercator_finetune_weight.pth"
index_path = "index/G3.index"  # FAISS index for RAG
database_csv_path = "data/mp16/MP16_Pro_filtered.csv"  # GPS database for RAG
device = "cuda"  if torch.cuda.is_available() else "cpu"  # Use GPU if available

# Optional context files
transcript_file_path = r"C:\Users\tungd\OneDrive - MSFT\Second Year\ML\ACMMM25 - Grand Challenge on Multimedia Verification\G3-Original\g3\data\batch_processing\combined_transcript.txt"  # Audio transcript (optional)
metadata_file_path = r"C:\Users\tungd\OneDrive - MSFT\Second Year\ML\ACMMM25 - Grand Challenge on Multimedia Verification\dataset\validation\ID115\ID115\input\ID115.json"    # Video metadata (optional)

# API keys (can also be set as environment variables)
serpapi_key = os.getenv("SERPAPI_KEY", "")
imgbb_key = os.getenv("IMGBB_KEY", "")

# Processing parameters
batch_size = 10  # Process 10 links at a time
links_per_image = 3  # Extract 3 search result links per image
top_k = 20  # Top 20 FAISS results for RAG

try:
    print("üöÄ Starting comprehensive batch prediction for keyframes...")
    
    # Method 1: Using the class directly (recommended for comprehensive features)
    predictor = BatchKeyframePredictor(
        checkpoint_path=checkpoint_path,
        device=device,
        index_path=index_path  # Enable RAG with FAISS index
    )
    
    result = predictor.predict(
        keyframes_dir=keyframes_dir,
        database_csv_path=database_csv_path,  # GPS database for RAG
        serpapi_key=serpapi_key,
        imgbb_key=imgbb_key,
        transcript_file_path=transcript_file_path,  # Audio context
        metadata_file_path=metadata_file_path,      # Video metadata
        batch_size=batch_size,
        links_per_image=links_per_image,
        top_k=top_k
    )
    
    print("\nüéâ Comprehensive batch prediction completed successfully!")
    print(f"üìç Predicted Location: {result['latitude']}, {result['longitude']}")
    print(f"üìã Place: {result['location']}")
    print(f"üí≠ Reasoning: {result['reason']}")
    
    # Print comprehensive metadata
    metadata = result['metadata']
    print(f"\nüìä Processing Statistics:")
    print(f"   ‚Ä¢ Images processed: {metadata['num_images_processed']}")
    print(f"   ‚Ä¢ Unique web links: {metadata['num_unique_links']}")
    print(f"   ‚Ä¢ RAG candidate coordinates: {metadata['num_rag_candidates']}")
    print(f"   ‚Ä¢ RAG reverse coordinates: {metadata['num_rag_reverse']}")
    print(f"   ‚Ä¢ Valid predictions: {metadata['num_predictions']}")
    print(f"   ‚Ä¢ Best similarity score: {metadata['best_similarity_score']:.4f}")
    print(f"   ‚Ä¢ Processing mode: {metadata['processing_mode']}")
    print(f"   ‚Ä¢ Has FAISS index: {metadata['has_faiss_index']}")
    
    # Show RAG coordinates info
    if metadata['num_rag_candidates'] > 0:
        rag_coords = metadata['rag_coordinates']
        print(f"\nüìç RAG Coordinates Summary:")
        print(f"   ‚Ä¢ Sample candidate GPS: {rag_coords['candidates_gps'][:3]}")
        print(f"   ‚Ä¢ Sample reverse GPS: {rag_coords['reverse_gps'][:3]}")
    
    # Show web candidates info
    if metadata['num_unique_links'] > 0:
        candidates_dict = metadata['candidates_dict']
        print(f"\nüîó Web Candidates Summary:")
        for i, (link, candidate) in enumerate(list(candidates_dict.items())[:3]):
            print(f"   {i+1}. {candidate[:100]}...")
    
    # Show context files used
    print(f"\nüìÇ Context Files Used:")
    print(f"   ‚Ä¢ Database CSV: {metadata['database_csv_path'] or 'None'}")
    print(f"   ‚Ä¢ Transcript: {metadata['transcript_file_path'] or 'None'}")
    print(f"   ‚Ä¢ Metadata: {metadata['metadata_file_path'] or 'None'}")
    
    # Save result to file
    import json
    output_file = "comprehensive_batch_prediction_result.json"
    with open(output_file, 'w') as f:
        json.dump(result, f, indent=2)
    print(f"üíæ Detailed result saved to: {output_file}")
    
except Exception as e:
    print(f"‚ùå Error during comprehensive batch prediction: {e}")

üöÄ Starting comprehensive batch prediction for keyframes...


  self.model.load_state_dict(torch.load(checkpoint_path, map_location=device))


‚úÖ Successfully loaded FAISS index from: index/G3.index
üìÅ Found 26 image files in C:\Users\tungd\OneDrive - MSFT\Second Year\ML\ACMMM25 - Grand Challenge on Multimedia Verification\G3-Original\g3\data\batch_processing\katna_keyframes\ID115
üñºÔ∏è Processing image 1/26: 3660eb4aed1548c7bb8a8dc305d9a4c3_kf_0000.jpeg
üîç Searching FAISS index for RAG coordinates...
Searching FAISS index...
ÔøΩ Found 20 candidate GPS and 20 reverse GPS coordinates
ÔøΩüîç Searching web for image: 3660eb4aed1548c7bb8a8dc305d9a4c3_kf_0000.jpeg
üîç Search mode: Image + Text
üöÄ Using engine: google_lens
üì§ Uploading image to imgbb: C:\Users\tungd\OneDrive - MSFT\Second Year\ML\ACMMM25 - Grand Challenge on Multimedia Verification\G3-Original\g3\data\batch_processing\katna_keyframes\ID115\3660eb4aed1548c7bb8a8dc305d9a4c3_kf_0000.jpeg
‚úÖ Image uploaded successfully: https://i.ibb.co/ns7kNW1Y/bac1da703780.jpg
üîç Searching with SerpAPI...
   Engine: google_lens
   Text query: 
   Image URL: https://i.

KeyboardInterrupt: 