In [1]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import open_clip
from tqdm import tqdm
import matplotlib.pyplot as plt
import textwrap
import re

# ------------------------------
# DEBUGGING & CONFIGURATION
# ------------------------------
# IMPORTANT: This forces CUDA operations to be synchronous.
# It makes the code slower but provides a correct stack trace if a CUDA error occurs.
# This is essential for debugging the "unknown error".
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

CONFIG = {
    "model_name": "ViT-H-14",
    "pretrained": "laion2b_s32b_b79k",
    "csv_train": "C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/train.csv",
    "csv_val": "C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/valid.csv",
    "csv_test": "C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/test.csv",
    "save_path": "finetuned_multi_attribute_best.pt",
    "plot_save_dir": "test_results_multi_attribute",
    "num_plots": 20,
    "batch_size": 4, # Adjust based on your VRAM
    "epochs": 5,
    "lr": 1e-6, # Lower LR is often better for fine-tuning large models
    "patience": 3,
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # --- Prompts must match your final inference script ---
    "identity_prompts": [
        "A face of Soekarno, a male First President (1945-1967) from Indonesia.",
        "A face of Suharto, a male Second President (1967-1998) from Indonesia.",
        "A face of Baharuddin Jusuf Habibie, a male Third President (1998-1999) from Indonesia.",
        "A face of Abdurrahman Wahid, a male Fourth President (1999-2001) from Indonesia.",
        "A face of Megawati Sukarnoputri, a female Fifth President (2001-2004) from Indonesia.",
        "A face of Susilo Bambang Yudhoyono, a male Sixth President (2004-2014) from Indonesia.",
        "A face of Joko Widodo, a male Seventh President (2014-2024) from Indonesia.",
        "A face of Prabowo Subianto, a male Eight President (2024-Present) from Indonesia.",
        "A face of Anies Rasyid Baswedan, a male Governor of Jakarta (2017-2022) and Presidential Candidate Election (2024) from Indonesia.",
        "A face of Ganjar Pranowo, a male Governor of Central Java (2013-2023) and Presidential Candidate Election (2024) from Indonesia.",
        "A face of Gibran Rakabuming Raka, a male Vice President (2024-2029) from Indonesia.",
        "A face of Maruf Amin, a male Vice President (2019-2024) from Indonesia.",
        "A face of Airlangga Hartarto, a male Coordinating Minister of Economic Affairs (2024-2029) from Indonesia.",
        "A face of Sri Mulyani Indrawati, a female Minister of Finance (2024-2029) from Indonesia.",
        "A face of Erick Thohir, a male Minister of State Owned Entreprises (2024-2029) from Indonesia.",
        "A face of Agus Harimurti Yudhoyono, a male Coordinating Minister of Agrarian Affairs and Spatial Planning (2024-2029) and Chairman of Democratic Party from Indonesia.",
        "A face of Muhaimin Iskandar, a male Coordinating Minister of Social Empowrement (2024-2029) and Chairman of National Awakening Party from Indonesia.",
        "A face of Mahfud MD, a male Coordinating Minister of Political, Legal, and Security Affairs (2019-2024) from Indonesia.",
        "A face of Boediono, a male Vice President (2009-2014) from Indonesia.",
        "A face of Jusuf Kalla, a male Vice President (2004-2009) and Vice President (2014-2019) from Indonesia.",
    ],
    "age_prompts": ["a teenager", "a young adult", "a middle-aged person", "a late adult", "an elderly person"],
    "gender_prompts": ["a male", "a female"],
    "expression_prompts": [
        "a person showing anger", "a person showing contempt", "a person showing disgust",
        "a person showing happiness", "a person showing fear", "a person showing sadness",
        "a person showing surprise", "a person with a neutral expression"
    ]
}

# --- Create simplified keys from prompts for mapping ---
CONFIG["identity_keys"] = [re.search(r"of (.*?),", p).group(1) for p in CONFIG["identity_prompts"]]
CONFIG["age_keys"] = [p.replace("a ", "").replace("an ", "") for p in CONFIG["age_prompts"]]
CONFIG["gender_keys"] = ["male", "female"]
CONFIG["expression_keys"] = [p.split(" showing ")[-1].split(" with a ")[-1].replace(" expression", "") for p in CONFIG["expression_prompts"]]

# ------------------------------
# Helper Functions
# ------------------------------
def parse_training_prompt(prompt_text):
    """Parses various prompt templates to extract name, age_group, and expression."""
    patterns = [
        (r"(.*?) (?:male|female) named (.*?) with (.*?) expression\.", ("age_group", "name", "expression")),
        (r"(.*?) is (.*?) (?:male|female) showing (.*?) face\.", ("name", "age_group", "expression")),
        (r"portrait of (.*?), (.*?) (?:male|female) who looks (.*?)\.", ("name", "age_group", "expression")),
        (r"face of (.*?), (.*?) (?:male|female), expressing (.*?)\.", ("name", "age_group", "expression")),
        (r"(.*?), (.*?) (?:male|female), with (.*?) look\.", ("name", "age_group", "expression")),
        (r"the (.*?) face of (.*?), (.*?) (?:male|female)\.", ("expression", "name", "age_group")),
        (r"(.*?) looks (.*?), is (.*?) (?:male|female)\.", ("name", "expression", "age_group")),
        (r"a photo of (.*?), a (.*?) year old person, with a (.*?) expression", ("name", "age_group", "expression")),
    ]
    for pattern, keys in patterns:
        match = re.search(pattern, prompt_text, re.IGNORECASE)
        if match:
            return {keys[i]: match.group(i + 1).strip() for i in range(len(keys))}
    return None

def preprocess_and_cache_csv(csv_path):
    """
    Parses prompts in a CSV to get ground truth indices and saves a cached version.
    If the cached version exists, it loads it directly.
    """
    cache_path = csv_path.replace(".csv", ".cached.csv")
    if os.path.exists(cache_path):
        print(f"Loading preprocessed data from cache: {cache_path}")
        try:
            # Also check if cached file is empty
            df_cache = pd.read_csv(cache_path)
            if df_cache.empty:
                print(f"Warning: Cached file is empty. Reprocessing: {csv_path}")
            else:
                return df_cache
        except pd.errors.EmptyDataError:
            print(f"Warning: Cached file is empty. Reprocessing: {csv_path}")

    print(f"Preprocessing and caching data from: {csv_path}")
    try:
        df = pd.read_csv(csv_path)
        if df.empty:
            print(f"Warning: Original CSV file is empty: {csv_path}")
            return None
    except FileNotFoundError:
        print(f"Error: Original CSV not found at {csv_path}")
        return None
    except pd.errors.EmptyDataError:
        print(f"Error: No columns to parse from file. It is empty: {csv_path}")
        return None

    new_data = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc=f"Preprocessing {os.path.basename(csv_path)}"):
        attrs = parse_training_prompt(row['prompt'])
        if attrs:
            row_data = {
                'filepath': row['filepath'],
                'prompt': row['prompt'],
                'identity_idx': CONFIG["identity_keys"].index(attrs["name"]) if attrs.get("name") in CONFIG["identity_keys"] else -1,
                'gender_idx': 0 if 'male' in row['prompt'].lower() else 1,
                'age_idx': CONFIG["age_keys"].index(attrs["age_group"]) if attrs.get("age_group") in CONFIG["age_keys"] else -1,
                'expression_idx': CONFIG["expression_keys"].index(attrs["expression"]) if attrs.get("expression") in CONFIG["expression_keys"] else -1
            }
            new_data.append(row_data)

    if not new_data:
        print("Warning: Preprocessing resulted in an empty dataset. Check your prompts and templates.")
        return None

    cached_df = pd.DataFrame(new_data)
    cached_df.to_csv(cache_path, index=False)
    print(f"Saved cached data to: {cache_path}")
    return cached_df

# ------------------------------
# Dataset Class
# ------------------------------
class FaceAttributeDataset(Dataset):
    def __init__(self, df, preprocess):
        self.df = df if df is not None else pd.DataFrame()
        self.preprocess = preprocess

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = row["filepath"]
        try:
            image = Image.open(image_path).convert("RGB")
            image = self.preprocess(image)
        except FileNotFoundError:
            image = torch.zeros((3, 224, 224))

        gt_indices = torch.tensor([
            row["identity_idx"], row["gender_idx"],
            row["age_idx"], row["expression_idx"]
        ])

        return image, gt_indices, image_path, row["prompt"]

# ------------------------------
# Main Training & Testing Functions
# ------------------------------
def train_and_validate(model, train_loader, val_loader, all_text_features):
    """Encapsulates the entire training and validation loop."""
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"])
    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1) # Ignore samples where an attribute wasn't found
    best_val_loss = float("inf")
    patience_counter = 0
    model_saved = False

    for epoch in range(CONFIG["epochs"]):
        # --- TRAINING ---
        model.train()
        total_loss_sum = 0
        for images, gt_indices, _, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']} [Train]"):
            images, gt_indices = images.to(CONFIG["device"]), gt_indices.to(CONFIG["device"])

            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            logits = (100.0 * image_features @ all_text_features.T)

            offset = 0
            identity_logits = logits[:, offset:offset+len(CONFIG["identity_prompts"])]; offset += len(CONFIG["identity_prompts"])
            gender_logits = logits[:, offset:offset+len(CONFIG["gender_prompts"])]; offset += len(CONFIG["gender_prompts"])
            age_logits = logits[:, offset:offset+len(CONFIG["age_prompts"])]; offset += len(CONFIG["age_prompts"])
            expression_logits = logits[:, offset:offset+len(CONFIG["expression_prompts"])]

            loss_identity = loss_fn(identity_logits, gt_indices[:, 0])
            loss_gender = loss_fn(gender_logits, gt_indices[:, 1])
            loss_age = loss_fn(age_logits, gt_indices[:, 2])
            loss_expression = loss_fn(expression_logits, gt_indices[:, 3])

            total_loss = loss_identity + loss_gender + loss_age + loss_expression

            optimizer.zero_grad(); total_loss.backward(); optimizer.step()
            total_loss_sum += total_loss.item()

        avg_train_loss = total_loss_sum / len(train_loader)

        # --- VALIDATION ---
        model.eval()
        total_val_loss_sum = 0
        with torch.no_grad():
            for images, gt_indices, _, _ in tqdm(val_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']} [Val]"):
                images, gt_indices = images.to(CONFIG["device"]), gt_indices.to(CONFIG["device"])
                image_features = model.encode_image(images)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                logits = (100.0 * image_features @ all_text_features.T)

                offset = 0
                identity_logits = logits[:, offset:offset+len(CONFIG["identity_prompts"])]; offset += len(CONFIG["identity_prompts"])
                gender_logits = logits[:, offset:offset+len(CONFIG["gender_prompts"])]; offset += len(CONFIG["gender_prompts"])
                age_logits = logits[:, offset:offset+len(CONFIG["age_prompts"])]; offset += len(CONFIG["age_prompts"])
                expression_logits = logits[:, offset:offset+len(CONFIG["expression_prompts"])]

                loss_identity = loss_fn(identity_logits, gt_indices[:, 0])
                loss_gender = loss_fn(gender_logits, gt_indices[:, 1])
                loss_age = loss_fn(age_logits, gt_indices[:, 2])
                loss_expression = loss_fn(expression_logits, gt_indices[:, 3])
                total_loss = loss_identity + loss_gender + loss_age + loss_expression
                total_val_loss_sum += total_loss.item()

        avg_val_loss = total_val_loss_sum / len(val_loader)
        print(f"✅ Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), CONFIG["save_path"])
            print(f"🎉 Saved best model to {CONFIG['save_path']}")
            patience_counter = 0
            model_saved = True # Flag that we have successfully saved a model
        else:
            patience_counter += 1
            print(f"⚠️ No improvement. Patience: {patience_counter}/{CONFIG['patience']}")
        if patience_counter >= CONFIG["patience"]:
            print("🛑 Early stopping triggered."); break

    return model_saved


def test_and_plot(model, all_text_features, preprocess):
    """Encapsulates the testing and plotting logic."""
    print("\n--- Starting Final Testing and Plotting Phase ---")
    model.load_state_dict(torch.load(CONFIG["save_path"])); model.to(CONFIG["device"]).eval()
    print("Best model loaded.")

    test_df = preprocess_and_cache_csv(CONFIG["csv_test"])
    if test_df is None or test_df.empty:
        print(f"Error or empty data in test CSV: {CONFIG['csv_test']}. Halting testing.")
        return

    test_dataset = FaceAttributeDataset(test_df, preprocess)
    test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=0)
    print(f"Found {len(test_dataset)} images in the test set.")

    attribute_correct, total_samples = 0, 0
    all_results = []

    with torch.no_grad():
        for images, _, image_paths, texts in tqdm(test_loader, desc="[Testing]"):
            images = images.to(CONFIG["device"])
            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            logits = (100.0 * image_features @ all_text_features.T)
            offset = 0
            identity_logits = logits[:, offset:offset+len(CONFIG["identity_prompts"])]; offset += len(CONFIG["identity_prompts"])
            gender_logits = logits[:, offset:offset+len(CONFIG["gender_prompts"])]; offset += len(CONFIG["gender_prompts"])
            age_logits = logits[:, offset:offset+len(CONFIG["age_prompts"])]; offset += len(CONFIG["age_prompts"])
            expression_logits = logits[:, offset:offset+len(CONFIG["expression_prompts"])]

            pred_identity_indices = identity_logits.argmax(dim=-1)
            pred_gender_indices = gender_logits.argmax(dim=-1)
            pred_age_indices = age_logits.argmax(dim=-1)
            pred_expression_indices = expression_logits.argmax(dim=-1)

            for i in range(len(images)):
                gt_attrs = parse_training_prompt(texts[i])
                if not gt_attrs: continue

                pred_attrs = {
                    "name": CONFIG["identity_keys"][pred_identity_indices[i]],
                    "gender": CONFIG["gender_keys"][pred_gender_indices[i]],
                    "age_group": CONFIG["age_keys"][pred_age_indices[i]],
                    "expression": CONFIG["expression_keys"][pred_expression_indices[i]],
                }

                is_correct = (gt_attrs.get("name") == pred_attrs.get("name") and
                              gt_attrs.get("age_group") == pred_attrs.get("age_group") and
                              gt_attrs.get("expression") == pred_attrs.get("expression"))

                if is_correct: attribute_correct += 1

                all_results.append({
                    "image_path": image_paths[i], "gt_attrs": gt_attrs,
                    "pred_attrs": pred_attrs, "is_correct": is_correct
                })
                total_samples += 1

    if total_samples > 0:
        attr_accuracy = (attribute_correct / total_samples) * 100
        print("\n--- Test Results ---")
        print(f"📊 Attribute Accuracy: {attr_accuracy:.2f}% (Correct if Name, Age, and Expression all match)")
        print("--------------------")
    else:
        print("No valid samples were processed in the test set.")

    # --- PLOTTING ---
    print(f"\n--- Plotting up to {CONFIG['num_plots']} results ---")
    os.makedirs(CONFIG["plot_save_dir"], exist_ok=True)
    print(f"Saving plots to '{os.path.abspath(CONFIG['plot_save_dir'])}'")

    for i, result in enumerate(all_results):
        if i >= CONFIG["num_plots"]:
            print(f"Reached plot limit of {CONFIG['num_plots']}."); break
        try:
            img = Image.open(result["image_path"])
        except FileNotFoundError:
            print(f"Warning: Could not find image {result['image_path']} for plotting. Skipping."); continue

        fig, ax = plt.subplots(figsize=(10, 12))
        ax.imshow(img); ax.axis("off")

        title = f"Result {i+1}: {'CORRECT' if result['is_correct'] else 'INCORRECT'}"
        fig.suptitle(title, fontsize=18, color='green' if result['is_correct'] else 'red', y=0.95)

        gt_attrs_str = (f"Ground Truth:\n"
                        f"  - Name: {result['gt_attrs'].get('name', 'N/A')}\n"
                        f"  - Age: {result['gt_attrs'].get('age_group', 'N/A')}\n"
                        f"  - Expression: {result['gt_attrs'].get('expression', 'N/A')}")

        pred_attrs_str = (f"Prediction:\n"
                          f"  - Name: {result['pred_attrs'].get('name', 'N/A')}\n"
                          f"  - Age: {result['pred_attrs'].get('age_group', 'N/A')}\n"
                          f"  - Expression: {result['pred_attrs'].get('expression', 'N/A')}")

        plt.figtext(0.1, 0.02, gt_attrs_str, ha="left", fontsize=12, wrap=True, va="bottom")
        plt.figtext(0.9, 0.02, pred_attrs_str, ha="right", fontsize=12, wrap=True, va="bottom",
                    color='green' if result['is_correct'] else 'red')

        plt.tight_layout(rect=[0, 0.1, 1, 0.9])

        save_name = f"result_{i+1}_{'correct' if result['is_correct'] else 'incorrect'}.png"
        plt.savefig(os.path.join(CONFIG["plot_save_dir"], save_name), bbox_inches='tight')
        plt.close(fig)

    print("--- Plotting complete ---")

# ==============================
#      MAIN EXECUTION BLOCK
# ==============================
if __name__ == '__main__':
    print(f"Using device: {CONFIG['device']}")
    # --- ADDED: Version checking for easier debugging ---
    print(f"PyTorch Version: {torch.__version__}")
    if torch.cuda.is_available():
        print(f"CUDA Version: {torch.version.cuda}")
        print(f"cuDNN Version: {torch.backends.cudnn.version()}")

    # --- Setup Model and Pre-compute Text Features ---
    model, _, preprocess = open_clip.create_model_and_transforms(
        CONFIG["model_name"], pretrained=CONFIG["pretrained"], device=CONFIG["device"]
    )
    tokenizer = open_clip.get_tokenizer(CONFIG["model_name"])
    all_prompts = CONFIG["identity_prompts"] + CONFIG["gender_prompts"] + CONFIG["age_prompts"] + CONFIG["expression_prompts"]
    all_text_tokens = tokenizer(all_prompts).to(CONFIG["device"])
    with torch.no_grad():
        all_text_features = model.encode_text(all_text_tokens)
        all_text_features = all_text_features / all_text_features.norm(dim=-1, keepdim=True)
    print("Model and text features loaded.")

    # --- Preprocess Data and Create Dataloaders ---
    train_df = preprocess_and_cache_csv(CONFIG["csv_train"])
    val_df = preprocess_and_cache_csv(CONFIG["csv_val"])

    if train_df is not None and val_df is not None and not train_df.empty and not val_df.empty:
        train_dataset = FaceAttributeDataset(train_df, preprocess)
        val_dataset = FaceAttributeDataset(val_df, preprocess)
        train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=0)
        val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=0)

        # --- Run Training ---
        training_successful = train_and_validate(model, train_loader, val_loader, all_text_features)

        # --- Run Testing only if Training was Successful ---
        if training_successful:
            test_and_plot(model, all_text_features, preprocess)
        else:
            print("\nSkipping testing phase because no model was saved during training.")

    else:
        print("\nTraining/validation datasets are empty or could not be loaded. Skipping training and testing.")


Using device: cuda
PyTorch Version: 2.8.0.dev20250507+cu128
CUDA Version: 12.8
cuDNN Version: 90701
Model and text features loaded.
Loading preprocessed data from cache: C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/train.cached.csv
Loading preprocessed data from cache: C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/valid.cached.csv


Epoch 1/5 [Train]:   2%|▏         | 4/175 [09:10<6:32:07, 137.59s/it]


KeyboardInterrupt: 