In [33]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import open_clip
from tqdm import tqdm
import matplotlib.pyplot as plt
import textwrap
import re
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from pathlib import Path

# ------------------------------
# DEBUGGING & CONFIGURATION
# ------------------------------
# This forces CUDA operations to be synchronous for accurate error reporting.
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

CONFIG = {
    "model_name": "ViT-B-16-quickgelu",
    "pretrained": "openai",
    "model_path": "finetuned_identity_only_best_ViT-B-16_openai_7_11.pt",
    "csv_test": "C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/all_dataset.csv",
    "plot_save_dir": "C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/zero_shot_face_recognition/fourth_discussion/finetune_all_dataset_identity_7_11",
    "num_plots": 1000,
    "batch_size": 32,
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # --- Prompts for Identity ---
    "identity_prompts": [
        "A face of Soekarno, a male First President (1945-1967) from Indonesia.",
        "A face of Suharto, a male Second President (1967-1998) from Indonesia.",
        "A face of Baharuddin Jusuf Habibie, a male Third President (1998-1999) from Indonesia.",
        "A face of Abdurrahman Wahid, a male Fourth President (1999-2001) from Indonesia.",
        "A face of Megawati Sukarnoputri, a female Fifth President (2001-2004) from Indonesia.",
        "A face of Susilo Bambang Yudhoyono, a male Sixth President (2004-2014) from Indonesia.",
        "A face of Joko Widodo, a male Seventh President (2014-2024) from Indonesia.",
        "A face of Prabowo Subianto, a male Eight President (2024-Present) from Indonesia.",
        "A face of Anies Rasyid Baswedan, a male Governor of Jakarta (2017-2022) and Presidential Candidate Election (2024) from Indonesia.",
        "A face of Ganjar Pranowo, a male Governor of Central Java (2013-2023) and Presidential Candidate Election (2024) from Indonesia.",
        "A face of Gibran Rakabuming Raka, a male Vice President (2024-2029) from Indonesia.",
        "A face of Maruf Amin, a male Vice President (2019-2024) from Indonesia.",
        "A face of Airlangga Hartarto, a male Coordinating Minister of Economic Affairs (2024-2029) from Indonesia.",
        "A face of Sri Mulyani Indrawati, a female Minister of Finance (2024-2029) from Indonesia.",
        "A face of Erick Thohir, a male Minister of State Owned Enterprises (2024-2029) from Indonesia.",
        "A face of Agus Harimurti Yudhoyono, a male Coordinating Minister of Agrarian Affairs and Spatial Planning (2024-2029) and Chairman of Democratic Party from Indonesia.",
        "A face of Muhaimin Iskandar, a male Coordinating Minister of Social Empowerment (2024-2029) and Chairman of National Awakening Party from Indonesia.",
        "A face of Mahfud MD, a male Coordinating Minister of Political, Legal, and Security Affairs (2019-2024) from Indonesia.",
        "A face of Boediono, a male Vice President (2009-2014) from Indonesia.",
        "A face of Jusuf Kalla, a male Vice President (2004-2009) and Vice President (2014-2019) from Indonesia."
    ],
    # --- Prompts for Gender ---
    "gender_prompts": [
        "a photo of a male person.",
        "a photo of a female person."
    ],
    # --- Prompts for Age Group ---
    "age_prompts": [
        "a photo of a teenager.", "a photo of a young adult.", "a photo of a middle-aged person.",
        "a photo of a late adult.", "a photo of an elderly person."
    ],
    # --- Prompts for Expression ---
    "expression_prompts": [
        "a photo of a person with an anger expression.", "a photo of a person with a contempt expression.",
        "a photo of a person with a disgust expression.", "a photo of a person with a happiness expression.",
        "a photo of a person with a fear expression.", "a photo of a person with a sadness expression.",
        "a photo of a person with a surprise expression.", "a photo of a person with a neutral expression."
    ]
}

# --- Create simplified keys from prompts for mapping ---
CONFIG["identity_keys"] = [p.replace("A face of ", "").split(",")[0] for p in CONFIG["identity_prompts"]]
print(f"Identity Keys: {CONFIG['identity_keys']}")

CONFIG["gender_keys"] = ["male", "female"]
print(f"Gender Keys: {CONFIG['gender_keys']}")

CONFIG["age_keys"] = [p.replace("a photo of ", "").replace("a ", "").replace("an ", "").replace(" person", "").replace(".", "") for p in CONFIG["age_prompts"]]
print(CONFIG["age_keys"])

CONFIG["expression_keys"] = [p.split(" with ")[-1].replace("an ", "").replace("a ", "").replace(" expression.", "") for p in CONFIG["expression_prompts"]]
print(CONFIG["expression_keys"])

# ------------------------------
# Helper Function for Parsing
# ------------------------------
def parse_identity_from_prompt(prompt_text):
    """
    A more robust parser to find any of the identity keys in the prompt
    using word boundaries to ensure exact matches.
    """
    for name in CONFIG["identity_keys"]:
        if re.search(r'\b' + re.escape(name) + r'\b', prompt_text, re.IGNORECASE):
            return name
    return None

def parse_gender_from_prompt(prompt_text):
    """
    A robust parser to find any of the gender keys in the prompt.
    """
    # Search for any of the gender keys from the list in the prompt text
    for gender_key in CONFIG["gender_keys"]:
        # Use regex with word boundaries (\b) to ensure it matches the whole phrase
        if re.search(r'\b' + re.escape(gender_key) + r'\b', prompt_text, re.IGNORECASE):
            return gender_key
    return None

def parse_age_from_prompt(prompt_text):
    """
    A robust parser to find any of the age keys in the prompt.
    """
    # Search for any of the age keys from the list in the prompt text
    for age_key in CONFIG["age_keys"]:
        # Use regex with word boundaries (\b) to ensure it matches the whole phrase
        if re.search(r'\b' + re.escape(age_key) + r'\b', prompt_text, re.IGNORECASE):
            return age_key
    return None

def parse_expression_from_prompt(prompt_text):
    """
    A robust parser to find any of the expression keys in the prompt.
    """
    # Search for any of the expression keys from the list in the prompt text
    for expression_key in CONFIG["expression_keys"]:
        # Use regex with word boundaries (\b) to ensure it matches the whole word/phrase
        if re.search(r'\b' + re.escape(expression_key) + r'\b', prompt_text, re.IGNORECASE):
            return expression_key
    return None

# ------------------------------
# Preprocessing and Caching
# ------------------------------
def preprocess_and_cache_csv_identity(csv_path):
    """Parses prompts for identity and saves a cached version."""
    cache_path = csv_path.replace(".csv", ".identity_cached.csv")
    if os.path.exists(cache_path):
        print(f"Loading preprocessed identity data from cache: {cache_path}")
        try:
            df_cache = pd.read_csv(cache_path)
            if not df_cache.empty:
                return df_cache
        except pd.errors.EmptyDataError:
            print(f"Warning: Cached file is empty. Reprocessing.")

    print(f"Preprocessing identity data and caching: {csv_path}")
    try:
        df = pd.read_csv(csv_path)
    except (FileNotFoundError, pd.errors.EmptyDataError):
        print(f"Error or empty file at {csv_path}")
        return None

    new_data = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc=f"Preprocessing {os.path.basename(csv_path)}"):
        name = parse_identity_from_prompt(row['prompt'])
        if name:
            new_data.append({
                'filepath': row['filepath'],
                'prompt': row['prompt'],
                'identity_idx': CONFIG["identity_keys"].index(name)
            })

    if not new_data:
        print("Warning: Preprocessing resulted in an empty dataset.")
        return None

    cached_df = pd.DataFrame(new_data)
    cached_df.to_csv(cache_path, index=False)
    return cached_df

def preprocess_and_cache_csv_gender(csv_path):
    """
    Parses prompts for gender and saves a cached version.
    """
    cache_path = csv_path.replace(".csv", ".gender_cached.csv")
    if os.path.exists(cache_path):
        print(f"Loading preprocessed gender data from cache: {cache_path}")
        try:
            df_cache = pd.read_csv(cache_path)
            if df_cache.empty:
                 print("Warning: Cached file is empty. Reprocessing.")
            else:
                return df_cache
        except pd.errors.EmptyDataError:
            print(f"Warning: Cached file is empty. Reprocessing.")

    print(f"Preprocessing gender data and caching: {csv_path}")
    try:
        df = pd.read_csv(csv_path)
    except (FileNotFoundError, pd.errors.EmptyDataError):
        print(f"Error or empty file at {csv_path}")
        return None

    new_data = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc=f"Preprocessing {os.path.basename(csv_path)}"):
        gender = parse_gender_from_prompt(row['prompt'])
        if gender:
            row_data = {
                'filepath': row['filepath'],
                'prompt': row['prompt'],
                'gender_idx': CONFIG["gender_keys"].index(gender)
            }
            new_data.append(row_data)

    if not new_data:
        print("Warning: Preprocessing resulted in an empty dataset for gender.")
        return None

    cached_df = pd.DataFrame(new_data)
    cached_df.to_csv(cache_path, index=False)
    return cached_df

def preprocess_and_cache_csv_age(csv_path):
    """
    Parses prompts for age and saves a cached version.
    """
    cache_path = csv_path.replace(".csv", ".age_cached.csv")
    if os.path.exists(cache_path):
        print(f"Loading preprocessed age data from cache: {cache_path}")
        try:
            df_cache = pd.read_csv(cache_path)
            if df_cache.empty:
                 print("Warning: Cached file is empty. Reprocessing.")
            else:
                return df_cache
        except pd.errors.EmptyDataError:
            print(f"Warning: Cached file is empty. Reprocessing.")

    print(f"Preprocessing age data and caching: {csv_path}")
    try:
        df = pd.read_csv(csv_path)
    except (FileNotFoundError, pd.errors.EmptyDataError):
        print(f"Error or empty file at {csv_path}")
        return None

    new_data = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc=f"Preprocessing {os.path.basename(csv_path)}"):
        age_group = parse_age_from_prompt(row['prompt'])
        if age_group:
            row_data = {
                'filepath': row['filepath'],
                'prompt': row['prompt'],
                'age_idx': CONFIG["age_keys"].index(age_group)
            }
            new_data.append(row_data)

    if not new_data:
        print("Warning: Preprocessing resulted in an empty dataset for age.")
        return None

    cached_df = pd.DataFrame(new_data)
    cached_df.to_csv(cache_path, index=False)
    return cached_df

def preprocess_and_cache_csv_expression(csv_path):
    """
    Parses prompts for expression and saves a cached version.
    """
    cache_path = csv_path.replace(".csv", ".expression_cached.csv")
    if os.path.exists(cache_path):
        print(f"Loading preprocessed expression data from cache: {cache_path}")
        try:
            df_cache = pd.read_csv(cache_path)
            if df_cache.empty:
                 print("Warning: Cached file is empty. Reprocessing.")
            else:
                return df_cache
        except pd.errors.EmptyDataError:
            print(f"Warning: Cached file is empty. Reprocessing.")

    print(f"Preprocessing expression data and caching: {csv_path}")
    try:
        df = pd.read_csv(csv_path)
    except (FileNotFoundError, pd.errors.EmptyDataError):
        print(f"Error or empty file at {csv_path}")
        return None

    new_data = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc=f"Preprocessing {os.path.basename(csv_path)}"):
        expression = parse_expression_from_prompt(row['prompt'])
        if expression:
            row_data = {
                'filepath': row['filepath'],
                'prompt': row['prompt'],
                'expression_idx': CONFIG["expression_keys"].index(expression)
            }
            new_data.append(row_data)

    if not new_data:
        print("Warning: Preprocessing resulted in an empty dataset for expression.")
        return None

    cached_df = pd.DataFrame(new_data)
    cached_df.to_csv(cache_path, index=False)
    return cached_df

# ------------------------------
# Dataset Class
# ------------------------------
class IdentityDataset(Dataset):
    def __init__(self, df, preprocess):
        self.df = df if df is not None else pd.DataFrame()
        self.preprocess = preprocess

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = row["filepath"]
        try:
            image = Image.open(image_path).convert("RGB")
            image = self.preprocess(image)
        except FileNotFoundError:
            image = torch.zeros((3, 224, 224))

        identity_idx = torch.tensor(row["identity_idx"])
        return image, identity_idx, image_path

class GenderDataset(Dataset):
    def __init__(self, df, preprocess):
        self.df = df if df is not None else pd.DataFrame()
        self.preprocess = preprocess

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = row["filepath"]
        try:
            image = Image.open(image_path).convert("RGB")
            image = self.preprocess(image)
        except FileNotFoundError:
            image = torch.zeros((3, 224, 224))

        gender_idx = torch.tensor(row["gender_idx"])
        return image, gender_idx, image_path

class AgeDataset(Dataset):
    def __init__(self, df, preprocess):
        self.df = df if df is not None else pd.DataFrame()
        self.preprocess = preprocess

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = row["filepath"]
        try:
            image = Image.open(image_path).convert("RGB")
            image = self.preprocess(image)
        except FileNotFoundError:
            image = torch.zeros((3, 224, 224))

        age_idx = torch.tensor(row["age_idx"])
        return image, age_idx, image_path

class ExpressionDataset(Dataset):
    def __init__(self, df, preprocess):
        self.df = df if df is not None else pd.DataFrame()
        self.preprocess = preprocess

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = row["filepath"]
        try:
            image = Image.open(image_path).convert("RGB")
            image = self.preprocess(image)
        except FileNotFoundError:
            image = torch.zeros((3, 224, 224))

        expression_idx = torch.tensor(row["expression_idx"])
        return image, expression_idx, image_path

# ------------------------------
# Main Testing and Plotting Function
# ------------------------------
def test_and_plot(model, text_features, preprocess):
    print("\n--- Starting Testing and Plotting Phase ---")
    all_df = preprocess_and_cache_csv_identity(CONFIG["csv_test"])
    if all_df is None or all_df.empty:
        print("Halting due to empty or missing test dataset.")
        return

    test_dataset = IdentityDataset(all_df, preprocess)
    test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False)

    correct_predictions, total_samples = 0, 0
    all_results = []
    y_true = []
    y_pred = []

    with torch.no_grad():
        for images, gt_indices, image_paths in tqdm(test_loader, desc="[Testing]"):
            images = images.to(CONFIG["device"])
            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            logits = (100.0 * image_features @ text_features.T)
            preds = logits.argmax(dim=-1)

            y_true.extend(gt_indices.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

            for i in range(len(images)):
                is_correct = (preds[i] == gt_indices[i]).item()
                if is_correct: correct_predictions += 1

                all_results.append({
                    "image_path": image_paths[i],
                    "ground_truth": CONFIG["identity_keys"][gt_indices[i]],
                    "prediction": CONFIG["identity_keys"][preds[i]],
                    "is_correct": is_correct
                })
            total_samples += len(images)

    if total_samples > 0:
        accuracy = (correct_predictions / total_samples) * 100
        print(f"\n📊 Test Accuracy: {accuracy:.2f}%")
    else:
        print("No valid samples were processed in the test set.")
        return # Exit if no results to plot

    # --- Confusion Matrix Calculation and Plotting ---
    print("\n--- Generating Confusion Matrix ---")
    os.makedirs(CONFIG["plot_save_dir"], exist_ok=True)

    cm = confusion_matrix(y_true, y_pred, labels=range(len(CONFIG["identity_keys"])))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=CONFIG["identity_keys"])

    fig, ax = plt.subplots(figsize=(20, 20))
    disp.plot(ax=ax, xticks_rotation='vertical', cmap='Blues')
    plt.title("Identity Classification Confusion Matrix")
    plt.tight_layout()
    cm_save_path = os.path.join(CONFIG["plot_save_dir"], "confusion_matrix.png")
    plt.savefig(cm_save_path)
    plt.close(fig)
    print(f"Confusion matrix saved to: {cm_save_path}")

    # --- Individual Result Plotting ---
    print(f"\n--- Plotting up to {CONFIG['num_plots']} individual results ---")
    for i, result in enumerate(all_results):
        if i >= CONFIG["num_plots"]: break
        try:
            img = Image.open(result["image_path"])
        except FileNotFoundError: continue

        fig, ax = plt.subplots(figsize=(8, 10))
        ax.imshow(img); ax.axis("off")

        title = f"Result {i+1}: {'CORRECT' if result['is_correct'] else 'INCORRECT'}"
        color = 'green' if result['is_correct'] else 'red'
        fig.suptitle(title, fontsize=16, color=color)

        text = (f"Ground Truth: {result['ground_truth']}\n"
                f"Prediction:   {result['prediction']}")
        ax.set_title(text, fontsize=12, pad=10)

        # --- REFINED: Use original filename and add status for saving plots ---
        original_filename_stem = Path(result["image_path"]).stem
        status = "CORRECT" if result['is_correct'] else "INCORRECT"
        save_name = f"{original_filename_stem}_result_{status}.png"
        plt.savefig(os.path.join(CONFIG["plot_save_dir"], save_name), bbox_inches='tight')
        plt.close(fig)

    print("--- Plotting complete ---")

# ==============================
#      MAIN EXECUTION BLOCK
# ==============================
if __name__ == '__main__':
    print(f"Using device: {CONFIG['device']}")
    if torch.cuda.is_available():
        print(f"PyTorch Version: {torch.__version__}, CUDA Version: {torch.version.cuda}")

    model, _, preprocess = open_clip.create_model_and_transforms(
        CONFIG["model_name"], pretrained=CONFIG["pretrained"], device=CONFIG["device"]
    )

    # Load the fine-tuned model if the path exists
    if os.path.exists(CONFIG["model_path"]):
        print(f"Loading fine-tuned model from: {CONFIG['model_path']}")
        model.load_state_dict(torch.load(CONFIG["model_path"]))
    else:
        print(f"Warning: No fine-tuned model found at {CONFIG['model_path']}. Using the pretrained model directly.")

    model.to(CONFIG["device"]).eval()

    tokenizer = open_clip.get_tokenizer(CONFIG["model_name"])

    with torch.no_grad():
        text_tokens = tokenizer(CONFIG["identity_prompts"]).to(CONFIG["device"])
        text_features = model.encode_text(text_tokens)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    test_and_plot(model, text_features, preprocess)


Identity Keys: ['Soekarno', 'Suharto', 'Baharuddin Jusuf Habibie', 'Abdurrahman Wahid', 'Megawati Sukarnoputri', 'Susilo Bambang Yudhoyono', 'Joko Widodo', 'Prabowo Subianto', 'Anies Rasyid Baswedan', 'Ganjar Pranowo', 'Gibran Rakabuming Raka', 'Maruf Amin', 'Airlangga Hartarto', 'Sri Mulyani Indrawati', 'Erick Thohir', 'Agus Harimurti Yudhoyono', 'Muhaimin Iskandar', 'Mahfud MD', 'Boediono', 'Jusuf Kalla']
Gender Keys: ['male', 'female']
['teenager', 'young adult', 'middle-aged', 'late adult', 'elderly']
['anger', 'contempt', 'disgust', 'happiness', 'fear', 'sadness', 'surprise', 'neutral']
Using device: cuda
PyTorch Version: 2.8.0.dev20250507+cu128, CUDA Version: 12.8
Loading fine-tuned model from: finetuned_identity_only_best_ViT-B-16_openai_7_11.pt

--- Starting Testing and Plotting Phase ---
Loading preprocessed identity data from cache: C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/all_dataset.identity_cached.csv


[Testing]: 100%|██████████| 32/32 [00:10<00:00,  3.07it/s]



📊 Test Accuracy: 98.80%

--- Generating Confusion Matrix ---
Confusion matrix saved to: C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/zero_shot_face_recognition/fourth_discussion/finetune_all_dataset_identity_7_11\confusion_matrix.png

--- Plotting up to 1000 individual results ---
--- Plotting complete ---
