In [4]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import open_clip
from tqdm import tqdm
import matplotlib.pyplot as plt
import textwrap
import re

# ------------------------------
# DEBUGGING & CONFIGURATION
# ------------------------------
# This forces CUDA operations to be synchronous for accurate error reporting.
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

CONFIG = {
    "model_name": "ViT-B-16",
    "pretrained": "laion2b_s34b_b88k",
    "csv_train": "C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/train.csv",
    "csv_val": "C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/valid.csv",
    "csv_test": "C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/test.csv",
    "save_path": "finetuned_gender_only_best.pt", # <-- New save path
    "plot_save_dir": "test_results_gender_only", # <-- New plot directory
    "num_plots": 100,
    "batch_size": 32, # Start small for large models
    "epochs": 100,
    "lr": 1e-6,
    "patience": 5,
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # --- FOCUSED: Prompts for Gender ONLY ---
    "gender_prompts": [
        "a photo of a male person.",
        "a photo of a female person."
    ]
}

# --- Create simplified keys from prompts for mapping ---
CONFIG["gender_keys"] = ["male", "female"]

# ------------------------------
# Helper Function for Parsing
# ------------------------------
def parse_gender_from_prompt(prompt_text):
    """
    A robust parser to find any of the gender keys in the prompt.
    """
    # Search for any of the gender keys from the list in the prompt text
    for gender_key in CONFIG["gender_keys"]:
        # Use regex with word boundaries (\b) to ensure it matches the whole phrase
        if re.search(r'\b' + re.escape(gender_key) + r'\b', prompt_text, re.IGNORECASE):
            return gender_key
    return None

# def parse_gender_from_prompt(prompt_text):
#     """
#     A simple parser to find 'male' or 'female' in the prompt.
#     """
#     prompt_lower = prompt_text.lower()
#     if 'male' in prompt_lower:
#         return 'male'
#     elif 'female' in prompt_lower:
#         return 'female'
#     return None

# ------------------------------
# Preprocessing and Caching
# ------------------------------
def preprocess_and_cache_csv(csv_path):
    """
    Parses prompts for gender and saves a cached version.
    """
    cache_path = csv_path.replace(".csv", ".gender_cached.csv")
    if os.path.exists(cache_path):
        print(f"Loading preprocessed gender data from cache: {cache_path}")
        try:
            df_cache = pd.read_csv(cache_path)
            if df_cache.empty:
                 print("Warning: Cached file is empty. Reprocessing.")
            else:
                return df_cache
        except pd.errors.EmptyDataError:
            print(f"Warning: Cached file is empty. Reprocessing.")

    print(f"Preprocessing gender data and caching: {csv_path}")
    try:
        df = pd.read_csv(csv_path)
    except (FileNotFoundError, pd.errors.EmptyDataError):
        print(f"Error or empty file at {csv_path}")
        return None

    new_data = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc=f"Preprocessing {os.path.basename(csv_path)}"):
        gender = parse_gender_from_prompt(row['prompt'])
        if gender:
            row_data = {
                'filepath': row['filepath'],
                'prompt': row['prompt'],
                'gender_idx': CONFIG["gender_keys"].index(gender)
            }
            new_data.append(row_data)

    if not new_data:
        print("Warning: Preprocessing resulted in an empty dataset for gender.")
        return None

    cached_df = pd.DataFrame(new_data)
    cached_df.to_csv(cache_path, index=False)
    return cached_df

# ------------------------------
# Dataset Class
# ------------------------------
class GenderDataset(Dataset):
    def __init__(self, df, preprocess):
        self.df = df if df is not None else pd.DataFrame()
        self.preprocess = preprocess

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = row["filepath"]
        try:
            image = Image.open(image_path).convert("RGB")
            image = self.preprocess(image)
        except FileNotFoundError:
            image = torch.zeros((3, 224, 224))

        gender_idx = torch.tensor(row["gender_idx"])
        return image, gender_idx, image_path

# ------------------------------
# Training and Evaluation Functions
# ------------------------------
def train_and_validate(model, train_loader, val_loader, gender_text_features):
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"])
    loss_fn = torch.nn.CrossEntropyLoss()
    best_val_loss = float("inf")
    patience_counter = 0
    model_saved = False

    for epoch in range(CONFIG["epochs"]):
        model.train()
        total_train_loss = 0
        for images, gt_indices, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']} [Train]"):
            images, gt_indices = images.to(CONFIG["device"]), gt_indices.to(CONFIG["device"])

            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            logits_per_image = (100.0 * image_features @ gender_text_features.T)
            loss = loss_fn(logits_per_image, gt_indices)

            optimizer.zero_grad(); loss.backward(); optimizer.step()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for images, gt_indices, _ in tqdm(val_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']} [Val]"):
                images, gt_indices = images.to(CONFIG["device"]), gt_indices.to(CONFIG["device"])
                image_features = model.encode_image(images)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                logits_per_image = (100.0 * image_features @ gender_text_features.T)
                loss = loss_fn(logits_per_image, gt_indices)
                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader)
        print(f"✅ Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), CONFIG["save_path"])
            print(f"🎉 Saved best model to {CONFIG['save_path']}")
            patience_counter = 0
            model_saved = True
        else:
            patience_counter += 1
            print(f"⚠️ No improvement. Patience: {patience_counter}/{CONFIG['patience']}")
        if patience_counter >= CONFIG["patience"]:
            print("🛑 Early stopping triggered."); break

    return model_saved


def test_and_plot(model, gender_text_features, preprocess):
    print("\n--- Starting Final Testing and Plotting Phase for Gender ---")
    model.load_state_dict(torch.load(CONFIG["save_path"])); model.to(CONFIG["device"]).eval()
    print("Best model loaded.")

    test_df = preprocess_and_cache_csv(CONFIG["csv_test"])
    if test_df is None or test_df.empty:
        print(f"Error or empty data in test CSV. Halting testing.")
        return

    test_dataset = GenderDataset(test_df, preprocess)
    test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False)

    correct_predictions, total_samples = 0, 0
    all_results = []

    with torch.no_grad():
        for images, gt_indices, image_paths in tqdm(test_loader, desc="[Testing Gender]"):
            images = images.to(CONFIG["device"])
            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            logits = (100.0 * image_features @ gender_text_features.T)
            preds = logits.argmax(dim=-1)

            for i in range(len(images)):
                is_correct = (preds[i] == gt_indices[i]).item()
                if is_correct: correct_predictions += 1

                all_results.append({
                    "image_path": image_paths[i],
                    "ground_truth": CONFIG["gender_keys"][gt_indices[i]],
                    "prediction": CONFIG["gender_keys"][preds[i]],
                    "is_correct": is_correct
                })
            total_samples += len(images)

    if total_samples > 0:
        accuracy = (correct_predictions / total_samples) * 100
        print(f"\n📊 Gender Test Accuracy: {accuracy:.2f}%")
    else:
        print("No valid samples were processed in the test set.")

    print(f"\n--- Plotting up to {CONFIG['num_plots']} gender results ---")
    os.makedirs(CONFIG["plot_save_dir"], exist_ok=True)

    for i, result in enumerate(all_results):
        if i >= CONFIG["num_plots"]: break
        try:
            img = Image.open(result["image_path"])
        except FileNotFoundError: continue

        fig, ax = plt.subplots(figsize=(8, 10))
        ax.imshow(img); ax.axis("off")

        title = f"Result {i+1}: {'CORRECT' if result['is_correct'] else 'INCORRECT'}"
        color = 'green' if result['is_correct'] else 'red'
        fig.suptitle(title, fontsize=16, color=color)

        text = (f"Ground Truth: {result['ground_truth']}\n"
                f"Prediction:   {result['prediction']}")

        ax.set_title(text, fontsize=12, pad=10)

        save_name = f"result_{i+1}_{'correct' if result['is_correct'] else 'incorrect'}.png"
        plt.savefig(os.path.join(CONFIG["plot_save_dir"], save_name), bbox_inches='tight')
        plt.close(fig)

    print("--- Plotting complete ---")

# ==============================
#      MAIN EXECUTION BLOCK
# ==============================
if __name__ == '__main__':
    print(f"Using device: {CONFIG['device']}")
    if torch.cuda.is_available():
        print(f"PyTorch Version: {torch.__version__}, CUDA Version: {torch.version.cuda}")

    model, _, preprocess = open_clip.create_model_and_transforms(
        CONFIG["model_name"], pretrained=CONFIG["pretrained"], device=CONFIG["device"]
    )
    tokenizer = open_clip.get_tokenizer(CONFIG["model_name"])

    with torch.no_grad():
        gender_text_tokens = tokenizer(CONFIG["gender_prompts"]).to(CONFIG["device"])
        gender_text_features = model.encode_text(gender_text_tokens)
        gender_text_features = gender_text_features / gender_text_features.norm(dim=-1, keepdim=True)

    train_df = preprocess_and_cache_csv(CONFIG["csv_train"])
    val_df = preprocess_and_cache_csv(CONFIG["csv_val"])

    if train_df is not None and val_df is not None and not train_df.empty and not val_df.empty:
        train_dataset = GenderDataset(train_df, preprocess)
        val_dataset = GenderDataset(val_df, preprocess)
        train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False)

        training_successful = train_and_validate(model, train_loader, val_loader, gender_text_features)

        if training_successful:
            test_and_plot(model, gender_text_features, preprocess)
        else:
            print("\nSkipping testing phase: No model was saved during training.")
    else:
        print("\nSkipping training: Training/validation datasets are empty or could not be loaded.")


Using device: cuda
PyTorch Version: 2.8.0.dev20250507+cu128, CUDA Version: 12.8
Preprocessing gender data and caching: C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/train.csv


Preprocessing train.csv: 100%|██████████| 700/700 [00:00<00:00, 59903.96it/s]


Preprocessing gender data and caching: C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/valid.csv


Preprocessing valid.csv: 100%|██████████| 200/200 [00:00<00:00, 76405.94it/s]
Epoch 1/100 [Train]: 100%|██████████| 22/22 [00:10<00:00,  2.06it/s]
Epoch 1/100 [Val]: 100%|██████████| 7/7 [00:02<00:00,  3.47it/s]


✅ Epoch 1: Train Loss = 0.0673, Val Loss = 0.0293
🎉 Saved best model to finetuned_gender_only_best.pt


Epoch 2/100 [Train]: 100%|██████████| 22/22 [00:15<00:00,  1.40it/s]
Epoch 2/100 [Val]: 100%|██████████| 7/7 [00:02<00:00,  2.45it/s]


✅ Epoch 2: Train Loss = 0.0052, Val Loss = 0.0227
🎉 Saved best model to finetuned_gender_only_best.pt


Epoch 3/100 [Train]: 100%|██████████| 22/22 [00:15<00:00,  1.41it/s]
Epoch 3/100 [Val]: 100%|██████████| 7/7 [00:02<00:00,  2.46it/s]


✅ Epoch 3: Train Loss = 0.0021, Val Loss = 0.0182
🎉 Saved best model to finetuned_gender_only_best.pt


Epoch 4/100 [Train]: 100%|██████████| 22/22 [00:14<00:00,  1.57it/s]
Epoch 4/100 [Val]: 100%|██████████| 7/7 [00:02<00:00,  2.72it/s]


✅ Epoch 4: Train Loss = 0.0012, Val Loss = 0.0188
⚠️ No improvement. Patience: 1/5


Epoch 5/100 [Train]: 100%|██████████| 22/22 [00:12<00:00,  1.73it/s]
Epoch 5/100 [Val]: 100%|██████████| 7/7 [00:02<00:00,  2.78it/s]


✅ Epoch 5: Train Loss = 0.0009, Val Loss = 0.0203
⚠️ No improvement. Patience: 2/5


Epoch 6/100 [Train]: 100%|██████████| 22/22 [00:12<00:00,  1.77it/s]
Epoch 6/100 [Val]: 100%|██████████| 7/7 [00:01<00:00,  3.55it/s]


✅ Epoch 6: Train Loss = 0.0007, Val Loss = 0.0210
⚠️ No improvement. Patience: 3/5


Epoch 7/100 [Train]: 100%|██████████| 22/22 [00:10<00:00,  2.07it/s]
Epoch 7/100 [Val]: 100%|██████████| 7/7 [00:01<00:00,  3.53it/s]


✅ Epoch 7: Train Loss = 0.0006, Val Loss = 0.0215
⚠️ No improvement. Patience: 4/5


Epoch 8/100 [Train]: 100%|██████████| 22/22 [00:10<00:00,  2.06it/s]
Epoch 8/100 [Val]: 100%|██████████| 7/7 [00:01<00:00,  3.55it/s]


✅ Epoch 8: Train Loss = 0.0006, Val Loss = 0.0218
⚠️ No improvement. Patience: 5/5
🛑 Early stopping triggered.

--- Starting Final Testing and Plotting Phase for Gender ---
Best model loaded.
Preprocessing gender data and caching: C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/test.csv


Preprocessing test.csv: 100%|██████████| 100/100 [00:00<00:00, 56496.55it/s]
[Testing Gender]: 100%|██████████| 4/4 [00:00<00:00,  4.08it/s]



📊 Gender Test Accuracy: 100.00%

--- Plotting up to 100 gender results ---
--- Plotting complete ---
