In [None]:
import faiss
import torch
import json
import os  
import numpy as np
import tkinter as tk
from tkinter import filedialog
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2 
import insightface
from insightface.app import FaceAnalysis
from datasets import Dataset, Image as HFImage
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import (
    DonutProcessor, 
    VisionEncoderDecoderModel, 
    Seq2SeqTrainingArguments, 
    Seq2SeqTrainer
)

# --- Setup Device ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f" Using device: {device}")
print(f"Faiss version: {faiss.__version__}")
print(f"Torch version: {torch.__version__}")

# --- Initialize SOTA Face Model (ArcFace) ---
print(" Initializing SOTA Face Model (ArcFace)...")
# providers=['CUDAExecutionProvider'] requires ONNX Runtime GPU. 
# If it fails, it will fall back to CPU.
face_app = FaceAnalysis(name='buffalo_l', providers=['CUDAExecutionProvider'])
face_app.prepare(ctx_id=0, det_size=(640, 640))

In [None]:
def select_folder(title="Select Folder"):
    """Opens a dialog to select a folder."""
    root = tk.Tk()
    root.withdraw()
    root.attributes('-topmost', True)
    return filedialog.askdirectory(title=title)

def select_file(title="Select File", filetypes=[("All Files", "*.*")]):
    """Opens a dialog to select a file."""
    root = tk.Tk()
    root.withdraw()
    root.attributes('-topmost', True)
    return filedialog.askopenfilename(title=title, filetypes=filetypes)


In [None]:
# ==========================================
# DATA LOADING CONFIGURATION
# ==========================================
print(" Please select the ROOT folder (containing 'positive', 'meta', 'fraud1...', etc.)")
root_dir = select_folder("Select IDNet Root Folder")

if not root_dir:
    print(" Selection cancelled.")
else:
    print(f"\n Root: {root_dir}")

# Define ALL Use Case Folders
target_folders = [
    "positive",
    "fraud1_copy_and_move",
    "fraud2_face_morphing",
    "fraud3_face_replacement",
    "fraud4_combined",
    "fraud5_inpaint_and_rewrite",
    "fraud6_crop_and_replace"
]


In [None]:
# --- Load Metadata ---
meta_base_dir = os.path.join(root_dir, "meta")
meta_subfolders = ["positive_meta", "detailed_with_fraud_info"]
metadata_map = {}

if os.path.exists(meta_base_dir):
    print(" Loading metadata files...")
    for subfolder in meta_subfolders:
        full_sub_path = os.path.join(meta_base_dir, subfolder)
        if os.path.exists(full_sub_path):
            print(f"   Processing: {subfolder}...")
            for filename in os.listdir(full_sub_path):
                if filename.lower().endswith(".json") or filename.lower().endswith(".jsonl"):
                    full_path = os.path.join(full_sub_path, filename)
                    try:
                        with open(full_path, 'r', encoding='utf-8') as f:
                            if filename.endswith('.jsonl'):
                                raw_data = [json.loads(line) for line in f]
                            else:
                                raw_data = json.load(f)
                            
                            if isinstance(raw_data, list):
                                for item in raw_data:
                                    key = item.get('file_name') or item.get('image_id') or item.get('id')
                                    if key: metadata_map[key] = item
                            elif isinstance(raw_data, dict):
                                metadata_map.update(raw_data)
                    except Exception as e:
                        print(f" Error reading {filename}: {e}")
    print(f" Loaded combined metadata for {len(metadata_map)} items.")
else:
    print(" 'meta' folder not found!")


In [None]:
# --- Load Images & Match Metadata ---
dataset_samples = []
valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff"}

print("  Scanning image folders...")

for folder_name in target_folders:
    folder_path = os.path.join(root_dir, folder_name)
    if not os.path.exists(folder_path):
        print(f" Warning: Folder not found: {folder_name} (Skipping)")
        continue
        
    print(f"   Scanning: {folder_name}...")
    count = 0
    for root, _, files in os.walk(folder_path):
        for file in files:
            if os.path.splitext(file)[1].lower() in valid_extensions:
                file_path = os.path.join(root, file)
                
                # Match metadata
                key = file 
                meta = metadata_map.get(key)
                if not meta: meta = metadata_map.get(os.path.splitext(file)[0])

                if meta:
                    # Tag the data with the specific Use Case/Fraud Type
                    is_fraud = "positive" not in folder_name
                    meta_copy = meta.copy()
                    meta_copy["is_fraud"] = is_fraud
                    meta_copy["fraud_type"] = folder_name if is_fraud else "none"

                    dataset_samples.append({
                        "image": file_path,
                        "text": json.dumps(meta_copy) 
                    })
                    count += 1
    print(f"      -> Found {count} matched images.")

if len(dataset_samples) == 0:
    print(" No matching images found.")


In [None]:
# Convert to HuggingFace Dataset
if len(dataset_samples) > 0:
    hf_dataset = Dataset.from_list(dataset_samples)
    hf_dataset = hf_dataset.cast_column("image", HFImage())
    print(" Dataset object created successfully.")
else:
    print(" Dataset creation failed (0 samples).")


In [None]:
# ==========================================
# MODEL TRAINING SETUP
# ==========================================
model_save_path = "custom_trained_donut_model"
max_length = 512
image_size = [1280, 960]

# Initialize Processor
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
processor.tokenizer.chat_template = None 
processor.feature_extractor.size = image_size[::-1]
processor.feature_extractor.do_align_long_axis = False

new_special_tokens = ["<s_name>", "</s_name>", "<s_id>", "</s_id>", "<s_address>", "</s_address>"] 
processor.tokenizer.add_special_tokens({"additional_special_tokens": new_special_tokens + ["<s>", "</s>"]})

# Generator Function
def transform_generator(sample):
    pixel_values = processor(sample["image"].convert("RGB"), random_padding=True, return_tensors="pt").pixel_values.squeeze()
    
    #FIX: text is already a string now, no need to json.dumps() again
    text_str = sample["text"] 
    target_sequence = "<s>" + text_str + "</s>"
    
    input_ids = processor.tokenizer(
        target_sequence, add_special_tokens=False, max_length=max_length,
        padding="max_length", truncation=True, return_tensors="pt",
    )["input_ids"].squeeze(0)
    
    labels = input_ids.clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    
    return {
        "pixel_values": pixel_values,
        "labels": labels
    }

# Prepare Iterable Dataset
iterable_dataset = hf_dataset.to_iterable_dataset()
processed_iterable_dataset = iterable_dataset.map(transform_generator)
shuffled_dataset = processed_iterable_dataset.shuffle(seed=42, buffer_size=100)


In [None]:
try:
    # Initialize Model
    model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")
    model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s>'])[0]
    model.config.pad_token_id = processor.tokenizer.pad_token_id
    model.decoder.resize_token_embeddings(len(processor.tokenizer))
    model.config.encoder.image_size = image_size[::-1]
    model.config.decoder.max_length = max_length
    
    # Training Arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir="training_output",
        max_steps=100, # Set low for testing, increase for real training
        learning_rate=2e-5,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        fp16=torch.cuda.is_available(),
        logging_steps=10,
        save_steps=50,
        remove_unused_columns=True,
    )

    trainer = Seq2SeqTrainer(
        model=model, 
        args=training_args, 
        train_dataset=shuffled_dataset, 
    )

    print(" Starting Training...")
    trainer.train()
    trainer.save_model(model_save_path)
    processor.save_pretrained(model_save_path)
    print(f" Model saved to {model_save_path}")

except Exception as e:
    print(f" Training failed: {e}")


#### Load Saved Custom Trained Donut Model

In [None]:
import os
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel

def load_custom_donut(model_path="custom_trained_donut_model"):
    """
    Loads the fine-tuned Donut model and processor from disk.
    Returns: (model, processor) or (None, None) if failed.
    """
    print(f" Loading Donut model from folder: '{model_path}'...")
    
    # 1. Check if the folder actually exists
    if not os.path.exists(model_path):
        print(f" Error: The folder '{model_path}' does not exist.")
        print("    You must run the 'Training' block at least once to create it.")
        return None, None

    try:
        # 2. Determine Device (GPU is much faster)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # 3. Load the Model Weights (pytorch_model.bin)
        # This reconstructs the exact neural network you saved.
        model = VisionEncoderDecoderModel.from_pretrained(model_path)
        model.to(device) # Move to GPU
        model.eval()     # Switch to Inference Mode (Critical for accuracy)
        
        # 4. Load the Processor (Tokenizer + Image config)
        # This ensures we process images exactly how the model expects them.
        processor = DonutProcessor.from_pretrained(model_path)
        
        print(f" Custom Donut Model loaded successfully on {device}!")
        return model, processor

    except Exception as e:
        print(f" Critical Error loading model: {e}")
        return None, None

# --- HOW TO USE IT ---
# Instead of training every time, just run this line:
loaded_model, loaded_processor = load_custom_donut()

# If it loaded correctly, you can now update your extraction function to use it:
if loaded_model:
    # Update the global variables so the rest of your code uses the loaded version
    model = loaded_model
    processor = loaded_processor


### For Large Training of Entire 5,000 - 20,000 Images

In [None]:
"""try:
    # Initialize Model
    model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")
    model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s>'])[0]
    model.config.pad_token_id = processor.tokenizer.pad_token_id
    model.decoder.resize_token_embeddings(len(processor.tokenizer))
    model.config.encoder.image_size = image_size[::-1]
    model.config.decoder.max_length = max_length
    
    # --- REAL TRAINING CONFIGURATION ---
    # Calculate steps based on dataset size
    # If you have 10,000 images and batch size 2, one epoch is 5,000 steps.
    # A good training run is usually 3 to 5 epochs.
    
    # Assuming roughly 5,000 - 10,000 images in IDNet:
    REAL_STEPS = 5000  
    
    training_args = Seq2SeqTrainingArguments(
        output_dir="training_output",
        
        # 1. INCREASE STEPS
        max_steps=REAL_STEPS,           # Changed from 100 to 5000
        
        # 2. OPTIMIZE LEARNING RATE
        learning_rate=1e-5,             # Slightly lower rate for stability over long runs
        warmup_steps=200,               # Gently ramp up speed at the start
        
        # 3. BATCH SIZE (Crucial for GPU Memory)
        # If you get "CUDA Out of Memory", change this to 1
        per_device_train_batch_size=2,  
        gradient_accumulation_steps=4,  # Simulates a batch size of 8 (2 * 4) for better convergence
        
        # 4. HARDWARE ACCELERATION
        fp16=torch.cuda.is_available(), # Use mixed precision (faster, less memory)
        
        # 5. LOGGING & SAVING
        logging_steps=100,              # Log progress every 100 steps
        save_steps=1000,                # Save a checkpoint every 1000 steps (so you don't lose progress)
        save_total_limit=2,             # Only keep the last 2 checkpoints to save disk space
        remove_unused_columns=True,
    )

    trainer = Seq2SeqTrainer(
        model=model, 
        args=training_args, 
        train_dataset=shuffled_dataset, 
    )

    print(f" Starting REAL Training for {REAL_STEPS} steps...")
    print("This may take several hours. Monitor the loss (it should go down).")
    
    trainer.train()
    
    # Save final model
    trainer.save_model(model_save_path)
    processor.save_pretrained(model_save_path)
    print(f" Training Complete! Model saved to {model_save_path}")

except Exception as e:
    print(f" Training failed: {e}")
    # Troubleshooting tip
    if "out of memory" in str(e).lower():
        print(" TIP: Reduce 'per_device_train_batch_size' to 1 and restart.")
"""

### Continued Code

In [None]:
# ==========================================
# INDEXING (Uses all Folders)
# ==========================================
print("  Building Indexes from loaded dataset...")

face_embeddings = []
text_documents = []
metadata_store = []

for idx, item in enumerate(tqdm(hf_dataset)):
    image = item['image'].convert("RGB")
    text_data = item['text'] # Already a string
    
    # 1. Face Embedding
    open_cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    faces = face_app.get(open_cv_image)
    
    if len(faces) > 0:
        emb = faces[0].normed_embedding.astype('float32')
    else:
        emb = np.zeros(512).astype('float32')
        
    face_embeddings.append(emb)
    text_documents.append(text_data)
    
    metadata_store.append({
        "id": idx,
        "data": text_data
    })

# FAISS Setup
face_embeddings = np.array(face_embeddings).astype('float32')
d = 512
if torch.cuda.is_available():
    res = faiss.StandardGpuResources()
    face_index = faiss.GpuIndexFlatL2(res, d)
else:
    face_index = faiss.IndexFlatL2(d)
face_index.add(face_embeddings)

# TF-IDF Setup
tfidf_vectorizer = TfidfVectorizer()
text_matrix = tfidf_vectorizer.fit_transform(text_documents)

print(" Search Indexes Ready!")


## Save and Load Index

In [None]:
import pickle

# 1. Save Metadata Store
with open("metadata_store.pkl", "wb") as f:
    pickle.dump(metadata_store, f)

# 2. Save FAISS Index
# Note: GPU indexes must be moved to CPU before saving
if torch.cuda.is_available():
    cpu_index = faiss.index_gpu_to_cpu(face_index)
    faiss.write_index(cpu_index, "face_index.bin")
else:
    faiss.write_index(face_index, "face_index.bin")

# 3. Save TF-IDF Vectorizer and Matrix
with open("tfidf_data.pkl", "wb") as f:
    pickle.dump((tfidf_vectorizer, text_matrix), f)

print(" All indexes and metadata saved to disk!")

In [None]:
import pickle

# Check if files exist first
if os.path.exists("metadata_store.pkl") and os.path.exists("face_index.bin"):
    print("‚è≥ Loading indexes from disk...")

    # 1. Load Metadata
    with open("metadata_store.pkl", "rb") as f:
        metadata_store = pickle.load(f)

    # 2. Load FAISS Index
    face_index = faiss.read_index("face_index.bin")
    # Optional: Move back to GPU for speed
    if torch.cuda.is_available():
        res = faiss.StandardGpuResources()
        face_index = faiss.index_cpu_to_gpu(res, 0, face_index)

    # 3. Load TF-IDF
    with open("tfidf_data.pkl", "rb") as f:
        tfidf_vectorizer, text_matrix = pickle.load(f)

    print(f" Loaded {len(metadata_store)} items from disk. Ready to search!")
else:
    print(" No saved indexes found. Please run the 'Indexing' block first.")


## Continue Training 

In [None]:
def extract_text_donut(image):
    """
    Extracts text using the GLOBALLY loaded model and processor.
    This is much faster because we don't reload the model for every single image.
    """
    # Safety Check: Ensure model is actually loaded
    if 'model' not in globals() or 'processor' not in globals() or model is None:
        print("‚ö†Ô∏è Model not loaded! Running fallback text extraction...")
        return "" 

    device = model.device # Get device from the model itself (cpu or cuda)

    # 1. Prepare Image
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
    
    # 2. Prepare Prompt (Start Token)
    task_prompt = "<s>"
    decoder_input_ids = processor.tokenizer(
        task_prompt, 
        add_special_tokens=False, 
        return_tensors="pt"
    ).input_ids.to(device)
    
    # 3. Generate Output
    with torch.no_grad():
        outputs = model.generate(
            pixel_values, 
            decoder_input_ids=decoder_input_ids, 
            max_length=512,
            return_dict_in_generate=True
        )
    
    # 4. Decode to String
    sequence = processor.batch_decode(outputs.sequences)[0]
    
    # 5. Clean up the output string (Remove <s> and </s> tokens)
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    # Also remove the start token if it appears
    sequence = sequence.replace("<s>", "")
    
    return sequence

def search_pipeline(query_image_path, k=5, alpha=0.5):
    """Hybrid search using Face Embeddings and OCR Text."""
    query_img = Image.open(query_image_path).convert("RGB")
    img_cv2 = cv2.imread(query_image_path)
    
    # Face Scores
    face_scores = np.zeros(len(metadata_store))
    if alpha > 0:
        faces = face_app.get(img_cv2)
        if len(faces) > 0:
            q_emb = faces[0].normed_embedding.reshape(1, -1).astype('float32')
            D, I = face_index.search(q_emb, k * 10) # Search wider then filter
            for rank, idx in enumerate(I[0]):
                if idx == -1: continue
                face_scores[idx] = 1 / (1 + D[0][rank])

    # Text Scores
    text_scores = np.zeros(len(metadata_store))
    if alpha < 1:
        extracted_text = extract_text_donut(query_img)
        q_vec = tfidf_vectorizer.transform([extracted_text])
        sims = cosine_similarity(q_vec, text_matrix).flatten()
        text_scores = sims

    # Combine Scores
    final_scores = (alpha * face_scores) + ((1 - alpha) * text_scores)
    top_indices = np.argsort(final_scores)[::-1][:k]
    
    results = []
    for idx in top_indices:
        results.append({
            "id": idx,
            "score": final_scores[idx],
            "metadata": metadata_store[idx]
        })
    return results, query_img

In [None]:
import math

print(" Select a query ID card image...")
query_path = select_file("Select Query Image")

if query_path:
    print(f"üîé Searching for: {query_path}")
    
    # User Input for Search Mode
    mode_map = {'1': "Face Only", '2': "Text Only", '3': "Hybrid"}
    mode_input = input("Select Mode: [1] Face Only  [2] Text Only  [3] Hybrid (Default): ")
    alpha = 1.0 if mode_input == '1' else (0.0 if mode_input == '2' else 0.5)
    selected_mode = mode_map.get(mode_input, "Hybrid")
    
    # Search Pipeline (k=10 for detailed audit)
    K_RESULTS = 10
    results, q_img = search_pipeline(query_path, k=K_RESULTS, alpha=alpha)
    
    print("\n" + "="*80)
    print(f" AUDIT REPORT | Mode: {selected_mode} | Alpha: {alpha}")
    print("="*80)
    
    # --- VISUALIZATION SETUP ---
    # Calculate grid size dynamically (e.g. 2 rows of 6)
    total_plots = len(results) + 1
    cols = 6
    rows = math.ceil(total_plots / cols)
    
    plt.figure(figsize=(20, 4 * rows))
    
    # Plot 1: The Query Image
    plt.subplot(rows, cols, 1)
    plt.imshow(q_img)
    plt.title("QUERY INPUT\n(Source Artifact)", color='blue', weight='bold')
    plt.axis("off")
    
    # --- DETAILED RESULT LOOP ---
    for i, res in enumerate(results):
        rank = i + 1
        score = res['score']
        
        # 1. Deep Metadata Extraction
        # The 'data' field is a JSON string we need to parse
        try:
            meta_dict = json.loads(res['metadata']['data'])
            is_fraud = meta_dict.get('is_fraud', False)
            fraud_type = meta_dict.get('fraud_type', 'None')
            
            # Extract PII for Audit (Obfuscated for privacy if needed)
            audit_id = meta_dict.get('id', 'N/A')
            audit_name = meta_dict.get('name', 'N/A')
            audit_dob = meta_dict.get('dob', 'N/A')
            
        except json.JSONDecodeError:
            is_fraud = False
            fraud_type = "METADATA_ERROR"
            audit_id = "ERR"
            
        # 2. Status Logic
        status_label = "üö® FRAUD DETECTED" if is_fraud else "‚úÖ AUTHENTIC"
        status_color = 'red' if is_fraud else 'green'
        
        # 3. Console Audit Log (Extreme Detail)
        print(f"\n[RANK #{rank}] Score: {score:.4f} | Status: {status_label}")
        print(f"   ‚îú‚îÄ  Database ID: {res['id']}")
        print(f"   ‚îú‚îÄ  Fraud Type:  {fraud_type.upper()}")
        print(f"   ‚îú‚îÄ  Document Metadata:")
        print(f"   ‚îÇ    ‚îú‚îÄ Name: {audit_name}")
        print(f"   ‚îÇ    ‚îú‚îÄ DOB:  {audit_dob}")
        print(f"   ‚îÇ    ‚îî‚îÄ File: {meta_dict.get('file_name', 'Unknown')}")
        print(f"   ‚îî‚îÄ  Match Logic: (Alpha={alpha})")
        
        # 4. Visualization Plot
        match_img = hf_dataset[int(res['id'])]['image']
        
        plt.subplot(rows, cols, i + 2)
        plt.imshow(match_img)
        
        # Detailed Plot Title
        title_text = (
            f"#{rank} {status_label}\n"
            f"Score: {score:.2f}\n"
            f"Type: {fraud_type}"
        )
        plt.title(title_text, color=status_color, fontsize=9)
        
        # Add a border color to the image based on status
        # (Matplotlib hack: draw a box around the axis)
        ax = plt.gca()
        for spine in ax.spines.values():
            spine.set_edgecolor(status_color)
            spine.set_linewidth(3)
            
        plt.axis("off")

    plt.tight_layout()
    plt.show()
    print("\n" + "="*80)
    print(" End of Audit Report")
    print("="*80)

else:
    print(" Selection Cancelled: No file selected.")
