In [None]:
import time
import torch
import sys
import os

sys.path.append('LLaVA')
sys.path.append('LLaVA/llava')

import pandas as pd
from PIL import Image
import numpy as np
import cv2
from torch.nn.functional import cosine_similarity
from pathlib import Path
import json
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model_frame
from tqdm import tqdm
import clip
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Base folder where video files are stored.
BASE_DATA_DIR = Path('data')

def normalize_dataset(ds):
    ds_lower = ds.strip().lower()
    mapping = {
        "deeperforensics": "Deeperforensics",
        "deepfakedetection": "DeepfakeDetection",
        "faceforensics++": "Faceforensics++",
        "farceforensics++": "Faceforensics++",  # fix misspelling
        "original": "Original",
    }
    return mapping.get(ds_lower, ds)

def find_movie_file_by_dataset_and_manipulation(movie_name, dataset_value, manipulation, base_folder):
    """
    Search for a file whose name contains movie_name (case insensitive)
    in the folder base_folder/normalized_dataset_value and ensure that one of its
    parent folder names contains the manipulation string (case insensitive).
    """
    normalized_ds = normalize_dataset(dataset_value)
    target_dir = Path(base_folder) / normalized_ds
    if not target_dir.exists():
        print(f"Dataset folder {target_dir} not found.")
        return None
    for file_path in target_dir.rglob('*'):
        if file_path.is_file() and movie_name.lower() in file_path.name.lower():
            if manipulation.lower() in str(file_path.parent).lower():
                return file_path.resolve()
    return None

class MaskUtils:
    @staticmethod
    def apply_mask(frame, keypoint, use_hard_mask=True, radius=75):
        h, w, c = frame.shape
        kp_x = keypoint["x"] * w
        kp_y = keypoint["y"] * h
        yy, xx = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
        distance = np.sqrt((xx - kp_x)**2 + (yy - kp_y)**2)
        mask = (distance <= radius).astype(np.uint8)
        if use_hard_mask:
            for ch in range(c):
                frame[:, :, ch] *= mask
        else:
            blur_ksize = 83
            mask = mask.astype(np.float32)
            blurred = cv2.GaussianBlur(mask, (blur_ksize, blur_ksize), 83)
            blurred_normalized = cv2.normalize(blurred, None, 0, 1, cv2.NORM_MINMAX)
            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
            dilated = cv2.dilate(blurred_normalized, kernel, iterations=1)
            for ch in range(c):
                frame[:, :, ch] = frame[:, :, ch].astype(np.float32) * dilated
            frame = frame.astype(np.uint8)
        return frame

def safe_load_model(model_path, max_retries=3, delay=5):
    attempt = 0
    while attempt < max_retries:
        try:
            print(f"Loading model from {model_path} (attempt {attempt+1})...")
            tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
                model_path=model_path,
                model_base=None,
                model_name=get_model_name_from_path(model_path)
            )
            print("Model loaded.")
            return tokenizer, llava_model, image_processor, context_len
        except Exception as e:
            print(f"Error: {e}. Retrying...")
            time.sleep(delay)
            attempt += 1
    raise RuntimeError("Failed to load model.")

def get_clip_embedding(image, layer="last"):
    image_tensor = preprocess(image).unsqueeze(0).to(device)
    image_tensor = image_tensor.half() if clip_model.visual.conv1.weight.dtype == torch.half else image_tensor
    with torch.no_grad():
        if layer == "first":
            x = clip_model.visual.conv1(image_tensor)
            x = clip_model.visual.bn1(x)
            x = clip_model.visual.relu1(x)
            x = clip_model.visual.conv2(x)
            x = clip_model.visual.bn2(x)
            x = clip_model.visual.relu2(x)
            x = clip_model.visual.conv3(x)
            x = clip_model.visual.bn3(x)
            x = clip_model.visual.relu3(x)
            # Use adaptive average pooling to reduce the spatial dimensions to 1x1
            x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
            embedding = x.view(x.size(0), -1)

        elif layer == "middle":
            # Process through early layers
            x = clip_model.visual.conv1(image_tensor)
            x = clip_model.visual.bn1(x)
            x = clip_model.visual.relu1(x)
            x = clip_model.visual.conv2(x)
            x = clip_model.visual.bn2(x)
            x = clip_model.visual.relu2(x)
            x = clip_model.visual.conv3(x)
            x = clip_model.visual.bn3(x)
            x = clip_model.visual.relu3(x)
            x = clip_model.visual.avgpool(x)
            # Process through later layers
            x = clip_model.visual.layer1(x)
            x = clip_model.visual.layer2(x)
            x = clip_model.visual.layer3(x)
            # Use adaptive average pooling to reduce spatial dimensions to 1x1
            x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
            embedding = x.view(x.size(0), -1)

        elif layer == "last":
            embedding = clip_model.encode_image(image_tensor)
        else:
            raise ValueError("Invalid layer specified.")
        embedding = embedding.view(embedding.size(0), -1)
    return embedding

def create_custom_prompt(annotations, prompt_version=1):
    if prompt_version == 1:
        base_prompt = (
            "Based on the following descriptions: {annotations}, analyze the face in the image and "
            "identify any signs of deepfake artifacts. Provide a detailed description of any anomalies."
        )
    elif prompt_version == 2:
        base_prompt = (
            "Analyze the face based on: '{annotations}'. Provide a short explanation highlighting inconsistencies."
        )
    elif prompt_version == 3:
        base_prompt = (
            "Based on the annotations: {annotations}, examine the face for any signs of deepfake manipulation. "
            "The provided annotations should serve as possible signs, but remember that the image can also be real. "
            "Keep your answer between 10 and 30 words."
        )
    elif prompt_version == 4:
        base_prompt = (
            "Analyze the face in the image based on the description: '{annotations}'. Identify any deepfake artifacts, focusing "
            "specifically on the affected parts of the face mentioned. Provide a short and direct explanation,"
            "highlighting the inconsistencies or manipulations."
        )
    else:
        raise ValueError("Invalid prompt version.")
    return base_prompt.format(annotations=', '.join(f'"{ann}"' for ann in annotations))

def detect_deepfake(frame, custom_prompt):
    if not isinstance(frame, np.ndarray):
        raise ValueError("Expected NumPy array.")
    frame = frame.astype('uint8')
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    image_pil = Image.fromarray(frame_rgb)
    local_args = type('Args', (), {
        "model_path": model_path,
        "model_base": None,
        "model_name": get_model_name_from_path(model_path),
        "query": custom_prompt,
        "conv_mode": None,
        "temperature": 0,
        "top_p": None,
        "num_beams": 1,
        "max_new_tokens": 512,
    })()
    return eval_model_frame(
        local_args,
        image_pil,
        tokenizer=tokenizer,
        model=llava_model,
        image_processor=image_processor
    )

def get_training_frames(df):
    print("Processing training frames...")
    training_frames = {}
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Training rows"):
        if 'movie_path' in row and pd.notna(row['movie_path']):
            video_path = Path(row['movie_path']).resolve()
        else:
            movie_name = row['movie_name']
            dataset_val = row['dataset'] if 'dataset' in row and pd.notna(row['dataset']) else ""
            manipulation = row['manipulation'] if 'manipulation' in row and pd.notna(row['manipulation']) else ""
            video_path = find_movie_file_by_dataset_and_manipulation(movie_name, dataset_val, manipulation, BASE_DATA_DIR)
            if video_path is None:
                print(f"[Row {idx}] File for '{movie_name}' with dataset '{dataset_val}' and manipulation '{manipulation}' not found. Skipping.")
                continue

        click_locations = row['click_locations']
        if not click_locations or pd.isna(click_locations):
            continue
        try:
            frame_data = json.loads(click_locations)
            # For training frames, store the annotation from "text" field
            if (len(frame_data) == 1 and "0" in frame_data and 
                float(frame_data["0"].get("x", -1)) == 0.0 and float(frame_data["0"].get("y", -1)) == 0.0):
                training_frames.setdefault(video_path, []).append((0, row['text']))
            else:
                for frame_str, _ in frame_data.items():
                    if frame_str.isdigit():
                        frame_num = int(frame_str) - 1
                        training_frames.setdefault(video_path, []).append((frame_num, row['text']))
        except Exception as e:
            print(f"Error processing row {idx}: {e}")
            continue
    for vp, frames in training_frames.items():
        training_frames[vp] = sorted(frames, key=lambda x: x[0])
    print(f"Processed training frames for {len(training_frames)} videos.")
    return training_frames

def get_test_frames(df):

    frames_dict = {}
    
    for idx, row in df.iterrows():
        if 'movie_path' in row and pd.notna(row['movie_path']):
            video_path = Path(row['movie_path']).resolve()
        else:
            movie_name = row['movie_name']
            dataset_val = row['dataset'] if 'dataset' in row and pd.notna(row['dataset']) else ""
            manipulation = row['manipulation'] if 'manipulation' in row and pd.notna(row['manipulation']) else ""
            video_path = find_movie_file_by_dataset_and_manipulation(movie_name, dataset_val, manipulation, BASE_DATA_DIR)
            if video_path is None:
                continue

        if not video_path.exists():
            print(f"Video not found at: {video_path}, skipping.")
            continue
        
        cap = cv2.VideoCapture(str(video_path))
        if not cap.isOpened():
            print(f"Could not open video: {video_path}")
            continue
        
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        middle_frame = frame_count // 2
        cap.release()
        
        ground_truth = row['text'] if 'text' in row and pd.notna(row['text']) else None
        frames_dict[video_path] = [(middle_frame, [], ground_truth)]
    
    print(f"Prepared middle-frame test data for {len(frames_dict)} videos.")
    return frames_dict

def run_pipeline(rn_model_name, extraction_layer, top_k,
                 mask_on=False, use_hard_mask=True, prompt_version=1):
    print(f"\nLoading CLIP model {rn_model_name} on {device} ...")
    global clip_model, preprocess
    clip_model, preprocess = clip.load(rn_model_name, device=device)
    clip_model.eval()
    print("CLIP model loaded.")

    print("Extracting training frames...")
    train_frames = get_training_frames(train_df)

    print("Extracting test frames (middle only)...")
    test_frames = get_test_frames(test_df)

    emb_save_path = f"training_embeddings_{rn_model_name}_{extraction_layer}.pt"
    if os.path.exists(emb_save_path):
        print(f"Loading cached training embeddings from {emb_save_path} ...")
        emb_data = torch.load(emb_save_path, map_location=device)
        training_embeddings_tensor = emb_data['embeddings']
        training_annotations = emb_data['annotations']
        training_keys = emb_data['keys']
    else:
        print(f"Computing training embeddings for [{rn_model_name}, {extraction_layer}] ...")
        training_embeddings_list, training_annotations, training_keys = [], [], []
        for video_path, frames in tqdm(train_frames.items(),
                                       desc=f"Processing training videos [{rn_model_name}, {extraction_layer}]"):
            cap = cv2.VideoCapture(str(video_path))
            if not cap.isOpened():
                print(f"Could not open video: {video_path}")
                continue
            for (frame_number, annotation) in frames:
                cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
                ret, frame = cap.read()
                if not ret:
                    continue
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                image_pil = Image.fromarray(frame_rgb)
                embedding = get_clip_embedding(image_pil, layer=extraction_layer)
                training_embeddings_list.append(embedding)
                training_annotations.append(annotation)
                training_keys.append((video_path, frame_number))
            cap.release()
        if not training_embeddings_list:
            raise RuntimeError("No training embeddings computed.")
        training_embeddings_tensor = torch.cat(training_embeddings_list, dim=0)
        torch.save({
            'embeddings': training_embeddings_tensor,
            'annotations': training_annotations,
            'keys': training_keys
        }, emb_save_path)
        print(f"Saved training embeddings to {emb_save_path}")

    print("Analyzing test videos (middle frame)...")
    results = []
    for video_path, frame_info_list in tqdm(test_frames.items(),
                                            desc=f"Analyzing test videos [{rn_model_name}, {extraction_layer}, top_k={top_k}]"):
        cap = cv2.VideoCapture(str(video_path))
        if not cap.isOpened():
            print(f"Could not open video: {video_path}")
            continue

        for (frame_number, keypoints, ground_truth) in frame_info_list:
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
            ret, frame = cap.read()
            if not ret:
                continue

            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            image_pil = Image.fromarray(frame_rgb)
            test_embedding = get_clip_embedding(image_pil, layer=extraction_layer)

            sims = cosine_similarity(test_embedding, training_embeddings_tensor).squeeze(0)
            sorted_indices = torch.argsort(sims, descending=True)

            # Distinct videos among top matches
            distinct_indices = []
            used_videos = set()
            for idx in sorted_indices.tolist():
                vid_path, _ = training_keys[idx]
                if vid_path not in used_videos:
                    used_videos.add(vid_path)
                    distinct_indices.append(idx)
                if len(distinct_indices) == top_k:
                    break

            chosen_indices = distinct_indices
            chosen_values = sims[chosen_indices]
            chosen_annotations = [training_annotations[i] for i in chosen_indices]

            custom_prompt = create_custom_prompt(chosen_annotations, prompt_version=prompt_version) \
                if chosen_annotations else "No annotation available"

            frame_for_llava = frame.copy()
            if mask_on and keypoints:
                for kp in keypoints:
                    frame_for_llava = MaskUtils.apply_mask(frame_for_llava, kp,
                                                           use_hard_mask=use_hard_mask,
                                                           radius=75)

            test_deepfake_analysis = detect_deepfake(frame_for_llava, custom_prompt) \
                if chosen_annotations else "No analysis available"

            results.append({
                'rn_model': rn_model_name,
                'extraction_layer': extraction_layer,
                'top_k': top_k,
                'test_video': str(video_path),
                'test_frame': frame_number,
                'ground_truth': ground_truth,
                'closest_train_annotations': chosen_annotations,
                'test_deepfake_analysis': test_deepfake_analysis,
                'top_k_similarities': chosen_values.tolist(),
                'prompt_version_used': prompt_version,
            })
        cap.release()
    print(f"Completed analysis for {len(results)} test frames.")
    return results


if __name__ == "__main__":
    print("Starting pipeline execution...")

    # Model path for LLaVA
    model_path = "liuhaotian/llava-v1.5-7b"
    tokenizer, llava_model, image_processor, context_len = safe_load_model(model_path)
    llava_model = llava_model.to(device)
    
    print("Loading dataset...")
    csv_path = 'dataset_last.csv'
    df = pd.read_csv(csv_path)
    df = df.sample(frac=1, random_state=0).reset_index(drop=True)
    
    global train_df, val_df, test_df
    train_df = df[df['split'] == 'train']
    val_df = df[df['split'] == 'val']   # Not used in this script, but kept for reference
    test_df = df[df['split'] == 'test']
    
    print("Dataset Summary:")
    train_unique = train_df[['dataset', 'manipulation', 'movie_name']].drop_duplicates().shape[0]
    val_unique = val_df[['dataset', 'manipulation', 'movie_name']].drop_duplicates().shape[0]
    test_unique = test_df[['dataset', 'manipulation', 'movie_name']].drop_duplicates().shape[0]
    print("Total Train Videos:", train_unique)
    print("Total Validation Videos:", val_unique)
    print("Total Test Videos:", test_unique)
    
    print("\nRepresentative file paths per dataset and manipulation:")
    unique_ds_manip = df[['dataset', 'manipulation']].drop_duplicates()
    for ds, manip in unique_ds_manip.values:
        subset = df[(df['dataset'] == ds) & (df['manipulation'] == manip)]
        if len(subset) == 0:
            continue
        sample = subset.iloc[0]
        movie_name = sample['movie_name']
        if 'movie_path' in sample and pd.notna(sample['movie_path']):
            path = Path(sample['movie_path']).resolve()
        else:
            path = find_movie_file_by_dataset_and_manipulation(movie_name, ds, manip, BASE_DATA_DIR)
        print(f"Dataset: {ds} | Manipulation: {manip} -> {path}")
    
    # Preview how many frames are in the train/test sets
    train_frames_info = get_training_frames(train_df)
    test_frames_info = get_test_frames(test_df)
    total_train_frames = sum(len(frames) for frames in train_frames_info.values())
    total_test_frames = sum(len(frames) for frames in test_frames_info.values())
    print("Total training frames of interest:", total_train_frames)
    print("Total test frames of interest:", total_test_frames)
    
    # Example combos
    rn_models = ["RN101"]
    extraction_layers = ["last"]
    top_k_values = [5]
    mask_on_val = True
    use_hard_mask_val = False
    
    all_results = []
    for rn_model_name in rn_models:
        for extraction_layer in extraction_layers:
            for top_k in top_k_values:
                print(f"\n=== Running: {rn_model_name}, {extraction_layer} extraction, top_k = {top_k} ===")
                comb_results = run_pipeline(
                    rn_model_name=rn_model_name,
                    extraction_layer=extraction_layer,
                    top_k=top_k,
                    mask_on=mask_on_val,
                    use_hard_mask=use_hard_mask_val,
                    prompt_version=4
                )
                comb_csv_path = f"results_{rn_model_name}_{extraction_layer}_k{top_k}_mask_{mask_on_val}_hardmask_{use_hard_mask_val}.csv"
                pd.DataFrame(comb_results).to_csv(comb_csv_path, index=False)
                print(f"Saved results to {comb_csv_path}")
                all_results.extend(comb_results)
    
    combined_csv_path = "results_all_combinations.csv"
    pd.DataFrame(all_results).to_csv(combined_csv_path, index=False)
    print(f"\nAll results saved to {combined_csv_path}")
    print("Pipeline execution completed.")
