In [1]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import open_clip
from tqdm import tqdm
import matplotlib.pyplot as plt
import textwrap
import re

# ------------------------------
# DEBUGGING & CONFIGURATION
# ------------------------------
# This forces CUDA operations to be synchronous for accurate error reporting.
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

CONFIG = {
    "model_name": "ViT-B-16",
    "pretrained": "datacomp_xl_s13b_b90k",
    "csv_train": "C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/train.csv",
    "csv_val": "C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/valid.csv",
    "csv_test": "C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/test.csv",
    "save_path": "finetuned_multi_attribute_final_consistent.pt",
    "plot_save_dir": "test_results_multi_attribute_final",
    "num_plots": 100,
    "batch_size": 8, # Start small for large models to avoid CUDA OOM errors
    "epochs": 100,
    "lr": 1e-6,
    "patience": 5,
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    "identity_prompts": [
        "A photo of soekarno.", "A photo of suharto.", "A photo of baharuddin jusuf habibie.",
        "A photo of abdurrahman wahid.", "A photo of megawati sukarnoputri.",
        "A photo of susilo bambang yudhoyono.", "A photo of joko widodo.",
        "A photo of prabowo subianto.", "A photo of anies rasyid baswedan.",
        "A photo of ganjar pranowo.", "A photo of gibran rakabuming raka.",
        "A photo of maruf amin.", "A photo of airlangga hartarto.",
        "A photo of sri mulyani indrawati.", "A photo of erick thohir.",
        "A photo of agus harimurti yudhoyono.", "A photo of muhaimin iskandar.",
        "A photo of mahfud md.", "A photo of boediono", "A photo of jusuf kalla"
    ],
    "age_prompts": ["a photo of a teenager.", "a photo of a young adult.", "a photo of a middle-aged person.", "a photo of a late adult.", "a photo of an elderly person."],
    "gender_prompts": ["a photo of a male person.", "a photo of a female person."],
    "expression_prompts": [
        "a photo of a person with an anger expression.", "a photo of a person with a contempt expression.",
        "a photo of a person with a disgust expression.", "a photo of a person with a happiness expression.",
        "a photo of a person with a fear expression.", "a photo of a person with a sadness expression.",
        "a photo of a person with a surprise expression.", "a photo of a person with a neutral expression."
    ]
}

# --- Create simplified keys from prompts for mapping ---
CONFIG["identity_keys"] = [p.replace("A photo of ", "").replace(".", "") for p in CONFIG["identity_prompts"]]
CONFIG["age_keys"] = [p.replace("a photo of ", "").replace("a ", "").replace("an ", "").replace(" person", "").replace(".", "") for p in CONFIG["age_prompts"]]
CONFIG["gender_keys"] = ["male", "female"]
CONFIG["expression_keys"] = [p.split(" with ")[-1].replace("a ", "").replace("an ", "").replace(" expression.", "") for p in CONFIG["expression_prompts"]]
# print(CONFIG["identity_keys"], CONFIG["age_keys"], CONFIG["gender_keys"], CONFIG["expression_keys"])
# ------------------------------
# Helper Functions
# ------------------------------
def parse_training_prompt(prompt_text):
    """Parses a consistent prompt template to extract all attributes."""
    # This regex is specifically for the format: "Name, gender, age_group, expression."
    pattern = r"^(.*?),\s*(male|female),\s*(.*?),\s*(.*?)\.$"
    match = re.search(pattern, prompt_text.strip(), re.IGNORECASE)

    if match:
        attrs = {
            "name": match.group(1).strip().lower(),
            "gender": match.group(2).strip().lower(),
            "age_group": match.group(3).strip().lower(),
            "expression": match.group(4).strip().lower()
        }
        return attrs
    return None

def preprocess_and_cache_csv(csv_path):
    """Parses all attributes and saves a single cached version."""
    cache_path = csv_path.replace(".csv", ".multi_attribute_cached.csv")
    if os.path.exists(cache_path):
        print(f"Loading preprocessed data from cache: {cache_path}")
        try:
            df = pd.read_csv(cache_path)
            if df.empty:
                 print("Warning: Cached file is empty. Reprocessing.")
            else:
                return df
        except pd.errors.EmptyDataError:
            print(f"Warning: Cached file is empty. Reprocessing.")

    print(f"Preprocessing and caching data from: {csv_path}")
    try:
        df = pd.read_csv(csv_path)
        if df.empty:
            print(f"Warning: Original CSV file is empty: {csv_path}")
            return None
    except (FileNotFoundError, pd.errors.EmptyDataError):
        print(f"Error or empty file at {csv_path}")
        return None

    new_data = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc=f"Preprocessing {os.path.basename(csv_path)}"):
        attrs = parse_training_prompt(row['prompt'])
        if attrs:
            row_data = {
                'filepath': row['filepath'],
                'prompt': row['prompt'],
                'identity_idx': CONFIG["identity_keys"].index(attrs["name"]) if attrs.get("name") in CONFIG["identity_keys"] else -1,
                'gender_idx': CONFIG["gender_keys"].index(attrs["gender"]) if attrs.get("gender") in CONFIG["gender_keys"] else -1,
                'age_idx': CONFIG["age_keys"].index(attrs["age_group"]) if attrs.get("age_group") in CONFIG["age_keys"] else -1,
                'expression_idx': CONFIG["expression_keys"].index(attrs["expression"]) if attrs.get("expression") in CONFIG["expression_keys"] else -1
            }
            new_data.append(row_data)

    if not new_data:
        print("Warning: Preprocessing resulted in an empty dataset. Check prompt templates and keys.")
        return None

    cached_df = pd.DataFrame(new_data)
    cached_df.to_csv(cache_path, index=False)
    return cached_df


# ------------------------------
# Dataset Class
# ------------------------------
class MultiAttributeDataset(Dataset):
    def __init__(self, df, preprocess):
        self.df = df if df is not None else pd.DataFrame()
        self.preprocess = preprocess

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = row["filepath"]
        try:
            image = Image.open(image_path).convert("RGB")
            image = self.preprocess(image)
        except FileNotFoundError:
            image = torch.zeros((3, 224, 224))

        gt_indices = torch.tensor([
            row["identity_idx"], row["gender_idx"],
            row["age_idx"], row["expression_idx"]
        ])

        return image, gt_indices, image_path

# ------------------------------
# Main Training & Testing Functions
# ------------------------------
def train_and_validate(model, train_loader, val_loader, all_text_features):
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"])
    # --- ADDED: ignore_index=-1 tells the loss function to skip samples that couldn't be parsed.
    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1)
    best_val_loss = float("inf")
    patience_counter = 0
    model_saved = False

    for epoch in range(CONFIG["epochs"]):
        model.train()
        total_loss_sum = 0
        for images, gt_indices, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']} [Train]"):
            images, gt_indices = images.to(CONFIG["device"]), gt_indices.to(CONFIG["device"])

            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            logits = (100.0 * image_features @ all_text_features.T)

            offset = 0
            identity_logits = logits[:, offset:offset+len(CONFIG["identity_prompts"])]; offset += len(CONFIG["identity_prompts"])
            gender_logits = logits[:, offset:offset+len(CONFIG["gender_prompts"])]; offset += len(CONFIG["gender_prompts"])
            age_logits = logits[:, offset:offset+len(CONFIG["age_prompts"])]; offset += len(CONFIG["age_prompts"])
            expression_logits = logits[:, offset:offset+len(CONFIG["expression_prompts"])]

            loss_identity = loss_fn(identity_logits, gt_indices[:, 0])
            loss_gender = loss_fn(gender_logits, gt_indices[:, 1])
            loss_age = loss_fn(age_logits, gt_indices[:, 2])
            loss_expression = loss_fn(expression_logits, gt_indices[:, 3])

            total_loss = loss_identity + loss_gender + loss_age + loss_expression

            optimizer.zero_grad()
            total_loss.backward()

            # --- ADDED: Gradient Clipping to prevent exploding gradients and NaN loss ---
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()
            total_loss_sum += total_loss.item()

        avg_train_loss = total_loss_sum / len(train_loader)

        model.eval()
        total_val_loss_sum = 0
        with torch.no_grad():
            for images, gt_indices, _ in tqdm(val_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']} [Val]"):
                images, gt_indices = images.to(CONFIG["device"]), gt_indices.to(CONFIG["device"])
                image_features = model.encode_image(images)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                logits = (100.0 * image_features @ all_text_features.T)

                offset = 0
                identity_logits = logits[:, offset:offset+len(CONFIG["identity_prompts"])]; offset += len(CONFIG["identity_prompts"])
                gender_logits = logits[:, offset:offset+len(CONFIG["gender_prompts"])]; offset += len(CONFIG["gender_prompts"])
                age_logits = logits[:, offset:offset+len(CONFIG["age_prompts"])]; offset += len(CONFIG["age_prompts"])
                expression_logits = logits[:, offset:offset+len(CONFIG["expression_prompts"])]

                loss_identity = loss_fn(identity_logits, gt_indices[:, 0])
                loss_gender = loss_fn(gender_logits, gt_indices[:, 1])
                loss_age = loss_fn(age_logits, gt_indices[:, 2])
                loss_expression = loss_fn(expression_logits, gt_indices[:, 3])
                total_loss = loss_identity + loss_gender + loss_age + loss_expression
                total_val_loss_sum += total_loss.item()

        avg_val_loss = total_val_loss_sum / len(val_loader)
        print(f"✅ Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), CONFIG["save_path"])
            print(f"🎉 Saved best model to {CONFIG['save_path']}")
            patience_counter = 0
            model_saved = True
        else:
            patience_counter += 1
            print(f"⚠️ No improvement. Patience: {patience_counter}/{CONFIG['patience']}")
        if patience_counter >= CONFIG["patience"]:
            print("🛑 Early stopping triggered."); break

    return model_saved

def test_and_plot(model, all_text_features, preprocess):
    print("\n--- Starting Final Testing and Plotting Phase ---")
    model.load_state_dict(torch.load(CONFIG["save_path"])); model.to(CONFIG["device"]).eval()
    print("Best model loaded.")

    test_df = preprocess_and_cache_csv(CONFIG["csv_test"])
    if test_df is None or test_df.empty:
        print(f"Error or empty data in test CSV. Halting testing.")
        return

    test_dataset = MultiAttributeDataset(test_df, preprocess)
    test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False)

    attribute_correct, total_samples = 0, 0
    all_results = []

    with torch.no_grad():
        for images, gt_indices, image_paths in tqdm(test_loader, desc="[Testing]"):
            images = images.to(CONFIG["device"])
            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            logits = (100.0 * image_features @ all_text_features.T)
            offset = 0
            identity_logits = logits[:, offset:offset+len(CONFIG["identity_prompts"])]; offset += len(CONFIG["identity_prompts"])
            gender_logits = logits[:, offset:offset+len(CONFIG["gender_prompts"])]; offset += len(CONFIG["gender_prompts"])
            age_logits = logits[:, offset:offset+len(CONFIG["age_prompts"])]; offset += len(CONFIG["age_prompts"])
            expression_logits = logits[:, offset:offset+len(CONFIG["expression_prompts"])]

            pred_identity_indices = identity_logits.argmax(dim=-1)
            pred_gender_indices = gender_logits.argmax(dim=-1)
            pred_age_indices = age_logits.argmax(dim=-1)
            pred_expression_indices = expression_logits.argmax(dim=-1)

            for i in range(len(images)):
                gt_attrs = {
                    "name": CONFIG["identity_keys"][gt_indices[i, 0]],
                    "gender": CONFIG["gender_keys"][gt_indices[i, 1]],
                    "age_group": CONFIG["age_keys"][gt_indices[i, 2]],
                    "expression": CONFIG["expression_keys"][gt_indices[i, 3]],
                }

                pred_attrs = {
                    "name": CONFIG["identity_keys"][pred_identity_indices[i]],
                    "gender": CONFIG["gender_keys"][pred_gender_indices[i]],
                    "age_group": CONFIG["age_keys"][pred_age_indices[i]],
                    "expression": CONFIG["expression_keys"][pred_expression_indices[i]],
                }

                is_correct = (gt_attrs["name"] == pred_attrs["name"] and
                              gt_attrs["age_group"] == pred_attrs["age_group"] and
                              gt_attrs["expression"] == pred_attrs["expression"])

                if is_correct: attribute_correct += 1

                all_results.append({
                    "image_path": image_paths[i], "gt_attrs": gt_attrs,
                    "pred_attrs": pred_attrs, "is_correct": is_correct
                })
                total_samples += 1

    if total_samples > 0:
        attr_accuracy = (attribute_correct / total_samples) * 100
        print("\n--- Test Results ---")
        print(f"📊 Attribute Accuracy: {attr_accuracy:.2f}% (Correct if Name, Age, and Expression all match)")
    else:
        print("No valid samples were processed in the test set.")

    print(f"\n--- Plotting up to {CONFIG['num_plots']} results ---")
    os.makedirs(CONFIG["plot_save_dir"], exist_ok=True)

    for i, result in enumerate(all_results):
        if i >= CONFIG["num_plots"]: break
        try:
            img = Image.open(result["image_path"])
        except FileNotFoundError: continue

        fig, ax = plt.subplots(figsize=(10, 12))
        ax.imshow(img); ax.axis("off")

        title = f"Result {i+1}: {'CORRECT' if result['is_correct'] else 'INCORRECT'}"
        fig.suptitle(title, fontsize=18, color='green' if result['is_correct'] else 'red', y=0.95)

        gt_attrs_str = (f"Ground Truth:\n"
                        f"  - Name: {result['gt_attrs'].get('name', 'N/A')}\n"
                        f"  - Age: {result['gt_attrs'].get('age_group', 'N/A')}\n"
                        f"  - Expression: {result['gt_attrs'].get('expression', 'N/A')}")

        pred_attrs_str = (f"Prediction:\n"
                          f"  - Name: {result['pred_attrs'].get('name', 'N/A')}\n"
                          f"  - Age: {result['pred_attrs'].get('age_group', 'N/A')}\n"
                          f"  - Expression: {result['pred_attrs'].get('expression', 'N/A')}")

        plt.figtext(0.1, 0.02, gt_attrs_str, ha="left", fontsize=12, wrap=True, va="bottom")
        plt.figtext(0.9, 0.02, pred_attrs_str, ha="right", fontsize=12, wrap=True, va="bottom",
                    color='green' if result['is_correct'] else 'red')

        plt.tight_layout(rect=[0, 0.1, 1, 0.9])

        save_name = f"result_{i+1}_{'correct' if result['is_correct'] else 'incorrect'}.png"
        plt.savefig(os.path.join(CONFIG["plot_save_dir"], save_name), bbox_inches='tight')
        plt.close(fig)

    print("--- Plotting complete ---")

# ==============================
#      MAIN EXECUTION BLOCK
# ==============================
if __name__ == '__main__':
    print(f"Using device: {CONFIG['device']}")
    if torch.cuda.is_available():
        print(f"PyTorch Version: {torch.__version__}, CUDA Version: {torch.version.cuda}")

    model, _, preprocess = open_clip.create_model_and_transforms(
        CONFIG["model_name"], pretrained=CONFIG["pretrained"], device=CONFIG["device"]
    )
    tokenizer = open_clip.get_tokenizer(CONFIG["model_name"])

    with torch.no_grad():
        all_prompts = CONFIG["identity_prompts"] + CONFIG["gender_prompts"] + CONFIG["age_prompts"] + CONFIG["expression_prompts"]
        all_text_tokens = tokenizer(all_prompts).to(CONFIG["device"])
        all_text_features = model.encode_text(all_text_tokens)
        all_text_features = all_text_features / all_text_features.norm(dim=-1, keepdim=True)

    train_df = preprocess_and_cache_csv(CONFIG["csv_train"])
    val_df = preprocess_and_cache_csv(CONFIG["csv_val"])

    if train_df is not None and val_df is not None and not train_df.empty and not val_df.empty:
        train_dataset = MultiAttributeDataset(train_df, preprocess)
        val_dataset = MultiAttributeDataset(val_df, preprocess)
        train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False)

        training_successful = train_and_validate(model, train_loader, val_loader, all_text_features)

        if training_successful:
            test_and_plot(model, all_text_features, preprocess)
        else:
            print("\nSkipping testing phase: No model was saved during training.")
    else:
        print("\nSkipping training: Training/validation datasets are empty or could not be loaded.")


Using device: cuda
PyTorch Version: 2.8.0.dev20250507+cu128, CUDA Version: 12.8
Loading preprocessed data from cache: C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/train.multi_attribute_cached.csv
Loading preprocessed data from cache: C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/valid.multi_attribute_cached.csv


Epoch 1/100 [Train]: 100%|██████████| 88/88 [00:14<00:00,  5.88it/s]
Epoch 1/100 [Val]: 100%|██████████| 25/25 [00:02<00:00, 11.04it/s]


✅ Epoch 1: Train Loss = 5.1747, Val Loss = 4.3018
🎉 Saved best model to finetuned_multi_attribute_final_consistent.pt


Epoch 2/100 [Train]: 100%|██████████| 88/88 [00:14<00:00,  6.00it/s]
Epoch 2/100 [Val]: 100%|██████████| 25/25 [00:02<00:00, 11.16it/s]


✅ Epoch 2: Train Loss = 3.5331, Val Loss = 3.7214
🎉 Saved best model to finetuned_multi_attribute_final_consistent.pt


Epoch 3/100 [Train]: 100%|██████████| 88/88 [00:14<00:00,  6.08it/s]
Epoch 3/100 [Val]: 100%|██████████| 25/25 [00:02<00:00, 11.05it/s]


✅ Epoch 3: Train Loss = 2.6807, Val Loss = 3.4568
🎉 Saved best model to finetuned_multi_attribute_final_consistent.pt


Epoch 4/100 [Train]: 100%|██████████| 88/88 [00:14<00:00,  6.08it/s]
Epoch 4/100 [Val]: 100%|██████████| 25/25 [00:02<00:00, 11.16it/s]


✅ Epoch 4: Train Loss = 2.0048, Val Loss = 3.1154
🎉 Saved best model to finetuned_multi_attribute_final_consistent.pt


Epoch 5/100 [Train]: 100%|██████████| 88/88 [00:14<00:00,  6.13it/s]
Epoch 5/100 [Val]: 100%|██████████| 25/25 [00:02<00:00, 11.31it/s]


✅ Epoch 5: Train Loss = 1.4758, Val Loss = 3.0992
🎉 Saved best model to finetuned_multi_attribute_final_consistent.pt


Epoch 6/100 [Train]: 100%|██████████| 88/88 [00:14<00:00,  6.16it/s]
Epoch 6/100 [Val]: 100%|██████████| 25/25 [00:02<00:00, 11.31it/s]


✅ Epoch 6: Train Loss = 1.0434, Val Loss = 3.0526
🎉 Saved best model to finetuned_multi_attribute_final_consistent.pt


Epoch 7/100 [Train]: 100%|██████████| 88/88 [00:14<00:00,  6.12it/s]
Epoch 7/100 [Val]: 100%|██████████| 25/25 [00:02<00:00, 11.05it/s]


✅ Epoch 7: Train Loss = 0.7179, Val Loss = 3.1969
⚠️ No improvement. Patience: 1/5


Epoch 8/100 [Train]: 100%|██████████| 88/88 [00:14<00:00,  6.17it/s]
Epoch 8/100 [Val]: 100%|██████████| 25/25 [00:02<00:00, 11.12it/s]


✅ Epoch 8: Train Loss = 0.4566, Val Loss = 3.1324
⚠️ No improvement. Patience: 2/5


Epoch 9/100 [Train]: 100%|██████████| 88/88 [00:14<00:00,  6.11it/s]
Epoch 9/100 [Val]: 100%|██████████| 25/25 [00:02<00:00, 11.52it/s]


✅ Epoch 9: Train Loss = 0.2906, Val Loss = 3.3067
⚠️ No improvement. Patience: 3/5


Epoch 10/100 [Train]: 100%|██████████| 88/88 [00:14<00:00,  6.13it/s]
Epoch 10/100 [Val]: 100%|██████████| 25/25 [00:02<00:00, 11.13it/s]


✅ Epoch 10: Train Loss = 0.1748, Val Loss = 3.3325
⚠️ No improvement. Patience: 4/5


Epoch 11/100 [Train]: 100%|██████████| 88/88 [00:14<00:00,  6.11it/s]
Epoch 11/100 [Val]: 100%|██████████| 25/25 [00:02<00:00, 11.18it/s]


✅ Epoch 11: Train Loss = 0.0985, Val Loss = 3.4741
⚠️ No improvement. Patience: 5/5
🛑 Early stopping triggered.

--- Starting Final Testing and Plotting Phase ---
Best model loaded.
Loading preprocessed data from cache: C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/test.multi_attribute_cached.csv


[Testing]: 100%|██████████| 13/13 [00:01<00:00, 11.98it/s]



--- Test Results ---
📊 Attribute Accuracy: 43.00% (Correct if Name, Age, and Expression all match)

--- Plotting up to 100 results ---
--- Plotting complete ---
